## Imports

In [None]:
%pip install --upgrade setuptools wheel
%pip install torch-summary
%pip install torchmetrics
%pip install pytorch-gradcam

In [None]:
import numpy as np
import os
import gc
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as T
from scipy import stats
import cv2 as cv
import matplotlib
import matplotlib.pyplot as plt
from torchsummary import summary
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.optim.lr_scheduler import _LRScheduler
import torch.nn.functional as F
# for evaluation 
from torchmetrics import ConfusionMatrix
import pandas as pd
import seaborn as sn
import pickle
import re
import time
import json
import shutil
# from pytorch_grad_cam import GradCAM
# from pytorch_grad_cam.utils.image import show_cam_on_image
default_matplotlib_backend = matplotlib.get_backend()
print('imported')
print('default_matplotlib_backend: {}'.format(default_matplotlib_backend))

In [None]:
if not os.path.exists('pytorch_utils'):
    !git clone https://github.com/rishiswethan/pytorch_utils.git
    print("waiting for pytorch_utils to be downloaded...")
    while not os.path.exists('pytorch_utils'):
        time.sleep(1)
    !cd pytorch_utils && git checkout v1.0.5 && cd ../..

## Constants

In [None]:
RUN_MODE = ['DEV','LIVE'][1]
SIMPLE_PATH = True  # Set this to false if you want to use custom paths

In [None]:
# Important, to have the same repartition of data between different machines
np.random.seed(42)
windows = (True if (os.name == 'nt') else False)

def create_folder(new_path):
    if not os.path.exists(new_path):
        os.makedirs(new_path)

if SIMPLE_PATH or windows:
    # Below path only applies for windows or if you want to use paths from root
    root_path = os.getcwd() # path of the root project folder
    extracted_data_path = root_path + "\\datasets\\dataset_4a_n_crop\\".replace("\\", os.sep) # this one should exist before executing xD
    weights_path =  root_path + "\\weights\\".replace("\\", os.sep)
    stats_path = root_path + "\\stats\\".replace("\\", os.sep)
    npy_data_path = root_path + "\\npy_data\\".replace("\\", os.sep)
    model_save_path = root_path + "\\model\\".replace("\\", os.sep)
    pretrained_models_folder = root_path + "\\pretrained_models\\".replace("\\", os.sep)
else:
    # Below path only applies for linux, if you want to use custom paths
    files_path_name = "files_seq1"
    root_path = '/home/rsaric/Desktop/plant_classification/hyper_par1/' # path of the root project folder
    extracted_data_path = '/media/medaghub_data/Rick/Classification_ex01/TRK-3a/dataset_4a_n_crop/' # this one should exist before executing xD
    weights_path =  '/home/rsaric/Desktop/plant_classification/hyper_par1/weights/'
    stats_path = '/home/rsaric/Desktop/plant_classification/hyper_par1/stats/'
    npy_data_path = '/home/rsaric/Desktop/plant_classification/hyper_par1/npy_data/'
    model_save_path = '/home/rsaric/Desktop/plant_classification/hyper_par1/model/model.pth'
    pretrained_models_folder = '/home/rsaric/Desktop/plant_classification/hyper_par1/pretrained_models/'

# creating the paths
if not os.path.exists(npy_data_path):
    os.makedirs(npy_data_path)
if not os.path.exists(weights_path):
    os.makedirs(weights_path)
if not os.path.exists(stats_path):
    os.makedirs(stats_path)

if not os.path.exists(os.path.dirname(model_save_path)):
    os.makedirs(model_save_path)

train_valid_test_split_json_name = 'NEW25split.json'

# Define dataset paths based on the chosen JSON file
if train_valid_test_split_json_name == 'NEW8split.json':
    dataset_path = "/mnt/medaghub-ws/Rick/CR_datasets/dataset_3n_a_auto_cropped/"
elif train_valid_test_split_json_name == 'NEW25split.json':
    dataset_path = "/mnt/medaghub-ws/Rick/CR_datasets/dataset_4a_n_auto_cropped/"
else:
    raise ValueError(f"Unknown JSON file: {train_valid_test_split_json_name}")

# Update extracted_data_path accordingly
extracted_data_path = dataset_path

# getting the list of the classes
class_list = os.listdir(extracted_data_path)
class_list.sort()
print('Number of classes: {}'.format(len(class_list)))

# working device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Selected device: {}'.format(device))
if device.type != 'cpu':
    print('Device name: {}'.format(torch.cuda.get_device_name(device)))

# other cosntants
# you can modify from here: and only modify whats possible to modify (for backbones and optimizers)
BATCH_SIZE = 8 if RUN_MODE == 'DEV' else 16
DAY_TO_USE = 7
IMAGES_IN_A_DAY = 2
SEQ_LEN = DAY_TO_USE * IMAGES_IN_A_DAY
EPOCHS = 60
IMG_SIZE = (350, 350) # base size! if you want to change it you have to do it by T.Resize(*target_size). .npy data already on (350, 350)
RESIZE_SIZE = (128, 128)
NUM_CHANNELS = 3
BACKBONE = 'EfficientNetB2'
OPTIMIZER = 'Adam'

available_backbones = [
    'AlexNet', 'ResNet18', 'ResNet34', 'EfficientNetB0', 'EfficientNetB1', 'EfficientNetB2',
    'EfficientNetV2_S', 'ConvNext_T', 'MobileNet_V3_Small', 'MobileNet_V3_Large',
    'ViT_B_16'
]

avaible_optimizers = ['SGD', 'Adam', 'AdamW', 'RMSprop']

## Data loading

In [None]:
#specify which subdataset to use:
DATASE_GROUP_IDX = None #could be: None, 0, 1, 2

#availiable subdatasets:
indep1 = [18, 9, 26, 29, 34]

dataset_groups = [
    [1, 7, 10, 14, 11, 32, 13, 4, 8, 17, 20, 25, 28, 31, 36],
    [0, 15, 21, 22, 23, 5]  + indep1,
    [37, 33, 16, 2, 3, 6, 12, 19, 24, 27, 30, 35],
]

# some stats about the data
image_files = []
targets = []

# dict helps to go from class_name to class_index
class_dict = dict([(j, i) for i, j in enumerate(class_list)])

# loading all image_paths (IN A SORTED ORDER, this is really important to avoid any weird exceptions)
for class_name in class_dict.keys():
    if class_name.startswith('.'):
        continue
    repetitions_list = os.listdir(extracted_data_path + class_name)
    repetitions_list.sort()
    for repetition in repetitions_list:
        if repetition.startswith('.'):
            continue
        image_list = os.listdir(extracted_data_path + class_name + os.sep + repetition)
        image_list.sort()
        image_files.extend(
            [extracted_data_path + class_name + os.sep + repetition + os.sep + img for img in image_list]
        )
        targets.extend([class_dict[class_name]] * len(image_list))

targets = np.array(targets)

if DATASE_GROUP_IDX is not None:
    dataset_groups = [sorted(el) for el in dataset_groups]
    assert 0 <= DATASE_GROUP_IDX <len(dataset_groups), '...'

    mapper_allcls_to_subcls = {j:i for i,j in enumerate(dataset_groups[DATASE_GROUP_IDX])}
    class_dict = {class_list[j]:i for i,j in enumerate(dataset_groups[DATASE_GROUP_IDX])}
    class_list = [class_list[j] for i, j in enumerate(dataset_groups[DATASE_GROUP_IDX])]

    new_targets = []
    new_image_files = []
    for t,im in zip(targets, image_files):
        if t in mapper_allcls_to_subcls:
            new_targets.append(mapper_allcls_to_subcls[t])
            new_image_files.append(im)
    new_targets = np.array(new_targets)
    targets = new_targets
    image_files = new_image_files

