# Notebook to train and evaluate the classifier

## Python imports

In [1]:
%load_ext autoreload
%autoreload 2
from CustomDataset import CustomDataset
from EarlyStopper import EarlyStopper
from train_utils import initialize_weights, train
from test_utils import get_test_loader, test_model
from Classifier import BOURRICOT

import copy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime
import os
import pandas as pd
from pathlib import Path
from PIL import Image
from sklearn.metrics import accuracy_score, fbeta_score, confusion_matrix, ConfusionMatrixDisplay
from timeit import default_timer as timer
from tqdm.notebook import tqdm
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torchsummary import summary
from torchvision import transforms, utils
from torchvision.ops import sigmoid_focal_loss
from torch.utils.data import ConcatDataset, Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
# tensorflow
import tensorflow as tf

  from .autonotebook import tqdm as notebook_tqdm


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3164075/2260974858.py", line 26, in <module>
    from torchsummary import summary
ModuleNotFoundError: No module named 'torchsummary'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2057, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/ultratb.py", line 1118, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/ultratb.py", line 1012, in structured_traceback
    return VerboseTB.structured_traceback(
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/ultratb.py", l

In [None]:
# Set fixed random number seed
torch.manual_seed(42)

## Data pipeline

### Specify paths to your data

In [None]:
datasets_path = '/media/mlreds/ntfs_partition/ECO4AI/' # TODO set your own path

# paths to the training datasets
paths_training = [
    Path(datasets_path, 'col_dataset/training'),
    Path(datasets_path, 'HDIN/training'),
]
paths_validation = [
    Path(datasets_path, 'col_dataset/validation'),
    Path(datasets_path, 'HDIN/validation'),
]
# paths to the fine-tuning datasets
ft_paths_training = [
    Path(datasets_path, 'Himax_Dataset-master/training'),
    Path(datasets_path, 'students_ds_new2/training'),
]
ft_paths_validation = [
    Path(datasets_path, 'students_ds_new2/validation'),
]
# paths to the testing datasets
paths_test_A = [
    Path(datasets_path, 'col_dataset/testing'),
    Path(datasets_path, 'HDIN/testing'),
]
paths_test_B =[
    Path(datasets_path, 'Himax_Dataset-master/validation'),
    Path(datasets_path, 'students_ds_new2/testing'),
]

### Transforms

In [None]:
img_width, img_height = (200, 200)
batchsize = 64

In [None]:
# general transforms (applied to train, validation, and test sets)
general_transforms_old = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Grayscale(),
    transforms.Resize((img_width, img_height)),
])
general_transforms_beg = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((img_width, img_height)),
])
general_transforms_end = transforms.Compose([
    transforms.ToTensor(), # transform to 0;1
    transforms.Normalize(.5, .5),
])
general_transforms = transforms.Compose([
    *general_transforms_beg.transforms,
    *general_transforms_end.transforms,
])

# transformations specific to the training set
train_transforms = transforms.Compose([
    *general_transforms_beg.transforms,
    transforms.RandAugment(),
    transforms.AugMix(),
    *general_transforms_end.transforms,
])

### Datasets

In [None]:
# init Datasets
training_dataset = ConcatDataset([CustomDataset(path, img_width, img_height, train_transforms)
                                  for path in paths_training])
validation_dataset = ConcatDataset([CustomDataset(path, img_width, img_height, general_transforms)
                                    for path in paths_validation])

### DataLoaders

In [None]:
# init DataLoaders
training_loader = DataLoader(
    training_dataset,
    batch_size=batchsize,
    shuffle=True,
    num_workers=8,
    prefetch_factor=12,
)
validation_loader = DataLoader(
    validation_dataset,
    batch_size=batchsize,
    shuffle=False,
    num_workers=8,
    prefetch_factor=4,
)

In [None]:
try: labels = training_loader.dataset.labels
except: labels = np.concatenate([ds.labels for ds in training_loader.dataset.datasets])
print('Number of examples per class:', np.unique(labels, return_counts=True)[1])

### Double-check the transforms

In [None]:
def plot_dataset_images(dataloader, square_size=3, figsize=10, is_batchsize_1=False):
    """Plots images from the given data loader.
    Partial credit to O. D'Ancona.

    Parameters
    ----------
    dataloader : torch.utils.data.DataLoader
        DataLoader to plot images from
    square_size : int
        number of rows and cols in the Matplotlib plot
    figsize : int
        Figures' width and height
    is_batchsize_1 : bool
        True if DataLoader's batch size is 1, False otherwise
    """
    # Define a reverse preprocess transform to convert tensors back to images
    reverse_transform = transforms.Compose([
        transforms.ToPILImage(),
        np.array,
    ])

    # Create a grid of subplots
    fig, ax = plt.subplots(square_size, square_size, figsize=(20, 20))
    ax = ax.ravel()

    iterator = iter(dataloader)

    # Loop over the images in the batch and display them in the subplots
    for i in range(square_size**2):
        if is_batchsize_1 or i == 0:
            images, labels = next(iterator)
        ax[i].imshow(reverse_transform(images[0 if is_batchsize_1 else i]), cmap='gray', interpolation=None)
        ax[i].set_title(f"Sample #{i}, label: {labels[0 if is_batchsize_1 else i]}")
        ax[i].axis('off')

    # Show the plot
    plt.tight_layout()
    plt.show()

In [None]:
plot_dataset_images(training_loader)

## Model

In [None]:
# identify physical device to use for training
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
device

In [None]:
# initialize model
model_cls = BOURRICOT
model = model_cls()
model.apply(initialize_weights)
model = model.to(device)
summary(model, (1,img_width,img_height))

## Training

In [None]:
# init tensorboard
%load_ext tensorboard

