In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import time
import copy


import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as ds
from torchvision import models, transforms, utils, datasets
from torch.utils import data
from torchsummary import summary
from torch.optim import lr_scheduler

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Data Read-In and Loading

In [2]:
fish_df = pd.read_csv("data/fish_df.csv")

is_fish_df = pd.read_csv("data/is_fish.csv").iloc[:,1:]
is_fish_df["local_paths"] = is_fish_df["Species"].astype(str) + "/" + is_fish_df["Filename"]
path_set = set(is_fish_df["local_paths"])

In [3]:
try:
    acc_df = pd.read_csv("data/accuracies.csv")
except:
    acc_df = pd.DataFrame()

In [4]:
species_count = 286 #len(is_fish_df["Species"].unique())

In [5]:
# We normalize to imagenet mean for the data (https://stackoverflow.com/questions/58151507/why-pytorch-officially-use-mean-0-485-0-456-0-406-and-std-0-229-0-224-0-2)
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomResizedCrop(size=(224, 224)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.Resize([224, 224])]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.Resize([224, 224])])}

In [6]:
data_dirs = ["/home/shivaram/DS/Projects/FishID/data/expanded_model_data/is_fish_scientific/", "/home/shivaram/DS/Projects/FishID/data/expanded_model_data/is_fish_common/", "/home/shivaram/DS/Projects/FishID/data/expanded_model_data/is_fish_mixed/"]
diff_image_datasets = {dd:{x: datasets.ImageFolder(os.path.join(dd, x),
                                          data_transforms[x])
                  for x in ['train', 'val', 'test']} for dd in data_dirs}

In [7]:
image_datasets = diff_image_datasets["/home/shivaram/DS/Projects/FishID/data/expanded_model_data/is_fish_scientific/"]

In [8]:
batch_size = 64#8#32#64
#epoch_samples = 2560# len(samples_weight)

In [9]:
weighted_samplers = {}
for subset in ["train", "val", "test"]:
    target = image_datasets[subset].targets
    
    if subset != "test":
        class_sample_count =np.array([ len(np.where(target == t)[0]) for t in np.unique(target)])
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[t] for t in target])
        samples_weight = torch.from_numpy(samples_weight)
        epoch_samples = len(samples_weight)

        sampler = data.WeightedRandomSampler(samples_weight, epoch_samples)
        weighted_samplers[subset] = sampler
    else:
        sampler = data.SequentialSampler(image_datasets[subset])
        weighted_samplers[subset] = sampler
    


dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], sampler = weighted_samplers[x], 
                                              batch_size=batch_size, num_workers=4)
              for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
class_names = image_datasets['train'].classes

# Model Training and Evaluation Methods

In [10]:
def train_model(dataloaders, model, criterion, optimizer, scheduler, num_epochs=25, verbose = True, plateau = False, early_stop = 15):
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epoch_accs = []
    
    curr_val_acc = 0
    stop_count = 0
    

    for epoch in range(num_epochs):
        epa = {"Epoch": epoch}
        
        if verbose:
            print(f'Epoch {epoch + 1}/{num_epochs}')
            print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            if not plateau and phase == 'train':
                scheduler.step()
            elif plateau and phase == 'val':
                scheduler.step(epoch_loss)

            
            epa[f'{phase} Loss'] = float(epoch_loss)
            epa[f'{phase} Accuracy'] = float(epoch_acc)
            if verbose:
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
                
            if phase == 'val':
                if curr_val_acc < epoch_acc:
                    curr_val_acc = epoch_acc
                    stop_count = 0
                else:
                    stop_count += 1
                

        if verbose:
            print()
            
        epoch_accs.append(epa)
        
        if stop_count >= early_stop:
            print(f"No Validation Accuracy decrease over the last {early_stop} epochs. Stopping training...")
            break

    time_elapsed = time.time() - since
    if verbose:
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    epoch_accs = pd.DataFrame(epoch_accs)
    return model, best_acc, epoch_accs


In [11]:
def test_model(model, criterion, phase = "test", verbose = True):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    ts = 0
    
    for inputs, labels in dataloaders[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        ts += len(preds)

    epoch_loss = running_loss / ts
    epoch_acc = running_corrects.double() / ts

    if verbose:
        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
    return epoch_acc

In [12]:
def train_test_save_model(save_path, dataloaders, model, criterion, optimizer, scheduler, num_epochs = 25, verbose = True, plateau = False, test = False, early_stop = 10):
    
    trained_model, val_acc, epoch_accs = train_model(dataloaders, model, criterion, optimizer, scheduler, num_epochs, verbose, plateau, early_stop = early_stop) 
    torch.save(model.state_dict(), save_path)
    if test:
        test_acc = test_model(trained_model, criterion)
    else:
        test_acc = -1
    if verbose:
        print(f"Val Accuracy: {val_acc}")
        print(f"Test Accuracy: {test_acc}")        

    return trained_model, val_acc, test_acc, epoch_accs

In [13]:
def build_from_path(model_ft, path, image_datasets=image_datasets, weighted_samplers = weighted_samplers, decay = False, train_batch=64, test_batch=64, epochs = 30, early_stop = 10): #criterion = criterion, optimizer_ft = optimizer_ft, exp_lr_scheduler = exp_lr_scheduler, 
    if os.path.exists(path):
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], sampler = weighted_samplers[x], batch_size=test_batch, num_workers=4) for x in ['train', 'val', 'test']}
        model_ft.load_state_dict(torch.load(path))
        test_model(model_ft, criterion)
    else:
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], sampler = weighted_samplers[x], batch_size=train_batch, num_workers=4) for x in ['train', 'val', 'test']}
        if not decay:
            model_ft, val_acc, test_acc, epoch_accs = train_test_save_model(path, dataloaders, model_ft, criterion, optimizer_ft, exp_lr_scheduler, epochs, True, early_stop = early_stop)
        else:
            model_ft, val_acc, test_acc, epoch_accs = train_test_save_model(path, dataloaders, model_ft, criterion, optimizer_ft, exp_lr_scheduler, epochs, True, True, early_stop = early_stop)            