In [None]:
from copy import deepcopy

def _fix_images_paths(images_paths, targets):
    images_paths = deepcopy(images_paths)
    targets = deepcopy(targets)
    while os.path.basename(images_paths[0]).split('_')[1:4] != os.path.basename(image_files[1]).split('_')[1:4]:
        images_paths = images_paths[1:]
        targets = targets[1:]
    while os.path.basename(images_paths[-1]).split('_')[1:4] != os.path.basename(image_files[-2]).split('_')[1:4]:
        images_paths = images_paths[:-1]
        targets = targets[:-1]
    return images_paths, targets

def prepare_dataset(image_files, targets, train_valid_test_split, fix_len=10):
    images_paths, targets_cls = _fix_images_paths(image_files, targets)

    data_as_dict = {}
    for el, _trg in zip(images_paths, targets_cls):
        _cls, _rep_name = el.split(os.sep)[-3:-1]
        k = (_cls, _rep_name)
        if k not in data_as_dict:
            data_as_dict[k] = []
        data_as_dict[k].append((el, _trg))

    # Add more images to the dataset to make it divisible by fix_len
    for k in data_as_dict:
        while len(data_as_dict[k]) % fix_len != 0:
            data_as_dict[k].append(data_as_dict[k][-1])
    
    datas = {}
    for k in train_valid_test_split:
        for rep_name in train_valid_test_split[k]:
            if (k, rep_name) not in data_as_dict:
                continue
            data_type = train_valid_test_split[k][rep_name]
            if data_type not in datas:
                datas[data_type] = []
            files_in_rep = len(data_as_dict[(k, rep_name)])
            for start_idx in range(0,files_in_rep-fix_len+1,2):
                seq = data_as_dict[(k, rep_name)][start_idx:start_idx+fix_len]
                assert len(seq) == fix_len, '...'
                datas[data_type].append(seq)
    return datas

In [None]:
# train_valid_test_split_json_name = 'train_valid_test_split_DS1a.json'
train_valid_test_split_json_name = 'NEW25split.json' 
with open(os.path.join(train_valid_test_split_json_name), 'r') as f:
    train_valid_test_split = json.load(f)
datas = prepare_dataset(image_files, targets, train_valid_test_split, fix_len=max(SEQ_LEN, 10))

In [None]:
USE_AUGMENTATION = False

def rotate_image(image, angle):
    image_center = tuple(np.array(image.shape[1::-1]) / 2)
    rot_mat = cv.getRotationMatrix2D(image_center, angle, 1.0)
    result = cv.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv.INTER_LINEAR)
    return result

def zoom_at(img, zoom=1, angle=0, coord=None):
    cy, cx = [ i/2 for i in img.shape[:-1] ] if coord is None else coord[::-1]
    rot_mat = cv.getRotationMatrix2D((cx,cy), angle, zoom)
    result = cv.warpAffine(img, rot_mat, img.shape[1::-1], flags=cv.INTER_LINEAR)
    return result

import albumentations as A

class ClassificationPlantSequenceDataset(torch.utils.data.Dataset):
    def __init__(self, data, use_augmentation=False):
        self.use_augmentation = use_augmentation
        self.data = data

        if self.use_augmentation:
            self.transform = A.Compose([
                A.Resize(RESIZE_SIZE[0], RESIZE_SIZE[1]),
                A.Rotate(limit=180, p=0.7),
                A.Flip(p=0.9),
                A.ShiftScaleRotate(shift_limit=0.2, scale_limit=(0.2, 0.5)),
            ])
        else:
            self.transform = A.Compose([
                A.Resize(RESIZE_SIZE[0], RESIZE_SIZE[1]),
            ])

    def __getitem__(self, index):
        seq, classes = list(zip(*self.data[index]))
        assert len(set(classes)) == 1, 'wrong seq clses'
        seq_cls = classes[0]

        images = []
        for im_path in seq:
            img = cv.imread(im_path)
            img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
            img = cv.resize(img, IMG_SIZE)
            img = self.transform(image=img)['image']
            img = img.astype(np.float32) / 255.
            img = torch.from_numpy(img.transpose((2, 0, 1)))
            images.append(img)

        images = torch.stack(images)
        images = torch.permute(images, (1, 0, 2, 3))
        return images, seq_cls

    def __len__(self):
        return len(self.data)

# loading data
train_dataset = ClassificationPlantSequenceDataset(datas['train'], use_augmentation=USE_AUGMENTATION)
val_dataset = ClassificationPlantSequenceDataset(datas['valid'])
test_dataset = ClassificationPlantSequenceDataset(datas['test'])

# # combine val and test
val_dataset = torch.utils.data.ConcatDataset([val_dataset, test_dataset])
test_dataset = val_dataset

# data loaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True, drop_last=False)

print('Data loaders created')

In [None]:
dataset = ClassificationPlantSequenceDataset(datas['train'], use_augmentation=USE_AUGMENTATION)
idxs_in_batch = [0,1,2]
sample_size = 10
for idx_in_batch in idxs_in_batch:
    img_idx = np.random.randint(0, len(dataset), size=sample_size)
    sample = [dataset[i] for i in img_idx]
    imgs = sample[idx_in_batch][0].cpu()
    plt.figure(figsize=(20, 10))
    for i in range(sample_size):
        img = imgs[:,i,:,:].permute(1, 2, 0)
        plt.subplot(1, sample_size, i + 1)
        plt.imshow((img * 255).type(torch.uint8))
        plt.title('Class: {}'.format(class_list[sample[idx_in_batch][1]]))
        plt.axis('off')
    plt.show()

#### Dataset class and data loaders

### Data sample (Run this only if you want to see an example of the data)

### LR Scheduler and utility functions

In [None]:
class PolyScheduler(_LRScheduler):
    def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
        self.base_lr = base_lr
        self.warmup_lr_init = 0.0001
        self.max_steps: int = max_steps
        self.warmup_steps: int = warmup_steps
        self.power = 2
        super(PolyScheduler, self).__init__(optimizer, -1, False)
        self.last_epoch = last_epoch

    def get_warmup_lr(self):
        alpha = float(self.last_epoch) / float(self.warmup_steps)
        return [self.base_lr * alpha for _ in self.optimizer.param_groups]

    def get_lr(self):
        if self.last_epoch == -1:
            return [self.warmup_lr_init for _ in self.optimizer.param_groups]
        if self.last_epoch < self.warmup_steps:
            return self.get_warmup_lr()
        else:
            alpha = pow(
                1
                - float(self.last_epoch - self.warmup_steps)
                / float(self.max_steps - self.warmup_steps),
                self.power,
            )
            return [self.base_lr * alpha for _ in self.optimizer.param_groups]

