# Installations and imports


In [None]:
!pip install lightning wandb

In [None]:
import os
import shutil
import csv

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import wandb

import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

import torchmetrics

from torchvision import transforms, models

from torchsummary import summary

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
)

Set random seed


In [None]:
RANDOM_SEED = 42
pl.seed_everything(RANDOM_SEED, workers=True)

Login to wandb


In [None]:
wandb.login()

# Dataset


In [None]:
# from google.colab import drive

# drive.mount('/content/drive')

# CHANGE PATH TO CURRENT DIRECTORY
DRIVE_DIR = "/content/drive/MyDrive/Colab Notebooks/Computer Vision/Assignment 2/"

In [None]:
LABELS = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"]
NUM_LABELS = len(LABELS)

IM_H, IM_W = 48, 48

Load data and split into train/validation/test sets


In [None]:
data_path = DRIVE_DIR + "data/fer2013.csv"

# Training
train_images = []
train_labels = []

# Validation
val_images = []
val_labels = []

# Test
test_images = []
test_labels = []

with open(data_path, "r") as file:
    csv_reader = csv.reader(file)
    next(csv_reader)  # Ignore header

    for emotion, pixels, usage in csv_reader:
        pixels = [int(p) for p in pixels.split()]
        emotion = int(emotion)

        if usage == "Training":
            train_images.append(pixels)
            train_labels.append(emotion)

        elif usage == "PrivateTest":
            val_images.append(pixels)
            val_labels.append(emotion)

        elif usage == "PublicTest":
            test_images.append(pixels)
            test_labels.append(emotion)

train_images = np.array(train_images, dtype=np.uint8).reshape(
    (len(train_images), IM_H, IM_W, 1)
)
train_labels = np.array(train_labels, dtype=np.uint8)

val_images = np.array(val_images, dtype=np.uint8).reshape(
    (len(val_images), IM_H, IM_W, 1)
)
val_labels = np.array(val_labels, dtype=np.uint8)

test_images = np.array(test_images, dtype=np.uint8).reshape(
    (len(test_images), IM_H, IM_W, 1)
)
test_labels = np.array(test_labels, dtype=np.uint8)

Dataset distribution


In [None]:
def dataset_distribution(labels, verbose=True):
    ratios = []
    labels_len = len(labels)
    if verbose:
        print(f"Total images: {labels_len}")
    for i, label in enumerate(LABELS):
        n = len(labels[labels == i])
        r = n / labels_len
        ratios.append(r)

        if verbose:
            print(f"- {label}: {n} ({r*100:0.2f}%)")

    return ratios

In [None]:
print("------ Train Dataset ------")
_ = dataset_distribution(train_labels)
print()
print("------ Validation Dataset ------")
_ = dataset_distribution(val_labels)
print()
print("------ Test Dataset ------")
_ = dataset_distribution(test_labels)

Visualize the images


In [None]:
def plot_image(image, label, save_as=None):
    plt.imshow(image, cmap="gray")
    plt.axis("off")
    plt.title(LABELS[label])

    # Save
    if save_as:
        plt.savefig(save_as, bbox_inches="tight")

    plt.show()

In [None]:
# Choose a test image from each label
ex_indices = np.zeros((NUM_LABELS,), dtype=np.uint8)
for i in range(NUM_LABELS):
    ex_indices[i] = np.where(test_labels == i)[0][0]

# Collect images and labels
ex_images = test_images[ex_indices]
ex_labels = test_labels[ex_indices]

for image, label in zip(ex_images, ex_labels):
    plot_image(image, label)
    print()

## Defining the Dataset class


In [None]:
grayscale_to_rgb = lambda x: x.repeat(3, 1, 1)

NORM_MEAN = (0.485, 0.456, 0.406)
NORM_STD = (0.229, 0.224, 0.225)