In [14]:
def test_model_k(model, criterion, phase = "test", verbose = True, k = 5):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    ts = 0
    if k > 1:
        print(f"Evaluating Top {k} Accuracy...")
    
    for inputs, labels in dataloaders[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        item_ind = 0
        
        for ind_preds in outputs:
            item_class_rankings = sorted(range(species_count), key=lambda k: ind_preds[k], reverse = True)
            correct_label = labels.data[item_ind]
            rank = item_class_rankings.index(correct_label)
            item_ind += 1
            if rank < k:
                running_corrects += 1
            
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        
        running_loss += loss.item() * inputs.size(0)
        #running_corrects += torch.sum(preds == labels.data)
        ts += len(preds)

    epoch_loss = running_loss / ts
    epoch_acc = running_corrects / ts

    if verbose:
        print(f'{k} Acc: {epoch_acc:.4f}')
        
    return epoch_acc

In [15]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

NVIDIA GeForce RTX 3060 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


# Transfer Learning off Limited Species Classifier

In [16]:
# Load already trained model
model_ft =  models.convnext_tiny(pretrained = True)
for param in model_ft.parameters():
    param.requires_grad = False
for param in model_ft.classifier.parameters():
    param.requires_grad = True
for param in model_ft.avgpool.parameters():
    param.requires_grad = True
for param in model_ft.features[7].parameters():
    param.requires_grad = True
#Unfreeze the next convolutional block
for param in model_ft.features[6].parameters():
    param.requires_grad = True
    
# Create Last Layer
num_ftrs = model_ft.classifier[2].in_features
model_ft.classifier[2] = nn.Linear(num_ftrs, 92)

criterion = nn.CrossEntropyLoss()


# Load old transferred weights
best_path = "models/92_classifier/ct_fp2ad_mixed.pt"
model_ft.load_state_dict(torch.load(best_path))

# Change last layer
model_ft.classifier[2] = nn.Linear(num_ftrs,species_count)
model_ft = model_ft.to(device)


In [42]:
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
mod_path = "models/286_classifier/transfer_model_001.pt"
exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, factor = .2, patience = 5)
build_from_path(model_ft, mod_path, train_batch=8, test_batch=32, decay = True, epochs = 30, early_stop = 10)

test Loss: 1.5204 Acc: 0.6517


In [43]:
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.0001, momentum=0.9)
mod_path = "models/286_classifier/transfer_model_0001.pt"
exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, factor = .2, patience = 5)
build_from_path(model_ft, mod_path, train_batch=8, test_batch=32, decay = True, epochs = 30, early_stop = 10)

test Loss: 1.5356 Acc: 0.6494


# Learning New Model from Scratch

In [46]:
model_ft =  models.convnext_tiny(pretrained = True)
for param in model_ft.parameters():
    param.requires_grad = False
for param in model_ft.classifier.parameters():
    param.requires_grad = True
for param in model_ft.avgpool.parameters():
    param.requires_grad = True
for param in model_ft.features[7].parameters():
    param.requires_grad = True
#Unfreeze the next convolutional block
for param in model_ft.features[6].parameters():
    param.requires_grad = True
    
# Create Last Layer
num_ftrs = model_ft.classifier[2].in_features
model_ft.classifier[2] = nn.Linear(num_ftrs, species_count)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

In [41]:
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
mod_path = "models/286_classifier/from_scratch_model_001.pt"
exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer_ft, factor = .2, patience = 5)
build_from_path(model_ft, mod_path, train_batch=8, test_batch=32, decay = True, epochs = 30, early_stop = 10)

test Loss: 1.4563 Acc: 0.6345


## Top N Accuracies

In [18]:
best_path = "models/286_classifier/transfer_model_001.pt"
model_ft.load_state_dict(torch.load(best_path))

<All keys matched successfully>

In [19]:
criterion = nn.CrossEntropyLoss()
test_model_k(model_ft, criterion, k=3)
test_model_k(model_ft, criterion)

Evaluating Top 3 Accuracy...
3 Acc: 0.8226
Evaluating Top 5 Accuracy...
5 Acc: 0.8740


0.8739789964994166