In [1]:
import sys
import numpy as np
import timm
import torch
from torch import tensor
import torch.nn as nn
from torchvision.transforms import InterpolationMode, transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from PIL import Image
import os
from tqdm import tqdm
import getpass
import socket
from datetime import datetime
from sklearn.metrics import precision_score, recall_score

# Set the PyTorch device (GPU/cuda or CPU)
if torch.cuda.is_available():
    dev = "cuda"
    device = torch.device(dev)

    gpu_name = torch.cuda.get_device_name(torch.device("cuda"))
    print(f"GPU name: {gpu_name} ({torch.cuda.device_count()} available)")
    
    print("Host name: ", socket.gethostname())  # Retrieve the hostname of the current system to determine the environment
    print("User name: ", getpass.getuser())  # Retrieve the current user's username

    # If the notebook is running on the JASMIN GPU cluster, select the GPU with the most free memory
    if socket.gethostname() == "gpuhost001.jc.rl.ac.uk":

        def select_gpu_with_most_free_memory():
            max_memory_available = 0
            gpu_id_with_max_memory = 0
            for i in range(torch.cuda.device_count()):
                torch.cuda.set_device(i)
                free_mem, total_mem = torch.cuda.mem_get_info(i)
                free_mem_gib = free_mem / (1024 ** 3)
                free_mem_rounded = round(free_mem_gib, 2)
                print(f"GPU {i} free memory: {free_mem_rounded} GiB")
                if free_mem_gib >= max_memory_available:  # >= biases away from GPU 0, which most JASMIN users default to
                    max_memory_available = free_mem_gib
                    gpu_id_with_max_memory = i
            return gpu_id_with_max_memory

        best_gpu = select_gpu_with_most_free_memory()

        torch.cuda.set_device(best_gpu)
        print(f"Using GPU: {best_gpu}")
    
    else:
        _, max_memory = torch.cuda.mem_get_info()
        max_memory = max_memory / (1024 ** 3)
        print(f"GPU memory: {max_memory} GiB")

else:
    dev = "cpu"
    device = torch.device(dev)
    print("No GPU available.")

gpu_override = False
if gpu_override:
    torch.cuda.set_device(3)
    print(f"OVERRIDE: Using GPU: {3}")

CROP_SIZE = 182
BACKBONE = "vit_large_patch14_dinov2"
weight_path = "../models/deepfaune-vit_large_patch14_dinov2.lvd142m.pt"

jasmin = True

if jasmin:
    train_path = "../data/split_data/train"
    val_path = "../data/split_data/val"
    test_path = "../data/split_data/test"
else:
    train_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/train"
    val_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/val"
    test_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/test"

ANIMAL_CLASSES = ["badger", "ibex", "red deer", "chamois", "cat", "goat", "roe deer", "dog", "squirrel", "equid", "genet",
                  "hedgehog", "lagomorph", "wolf", "lynx", "marmot", "micromammal", "mouflon",
                  "sheep", "mustelid", "bird", "bear", "nutria", "fox", "wild boar", "cow"]

class AnimalDataset(Dataset):
    def __init__(self, directory, transform=None, preload_to_gpu=False):
        self.directory = directory
        self.transform = transform
        self.images = []
        self.labels = []
        self.preload_to_gpu = preload_to_gpu

        for label in os.listdir(directory):
            label_dir = os.path.join(directory, label)
            if os.path.isdir(label_dir):
                for image in os.listdir(label_dir):
                    image_path = os.path.join(label_dir, image)
                    self.images.append(image_path)
                    self.labels.append(ANIMAL_CLASSES.index(label))

        if self.preload_to_gpu:
            self.preload_images()

    def preload_images(self):
        self.loaded_images = []
        for image_path in tqdm(self.images, desc="Preloading images to GPU"):
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            self.loaded_images.append(image.to(device))
        self.labels = torch.tensor(self.labels, device=device)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.preload_to_gpu:
            return self.loaded_images[idx], self.labels[idx]
        else:
            image_path = self.images[idx]
            label = self.labels[idx]
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label