In [None]:
# display tensorboard, and bind port 6008 for users running this notebook locally
%tensorboard --logdir=runs --bind_all --port 6008

In [None]:
# training parameters
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = partial(sigmoid_focal_loss, reduction='mean', alpha=.85, gamma=4.)
# gamma: balance easy vs hard examples; higher gamma gives more weight to incorrect predictions
# alpha: weighs the loss of class 1 by α, and class 0 by (1 - α)
# alpha tackles class imbalance while gamma helps the model learn better representation for harder examples
model_path = 'models/clf.pth'
early_stopper = EarlyStopper(patience=3, min_delta=1e-3)
n_epochs = 100

In [None]:
# train the model
history = {
    'train': {
        'targets': [],
        'preds': [],
    },
    'validation': {
        'targets': [],
        'preds': [],
    },
}
print(optimizer)
print(criterion)
train(model, optimizer, criterion, 
      training_loader, validation_loader, 
      n_epochs, history, early_stopper, model_path, device)

## Fine-tuning

### Prepare model

In [None]:
# load pre-trained model
model = model_cls()
try: model.load_state_dict(torch.load(model_path)['model_state_dict'])
except: model.load_state_dict(torch.load(model_path))
model = model.to(device)

In [None]:
# freeze convolution layers
n_convs_to_train = 0
conv_layers = [(name, param) for name, param in model.named_parameters() if 'features' in name]
for name, param in conv_layers[:-n_convs_to_train*2]:
    param.requires_grad = False
    print(name)
ft_trainable_model_params = [param for param in model.features.parameters()][-n_convs_to_train*2:] + \
                            [param for param in model.classifier.parameters()]

### Prepare Datasets & DataLoaders

In [None]:
# init Datasets for Fine-Tuning 

# transformations specific to the training set for fine-tuning
ft_train_transforms_old = transforms.Compose([
    *general_transforms.transforms,
    transforms.RandAugment(),
    transforms.AugMix(),
])
ft_train_transforms = transforms.Compose([
    *general_transforms_beg.transforms,
    transforms.RandAugment(),
    transforms.AugMix(),
    *general_transforms_end.transforms,
])

ft_training_dataset = ConcatDataset([CustomDataset(path, img_width, img_height, ft_train_transforms)
                                     for path in ft_paths_training])
ft_validation_dataset = ConcatDataset([CustomDataset(path, img_width, img_height, general_transforms)
                                       for path in ft_paths_validation])

In [None]:
# init DataLoaders for Fine-Tuning 
ft_training_loader = DataLoader(
    ft_training_dataset,
    batch_size=batchsize,
    shuffle=True,
    num_workers=12,
    prefetch_factor=24,
)
ft_validation_loader = DataLoader(
    ft_validation_dataset,
    batch_size=batchsize,
    shuffle=False,
    num_workers=8,
    prefetch_factor=4,
)

In [None]:
try: labels_ft = ft_training_loader.dataset.labels
except: labels_ft = np.concatenate([ds.labels for ds in ft_training_loader.dataset.datasets])
print('Number of examples per class:', np.unique(labels_ft, return_counts=True)[1])

### Training

In [None]:
ft_optimizer = optim.Adam(ft_trainable_model_params, lr=1e-4)
#ft_lr_scheduler = optim.lr_scheduler.StepLR(ft_optimizer, step_size=10, gamma=1e-2) # decay LR by a factor of 0.01 every 50 epochs
#ft_lr_scheduler = optim.lr_scheduler.ExponentialLR(ft_optimizer, gamma=.02)

ft_criterion = partial(sigmoid_focal_loss, reduction='mean', alpha=.55, gamma=6.)
# gamma: balance easy vs hard examples; higher gamma gives more weight to incorrect predictions
# alpha: weighs the loss of class 1 by α, and class 0 by (1 - α)
# alpha tackles class imbalance while gamma helps the model learn better representation for harder examples
ft_model_path = model_path.replace('.pth', '_ft.pth')
ft_early_stopper = EarlyStopper(patience=100, min_delta=1e-3)
ft_n_epochs = 1500

In [None]:
ft_history = {
    'train': {
        'targets': [],
        'preds': [],
    },
    'validation': {
        'targets': [],
        'preds': [],
    },
}
train(model, ft_optimizer, ft_criterion, ft_training_loader, ft_validation_loader, ft_n_epochs, ft_history, ft_early_stopper, ft_model_path, device)

## Testing

In [None]:
test_loader_A = get_test_loader(paths_test_A, img_width, img_height, general_transforms)
test_loader_B = get_test_loader(paths_test_B, img_width, img_height, general_transforms)
test_loader_all = get_test_loader(paths_test_A + paths_test_B, img_width, img_height, general_transforms)

In [None]:
test_model(model_path, model_cls,
           [test_loader_A, test_loader_B, test_loader_all], 
           ['Collision & HDIN', 'Himax & Students', 'All'],
           use_sigmoid=True)

In [None]:
test_model(ft_model_path, model_cls,
           [test_loader_A, test_loader_B, test_loader_all], 
           ['Collision & HDIN', 'Himax & Students', 'All'],
           use_sigmoid=True)

## Export to ONNX

In [None]:
model_to_load_path = ft_model_path
onnx_model_path = model_to_load_path + '.onnx'

In [None]:
# load model to export
model = model_cls(binarize_output=True)
try: model.load_state_dict(torch.load(model_to_load_path)['model_state_dict'])
except: model.load_state_dict(torch.load(model_to_load_path))

In [None]:
# export to onnx
dummy_input = torch.rand(1, 1, 200, 200)
torch.onnx.export(model, dummy_input, onnx_model_path, input_names=['input'], output_names=['output'])

In [None]:
model(torch.rand(1, 1, 200, 200))