In [1]:
import sys
import numpy as np
import timm
import torch
from torch import tensor
import torch.nn as nn
from torchvision.transforms import InterpolationMode, transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
import os
from tqdm import tqdm
import getpass
import socket
from datetime import datetime

# Set the PyTorch device (GPU/cuda or CPU)
if torch.cuda.is_available():
    dev = "cuda"
    device = torch.device(dev)

    gpu_name = torch.cuda.get_device_name(torch.device("cuda"))
    print(f"GPU name: {gpu_name} ({torch.cuda.device_count()} available)")
    
    print("Host name: ", socket.gethostname())  # Retrieve the hostname of the current system to determine the environment
    print("User name: ", getpass.getuser())  # Retrieve the current user's username

    # If the notebook is running on the JASMIN GPU cluster, select the GPU with the most free memory
    if socket.gethostname() == "gpuhost001.jc.rl.ac.uk":

        def select_gpu_with_most_free_memory():
            max_memory_available = 0
            gpu_id_with_max_memory = 0
            for i in range(torch.cuda.device_count()):
                torch.cuda.set_device(i)
                free_mem, total_mem = torch.cuda.mem_get_info(i)
                free_mem_gib = free_mem / (1024 ** 3)
                free_mem_rounded = round(free_mem_gib, 2)
                print(f"GPU {i} free memory: {free_mem_rounded} GiB")
                if free_mem_gib >= max_memory_available:  # >= biases away from GPU 0, which most JASMIN users default to
                    max_memory_available = free_mem_gib
                    gpu_id_with_max_memory = i
            return gpu_id_with_max_memory

        best_gpu = select_gpu_with_most_free_memory()

        torch.cuda.set_device(best_gpu)
        print(f"Using GPU: {best_gpu}")
    
    else:
        _, max_memory = torch.cuda.mem_get_info()
        max_memory = max_memory / (1024 ** 3)
        print(f"GPU memory: {max_memory} GiB")

else:
    dev = "cpu"
    device = torch.device(dev)
    print("No GPU available.")

gpu_override = False
if gpu_override:
    torch.cuda.set_device(3)
    print(f"OVERRIDE: Using GPU: {3}")

CROP_SIZE = 182
BACKBONE = "vit_large_patch14_dinov2"
weight_path = "../models/deepfaune-vit_large_patch14_dinov2.lvd142m.pt"

jasmin = True

if jasmin:
    train_path = "../data/split_data/train"
    val_path = "../data/split_data/val"
    test_path = "../data/split_data/test"
else:
    train_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/train"
    val_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/val"
    test_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/test"

ANIMAL_CLASSES = ["badger", "ibex", "red deer", "chamois", "cat", "goat", "roe deer", "dog", "squirrel", "equid", "genet",
                  "hedgehog", "lagomorph", "wolf", "lynx", "marmot", "micromammal", "mouflon",
                  "sheep", "mustelid", "bird", "bear", "nutria", "fox", "wild boar", "cow"]