In [None]:
class FER_Dataset_KFold(Dataset):
    def __init__(self, mode, indices=None, transform=None):
        # Train
        if mode == "train":
            if indices is None:
                self.dataset = train_images
                self.labels = train_labels
            else:
                self.dataset = train_images[indices]
                self.labels = train_labels[indices]

        # Validation
        if mode == "val":
            if indices is None:
                self.dataset = val_images
                self.labels = val_labels
            else:
                self.dataset = train_images[indices]
                self.labels = train_labels[indices]

        # Test
        if mode == "test":
            if indices is None:
                self.dataset = test_images
                self.labels = test_labels
            else:
                self.dataset = test_images[indices]
                self.labels = test_labels[indices]

        # Transform
        if transform is not None:
            self.transform = transform
        else:
            self.transform = transforms.Compose(
                [
                    # Convert to RGB Tensor
                    transforms.ToTensor(),
                    transforms.Lambda(grayscale_to_rgb),
                    # Normalization
                    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
                ]
            )

    def __getitem__(self, idx):
        image = self.transform(self.dataset[idx])
        label = self.labels[idx]

        return image, label

    def __len__(self):
        return len(self.dataset)

    def weights(self):
        dist = dataset_distribution(self.labels, verbose=False)
        label_weights = 1 / np.array(dist)
        samples_weights = label_weights[self.labels]

        return samples_weights

# Models


## Defining the Lightning module


In [None]:
NUM_WORKERS = 2  # For dataloaders

_SEP_LEN = 80  # For printing results