def plot_model_stats(model_name, train_loss, train_acc, test_loss, test_acc):
    fig = plt.figure(figsize=(16, 5))
    ax1 = plt.subplot(1, 2, 1)
    plt.plot(np.arange(len(train_loss)), train_loss, label = 'Train loss')
    plt.plot(np.arange(len(test_loss)), test_loss, label = 'Val loss')
    plt.title('Model: {} - Validation loss: {:.4f}'.format(model_name, test_loss[-1]))
    ax1.set_ylabel("Loss")
    ax1.set_xlabel("Epochs")
    ax1.legend(loc='upper right')
    ax2 = plt.subplot(1, 2, 2)
    plt.plot(np.arange(len(train_acc)), np.array(train_acc) * 100, color='green', label = 'Train accuracy')
    plt.plot(np.arange(len(test_acc)), np.array(test_acc) * 100, color='red', label = 'Val accuracy')
    plt.title('Model: {} - Validation accuracy: {:.2f} %'.format(model_name, test_acc[-1] * 100))
    ax2.set_ylabel("Accuracy")
    ax2.set_xlabel("Epochs")
    ax2.legend(loc='lower right')
    plt.show()

def load_model_weights(BACKBONE=BACKBONE, OPTIMIZER=OPTIMIZER):
    # this function loads the whole model with weights
    pth = os.path.join(weights_path, 'backbone_{}_{}.pth'.format(BACKBONE, OPTIMIZER))
    assert os.path.exists(pth), 'Configuration not found'
    model = torch.load(pth).to(device)
    print('Backbone weights loaded: {}'.format(BACKBONE))
    return model

def load_model_stats(BACKBONE=BACKBONE, OPTIMIZER=OPTIMIZER):
    # this function loads training stats (train/val loss and acc)
    stats_file = os.path.join(stats_path, 'stats_{}_{}.pkl'.format(BACKBONE, OPTIMIZER))
    assert os.path.exists(stats_file), 'Configuration not found'
    with open(stats_file, 'rb') as stats:
        stats = pickle.load(stats)
    print('Train/Validation stats loaded for the model: {}'.format(BACKBONE))
    return stats

## Training and evaluation functions

In [None]:
def _handlezero_division_np(a,b):
    # initialize output tensor with desired value
    c = np.zeros_like(a)
    mask = (b != 0)
    # finally perform division
    c[mask] = a[mask] / b[mask]
    return c

def mathews_correlation_coefficient_np(tp, fp, fn, tn):
    tp = tp.sum().astype(np.float64)
    tn = tn.sum().astype(np.float64)
    fp = fp.sum().astype(np.float64)
    fn = fn.sum().astype(np.float64)
    _numerator = (tp*tn - fp*fn)
    _denomerator = np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))
    x = _numerator / (_denomerator + 1e-11)
    return x

## Define callbacks

In [None]:
import pytorch_utils.callbacks as pt_callbacks
def get_callbacks(
        optimiser,
        result,
        model,
        defined_callbacks=None,
        continue_training=False,
        other_stats=None
):

    if defined_callbacks is None:
        defined_callbacks = {
            'val': pt_callbacks.Callbacks(optimizer=optimiser,
                                          model_save_path=model_save_path + 'model.pth',
                                          training_stats_path=model_save_path + 'training_stats_val',
                                          continue_training=continue_training),

            'train': pt_callbacks.Callbacks(optimizer=optimiser,
                                            training_stats_path=model_save_path + 'training_stats_train',
                                            continue_training=continue_training)
        }

    defined_callbacks['val'].reduce_lr_on_plateau(
        monitor_value=result["val_acc"],
        mode='max',
        factor=0.5,
        patience=5,
        indicator_text="Val LR scheduler: "
    )
    defined_callbacks['val'].model_checkpoint(
        model=model,
        monitor_value=result["val_acc"],
        mode='max',
        indicator_text="Val checkpoint: "
    )
    stop_flag = defined_callbacks['val'].early_stopping(
        monitor_value=result["val_acc"],
        mode='max',
        patience=25,
        indicator_text="Early stopping: "
    )
    defined_callbacks['val'].clear_memory()
    print("_________")
    return defined_callbacks, stop_flag

## Define training loop and evaluation

In [None]:
import pytorch_utils.training_utils as pt_train

def train_loop(
        model,
        optimizer,
        epochs,
        train_loader,
        val_loader,
        model_save_folder=model_save_path,
        initial_lr=0.001,
        weight_decay=None,
        running_hyperopt=False,
        verbose=False,
        continue_training=False,
):
    def get_result_list(history, metric):
        return [history[i][metric] for i in range(len(history))]

    # prep the model save path
    shutil.rmtree(model_save_folder, ignore_errors=True)
    os.makedirs(model_save_folder, exist_ok=True)

    # Train the model using torch
    history = pt_train.fit(
        epochs=epochs,
        lr=initial_lr,
        weight_decay=weight_decay,
        model=model,
        continue_training=continue_training,
        callbacks_function=get_callbacks,
        train_loader=train_loader,
        val_loader=val_loader,
        opt_func=optimizer,
    )

    del model
    # load the best model from checkpoint
    model = torch.load(model_save_path + "model.pth")
    train_loss_history = get_result_list(history, "train_loss")
    train_acc_history = get_result_list(history, "train_acc")
    val_loss_history = get_result_list(history, "val_loss")
    val_acc_history = get_result_list(history, "val_acc")
    return model, train_loss_history, train_acc_history, val_loss_history, val_acc_history