class AnimalDataset(Dataset):
    def __init__(self, directory, transform=None, preload_to_gpu=False):
        self.directory = directory
        self.transform = transform
        self.images = []
        self.labels = []
        self.preload_to_gpu = preload_to_gpu

        for label in os.listdir(directory):
            label_dir = os.path.join(directory, label)
            if os.path.isdir(label_dir):
                for image in os.listdir(label_dir):
                    image_path = os.path.join(label_dir, image)
                    self.images.append(image_path)
                    self.labels.append(ANIMAL_CLASSES.index(label))

        if self.preload_to_gpu:
            self.preload_images()

    def preload_images(self):
        self.loaded_images = []
        for image_path in tqdm(self.images, desc="Preloading images to GPU"):
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            self.loaded_images.append(image.to(device))
        self.labels = torch.tensor(self.labels, device=device)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.preload_to_gpu:
            return self.loaded_images[idx], self.labels[idx]
        else:
            image_path = self.images[idx]
            label = self.labels[idx]
            image = Image.open(image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label

class Classifier(nn.Module):
    def __init__(self, freeze_up_to_layer=16):
        super(Classifier, self).__init__()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = timm.create_model(BACKBONE, pretrained=False, num_classes=len(ANIMAL_CLASSES), dynamic_img_size=True)
        state_dict = torch.load(weight_path, map_location=torch.device(device))['state_dict']
        self.model.load_state_dict({k.replace('base_model.', ''): v for k, v in state_dict.items()})

        # Freeze layers up to the specified layer
        if freeze_up_to_layer is not None:
            for name, param in self.model.named_parameters():
                if self._should_freeze_layer(name, freeze_up_to_layer):
                    param.requires_grad = False

        self.transforms = transforms.Compose([
            transforms.Resize(size=(CROP_SIZE, CROP_SIZE), interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=None),
            transforms.ToTensor(),
            transforms.Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
        ])

    def _should_freeze_layer(self, name, freeze_up_to_layer):
        if 'blocks' in name:
            block_num = int(name.split('.')[1])
            if block_num <= freeze_up_to_layer:
                return True
        return False

    def forward(self, x):
        return self.model(x)

    def predict(self, image):
        img_tensor = self.transforms(image).unsqueeze(0)
        with torch.no_grad():
            output = self.forward(img_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            top_p, top_class = probabilities.topk(1, dim=1)
            return ANIMAL_CLASSES[top_class.item()], top_p.item()

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return running_loss / len(dataloader), accuracy

def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def save_model(model, total_epochs, learning_rate):
    now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    model_save_path = f"../models/{now}-deepfaune-finetuned-epochs{total_epochs}-lr{learning_rate}.pt"
    torch.save(model.state_dict(), model_save_path)
    print(f'Model saved to {model_save_path}')

def main():
    initial_epochs = 5  # Set the number of epochs
    batch_size = 32  # Set the batch size
    learning_rate = 1e-5  # Reduced learning rate for fine-tuning
    total_epochs = initial_epochs
    patience = 10  # Early stopping patience
    best_val_loss = float('inf')
    patience_counter = 0

    transform = transforms.Compose([
        transforms.Resize((CROP_SIZE, CROP_SIZE), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
    ])

    print('Loading training data...')
    train_dataset = AnimalDataset(train_path, transform=transform, preload_to_gpu=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    print('Loading validation data...')
    val_dataset = AnimalDataset(val_path, transform=transform, preload_to_gpu=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = Classifier(freeze_up_to_layer=16).to(device)  # Freeze up to the 16th layer

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # Evaluate validation set before training
    print('Initial validation evaluation...')
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)
    print(f'Initial Validation Loss: {val_loss}, Initial Validation Accuracy: {val_accuracy}%')

    print('Training started...')
    for epoch in range(initial_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy = validate(model, val_loader, criterion, device)
        print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%')

        # Update the learning rate based on validation loss
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            save_model(model, total_epochs, learning_rate)
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping triggered")
            break

    # Option to continue training
    while True:
        more_epochs = int(input("Enter the number of additional epochs to continue training (0 to stop): "))
        if more_epochs == 0:
            break
        total_epochs += more_epochs
        for epoch in range(more_epochs):
            train_loss = train(model, train_loader, criterion, optimizer, device)
            val_loss, val_accuracy = validate(model, val_loader, criterion, device)
            print(f'Additional Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%')
            
            # Update the learning rate based on validation loss
            scheduler.step(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                save_model(model, total_epochs, learning_rate)
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print("Early stopping triggered")
                break

    # Load test data
    print('Loading test data...')
    test_dataset = AnimalDataset(test_path, transform=transform, preload_to_gpu=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Test the model
    print('Testing the model...')
    test_accuracy = test(model, test_loader, device)
    print(f'Test Accuracy: {test_accuracy}%')

    # Return critical variables for further experimentation
    return model, train_loader, val_loader, test_loader, criterion, optimizer

if __name__ == '__main__':
    model, train_loader, val_loader, test_loader, criterion, optimizer = main()


GPU name: NVIDIA A100-SXM4-40GB (4 available)
Host name:  gpuhost001.jc.rl.ac.uk
User name:  trr26
GPU 0 free memory: 31.97 GiB
GPU 1 free memory: 38.56 GiB
GPU 2 free memory: 38.56 GiB
GPU 3 free memory: 38.56 GiB
Using GPU: 3
Loading training data...


Preloading images to GPU: 100%|██████████| 17336/17336 [36:26<00:00,  7.93it/s]  


Loading validation data...


Preloading images to GPU: 100%|██████████| 3669/3669 [11:32<00:00,  5.30it/s]


Initial validation evaluation...


Validation: 100%|██████████| 115/115 [00:24<00:00,  4.66it/s]


Initial Validation Loss: 0.35707510428695494, Initial Validation Accuracy: 91.16925592804579%
Training started...


Training: 100%|██████████| 542/542 [04:33<00:00,  1.98it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Epoch 1, Train Loss: 0.22024896902090618, Validation Loss: 0.25267313871097385, Validation Accuracy: 92.12319433088035%


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:26<00:00,  4.37it/s]


Epoch 2, Train Loss: 0.07666246296074998, Validation Loss: 0.2770420901980327, Validation Accuracy: 92.5320250749523%


Training: 100%|██████████| 542/542 [04:43<00:00,  1.92it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 3, Train Loss: 0.053731832943463224, Validation Loss: 0.30345335408536533, Validation Accuracy: 92.259471245571%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Epoch 4, Train Loss: 0.04143187188444791, Validation Loss: 0.2832480655672755, Validation Accuracy: 93.7857726901063%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Epoch 5, Train Loss: 0.02588538914914926, Validation Loss: 0.2582714545701848, Validation Accuracy: 94.19460343417825%
Model saved to ../models/2024-05-25-14-57-46-deepfaune-finetuned-epochs5-lr1e-05.pt


Training: 100%|██████████| 542/542 [05:41<00:00,  1.59it/s]
Validation: 100%|██████████| 115/115 [00:30<00:00,  3.81it/s]


Additional Epoch 1, Train Loss: 0.04615950358545492, Validation Loss: 0.3169180580244005, Validation Accuracy: 92.83183428727173%


Training: 100%|██████████| 542/542 [05:27<00:00,  1.65it/s]
Validation: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Additional Epoch 2, Train Loss: 0.023872527125842674, Validation Loss: 0.2771166227165573, Validation Accuracy: 93.54047424366313%


Training: 100%|██████████| 542/542 [05:11<00:00,  1.74it/s]
Validation: 100%|██████████| 115/115 [00:26<00:00,  4.32it/s]


Additional Epoch 3, Train Loss: 0.014429134706919004, Validation Loss: 0.2460505320283435, Validation Accuracy: 94.84873262469338%


Training: 100%|██████████| 542/542 [05:18<00:00,  1.70it/s]
Validation: 100%|██████████| 115/115 [00:28<00:00,  3.98it/s]


Additional Epoch 4, Train Loss: 0.07224217320746627, Validation Loss: 0.3013495525927283, Validation Accuracy: 93.21340964840556%


Training: 100%|██████████| 542/542 [05:10<00:00,  1.74it/s]
Validation: 100%|██████████| 115/115 [00:26<00:00,  4.39it/s]


Additional Epoch 5, Train Loss: 0.017780617791924663, Validation Loss: 0.2904463102285777, Validation Accuracy: 94.00381575361133%


Training: 100%|██████████| 542/542 [04:52<00:00,  1.85it/s]
Validation: 100%|██████████| 115/115 [00:26<00:00,  4.41it/s]


Additional Epoch 6, Train Loss: 0.01167543907671965, Validation Loss: 0.29568716043812976, Validation Accuracy: 94.2491142000545%


Training: 100%|██████████| 542/542 [05:02<00:00,  1.79it/s]
Validation: 100%|██████████| 115/115 [00:28<00:00,  3.99it/s]


Additional Epoch 7, Train Loss: 0.029326659573286475, Validation Loss: 0.333870319723697, Validation Accuracy: 93.10438811665304%


Training: 100%|██████████| 542/542 [05:21<00:00,  1.68it/s]
Validation: 100%|██████████| 115/115 [00:29<00:00,  3.95it/s]


Additional Epoch 8, Train Loss: 0.05275589163387533, Validation Loss: 0.34272687848302746, Validation Accuracy: 92.47751430907604%


Training: 100%|██████████| 542/542 [05:21<00:00,  1.68it/s]
Validation: 100%|██████████| 115/115 [00:26<00:00,  4.29it/s]


Additional Epoch 9, Train Loss: 0.014514613297154018, Validation Loss: 0.3318232206764072, Validation Accuracy: 93.40419732897247%


Training: 100%|██████████| 542/542 [04:31<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Additional Epoch 10, Train Loss: 0.02071174676099144, Validation Loss: 0.25540221354754267, Validation Accuracy: 94.38539111474516%


Training: 100%|██████████| 542/542 [04:31<00:00,  1.99it/s]
Validation: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Additional Epoch 11, Train Loss: 0.008840903803002387, Validation Loss: 0.2679339782932758, Validation Accuracy: 94.57617879531207%


Training: 100%|██████████| 542/542 [06:04<00:00,  1.49it/s]
Validation: 100%|██████████| 115/115 [00:29<00:00,  3.87it/s]


Additional Epoch 12, Train Loss: 0.015736363778317747, Validation Loss: 0.3600545950599474, Validation Accuracy: 91.9051512673753%


Training: 100%|██████████| 542/542 [04:56<00:00,  1.83it/s]
Validation: 100%|██████████| 115/115 [00:26<00:00,  4.26it/s]


Additional Epoch 13, Train Loss: 0.0239962086149727, Validation Loss: 0.40338166666207026, Validation Accuracy: 91.76887435268466%


Training: 100%|██████████| 542/542 [04:47<00:00,  1.89it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 14, Train Loss: 0.023773322177895583, Validation Loss: 0.28857912298272703, Validation Accuracy: 94.4126464976833%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 15, Train Loss: 0.007732784138767626, Validation Loss: 0.2927628387937543, Validation Accuracy: 94.4126464976833%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 16, Train Loss: 0.04051261291414116, Validation Loss: 0.46329555944719925, Validation Accuracy: 89.72472063232489%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 17, Train Loss: 0.010644173093068547, Validation Loss: 0.27159691222971205, Validation Accuracy: 94.54892341237394%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 18, Train Loss: 0.005830461322671586, Validation Loss: 0.31345982527288035, Validation Accuracy: 94.35813573180704%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 19, Train Loss: 0.017597864305531147, Validation Loss: 0.3351692418559295, Validation Accuracy: 93.75851730716816%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 20, Train Loss: 0.015025084544745989, Validation Loss: 0.4068621974920184, Validation Accuracy: 91.85064050149904%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 21, Train Loss: 0.010660814999920649, Validation Loss: 0.34861497379231127, Validation Accuracy: 94.08558190242573%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 22, Train Loss: 0.008634482588686666, Validation Loss: 0.38019896332335795, Validation Accuracy: 92.31398201144727%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 23, Train Loss: 0.025416586718128572, Validation Loss: 0.3532532045321521, Validation Accuracy: 93.21340964840556%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.74it/s]


Additional Epoch 24, Train Loss: 0.0061646138516099165, Validation Loss: 0.3217790649293499, Validation Accuracy: 94.16734805124013%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 25, Train Loss: 0.020124297714177897, Validation Loss: 0.30434746934665585, Validation Accuracy: 93.86753883892068%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.74it/s]


Additional Epoch 26, Train Loss: 0.005272617719545055, Validation Loss: 0.33750800008298504, Validation Accuracy: 94.00381575361133%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 27, Train Loss: 0.0042250561296632745, Validation Loss: 0.34130839235650534, Validation Accuracy: 93.84028345598256%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 28, Train Loss: 0.004459403390391314, Validation Loss: 0.34314549614945683, Validation Accuracy: 93.81302807304442%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.74it/s]


Additional Epoch 29, Train Loss: 0.00419014133395008, Validation Loss: 0.3429097923608701, Validation Accuracy: 94.03107113654947%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Additional Epoch 30, Train Loss: 0.004581312163062014, Validation Loss: 0.34553859252380237, Validation Accuracy: 94.0583265194876%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 31, Train Loss: 0.004750180773645697, Validation Loss: 0.33408303040815635, Validation Accuracy: 94.49441264649768%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 32, Train Loss: 0.004364845083445812, Validation Loss: 0.3279630865796603, Validation Accuracy: 94.19460343417825%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Additional Epoch 33, Train Loss: 0.052622352593341616, Validation Loss: 0.36640974492187245, Validation Accuracy: 92.94085581902425%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Additional Epoch 34, Train Loss: 0.013897899404628166, Validation Loss: 0.3726689679841713, Validation Accuracy: 92.80457890433361%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 35, Train Loss: 0.015035595828142527, Validation Loss: 0.3276033321850738, Validation Accuracy: 93.1588988825293%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.71it/s]


Additional Epoch 36, Train Loss: 0.01110592156583555, Validation Loss: 0.41784513546747054, Validation Accuracy: 92.09593894794222%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 37, Train Loss: 0.011971177908658672, Validation Loss: 0.34758663390661954, Validation Accuracy: 93.34968656309621%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.71it/s]


Additional Epoch 38, Train Loss: 0.004212810537413761, Validation Loss: 0.36092654784477457, Validation Accuracy: 93.64949577541564%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 39, Train Loss: 0.020475973197602214, Validation Loss: 0.3305898218898293, Validation Accuracy: 93.64949577541564%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.71it/s]


Additional Epoch 40, Train Loss: 0.007571281973714845, Validation Loss: 0.37186514586110797, Validation Accuracy: 93.62224039247751%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 41, Train Loss: 0.00824257102365704, Validation Loss: 0.28628237961881553, Validation Accuracy: 94.52166802943582%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.72it/s]


Additional Epoch 42, Train Loss: 0.013496069107807292, Validation Loss: 0.3614103731447378, Validation Accuracy: 93.34968656309621%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.71it/s]


Additional Epoch 43, Train Loss: 0.01079075410964538, Validation Loss: 0.42578056366349737, Validation Accuracy: 92.42300354319978%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.74it/s]


Additional Epoch 44, Train Loss: 0.006333805954974226, Validation Loss: 0.3312187950769533, Validation Accuracy: 94.14009266830199%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 45, Train Loss: 0.0041809465111500715, Validation Loss: 0.31436469634934383, Validation Accuracy: 94.52166802943582%
Model saved to ../models/2024-05-25-18-53-26-deepfaune-finetuned-epochs50-lr1e-05.pt
Loading test data...


Preloading images to GPU: 100%|██████████| 3557/3557 [03:33<00:00, 16.64it/s]


Testing the model...


Testing: 100%|██████████| 112/112 [00:23<00:00,  4.71it/s]

Test Accuracy: 94.20860275513073%





In [4]:
total_epochs = 55
learning_rate = 1e-7

while True:
    more_epochs = int(input("Enter the number of additional epochs to continue training (0 to stop): "))
    total_epochs += more_epochs
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    if more_epochs == 0:
        break
    for epoch in range(more_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        val_loss, val_accuracy = validate(model, val_loader, criterion, device)
        print(f'Additional Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%')
    save_model(model, total_epochs, learning_rate)


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 1, Train Loss: 0.002243481088951251, Validation Loss: 0.32722881228284145, Validation Accuracy: 94.82147724175525%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 2, Train Loss: 0.0022394274977986365, Validation Loss: 0.32746408745008665, Validation Accuracy: 94.82147724175525%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 3, Train Loss: 0.002229393843951057, Validation Loss: 0.32786337191390263, Validation Accuracy: 94.82147724175525%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 4, Train Loss: 0.0022240874828223216, Validation Loss: 0.3283195710445808, Validation Accuracy: 94.8759880076315%


Training: 100%|██████████| 542/542 [04:30<00:00,  2.00it/s]
Validation: 100%|██████████| 115/115 [00:24<00:00,  4.73it/s]


Additional Epoch 5, Train Loss: 0.0022193364509803777, Validation Loss: 0.3285740491295947, Validation Accuracy: 94.8759880076315%
Model saved to ../models/2024-05-25-21-00-25-deepfaune-finetuned-epochs60-lr1e-07.pt


## Thoughts
- Decreating learning rate below 1e-6 doesn't help
- Next step is to look into number of layers frozen
- Perhaps unfreeze all, then slowly increase number of frozen layers? Look into best practice
- Otherwise augment dataset - but is that the issue? What tests are thereforre this?
- Before augmenting dataset, look at loss function for wild boar instead - likely better resutls - i.e. fine tune for wild boar.