In [None]:
class CNN_Classifier_KFold(pl.LightningModule):
    def __init__(
        self,
        model,
        train_indices=None,
        val_indices=None,
        lr=1e-4,
        batch_size=128,
        num_workers=NUM_WORKERS,
    ):
        super().__init__()
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.num_workers = num_workers

        # K-fold cross-validation split
        self.train_indices = train_indices
        self.val_indices = val_indices

        # Loss (weighted Cross-Entropy)
        weights = 1 / torch.FloatTensor(
            dataset_distribution(train_labels[train_indices], verbose=False)
        )
        self.loss = nn.CrossEntropyLoss(weights)

        # Accuracy
        self.accuracy = torchmetrics.Accuracy("multiclass", num_classes=NUM_LABELS)

        # Save hyper-parameters
        self.save_hyperparameters(ignore=["model"])

        # Test step labels and predictions
        self.test_step_labels = []
        self.test_step_preds = []

    def forward(self, x):
        return self.model(x)

    def _get_preds_loss_acc(self, images, labels):
        # Forward pass
        logits = self(images)

        # Output probabilities
        probs = F.softmax(logits, dim=0)

        # Predictions
        preds = torch.argmax(probs, dim=1)

        # Loss
        loss = self.loss(logits, labels)

        # Accuracy
        acc = self.accuracy(preds, labels)

        return preds, loss, acc

    def configure_optimizers(self):
        # Optimizer
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-3)

        # Scheduler
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.1, patience=10, min_lr=1e-7, verbose=True
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"},
        }

    # Training

    def train_dataloader(self):
        # Transforms
        train_transform = transforms.Compose(
            [
                # Convert to RGB Tensor
                transforms.ToTensor(),
                transforms.Lambda(grayscale_to_rgb),
                # Data augmentation
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=1, contrast=1, saturation=1),
                transforms.RandomAdjustSharpness(np.random.uniform(1, 6)),
                # Normalization
                transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
            ]
        )

        # Dataset
        train_dataset = FER_Dataset_KFold("train", self.train_indices, train_transform)

        # Weighted sampler
        # weights = torch.FloatTensor(train_dataset.weights())
        # sampler = WeightedRandomSampler(weights, len(train_dataset))

        # Dataloader
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            # sampler=sampler,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
        )

        return train_dataloader

    def training_step(self, batch, _):
        _, loss, acc = self._get_preds_loss_acc(*batch)

        # Logging
        train_metrics = {"train_loss": loss, "train_acc": acc}
        self.log_dict(
            train_metrics,
            on_step=False,
            on_epoch=True,
            batch_size=self.batch_size,
            prog_bar=True,
            logger=True,
        )

        return loss

    # Validation

    def val_dataloader(self):
        # Dataset
        val_dataset = FER_Dataset_KFold("val", self.val_indices)

        # Dataloader
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

        return val_dataloader

    def validation_step(self, batch, _):
        _, loss, acc = self._get_preds_loss_acc(*batch)

        # Logging
        val_metrics = {"val_loss": loss, "val_acc": acc}
        self.log_dict(
            val_metrics,
            on_step=False,
            on_epoch=True,
            batch_size=self.batch_size,
            prog_bar=True,
            logger=True,
        )

    # Testing

    def test_dataloader(self):
        # Dataset
        test_dataset = FER_Dataset_KFold("test")

        # Dataloader
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

        return test_dataloader

    def test_step(self, batch, _):
        images, labels = batch
        preds, loss, acc = self._get_preds_loss_acc(images, labels)

        self.test_step_labels.append(labels)
        self.test_step_preds.append(preds)

        # Logging
        test_metrics = {"test_loss": loss, "test_acc": acc}
        self.log_dict(
            test_metrics,
            on_step=False,
            on_epoch=True,
            batch_size=self.batch_size,
            prog_bar=True,
            logger=True,
        )

    def on_test_epoch_end(self):
        labels = torch.hstack(self.test_step_labels).cpu()
        preds = torch.hstack(self.test_step_preds).cpu()

        # Calculate and log metrics

        def _wandb_log_bar_plot(metric, title):
            fig, ax = plt.subplots()
            x = np.arange(NUM_LABELS)
            ax.bar(x, metric)
            ax.set_xticks(x, LABELS)
            ax.set_title(title)
            fig.tight_layout()
            wandb.log({title.lower(): wandb.Image(fig)})
            plt.close()

        # Precision per label
        precision = precision_score(labels, preds, average=None)
        _wandb_log_bar_plot(precision, "Precision")

        # Recall per label
        recall = recall_score(labels, preds, average=None)
        _wandb_log_bar_plot(recall, "Recall")

        # F1 score per label
        f1 = f1_score(labels, preds, average=None)
        _wandb_log_bar_plot(f1, "F1")

        # Confusion matrix
        cm = confusion_matrix(labels, preds, normalize="true")
        df_cm = pd.DataFrame(cm, LABELS, LABELS)
        fig, ax = plt.subplots()
        sns.heatmap(df_cm, annot=True)
        ax.set_xlabel("Predictions")
        ax.set_ylabel("Labels")
        fig.tight_layout()
        wandb.log({"confusion_matrix": wandb.Image(fig)})
        plt.close()

        # Generate and print test report
        report = classification_report(labels, preds)
        print()
        print("-" * _SEP_LEN)
        print("Test report")
        print("-" * _SEP_LEN)
        print()
        print(report)
        print()
        print("-" * _SEP_LEN)
        print()

        # Clear lists
        self.test_step_labels.clear()
        self.test_step_preds.clear()

## Creating the models


Device


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Model architecture summary


In [None]:
def model_summary(make_model):
    model = make_model()
    model.to(device)
    summary(model, input_size=(3, IM_H, IM_W))

Verifying that the model works


In [None]:
def verify_model(make_model, batch_size=1):
    # Input data
    dataset = FER_Dataset_KFold("test")
    data_loader = DataLoader(dataset, batch_size=batch_size)
    images, labels = next(iter(data_loader))

    # Model
    model = make_model()
    model.to(device)

    # Forward pass
    model.eval()
    logits = model(images.to(device))
    probs = F.softmax(logits, dim=0)
    preds = torch.argmax(probs, dim=1)

    return preds, labels

### Model 1