class Classifier(nn.Module):
    def __init__(self, freeze_up_to_layer=16):
        super(Classifier, self).__init__()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = timm.create_model(BACKBONE, pretrained=False, num_classes=len(ANIMAL_CLASSES), dynamic_img_size=True)
        state_dict = torch.load(weight_path, map_location=torch.device(device))['state_dict']
        self.model.load_state_dict({k.replace('base_model.', ''): v for k, v in state_dict.items()})

        # Freeze layers up to the specified layer
        if freeze_up_to_layer is not None:
            for name, param in self.model.named_parameters():
                if self._should_freeze_layer(name, freeze_up_to_layer):
                    param.requires_grad = False

        self.transforms = transforms.Compose([
            transforms.Resize(size=(CROP_SIZE, CROP_SIZE), interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=None),
            transforms.ToTensor(),
            transforms.Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
        ])

    def _should_freeze_layer(self, name, freeze_up_to_layer):
        if 'blocks' in name:
            block_num = int(name.split('.')[1])
            if block_num <= freeze_up_to_layer:
                return True
        return False

    def forward(self, x):
        return self.model(x)

    def predict(self, image):
        img_tensor = self.transforms(image).unsqueeze(0)
        with torch.no_grad():
            output = self.forward(img_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            top_p, top_class = probabilities.topk(1, dim=1)
            return ANIMAL_CLASSES[top_class.item()], top_p.item()

# Custom loss function with a higher penalty for misclassifying wild boar
class CustomLoss(nn.Module):
    def __init__(self, penalty_weight=0.0):
        super(CustomLoss, self).__init__()
        self.penalty_weight = penalty_weight
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, outputs, targets):
        loss = self.ce_loss(outputs, targets)
        wild_boar_index = ANIMAL_CLASSES.index("wild boar")
        wild_boar_mask = (targets == wild_boar_index)
        if wild_boar_mask.sum() > 0:
            wild_boar_loss = self.ce_loss(outputs[wild_boar_mask], targets[wild_boar_mask])
            loss += self.penalty_weight * wild_boar_loss
        return loss

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
    
    accuracy = 100 * correct / total
    overall_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    overall_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    # Calculate precision and recall for wild boar
    wild_boar_index = ANIMAL_CLASSES.index("wild boar")
    wild_boar_precision = precision_score(all_labels, all_preds, labels=[wild_boar_index], average='macro', zero_division=0)
    wild_boar_recall = recall_score(all_labels, all_preds, labels=[wild_boar_index], average='macro', zero_division=0)
    
    return running_loss / len(dataloader), accuracy, overall_precision, overall_recall, wild_boar_precision, wild_boar_recall

def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def save_model(model, total_epochs, learning_rate, now, penalty_weight):
    model_save_path = f"../models/{now}-deepfaune-finetuned-epochs-{total_epochs}-lr-{learning_rate}-wbpenalty-{penalty_weight}.pt"
    torch.save(model.state_dict(), model_save_path)
    print(f'Model saved to {model_save_path}')

