In [1]:
import sys
import numpy as np
import timm
import torch
from torch import tensor
import torch.nn as nn
from torchvision.transforms import InterpolationMode, transforms
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from PIL import Image
import os
from tqdm import tqdm

CROP_SIZE = 182
BACKBONE = "vit_large_patch14_dinov2"
weight_path = "../models/fine-tuned-deepfaune-vit_large_patch14_dinov2.lvd142m.pt"
train_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/train"
val_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/val"
test_path = "/media/tom-ratsakatika/CRUCIAL 4TB/FCC Camera Trap Data/split_data/test"

ANIMAL_CLASSES = ["badger", "ibex", "red deer", "chamois", "cat", "goat", "roe deer", "dog", "squirrel", "equid", "genet",
                  "hedgehog", "lagomorph", "wolf", "lynx", "marmot", "micromammal", "mouflon",
                  "sheep", "mustelid", "bird", "bear", "nutria", "fox", "wild boar", "cow"]

class AnimalDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.images = []
        self.labels = []
        for label in os.listdir(directory):
            label_dir = os.path.join(directory, label)
            if os.path.isdir(label_dir):
                for image in os.listdir(label_dir):
                    self.images.append(os.path.join(label_dir, image))
                    self.labels.append(ANIMAL_CLASSES.index(label))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

class Classifier(nn.Module):
    def __init__(self, freeze_up_to_layer=16):
        super(Classifier, self).__init__()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = timm.create_model(BACKBONE, pretrained=False, num_classes=len(ANIMAL_CLASSES), dynamic_img_size=True)
        state_dict = torch.load(weight_path, map_location=torch.device(device))['state_dict']
        self.model.load_state_dict({k.replace('base_model.', ''): v for k, v in state_dict.items()})

        # Freeze layers up to the specified layer
        if freeze_up_to_layer is not None:
            for name, param in self.model.named_parameters():
                if self._should_freeze_layer(name, freeze_up_to_layer):
                    param.requires_grad = False

        self.transforms = transforms.Compose([
            transforms.Resize(size=(CROP_SIZE, CROP_SIZE), interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=None),
            transforms.ToTensor(),
            transforms.Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
        ])

    def _should_freeze_layer(self, name, freeze_up_to_layer):
        if 'blocks' in name:
            block_num = int(name.split('.')[1])
            if block_num <= freeze_up_to_layer:
                return True
        return False

    def forward(self, x):
        return self.model(x)

    def predict(self, image):
        img_tensor = self.transforms(image).unsqueeze(0)
        with torch.no_grad():
            output = self.forward(img_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            top_p, top_class = probabilities.topk(1, dim=1)
            return ANIMAL_CLASSES[top_class.item()], top_p.item()

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return running_loss / len(dataloader), accuracy

def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_epochs = 10  # Set the number of epochs
    batch_size = 4  # Set the batch size
    learning_rate = 1e-5  # Reduced learning rate for fine-tuning

    transform = transforms.Compose([
        transforms.Resize((CROP_SIZE, CROP_SIZE), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
    ])

    print('Loading training data...')
    train_dataset = AnimalDataset(train_path, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model = Classifier(freeze_up_to_layer=16).to(device)  # Freeze up to the 16th layer

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print('Training started...')
    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, criterion, optimizer, device)
        print(f'Epoch {epoch+1}, Train Loss: {train_loss}')

    # Load validation data only when needed
    print('Calculating validation loss...')
    val_dataset = AnimalDataset(val_path, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)
    print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%')

    # Load test data only when needed
    print('Testing the model...')
    test_dataset = AnimalDataset(test_path, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    test_accuracy = test(model, test_loader, device)
    print(f'Test Accuracy: {test_accuracy}%')

if __name__ == '__main__':
    main()


Loading training data...
Training started...


Training:  28%|██▊       | 1234/4334 [12:51<32:12,  1.60it/s]