In [None]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights


from src.model.model_classes import EfficientNet_ContextualData

In [None]:
# load the base model
checkpoint = torch.load(f"runs/model_pretrained_efficientnet_b0_final.pth")

# base model 

n_classes = 1486

trained_model = efficientnet_b0(pretrained=False) 

trained_model.classifier = nn.Sequential(
            nn.Linear(trained_model.classifier[1].in_features, 2048), 
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048, n_classes)  # Final output layer
)

In [None]:
# Load the geographical model
checkpoint = torch.load(f"runs/model_pretrained_efficientnet_geographical_data_b0_final.pth")

geographical_model = EfficientNet_ContextualData(n_classes, 3)

In [None]:
# load the data

In [None]:
#look for particular species that have improved between model runs
bird_classes = [cat for cat in train_data.all_categories if "Animalia_Chordata_Aves" in cat]

In [None]:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in bird_classes}
total_pred = {classname: 0 for classname in bird_classes}

# again no gradients needed
with torch.no_grad():
    for data in test_loader_birds:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        labels = remap_labels(labels).to(device)
        outputs = trained_model(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[bird_classes[label]] += 1
            total_pred[bird_classes[label]] += 1
# turn into dataframe and sort by smallest success classes
# then create confusion matrix for all classes - find the biggest misses


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')