In [None]:
def make_model_1():
    return nn.Sequential(
        # Features
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=1),
        # nn.BatchNorm2d(num_features=64),
        nn.MaxPool2d(kernel_size=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=1),
        # nn.BatchNorm2d(num_features=128),
        nn.MaxPool2d(kernel_size=2),
        nn.ReLU(),
        # Classifier
        nn.Flatten(),
        nn.Linear(in_features=128 * 10 * 10, out_features=128),
        nn.ReLU(),
        nn.Linear(in_features=128, out_features=NUM_LABELS),
    )


model_summary(make_model_1)

In [None]:
verify_model(make_model_1)

### Model 2


In [None]:
def make_model_2():
    return nn.Sequential(
        # Features
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
        # nn.BatchNorm2d(num_features=64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
        # nn.BatchNorm2d(num_features=128),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
        # nn.BatchNorm2d(num_features=256),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        # Classifier
        nn.Flatten(),
        nn.Linear(in_features=256 * 6 * 6, out_features=512),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=512, out_features=256),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=256, out_features=NUM_LABELS),
    )


model_summary(make_model_2)

In [None]:
verify_model(make_model_2)

### Model 3


In [None]:
def make_model_3():
    return nn.Sequential(
        # Features
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(num_features=64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(
            in_channels=64, out_channels=128, kernel_size=5, padding=3, bias=False
        ),
        nn.BatchNorm2d(num_features=128),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(
            in_channels=128, out_channels=256, kernel_size=7, padding=5, bias=False
        ),
        nn.BatchNorm2d(num_features=256),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(
            in_channels=256, out_channels=512, kernel_size=5, padding=3, bias=False
        ),
        nn.BatchNorm2d(num_features=512),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        # Classifier
        nn.Flatten(),
        nn.Linear(in_features=512 * 5 * 5, out_features=256),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=256, out_features=128),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=128, out_features=64),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        nn.Linear(in_features=64, out_features=NUM_LABELS),
    )


model_summary(make_model_3)

In [None]:
verify_model(make_model_3)

### ResNet 50


In [None]:
def make_resnet50():
    return nn.Sequential(
        models.resnet50(weights="DEFAULT"), nn.Linear(1000, NUM_LABELS)
    )


model_summary(make_resnet50)

In [None]:
verify_model(make_resnet50)

# Training


In [None]:
def train_model(
    make_base_model,
    project,
    run,
    train_indices=None,
    val_indices=None,
    lr=1e-4,
    max_epochs=100,
    min_epochs=20,
    save_on_drive=False,
):
    # Create an untrained version of the base model
    base_model = make_base_model()

    # Create a new model to train on each split
    model = CNN_Classifier_KFold(base_model, train_indices, val_indices, lr=lr)

    # Create a new WandB logger
    wandb_logger = WandbLogger(project=project, name=run)

    # Save model with lowest val_loss
    checkpoint_callback = ModelCheckpoint(
        save_top_k=1,
        monitor="val_loss",
        mode="min",
        dirpath=f"{project}/checkpoints/{run}/",
        filename=project + "_{epoch:02d}_{val_loss:0.3f}",
    )

    # Stop early if no improvement in val_loss
    early_stop_callback = EarlyStopping(
        monitor="val_loss", mode="min", min_delta=1e-6, patience=30
    )

    # Create the trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        min_epochs=min_epochs,
        logger=wandb_logger,
        log_every_n_steps=1,
        accelerator="gpu",
        devices=-1,
        callbacks=[checkpoint_callback, early_stop_callback],
        deterministic=True,
    )

    # Train
    trainer.fit(model)

    # Save checkpoints on WandB
    ckpt_dir = f"{project}/checkpoints/{run}/"
    wandb_ckpt_dir = os.path.join(wandb.run.dir, "checkpoints")
    os.mkdir(wandb_ckpt_dir)
    for checkpoint in os.listdir(ckpt_dir):
        shutil.copy(ckpt_dir + checkpoint, wandb_ckpt_dir)

    # Save checkpoints on drive
    if save_on_drive:
        drive_save_path = DRIVE_DIR + f"checkpoints/{project}/{run}/"

        # Create/clear directory
        if os.path.exists(drive_save_path):
            shutil.rmtree(drive_save_path)
        os.makedirs(drive_save_path)

        # Save
        for checkpoint in os.listdir(ckpt_dir):
            shutil.copy(ckpt_dir + checkpoint, drive_save_path + checkpoint)

    # Close logger
    wandb.finish()

    return trainer