def main():
    initial_epochs = 5  # Set the number of epochs
    batch_size = 32  # Set the batch size
    learning_rate = 1e-5  # Reduced learning rate for fine-tuning
    total_epochs = initial_epochs
    penalty_weight = 0.0  # Initial penalty weight for wild boar class

    transform = transforms.Compose([
        transforms.Resize((CROP_SIZE, CROP_SIZE), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
    ])

    print('Loading training data...')
    train_dataset = AnimalDataset(train_path, transform=transform, preload_to_gpu=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    print('Loading validation data...')
    val_dataset = AnimalDataset(val_path, transform=transform, preload_to_gpu=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = Classifier(freeze_up_to_layer=16).to(device)  # Freeze up to the 16th layer

    criterion = CustomLoss(penalty_weight=penalty_weight)  # Custom loss with initial penalty weight
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Initialize best_val_loss
    best_val_loss = float('inf')

    # Evaluate validation set before training
    print('Initial validation evaluation...')
    val_loss, val_accuracy, val_precision, val_recall, wb_precision, wb_recall = validate(model, val_loader, criterion, device)
    print(f'Initial Validation Loss: {val_loss}, Initial Validation Accuracy: {val_accuracy}%')
    print(f'Overall Precision: {val_precision}, Overall Recall: {val_recall}')
    print(f'Wild Boar Precision: {wb_precision}, Wild Boar Recall: {wb_recall}')

    now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    
    print('Training started...')
    for epoch in range(initial_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy, val_precision, val_recall, wb_precision, wb_recall = validate(model, val_loader, criterion, device)
        print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%')
        print(f'Overall Precision: {val_precision}, Overall Recall: {val_recall}')
        print(f'Wild Boar Precision: {wb_precision}, Wild Boar Recall: {wb_recall}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print("Saving best epoch...")
            save_model(model, total_epochs, learning_rate, now, penalty_weight)

    if val_loss != best_val_loss:
        now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        print("Saving model's current state...")
        save_model(model, total_epochs, learning_rate, now, penalty_weight)

    # Option to continue training
    while True:
        more_epochs = int(input("Enter the number of additional epochs to continue training (0 to stop): "))
        if more_epochs == 0:
            break
        while True:
            learning_rate = float(input("Enter the learning rate for the additional epochs (default 1e-5): "))
            if learning_rate <= 1e-4:
                break
            else:
                print("Learning rate too high")
        penalty_weight = float(input("Enter the penalty weight for wild boar class (default 0): "))
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Update optimizer with new learning rate
        criterion = CustomLoss(penalty_weight=penalty_weight)  # Update criterion with new penalty weight
        total_epochs += more_epochs
        now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        
        for epoch in range(more_epochs):
            train_loss = train(model, train_loader, criterion, optimizer, device)
            val_loss, val_accuracy, val_precision, val_recall, wb_precision, wb_recall = validate(model, val_loader, criterion, device)
            print(f'Epoch {total_epochs - more_epochs + epoch + 1}, Train Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%')
            print(f'Overall Precision: {val_precision}, Overall Recall: {val_recall}')
            print(f'Wild Boar Precision: {wb_precision}, Wild Boar Recall: {wb_recall}')

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                print("Saving best epoch...")
                save_model(model, total_epochs, learning_rate, now, penalty_weight)
        
        if val_loss != best_val_loss:
            now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
            print("Saving model's current state...")
            save_model(model, total_epochs, learning_rate, now, penalty_weight)

    # Load test data
    print('Loading test data...')
    test_dataset = AnimalDataset(test_path, transform=transform, preload_to_gpu=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Test the model
    print('Testing the model...')
    test_accuracy = test(model, test_loader, device)
    print(f'Test Accuracy: {test_accuracy}%')

    # Return critical variables for further experimentation
    return model, train_loader, val_loader, test_loader, criterion, optimizer, total_epochs

if __name__ == '__main__':
    model, train_loader, val_loader, test_loader, criterion, optimizer, total_epochs = main()

GPU name: NVIDIA A100-SXM4-40GB (4 available)
Host name:  gpuhost001.jc.rl.ac.uk
User name:  trr26
GPU 0 free memory: 37.36 GiB
GPU 1 free memory: 38.97 GiB
GPU 2 free memory: 38.97 GiB
GPU 3 free memory: 38.97 GiB
Using GPU: 3
Loading training data...


Preloading images to GPU: 100%|██████████| 17336/17336 [21:45<00:00, 13.28it/s]


Loading validation data...


Preloading images to GPU: 100%|██████████| 3669/3669 [05:29<00:00, 11.13it/s]


Initial validation evaluation...


Validation: 100%|██████████| 115/115 [00:33<00:00,  3.48it/s]


Initial Validation Loss: 0.35707510428695494, Initial Validation Accuracy: 91.16925592804579%
Overall Precision: 0.9354509494445699, Overall Recall: 0.9116925592804579
Wild Boar Precision: 0.9791666666666666, Wild Boar Recall: 0.9359886201991465
Training started...


Training: 100%|██████████| 542/542 [04:36<00:00,  1.96it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 1, Train Loss: 0.22082815971912823, Validation Loss: 0.29100183038769856, Validation Accuracy: 91.95966203325156%
Overall Precision: 0.931418682861641, Overall Recall: 0.9195966203325157
Wild Boar Precision: 0.9936406995230525, Wild Boar Recall: 0.8890469416785206
Saving best epoch...
Model saved to ../models/2024-05-28-10-05-53-deepfaune-finetuned-epochs-5-lr-1e-05-wbpenalty-0.0.pt


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 2, Train Loss: 0.0773325218484871, Validation Loss: 0.26516345080801884, Validation Accuracy: 93.43145271191061%
Overall Precision: 0.9373844818227234, Overall Recall: 0.934314527119106
Wild Boar Precision: 0.9579242636746143, Wild Boar Recall: 0.9715504978662873
Saving best epoch...
Model saved to ../models/2024-05-28-10-05-53-deepfaune-finetuned-epochs-5-lr-1e-05-wbpenalty-0.0.pt


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 3, Train Loss: 0.051064202051637246, Validation Loss: 0.22646075499145335, Validation Accuracy: 94.73971109294085%
Overall Precision: 0.9481746201757243, Overall Recall: 0.9473971109294086
Wild Boar Precision: 0.9481582537517054, Wild Boar Recall: 0.9886201991465149
Saving best epoch...
Model saved to ../models/2024-05-28-10-05-53-deepfaune-finetuned-epochs-5-lr-1e-05-wbpenalty-0.0.pt


Training: 100%|██████████| 542/542 [04:32<00:00,  1.99it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 4, Train Loss: 0.08146484264907619, Validation Loss: 0.3137011447921395, Validation Accuracy: 91.93240665031344%
Overall Precision: 0.9229899634406458, Overall Recall: 0.9193240665031344
Wild Boar Precision: 0.9495677233429395, Wild Boar Recall: 0.9374110953058321


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Epoch 5, Train Loss: 0.03130040349739705, Validation Loss: 0.23681606195019886, Validation Accuracy: 94.76696647587899%
Overall Precision: 0.9497424150148942, Overall Recall: 0.9476696647587899
Wild Boar Precision: 0.947945205479452, Wild Boar Recall: 0.984352773826458
Saving model's current state...
Model saved to ../models/2024-05-28-10-30-54-deepfaune-finetuned-epochs-5-lr-1e-05-wbpenalty-0.0.pt


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 6, Train Loss: 0.19262274690351502, Validation Loss: 0.33920544223823845, Validation Accuracy: 93.75851730716816%
Overall Precision: 0.9398390046212476, Overall Recall: 0.9375851730716817
Wild Boar Precision: 0.9216467463479415, Wild Boar Recall: 0.9871977240398293


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Epoch 7, Train Loss: 0.11287002908341918, Validation Loss: 0.36552460282615834, Validation Accuracy: 93.0771327337149%
Overall Precision: 0.9325494353481069, Overall Recall: 0.9307713273371491
Wild Boar Precision: 0.8965071151358344, Wild Boar Recall: 0.9857752489331437


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 8, Train Loss: 0.6450039715760876, Validation Loss: 0.744901849621016, Validation Accuracy: 83.892068683565%
Overall Precision: 0.8886792148054047, Overall Recall: 0.83892068683565
Wild Boar Precision: 0.5937234944868532, Wild Boar Recall: 0.9957325746799431


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 9, Train Loss: 0.3850227367513937, Validation Loss: 0.39250519302238346, Validation Accuracy: 92.99536658490052%
Overall Precision: 0.931363648883839, Overall Recall: 0.9299536658490052
Wild Boar Precision: 0.910761154855643, Wild Boar Recall: 0.9871977240398293


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 10, Train Loss: 0.3985151276782815, Validation Loss: 0.4294517515753598, Validation Accuracy: 90.56963750340692%
Overall Precision: 0.9173725255967452, Overall Recall: 0.9056963750340692
Wild Boar Precision: 0.8158508158508159, Wild Boar Recall: 0.9957325746799431
Saving model's current state...
Model saved to ../models/2024-05-28-10-59-31-deepfaune-finetuned-epochs-10-lr-1e-05-wbpenalty-10.0.pt


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.74it/s]


Epoch 11, Train Loss: 0.05076333156727148, Validation Loss: 0.2504712592354577, Validation Accuracy: 94.03107113654947%
Overall Precision: 0.9410591609837562, Overall Recall: 0.9403107113654947
Wild Boar Precision: 0.9732016925246827, Wild Boar Recall: 0.9815078236130867


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 12, Train Loss: 0.030898158538091733, Validation Loss: 0.33231496744742994, Validation Accuracy: 92.5320250749523%
Overall Precision: 0.9316510794116908, Overall Recall: 0.9253202507495231
Wild Boar Precision: 0.9398907103825137, Wild Boar Recall: 0.9786628733997155


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 13, Train Loss: 0.020518672279441687, Validation Loss: 0.2627720144185253, Validation Accuracy: 94.0583265194876%
Overall Precision: 0.9417679705185843, Overall Recall: 0.940583265194876
Wild Boar Precision: 0.9529085872576177, Wild Boar Recall: 0.9786628733997155


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 14, Train Loss: 0.021918598952709298, Validation Loss: 0.3124650197434399, Validation Accuracy: 93.32243118015808%
Overall Precision: 0.9366314370535437, Overall Recall: 0.9332243118015808
Wild Boar Precision: 0.9726224783861671, Wild Boar Recall: 0.9601706970128022


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 15, Train Loss: 0.02305713367909498, Validation Loss: 0.281918503213503, Validation Accuracy: 94.27636958299264%
Overall Precision: 0.9443447118513216, Overall Recall: 0.9427636958299264
Wild Boar Precision: 0.9679218967921897, Wild Boar Recall: 0.9871977240398293
Saving model's current state...
Model saved to ../models/2024-05-28-11-48-53-deepfaune-finetuned-epochs-15-lr-1e-05-wbpenalty-0.0.pt
Loading test data...


Preloading images to GPU: 100%|██████████| 3557/3557 [03:38<00:00, 16.28it/s]


Testing the model...


Testing: 100%|██████████| 112/112 [00:23<00:00,  4.74it/s]

Test Accuracy: 93.89935338768625%





## Further testing

In [11]:
new_weight_path = "../models/Boar Balanced PrecisionRecall - 96.8-98.7-deepfaune-finetuned-epochs-15-lr-1e-05-wbpenalty-0for5,10for5,0for5.pt"

if 'new_model' in locals():
    del new_model

new_model = Classifier(freeze_up_to_layer=16).to(device)  # Freeze up to the 16th layer

print('Initial validation evaluation...')
val_loss, val_accuracy, val_precision, val_recall, wb_precision, wb_recall = validate(new_model, val_loader, criterion, device)
print(f'Initial Validation Loss: {val_loss}, Initial Validation Accuracy: {val_accuracy}%')
print(f'Overall Precision: {val_precision}, Overall Recall: {val_recall}')
print(f'Wild Boar Precision: {wb_precision}, Wild Boar Recall: {wb_recall}')

while True:
    more_epochs = int(input("Enter the number of additional epochs to continue training (0 to stop): "))
    if more_epochs == 0:
        break
    while True:
        learning_rate = float(input("Enter the learning rate for the additional epochs (default 1e-5): "))
        if learning_rate <= 1e-4:
            break
        else:
            print("Learning rate too high")
    penalty_weight = float(input("Enter the penalty weight for wild boar class (default 0): "))
    optimizer = optim.Adam(new_model.parameters(), lr=learning_rate)  # Update optimizer with new learning rate
    criterion = CustomLoss(penalty_weight=penalty_weight)  # Update criterion with new penalty weight
    total_epochs += more_epochs
    now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    
    for epoch in range(more_epochs):
        train_loss = train(new_model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy, val_precision, val_recall, wb_precision, wb_recall = validate(new_model, val_loader, criterion, device)
        print(f'Epoch {total_epochs - more_epochs + epoch + 1}, Train Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%')
        print(f'Overall Precision: {val_precision}, Overall Recall: {val_recall}')
        print(f'Wild Boar Precision: {wb_precision}, Wild Boar Recall: {wb_recall}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print("Saving best epoch...")
            save_model(new_model, total_epochs, learning_rate, now, penalty_weight)
    
    if val_loss != best_val_loss:
        now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        print("Saving model's current state...")
        save_model(new_model, total_epochs, learning_rate, now, penalty_weight)


KeyError: 'state_dict'

: 

## More thoughts
- Change code to save best boar model
- Still consider freezing a different number of layers

## Thoughts
- Decreating learning rate below 1e-6 doesn't help
- Next step is to look into number of layers frozen
- Perhaps unfreeze all, then slowly increase number of frozen layers? Look into best practice
- Otherwise augment dataset - but is that the issue? What tests are thereforre this?
- Before augmenting dataset, look at loss function for wild boar instead - likely better resutls - i.e. fine tune for wild boar.