def evaluate_model(model, test_loader, verbose=True, eps=1e-10):
    if verbose:
        print('--------------------------------------------')
        print('Test metrics (on test set)')

    model.eval()

    confusion_matrix = ConfusionMatrix(num_classes=len(class_list))
    eval_preds = list()
    eval_targs = list()

    # computing predictions and confusion matrix
    for i, (images, targets) in enumerate(tqdm(test_loader, position=0, leave=True)):
        images, targets = images.to(device, dtype=torch.float), torch.Tensor(targets).to(device)
        outputs = torch.nn.functional.log_softmax(model(images), dim=1)
        preds = torch.argsort(outputs, dim=1, descending=True)[:, :3]
        eval_preds.extend(preds[:, 0].cpu().numpy())
        eval_targs.extend(targets.cpu().numpy())

    # computing main metrics (acc, precisio, recall and f1 score)
    matrix = confusion_matrix(torch.tensor(eval_preds), torch.tensor(eval_targs))
    accuracy = matrix.trace() / (matrix.sum()+eps)
    # loss = F.cross_entropy(torch.tensor(eval_preds), torch.tensor(eval_targs))
    precision = np.array([matrix[i, i] / (matrix.sum(axis=0)[i]+eps) for i in range(len(class_list))])
    recall = np.array([matrix[i, i] / (matrix.sum(axis=1)[i]+eps) for i in range(len(class_list))])
    f1_score = 2 * precision * recall / (precision + recall+eps)

    # computing false positive rate, false negative rate, false discovery rate, false omission rate
    fp_rate = np.zeros(len(class_list))
    for idx in range(len(class_list)):
        tn = matrix.trace() - matrix[idx, idx]
        fp = np.sum([matrix[j, idx] for j in range(len(class_list)) if j != idx])
        fp_rate[idx] = fp / (fp + tn+eps)

    fn_rate = 1 - recall
    fd_rate = 1 - precision
    specificity = 1 - fp_rate

    fo_rate = np.zeros(len(class_list))
    for idx in range(len(class_list)):
        n = np.sum(np.array(eval_targs) != idx)
        fn = np.sum([matrix.sum(axis=0)[j] - matrix[j, j] for j in range(len(class_list)) if j != idx])
        fo_rate[idx] = fn / (n+eps)

    missclassification_rate = 1 - accuracy
    npv = 1 - fo_rate

    mcc_per_class = []
    for idx in range(len(class_list)):
        tp = matrix[idx, idx].cpu().numpy()
        tn = (matrix.trace() - matrix[idx, idx]).cpu().numpy()
        fp = np.sum([matrix[j, idx] for j in range(len(class_list)) if j != idx])
        fn = np.sum([matrix.sum(axis=0)[j] - matrix[j, j] for j in range(len(class_list)) if j != idx])
        _mcc = mathews_correlation_coefficient_np(tp, fp, fn, tn)
        mcc_per_class.append(_mcc)

    if verbose:
        print('--------------------------------------------')
        print('Accuracy: {:.3f}%'.format(accuracy * 100))
        # print('Loss: {:.3f}'.format(loss))
        print('Average precision: {:.3f}'.format(precision.mean()))
        print('Average recall: {:.3f}'.format(recall.mean()))
        print('Average F1 score: {:.3f}'.format(f1_score.mean()))
        print('Average specificity: {:.3f}'.format(specificity.mean()))
        print('Average false positive rate: {:3f}'.format(fp_rate.mean()))
        print('Average false negative rate: {:3f}'.format(fn_rate.mean()))
        print('Average false discovery rate: {:.3f}'.format(fd_rate.mean()))
        print('Average false omission rate: {:.3f}'.format(fo_rate.mean()))
        print('Missclassification rate: {:.2f}%'.format(missclassification_rate * 100))
        print('Mathews Correlation Coefficient: {:.2f}'.format(np.mean(mcc_per_class)))
        print('--------------------------------------------')
        print('Results by class :')
        print('--------------------------------------------')
        print('{:<15}{:<12}{:<12}{:<12}{:<12}{:<12}{:<12}{:<12}{:<12}{:<12}{:<12}'.format('', 'Precision', 'Recall', 'F1 score', 'Specificity', 'FPR', 'FNR', 'FDR', 'FOR', 'NPV', 'MCC'))
        for idx, class_name in enumerate(class_list):
            print('{:<15}{:<12.2f}{:<12.2f}{:<12.2f}{:<12.3f}{:<12.3f}{:<12.3f}{:<12.3f}{:<12.3f}{:<12.3f}{:<12.3f}'.format(
                class_name, precision[idx], recall[idx], f1_score[idx], specificity[idx], fp_rate[idx], fn_rate[idx], fd_rate[idx], fo_rate[idx], npv[idx], mcc_per_class[idx]
            ))
        print('--------------------------------------------')
        print()

        # ploting confusion matrix
        matrix_df = pd.DataFrame(matrix.numpy(), index=class_list, columns=class_list)
        plt.figure(figsize=(12, 8))
        sn.heatmap(matrix_df, annot=True, fmt='d', cmap='Blues')

    return accuracy, precision, recall, f1_score, matrix_df

In [None]:
def accuracy(
        outputs: torch.Tensor,
        labels: torch.Tensor
):
    """
    Custom accuracy function to override the default one in pt_train
    """

    preds = torch.argmax(outputs, dim=1)

    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


class CustomModelBase(pt_train.CustomModelBase):
    """
    ModelBase override for training and validation steps
    """

    def __init__(self, class_weights=None):
        super(CustomModelBase, self).__init__()
        self.class_weights = class_weights
        self.accuracy_function = accuracy

    def training_step(self, batch):
        images, labels = batch
        out = self(images)  # Generate predictions
        loss = F.cross_entropy(out, labels, weight=self.class_weights)  # Calculate loss with class weights
        acc = accuracy(out, labels)  # Calculate accuracy
        return loss, acc

    def validation_step(self, batch):
        images, labels = batch
        out = self(images)  # Generate predictions
        loss = F.cross_entropy(out, labels, weight=self.class_weights)  # Calculate loss with class weights
        acc = accuracy(out, labels)  # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}

## Definition of the Sequence Models

In [None]:
class R3D_18(CustomModelBase):
    def __init__(self, num_classes):
        super(R3D_18, self).__init__()
        self.model = models.video.r3d_18(weights=models.video.R3D_18_Weights.DEFAULT)
        self.linear = nn.Linear(self.model.fc.out_features, num_classes, bias=True)

    def forward(self, x):
        return self.model(x)

# mc3_18
# 4M parameters
class MC3_18(CustomModelBase):
    def __init__(self, num_classes):
        super(MC3_18, self).__init__()
        self.model = models.video.mc3_18(weights=models.video.MC3_18_Weights.DEFAULT)
        self.linear = nn.Linear(self.model.fc.out_features, num_classes, bias=True)

    def forward(self, x):
        return self.model(x)

# r2plus1d_18
# 4M parameters
class r2plus1d_18(CustomModelBase):
    def __init__(self, num_classes):
        super(r2plus1d_18, self).__init__()
        self.model = models.video.r2plus1d_18(weights=models.video.R2Plus1D_18_Weights.DEFAULT)
        self.linear = nn.Linear(self.model.fc.out_features, num_classes, bias=True)

    def forward(self, x):
        return self.model(x)