def train_kfold(make_base_model, k, project, lr=1e-4, max_epochs=100, min_epochs=20):
    # Get the splits
    splits = StratifiedKFold(n_splits=k).split(train_images, train_labels)

    # Metrics
    avg_metrics = {"train_loss": 0, "train_acc": 0, "val_loss": 0, "val_acc": 0}

    for i, (train_indices, val_indices) in enumerate(splits):
        # Print split distribution
        print("=" * _SEP_LEN)
        print(f"Split {i+1}/{k}")
        print("=" * _SEP_LEN)
        print()
        print("------ Train Dataset ------")
        dataset_distribution(train_labels[train_indices])
        print()
        print("------ Validation Dataset ------")
        dataset_distribution(train_labels[val_indices])
        print()
        print("-" * _SEP_LEN)
        print()

        # Train
        trainer = train_model(
            make_base_model,
            project=project,
            run=f"split_{i+1}",
            train_indices=train_indices,
            val_indices=val_indices,
            lr=lr,
            max_epochs=max_epochs,
            min_epochs=min_epochs,
        )

        print()

        # Update global metrics from callback
        print("-" * _SEP_LEN)
        print("Callback metrics")
        print("-" * _SEP_LEN)
        print()
        callback_metrics = trainer.callback_metrics
        for metric in avg_metrics:
            value = callback_metrics[metric].item()
            avg_metrics[metric] += value
            print(f"{metric}: {value}")

        print()
        print()

    # Take average of metrics
    for metric in avg_metrics:
        avg_metrics[metric] /= k

    # Log average metrics
    print("=" * _SEP_LEN)
    print(f"Average metrics over {k} splits")
    print("=" * _SEP_LEN)
    print()
    wandb.init(project=project, name="average_metrics", dir=f"{project}/")
    wandb.log(
        {
            "avg_metrics": wandb.Table(
                data=[list(avg_metrics.values())], columns=list(avg_metrics.keys())
            )
        }
    )

    # Close logger
    wandb.finish()

    print()

    # Print average metrics
    for metric in avg_metrics:
        print(f"{metric}: {avg_metrics[metric]}")

    return avg_metrics

## K-fold cross validation


In [None]:
K = 5

LR = 1e-5
MAX_EPOCHS = 60
MIN_EPOCHS = 40

Model 1


In [None]:
avg_metrics_1 = train_kfold(
    make_model_1,
    k=K,
    project=f"model_1_k={K}",
    lr=LR,
    max_epochs=MAX_EPOCHS,
    min_epochs=MIN_EPOCHS,
)

Model 2


In [None]:
avg_metrics_2 = train_kfold(
    make_model_2,
    k=K,
    project=f"model_2_k={K}",
    lr=LR,
    max_epochs=MAX_EPOCHS,
    min_epochs=MIN_EPOCHS,
)

Model 3


In [None]:
avg_metrics_3 = train_kfold(
    make_model_3,
    k=K,
    project=f"model_3_k={K}",
    lr=LR,
    max_epochs=MAX_EPOCHS,
    min_epochs=MIN_EPOCHS,
)

### Comparison of average metrics from checkpoints


In [None]:
all_metrics = [avg_metrics_1, avg_metrics_2, avg_metrics_3]
all_models = ["model_1", "model_2", "model_3"]

all_metrics_df = pd.DataFrame.from_records(all_metrics, index=all_models)

print(f"Average metrics for each model over {K} splits")
print()
print(all_metrics_df)

## Training the best model


In [None]:
make_best_model = make_model_2
PROJECT_BEST = "best_model_2"
RUN_BEST = "train"

LR_BEST = 1e-5
MAX_EPOCHS_BEST = 400
MIN_EPOCHS_BEST = 0


_ = train_model(
    make_best_model,
    project=PROJECT_BEST,
    run=RUN_BEST,
    lr=LR_BEST,
    max_epochs=MAX_EPOCHS_BEST,
    min_epochs=MIN_EPOCHS_BEST,
    save_on_drive=True,
)

# Testing


## Testing best model from checkpoint


In [None]:
def load_all_checkpoints(make_base_model, project, run):
    checkpoint_dir = DRIVE_DIR + f"checkpoints/{project}/{run}/"
    checkpoints = {}
    for checkpoint in os.listdir(checkpoint_dir):
        checkpoint_path = checkpoint_dir + checkpoint

        # Create an untrained version of the base model
        base_model = make_base_model()

        # Load model parameters from checkpoint
        model = CNN_Classifier_KFold.load_from_checkpoint(
            checkpoint_path, model=base_model
        )

        checkpoints[checkpoint] = model

    return checkpoints


def test_checkpoints(make_base_model, project, run):
    # Load all models from checkpoints
    checkpoints = load_all_checkpoints(make_base_model, project, run)

    for checkpoint in checkpoints:
        print("=" * _SEP_LEN)
        print(f"Testing {checkpoint}")
        print("=" * _SEP_LEN)
        print()
        print("------ Test Dataset ------")
        dataset_distribution(test_labels)
        print()
        print("-" * _SEP_LEN)
        print()

        # Get model
        model = checkpoints[checkpoint]

        # Create a new WandB logger
        wandb_logger = WandbLogger(project=project, name=f"test_{checkpoint}")

        # Create trainer
        trainer = pl.Trainer(logger=wandb_logger, accelerator="gpu", devices=-1)

        # Test
        trainer.test(model)

        # Close logger
        wandb.finish()

        print()
        print()

In [None]:
test_checkpoints(make_best_model, PROJECT_BEST, RUN_BEST)

## Testing on real-life videos


Load the videos


In [None]:
videos_dir = DRIVE_DIR + "videos/"

videos = {}
for video in os.listdir(videos_dir):
    if video.endswith(".npy"):
        label = video[:-4]
        video_array = np.load(videos_dir + video).astype(np.uint8)
        videos[label] = video_array
        print(f"{label} --- {video_array.shape}")

Test


In [None]:
def test_on_videos(make_base_model, project, run, save_on_drive=False):
    # Load the pre-trained model
    checkpoints = load_all_checkpoints(make_base_model, project, run)
    model_name, model = list(checkpoints.items())[0]
    model = model.model.to(device)

    frame_trans = transforms.Compose(
        [
            # Convert to Tensor
            transforms.ToTensor(),
            # Normalization
            transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ]
    )

    for label, video_array in videos.items():
        frames = torch.stack([frame_trans(frame) for frame in video_array]).to(device)

        model.eval()

        # Forward pass
        logits = model(frames)

        # Output probabilities
        probs = F.softmax(logits, dim=0)

        # Predictions
        preds = torch.argmax(probs, dim=1)

        # Plot predictions distribution
        preds_array = preds.cpu().numpy()

        preds_dist = np.empty((NUM_LABELS,))
        for i in range(NUM_LABELS):
            preds_dist[i] = sum(preds_array == i) / len(preds_array)

        fig, ax = plt.subplots()
        ax.bar(np.arange(NUM_LABELS), preds_dist)
        ax.set_xticks(np.arange(NUM_LABELS), LABELS)
        ax.set_title(label.capitalize())
        plt.show()

        # Save on drive
        if save_on_drive:
            drive_save_path = videos_dir + "preds/" + f"{model_name}_preds_{label}.jpg"
            fig.savefig(drive_save_path, bbox_inches="tight")

        print()