# 4M parameters
# no init weights
class Eff_LSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(Eff_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=False)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)
        h0 = torch.zeros(self.num_layers, embeddings.shape[1], self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, embeddings.shape[1], self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff_GRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(Eff_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff_BiGRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(Eff_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff_BiLSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(Eff_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        self.gru = nn.LSTM(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff2_LSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(Eff2_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff2_GRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(Eff2_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=0.5)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff2_BiLSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(Eff2_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff2_BiGRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(Eff2_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class MNV3S_LSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(MNV3S_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class MNV3S_GRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(MNV3S_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class MNV3S_BiLSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(MNV3S_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class MNV3S_BiGRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(MNV3S_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.classifier[-1].out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R18_LSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R18_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers
        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R18_GRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R18_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R18_BiLSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R18_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R18_BiGRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R18_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R34_LSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R34_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R34_GRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R34_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R34_BiLSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R34_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R34_BiGRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R34_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out
# ---

# 4M parameters
class R50_LSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R50_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R50_GRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R50_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R50_BiLSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R50_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R50_BiGRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R50_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out
#-----------------------------#


# 4M parameters
class R101_LSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R101_LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R101_GRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2):
        super(R101_GRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R101_BiLSTM(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R101_BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
        self.lstm = nn.LSTM(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class R101_BiGRU(CustomModelBase):
    def __init__(self, num_classes, hidden_size=200, num_layers=2, bidirectional=True):
        super(R101_BiGRU, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
        self.gru = nn.GRU(self.model.fc.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


In [None]:
from vivit_modules import ViViT
# ViViT Small
ViViT_Small = ViViT(heads=3, depth=4, dim=192)

# ViViT Large
ViViT_Large = ViViT(heads=12, depth=24, dim=768)

## Define pretrained models

In [None]:
DROPOUT = 0.0
HIDDEN_SIZE = 256
PRETRAINED_NUM_CLASSES = len(class_list)  # default was len(class_list)
#--------------------------------
# EfficientNetB1 pretrained
#--------------------------------

class EfficientNetB1(CustomModelBase):
    def __init__(self, num_classes=PRETRAINED_NUM_CLASSES):
        super(EfficientNetB1, self).__init__()
        self.model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.DEFAULT)
        self.model.classifier = nn.Linear(self.model.classifier[-1].in_features, num_classes, bias=True)

    def forward(self, x):
        return self.model(x)


class Eff1_pretrained(EfficientNetB1):
    """
    Load the pretrained model from the .pth file and remove the last layer, so that the model can be used as a feature extractor
    """

    def __init__(
            self,
            num_classes=None,
            pretrained_model_path=f"{pretrained_models_folder}eff_b1.pth"
    ):
        super(Eff1_pretrained, self).__init__()
        # self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)

        loaded_model = torch.load(pretrained_model_path)
        self.load_state_dict(loaded_model.state_dict())

        self.out_features = models.efficientnet_b1().classifier[-1].in_features

        # self.out_features = self.model.fc.out_features

        self.model.classifier = nn.Identity()  # Remove last layer. Final layer in not useful when using this model as feature extractor

    def forward(self, x):
        return self.model(x)


class Eff1_GRU_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(Eff1_GRU_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff1_pretrained()

        self.gru = nn.GRU(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


class Eff1_LSTM_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(Eff1_LSTM_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff1_pretrained()

        self.lstm = nn.LSTM(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


# 4M parameters
class Eff1_BiLSTM_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2, bidirectional=True):
        super(Eff1_BiLSTM_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff1_pretrained()

        self.lstm = nn.LSTM(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff1_BiGRU_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2, bidirectional=True):
        super(Eff1_BiGRU_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff1_pretrained()

        self.gru = nn.GRU(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

#--------------------------------
# EfficientNetB2 pretrained
#--------------------------------

class EfficientNetB2(CustomModelBase):
    def __init__(self, num_classes=PRETRAINED_NUM_CLASSES):
        super(EfficientNetB2, self).__init__()
        self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)
        self.model.classifier = nn.Linear(self.model.classifier[-1].in_features, num_classes, bias=True)

    def forward(self, x):
        return self.model(x)


class Eff2_pretrained(EfficientNetB2):
    """
    Load the pretrained model from the .pth file and remove the last layer, so that the model can be used as a feature extractor
    """

    def __init__(
            self,
            num_classes=None,
            pretrained_model_path=f"{pretrained_models_folder}eff_b2.pth"
    ):
        super(Eff2_pretrained, self).__init__()
        # self.model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.DEFAULT)

        loaded_model = torch.load(pretrained_model_path)
        self.load_state_dict(loaded_model.state_dict())

        self.out_features = models.efficientnet_b2().classifier[-1].in_features

        # self.out_features = self.model.fc.out_features

        self.model.classifier = nn.Identity()  # Remove last layer. Final layer in not useful when using this model as feature extractor

    def forward(self, x):
        return self.model(x)


class Eff2_GRU_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(Eff2_GRU_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff2_pretrained()

        self.gru = nn.GRU(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


class Eff2_LSTM_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(Eff2_LSTM_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff2_pretrained()

        self.lstm = nn.LSTM(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


# 4M parameters
class Eff2_BiLSTM_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2, bidirectional=True):
        super(Eff2_BiLSTM_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff2_pretrained()

        self.lstm = nn.LSTM(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

# 4M parameters
class Eff2_BiGRU_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2, bidirectional=True):
        super(Eff2_BiGRU_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = Eff2_pretrained()

        self.gru = nn.GRU(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, bidirectional=bidirectional)
        self.num_directions = 2 if bidirectional else 1
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN * self.num_directions, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*self.num_directions, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out
#---------------------------------
# EfficientNetV2
#---------------------------------

class EfficientNetV2_S(nn.Module):
    def __init__(self, num_classes=PRETRAINED_NUM_CLASSES):
        super(EfficientNetV2_S, self).__init__()
        self.model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)
        self.model.classifier = nn.Linear(self.model.classifier[-1].in_features, num_classes, bias=True)

    def forward(self, x):
        return self.model(x)


class EffV2_S_pretrained(EfficientNetV2_S):
    """
    Load the pretrained model from the .pth file and remove the last layer, so that the model can be used as a feature extractor
    """

    def __init__(
            self,
            num_classes=None,
            pretrained_model_path=f"{pretrained_models_folder}effv2_s.pth"
    ):
        super(EffV2_S_pretrained, self).__init__()

        loaded_model = torch.load(pretrained_model_path)
        self.load_state_dict(loaded_model.state_dict())

        self.out_features = models.efficientnet_v2_s().classifier[-1].in_features

        self.model.classifier = nn.Identity()  # Remove last layer. Final layer in not useful when using this model as feature extractor

    def forward(self, x):
        return self.model(x)


class EffV2_S_GRU_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(EffV2_S_GRU_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = EffV2_S_pretrained()

        self.gru = nn.GRU(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


class EffV2_S_LSTM_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(EffV2_S_LSTM_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = EffV2_S_pretrained()

        self.lstm = nn.LSTM(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT)
        self.linear1 = nn.Linear(self.hidden_size * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


class EffV2_S_BiGRU_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(EffV2_S_BiGRU_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = EffV2_S_pretrained()

        self.gru = nn.GRU(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT, bidirectional=True)
        self.linear1 = nn.Linear(self.hidden_size * 2 * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        out, h = self.gru(embeddings, h)
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


class EffV2_S_BiLSTM_pretrained(CustomModelBase):
    def __init__(self, num_classes, hidden_size=HIDDEN_SIZE, num_layers=2):
        super(EffV2_S_BiLSTM_pretrained, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers  = num_layers

        self.model = EffV2_S_pretrained()

        self.lstm = nn.LSTM(self.model.out_features, self.hidden_size, self.num_layers, batch_first=True, dropout=DROPOUT, bidirectional=True)
        self.linear1 = nn.Linear(self.hidden_size * 2 * SEQ_LEN, num_classes*2, bias=True)
        self.linear2 = nn.Linear(num_classes*2, num_classes, bias=True)

    def forward(self, x):
        embeddings = []
        for idx in range(SEQ_LEN):
            emb = self.model(x[:,:,idx])
            embeddings.append(emb[:,None])
        embeddings = torch.concat(embeddings, 1)

        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device).requires_grad_()
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device).requires_grad_()

        out, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        out = out.reshape(out.shape[0], -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out


In [None]:
def get_model_by_name(name):
    model_directory = {
        "R3D_18": R3D_18,
        "MC3_18": MC3_18,
        "r2plus1d_18": r2plus1d_18,
        "Eff_LSTM": Eff_LSTM,
        "Eff_GRU": Eff_GRU,
        "Eff_BiGRU": Eff_BiGRU,
        "Eff_BiLSTM": Eff_BiLSTM,
        "Eff2_LSTM": Eff2_LSTM,
        "Eff2_GRU": Eff2_GRU,
        "Eff2_BiLSTM": Eff2_BiLSTM,
        "Eff2_BiGRU": Eff2_BiGRU,
        "MNV3S_LSTM": MNV3S_LSTM,
        "MNV3S_GRU": MNV3S_GRU,
        "MNV3S_BiLSTM": MNV3S_BiLSTM,
        "MNV3S_BiGRU": MNV3S_BiGRU,
        "R18_LSTM": R18_LSTM,
        "R18_GRU": R18_GRU,
        "R18_BiLSTM": R18_BiLSTM,
        "R18_BiGRU": R18_BiGRU,
        "R34_LSTM": R34_LSTM,
        "R34_GRU": R34_GRU,
        "R34_BiLSTM": R34_BiLSTM,
        "R34_BiGRU": R34_BiGRU,
        "R50_LSTM": R50_LSTM,
        "R50_GRU": R50_GRU,
        "R50_BiLSTM": R50_BiLSTM,
        "R50_BiGRU": R50_BiGRU,
        "R101_LSTM": R101_LSTM,
        "R101_GRU": R101_GRU,
        "R101_BiLSTM": R101_BiLSTM,
        "R101_BiGRU": R101_BiGRU,
        "Eff1_GRU_pretrained": Eff1_GRU_pretrained,
        "Eff1_LSTM_pretrained": Eff1_LSTM_pretrained,
        "Eff1_BiGRU_pretrained": Eff1_BiGRU_pretrained,
        "Eff1_BiLSTM_pretrained": Eff1_BiLSTM_pretrained,
        "Eff2_GRU_pretrained": Eff2_GRU_pretrained,
        "Eff2_LSTM_pretrained": Eff2_LSTM_pretrained,
        "Eff2_BiGRU_pretrained": Eff2_BiGRU_pretrained,
        "Eff2_BiLSTM_pretrained": Eff2_BiLSTM_pretrained,
        "EffV2_S_GRU_pretrained": EffV2_S_GRU_pretrained,
        "EffV2_S_LSTM_pretrained": EffV2_S_LSTM_pretrained,
        "EffV2_S_BiGRU_pretrained": EffV2_S_BiGRU_pretrained,
        "EffV2_S_BiLSTM_pretrained": EffV2_S_BiLSTM_pretrained,
        "ViViT_Small": ViViT_Small,
        "ViViT_Large": ViViT_Large
    }
    return model_directory[name]

def get_optimizer_by_name(optim):
    if optim == 'SGD':
        optimizer = torch.optim.SGD #(model.parameters(), lr=0.02, momentum=0.9, weight_decay=1e-6)
    elif optim == 'Adam':
        optimizer = torch.optim.Adam #(model.parameters(), lr=0.02, weight_decay=1e-6)
    elif optim == 'RMSProp':
        optimizer = torch.optim.RMSprop #(model.parameters(), lr=0.02, momentum=0.9, weight_decay=1e-6)
    elif optim == 'AdamW':
        optimizer = torch.optim.AdamW #(model.parameters(), lr=0.02, weight_decay=1e-6)
    return optimizer

In [None]:
from datetime import datetime
from hyperopt import hp, STATUS_OK, fmin, tpe, space_eval, Trials
from hyperopt.pyll import scope

def train_model(kwargs):

    print(kwargs)
    epochs = kwargs.get("epochs", EPOCHS)

    # get model by name
    model_name = kwargs.get("model")
    model_cls = get_model_by_name(model_name)
    model = model_cls(num_classes=len(class_list)).to(device)

    # get optimizer
    optim_args = kwargs.get("optim")
    print(optim_args["params"])
    optimizer_cls = get_optimizer_by_name(optim_args.get("name"))

    model, train_loss_history, train_acc_history, val_loss_history, val_acc_history = train_loop(
        model,
        optimizer_cls,
        epochs,
        train_loader,
        val_loader,
        initial_lr=optim_args["params"]["lr"],
        weight_decay=optim_args["params"]["weight_decay"],
        verbose=True,
        running_hyperopt=True,
        continue_training=False
    )

    return model, train_loss_history, train_acc_history, val_loss_history, val_acc_history

def train_model_hyperopt(kwargs):

    model, train_loss_history, train_acc_history, val_loss_history, val_acc_history  = train_model(kwargs)

    return {"loss": np.mean(val_loss_history), "status": STATUS_OK}

def unpack_values(trial):
    vals = trial["misc"]["vals"]
    # unpack the one-element lists to values
    # and skip over the 0-element lists
    rval = {}
    for k, v in list(vals.items()):
        if v:
            rval[k] = v[0]
    return rval

def export_hyperopt_log(trials):
    result_list = []
    for trial in trials.trials:
        trial_result = space_eval(search_space, unpack_values(trial))
        trial_result["val_loss"] = trial['result']['loss']
        result_list.append(trial_result)

    df_result = pd.DataFrame(result_list)
    df_result = pd.concat([df_result.drop("optim", axis=1), pd.json_normalize(df_result.optim)], axis=1)

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = "../data/log"
    output_path = os.path.join(output_dir, f"hyperopt_result_{ts}.csv")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    df_result.to_csv(output_path, index=False)
    print(f"Exported hyperopt log to {output_path}")
    return df_result


search_space = {
    "epochs": scope.int(hp.choice("epochs", [7, 10, 13, 15, 17, 20, 22, 25, 27, 30, 33, 35])),
    "model":  hp.choice("model_name", [
         "R3D_18",
         "MC3_18",
         "r2plus1d_18",
        # "Eff_LSTM",
        "Eff_GRU",
        # "Eff_BiGRU",
        # "Eff_BiLSTM",
        #"Eff2_LSTM",
        "Eff2_GRU",
        #"Eff2_BiLSTM",
        #"Eff2_BiGRU",
        # "MNV3S_LSTM",
        "MNV3S_GRU",
        # "MNV3S_BiLSTM",
        # "MNV3S_BiGRU",
        # "R18_LSTM",
         #"R18_GRU",
        # "R18_BiLSTM",
        # "R18_BiGRU",
        # "R34_LSTM",
         #"R34_GRU",
        # "R34_BiLSTM",
        # "R34_BiGRU",
        #"R50_LSTM",
        "R50_GRU",
        # "R50_BiLSTM",
        # "R50_BiGRU",
        # "R101_LSTM",
        "R101_GRU",
        # "R101_BiLSTM",
        #"R101_BiGRU",
        "ViViT_small",
        "ViViT_large"
    ]),
    "optim": hp.choice("optim",[
        # {
        #     "name":"SGD",
        #     "params": {
        #         "lr": hp.loguniform("lr-1", np.log(1e-6), np.log(1e-2)),
        #         "momentum": hp.uniform("momentum-1", 0.1, 0.95),
        #         "weight_decay": hp.loguniform("weight_decay-1", np.log(1e-7), np.log(1e-2))
        #     }
        # },
        # {
        #     "name":"RMSProp",
        #     "params": {
        #         "lr": hp.loguniform("lr-2", np.log(1e-6), np.log(1e-2)),
        #         "momentum": hp.uniform("momentum-2", 0.1, 0.95),
        #         "weight_decay": hp.loguniform("weight_decay-2", np.log(1e-7), np.log(1e-2))
        #     }
        # },
        {
            "name":"Adam",
            "params": {
                "lr": hp.choice("lr-3", [1e-3, 1e-4]),
                "weight_decay": hp.choice("weight_decay-3", [3.310305423548208e-05])
            }
        },
        # {
        #     "name":"AdamW",
        #     "params": {
        #         "lr": hp.loguniform("lr-4", np.log(1e-6), np.log(1e-4)),
        #         "weight_decay": hp.loguniform("weight_decay-4", np.log(1e-7), np.log(1e-2))
        #     }
        # },
    ])
}


In [None]:
# minimize the average train loss over the space
trials = Trials()
max_evals = 1 if RUN_MODE == "DEV" else 192
best = fmin(train_model_hyperopt, search_space, algo=tpe.suggest, max_evals=max_evals, trials=trials, verbose=False)
print(space_eval(search_space, best))
# export log
export_hyperopt_log(trials)

# Train the model with the best combination

In [None]:
# train the model with the best parameter
best_params = {'epochs': 1, 'model': 'Eff1_BiGRU_pretrained', 'optim': {'name': 'Adam', 'params': {
     'lr': 0.001, 'weight_decay': 3.310305423548208e-05}}}
#best_params = space_eval(search_space, best) #comment in case of changing parameters
best_model, train_loss_history, train_acc_history, val_loss_history, val_acc_history = train_model(best_params)
torch.save(best_model, "seq_images.pth")

plot_model_stats(type(best_model).__name__, train_loss_history, train_acc_history, val_loss_history, val_acc_history)
evaluate_model(best_model, test_loader, verbose=True, eps=1e-10)

In [None]:
best_model = torch.load(model_save_path + "model.pth")
evaluate_model(best_model, test_loader, verbose=True, eps=1e-10)

In [None]:
import random

matplotlib.use('Agg')

def get_all_conv_layers(model, modules_list=None, conv_layers=[], depth=0, grad_cam=False, feature_map=False):
    """
    Get all the convolutional layers of a given model
    """
    if modules_list is None:
        modules_list = list(model.modules())

    # get all the conv layers so that the last layer is used for grad cam visualisation
    if grad_cam and (not feature_map):
        for layer in modules_list:
            if isinstance(layer, torch.nn.Conv2d):
                conv_layers.append(layer)
            elif isinstance(layer, torch.nn.Sequential):
                get_all_conv_layers(model, layer, conv_layers, depth=depth + 1)

    # get all inner conv layers for feature map visualisation
    elif feature_map and (not grad_cam):
        for layer in modules_list:
            if isinstance(layer, torch.nn.Conv2d) and depth > 0:
                conv_layers.append(layer)
            elif isinstance(layer, torch.nn.Sequential) and depth > 0:
                get_all_conv_layers(model, layer, conv_layers, depth=depth + 1)

    return conv_layers


def visualise_feature_maps(feature_map, feature_map_name):
    """
    Visualise the feature maps of a given layer
    """
    feature_map = feature_map.cpu().numpy()

    # Get the number of feature maps
    num_feature_maps = feature_map.shape[1]

    # Calculate the number of rows and columns for the plot
    num_cols = 8
    num_rows = num_feature_maps // num_cols + int(num_feature_maps % num_cols > 0)

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 2, num_rows * 2))
    plots = []
    for i in range(num_feature_maps):
        ax = axes[i // num_cols, i % num_cols]
        ax.imshow(feature_map[0, i], cmap="viridis")
        ax.axis("off")
        plots.append(feature_map[0, i])

    # Hide empty subplots
    for i in range(num_feature_maps, num_rows * num_cols):
        axes[i // num_cols, i % num_cols].axis("off")

    plt.savefig(feature_map_name)
    plt.close('all')

    # return the figure
    return plots


def normalize_feature_map(feature_map):
    min_val, max_val = np.min(feature_map), np.max(feature_map)
    return (feature_map - min_val) / (max_val - min_val)


def visualise_gradcam(images_numpy, cam, ith_image, seq_idx, i, grad_img_i_folder, max_gradcam_images=3, show_gradcam=True, class_name=""):
    """
    save GradCAMs for the given image
    """

    if show_gradcam and (i < max_gradcam_images):
        print(f"Extracting grad cam for image {i + 1}_{seq_idx}/{max_gradcam_images}")

        plt.figure(figsize=(30, 10))
        plt.subplot(1, 3, 1)
        plt.imshow(images_numpy)
        plt.gca().set_title(class_name, fontsize=40, pad=20, y=-0.2)
        plt.axis('off')

        grayscale_cam = cam(input_tensor=ith_image, targets=None)
        grayscale_cam = grayscale_cam[0, :]
        plt.subplot(1, 3, 2)
        plt.imshow(grayscale_cam)
        plt.gca().set_title(class_name, fontsize=40, pad=20, y=-0.2)
        plt.axis('off')

        visualization = show_cam_on_image(images_numpy, grayscale_cam, use_rgb=True, image_weight=0.8)
        plt.subplot(1, 3, 3)
        plt.imshow(visualization)
        plt.gca().set_title(class_name, fontsize=40, pad=20, y=-0.2)
        plt.axis('off')

        plt.savefig(grad_img_i_folder + f"image_{seq_idx + 1}_{class_name}.png")
        plt.close('all')


def show_all_feature_maps(target_layers, model, ith_image, i, seq_index, layers_folder):
    """
    Show all the feature maps of all the target layers, of a given model, in a given image
    """

    # get feature maps for each detected target layers
    list_of_plots = {}
    valid_target_layers = []

    # Extract feature maps for each target layer and save them, if they are valid
    cnt = 0
    for j, layer in enumerate(target_layers):
        feature_maps = []

        def hook_fn(module, input, output):
            feature_maps.append(output.detach())

        # register hook to the layer to get the feature maps
        layer.register_forward_hook(hook_fn)
        model(ith_image)

        # if no feature maps were found, skip this layer
        if len(feature_maps) == 0:
            layer._forward_hooks.clear()
            continue

        if feature_maps[0].shape[-1] <=1:
            layer._forward_hooks.clear()
            continue

        print(f"Extracting feature maps for seq {i + 1}, image {seq_index}, from layer {j + 1}/{len(target_layers)}")

        # save feature maps using this function
        plots = visualise_feature_maps(feature_maps[0], f"{layers_folder}layer_{cnt + 1}.png")
        cnt += 1

        # update the list of valid target layers and the list of plots
        list_of_plots[j] = plots
        valid_target_layers.append((layer, j))  # save the layer and its index

        layer._forward_hooks.clear()

    plt.close('all')

    return list_of_plots, valid_target_layers


def create_chart(num_single_chart_layers, num_single_chart_conv_imgs, valid_target_layers, list_of_plots, layers_folder_prev, i):
    """
    Create a chart with num_single_chart_layers layers and num_single_chart_conv_imgs feature maps per layer
    """

    # select the smallest of chosen number of rows in the chart and number of valid target layers
    num_single_chart_layers = min(num_single_chart_layers, len(valid_target_layers))

    # pick num_single_chart_layers random layers without changing the order
    remove_layers_numbers = random.sample(valid_target_layers, max(len(valid_target_layers) - num_single_chart_layers, 0))
    valid_target_layers = [layer for layer in valid_target_layers if layer not in remove_layers_numbers]

    # put all layers in one image
    print(f"Merging all layers in one image...")
    plt.figure(figsize=(10 * num_single_chart_conv_imgs, 10 * num_single_chart_layers))
    cnt = 1

    # merge all feature map plots in one image to form a chart
    for j, layer in enumerate(valid_target_layers):
        layer, layer_index = layer
        plots = list_of_plots[layer_index]

        # pick num_single_chart_conv_imgs random features per layer without changing the order
        if num_single_chart_conv_imgs is not None:
            plots_numbers = [random.randint(0, len(plots) - 1) for _ in range(num_single_chart_conv_imgs)]
            plots_numbers = sorted(plots_numbers)
            plots = [plots[i] for i in plots_numbers]
        else:
            num_single_chart_conv_imgs = len(plots)

        # all plots in the selected layer
        for k, plot in enumerate(plots):
            plot = normalize_feature_map(plot)

            subplot = plt.subplot(len(valid_target_layers), num_single_chart_conv_imgs, cnt)  # (*nrows*, *ncols*, *index*)
            cnt += 1

            plt.imshow(plot, cmap="viridis")
            plt.axis('off')

            # Add "Row_j" ylabel to the first subplot of each row
            if k == 0:
                label_axis = subplot.twinx()
                label_axis.set_ylabel(f"Layer_{layer_index + 1}", fontsize=40, rotation=0, labelpad=160)
                label_axis.yaxis.set_label_position("left")
                label_axis.yaxis.tick_left()
                label_axis.yaxis.set_ticks([])
                label_axis.xaxis.set_ticks([])

    # plt.savefig(f"feature_maps{os.sep}image_{i + 1}{os.sep}Chart.png")
    plt.savefig(layers_folder_prev + f"Chart.png")
    plt.close('all')


def get_gradcam_feature_maps(model, test_loader, show_gradcam=False, max_gradcam_images=5, show_feature_map=False, max_feature_map_images=3, max_feature_map_classes=2, num_feature_map_seqs_per_class=2, num_single_chart_layers=None, num_single_chart_conv_imgs=None, class_list=class_list):
    """
    Compute GradCAM and feature maps for a given model and a given test_loader
    """

    model.eval()

    # Get all conv layers from the given model and get the last layer to visualize in GradCAM
    layers = get_all_conv_layers(model, feature_map=show_feature_map, grad_cam=show_gradcam)
    target_layers = layers.copy()
    layer = layers[-1]

    if num_single_chart_layers is None:
        num_single_chart_layers = len(target_layers)

    cam = GradCAM(model=model, target_layers=[layer], use_cuda=True)

    # create folders if they don't exist
    if show_gradcam:
        shutil.rmtree("gradcams", ignore_errors=True)
        os.makedirs("gradcams")

    if show_feature_map:
        shutil.rmtree("feature_maps", ignore_errors=True)
        os.makedirs("feature_maps")

    # computing predictions and confusion matrix
    class_seq_pairs = {}
    for i, (images, targets) in enumerate(tqdm(test_loader, position=0, leave=True)):
        # convert torch target to numpy and get the class name
        targets = targets.numpy()
        class_name = class_list[targets[0]]
        print("Class name:", class_name)
        # continue

        if show_feature_map:
            max_reached_cnt = 0
            for class_name_key in class_seq_pairs:
                if len(class_seq_pairs[class_name_key]) >= num_feature_map_seqs_per_class:
                    max_reached_cnt += 1

            # stop execution when the maximum number of classes needed is reached and all classes are filled
            if max_reached_cnt >= max_feature_map_classes:
                print(f"Max number of classes reached and filled: {max_feature_map_classes}")
                break

            # skip iteration when the maximum number of classes needed is reached, but current classes are not yet filled
            if (len(class_seq_pairs) >= max_feature_map_classes) and (class_name not in class_seq_pairs):
                # print(f"Max number of classes reached, seqs are still needed for current classes")
                continue

            # assign the class name and sequence number used within it
            if class_name not in class_seq_pairs:
                class_seq_pairs[class_name] = [i]
            else:
                # reached max number of sequences needed for this class
                if len(class_seq_pairs[class_name]) >= num_feature_map_seqs_per_class:
                    print(f"Max number of sequences reached for class {class_name}: {num_feature_map_seqs_per_class}")
                    continue
                class_seq_pairs[class_name].append(i)

        # for grad cams and feature maps create a separate folder for each image
        if show_gradcam and i < max_gradcam_images:
            grad_img_i_folder = f"gradcams{os.sep}seq_{i + 1}{os.sep}"
            shutil.rmtree(grad_img_i_folder, ignore_errors=True)
            os.makedirs(grad_img_i_folder)

        print(f"images shape", images.shape)

        images, targets = images.to(device, dtype=torch.float), torch.Tensor(targets).to(device)
        org_images = images[0]  # we assume that batch size is 1, since it is designed to run on test loader

        random_seq_idx = random.randint(0, images.shape[2] - 1 - max_feature_map_images)
        num_feature_map_image_cnt = 0
        for seq_idx in range(images.shape[2]):
            if show_feature_map:
                # we need 4 random consecutive images from the same sequence
                if seq_idx < random_seq_idx:
                    continue

            # clear memory
            gc.collect()
            torch.cuda.empty_cache()

            # Get the ith image from the sequence
            ith_image = org_images.permute(1, 0, 2, 3)  # ([3, 32, 128, 128]) to ([32, 3, 128, 128])
            ith_image = ith_image[seq_idx:seq_idx + 1].clone()  # Create a new tensor to avoid modifying the original one

            # No need to convert back to tensor since it's already a tensor
            ith_image = ith_image.to(device, dtype=torch.float)
            outputs = torch.nn.functional.log_softmax(model(ith_image), dim=1)
            preds = torch.argsort(outputs, dim=1, descending=True)[:, :3]

            # get numpy array from images
            images_numpy = ith_image.cpu().numpy()
            images_numpy = np.transpose(images_numpy, (0, 2, 3, 1))
            images_numpy = np.squeeze(images_numpy)

            # show GradCAMs for the max given images
            if show_gradcam and (i < max_gradcam_images):
                visualise_gradcam(
                    images_numpy=images_numpy,
                    cam=cam,
                    ith_image=ith_image,
                    seq_idx=seq_idx,
                    i=i,
                    max_gradcam_images=max_gradcam_images,
                    show_gradcam=show_gradcam,
                    grad_img_i_folder=grad_img_i_folder,
                    class_name=class_name,
                )

            # show feature maps for the max given images
            if show_feature_map:
                # create folder for each image/sequence and delete the previous one
                layers_folder = f"feature_maps{os.sep}class_{class_name}{os.sep}seq_{len(class_seq_pairs[class_name])}{os.sep}image_{seq_idx + 1}{os.sep}layers{os.sep}"
                layers_folder_prev = layers_folder.replace("layers" + os.sep, "")
                shutil.rmtree(layers_folder_prev, ignore_errors=True)
                os.makedirs(layers_folder, exist_ok=True)

                # get feature maps for all layers
                list_of_plots, valid_target_layers = show_all_feature_maps(
                    target_layers=target_layers,
                    model=model,
                    ith_image=ith_image,
                    i=i,
                    seq_index=seq_idx + 1,
                    layers_folder=layers_folder
                )

                # create chart for all layers of the current image
                create_chart(
                    num_single_chart_layers=num_single_chart_layers,
                    num_single_chart_conv_imgs=num_single_chart_conv_imgs,
                    valid_target_layers=valid_target_layers,
                    list_of_plots=list_of_plots,
                    layers_folder_prev=layers_folder_prev,
                    i=i
                )

                num_feature_map_image_cnt += 1

                print(f"\nFeature maps and chart for sequence {i + 1}, image {seq_idx + 1} saved successfully!\n\n")
                if num_feature_map_image_cnt >= max_feature_map_images and show_feature_map:
                    print(f"Max number of feature map images reached: {max_feature_map_images}")
                    break

        # stop after max images
        if ((i >= max_gradcam_images) and show_gradcam):
            break

    print("\nMaximum selected sequences completed!\n")