In [None]:
test_on_videos(make_best_model, PROJECT_BEST, RUN_BEST, save_on_drive=True)

## Comparison with ResNet50


In [None]:
PROJECT_RN = "resnet50"
RUN_RN = "train"

LR_RN = 1e-5
MAX_EPOCHS_RN = 400
MIN_EPOCHS_RN = 0


_ = train_model(
    make_resnet50,
    project=PROJECT_RN,
    run=RUN_RN,
    lr=LR_RN,
    max_epochs=MAX_EPOCHS_RN,
    min_epochs=MIN_EPOCHS_RN,
    save_on_drive=True,
)

test_checkpoints(make_resnet50, PROJECT_RN, RUN_RN)

In [None]:
test_on_videos(make_resnet50, PROJECT_RN, RUN_RN, save_on_drive=True)

# Activations visualization


In [None]:
def visualize_activations(
    model, image, save_on_drive=False, save_dir=None, base_save_as=None
):
    # Get base model
    model = model.model

    # Collect the activations
    activations = {}

    def get_activations(i):
        def hook(module, input, output):
            activations[i] = output.detach().cpu()

        return hook

    # Register the hooks for convolutional layers
    hook_handles = []
    i = 0
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            hook_handle = module.register_forward_hook(get_activations(i))
            hook_handles.append(hook_handle)
            i += 1

    # Save on drive
    if save_on_drive:
        drive_save_dir = DRIVE_DIR + "visualizations/" + save_dir

        # Create/clear directory
        if os.path.exists(drive_save_dir):
            shutil.rmtree(drive_save_dir)
        os.makedirs(drive_save_dir)

    # Forward pass
    model.eval()
    model(image.to(device))

    # Visualize the activations
    for i, activation in activations.items():
        num_channels = activation.size(1)
        num_rows = (num_channels + 7) // 8
        num_cols = min(8, num_channels)

        fig, axes = plt.subplots(
            num_rows, num_cols, figsize=(2 * num_cols, 2 * num_rows)
        )

        for channel_idx, ax in enumerate(axes.flatten()):
            if channel_idx < num_channels:
                channel_activation = activation[0, channel_idx, :, :]
                ax.imshow(channel_activation, cmap="gray")
                ax.axis("off")
                ax.set_title(f"Channel {channel_idx+1}")

        # Remove any unused subplots
        for unused_ax in axes.flatten()[num_channels:]:
            unused_ax.remove()

        plt.suptitle(
            f"Convolutional layer {i+1} ({num_channels} channels)", fontsize=16
        )
        plt.tight_layout(rect=[0, 0, 1, 0.98])  # Move the suptitle above the subplots
        plt.show()

        # Save on drive
        if save_on_drive:
            drive_save_path = drive_save_dir + f"Conv_Layer_{i+1}.jpg"
            fig.savefig(drive_save_path, bbox_inches="tight")

    # Remove the hook handles to release the memory
    for handle in hook_handles:
        handle.remove()

In [None]:
# Load the pre-trained model
checkpoints = load_all_checkpoints(make_best_model, PROJECT_BEST, RUN_BEST)
model_name, model = list(checkpoints.items())[0]

vis_dataset = FER_Dataset_KFold("test", ex_indices)
vis_dataloader = DataLoader(vis_dataset, batch_size=1, shuffle=False)

# Visualize and save activations
for image, label in vis_dataloader:
    print(LABELS[label])
    visualize_activations(
        model,
        image,
        save_on_drive=True,
        save_dir=f"{model_name}/{LABELS[label]}/",
        base_save_as=model_name,
    )
    print()