# Imports and Drive Acces

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision.io import read_image
from torchvision.transforms import v2, Lambda

# Data

In [None]:
# Extract the image tar files
!tar -xf "/content/drive/MyDrive/Graduation Project/AffectNet/train_images.tar" -C "/content/"
!tar -xf "/content/drive/MyDrive/Graduation Project/AffectNet/val_images.tar" -C "/content/"
!tar -xf "/content/drive/MyDrive/Graduation Project/AffectNet/test_images.tar" -C "/content/"

In [None]:
# Define train data path
train_annotations = "/content/drive/MyDrive/Graduation Project/AffectNet/train_annotations.csv"
train_images = "/content/train_images"

# Define test data path
val_annotations = "/content/drive/MyDrive/Graduation Project/AffectNet/val_annotations.csv"
val_images = "/content/val_images"

# Define test data path
test_annotations = "/content/drive/MyDrive/Graduation Project/AffectNet/test_annotations.csv"
test_images = "/content/test_images"

In [None]:
class AffectNet(Dataset):
    """
    A Dataset subclass for handling the AffectNet dataset.

    Attributes:
        annotations (DataFrame): The annotations for the images.
        root_dir (str): The root directory where the images are stored.
        transform (callable, optional): Optional transform to be applied on an image.
    """

    def __init__(self, annotations_file, img_root_dir, transform=None):
        """
        Initializes the AffectNet dataset.

        Args:
            annotations_file (str): The path to the CSV file containing the annotations.
            img_root_dir (str): The root directory where the images are stored.
            transform (callable, optional): Optional transform to be applied on an image.
        """

        self.annotations = pd.read_csv(annotations_file)
        self.root_dir = img_root_dir
        self.transform = transform

        # Check if number of images and annotations match
        if len(self.annotations) != len(os.listdir(self.root_dir)):
            raise ValueError(f"Number of images and annotations do not match:\
            {len(self.annotations)} != {len(os.listdir(self.root_dir))}"
                             )

    def __len__(self):
        """
        Returns the length of the dataset.

        Returns:
            int: The length of the dataset.
        """

        return len(self.annotations)


    def sample_dist(self):
        val_count = self.annotations.expression.value_counts()
        val_count = val_count.to_dict()
        category_weights = [1 / val_count[i] for i in sorted(val_count.keys())]
        return category_weights


    def sample_weights(self):
        category_weights = self.sample_dist()
        sample_weights = [category_weights[exp] for exp in self.annotations.expression.values]
        return sample_weights


    def __getitem__(self, idx):
        """
        Returns the image and its labels at the given index.

        Args:
            idx (int): The index of the image.

        Returns:
            tuple: A tuple containing the image, and its labels.
        """

        # Get image name and create path
        img_name = f"{self.annotations.iloc[idx, 0]}.jpg"
        img_path = os.path.join(self.root_dir, img_name)

        # Read image
        image = read_image(img_path)

        # Get labels and convert to tensor
        labels = self.annotations.iloc[idx, -1]
        labels = torch.tensor(labels)
        # Apply input transforms
        if self.transform:
            image = self.transform(image)

        # Return image and labels
        return image, labels

# Model

## MFN

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

def l2_norm(input,axis=1):
    norm = torch.norm(input,2,axis,True)
    output = torch.div(input, norm)
    return output

class Conv_block(nn.Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Conv_block, self).__init__()
        self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_c)
        self.prelu = nn.PReLU(out_c)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.prelu(x)
        return x

class Linear_block(nn.Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Linear_block, self).__init__()
        self.conv = nn.Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_c)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class Depth_Wise(nn.Module):
    def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
        super(Depth_Wise, self).__init__()
        self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
        self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.residual = residual
    def forward(self, x):
        if self.residual:
            short_cut = x
        x = self.conv(x)
        x = self.conv_dw(x)
        x = self.project(x)
        if self.residual:
            output = short_cut + x
        else:
            output = x
        return output


class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return x * self.sigmoid(x)

NON_LINEARITY = {
    'ReLU': nn.ReLU(inplace=True),
    'Swish': Swish(),
}


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, groups=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // groups)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.relu = h_swish()

    def forward(self, x):
        identity = x
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.relu(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        x_h = self.conv2(x_h).sigmoid()
        x_w = self.conv3(x_w).sigmoid()
        x_h = x_h.expand(-1, -1, h, w)
        x_w = x_w.expand(-1, -1, h, w)

        y = identity * x_w * x_h

        return y


class MDConv(nn.Module):
    def __init__(self, channels, kernel_size, split_out_channels, stride):
        super(MDConv, self).__init__()
        self.num_groups = len(kernel_size)
        self.split_channels = split_out_channels
        self.mixed_depthwise_conv = nn.ModuleList()
        for i in range(self.num_groups):
            self.mixed_depthwise_conv.append(nn.Conv2d(
                self.split_channels[i],
                self.split_channels[i],
                kernel_size[i],
                stride=stride,
                padding=kernel_size[i]//2,
                groups=self.split_channels[i],
                bias=False
            ))
        self.bn = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU(channels)

    def forward(self, x):
        if self.num_groups == 1:
            return self.mixed_depthwise_conv[0](x)

        x_split = torch.split(x, self.split_channels, dim=1)
        x = [conv(t) for conv, t in zip(self.mixed_depthwise_conv, x_split)]
        x = torch.cat(x, dim=1)

        return x


class Mix_Depth_Wise(nn.Module):
    def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1, kernel_size=[3,5,7], split_out_channels=[64,32,32]):
        super(Mix_Depth_Wise, self).__init__()
        self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.conv_dw = MDConv(channels=groups, kernel_size=kernel_size, split_out_channels=split_out_channels, stride=stride)
        self.CA = CoordAtt(groups, groups)
        self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.residual = residual
    def forward(self, x):
        if self.residual:
            short_cut = x
        x = self.conv(x)
        x = self.conv_dw(x)
        x = self.CA(x)
        x = self.project(x)
        if self.residual:
            output = short_cut + x
        else:
            output = x
        return output

class Residual(nn.Module):
    def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
        super(Residual, self).__init__()
        modules = []
        for _ in range(num_block):
            modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
        self.model = nn.Sequential(*modules)
    def forward(self, x):
        return self.model(x)

class Mix_Residual(nn.Module):
    def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1), kernel_size=[3,5], split_out_channels=[64,64]):
        super(Mix_Residual, self).__init__()
        modules = []
        for _ in range(num_block):
            modules.append(Mix_Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups, kernel_size=kernel_size, split_out_channels=split_out_channels ))
        self.model = nn.Sequential(*modules)
    def forward(self, x):
        return self.model(x)


class MixedFeatureNet(nn.Module):
    def __init__(self, embedding_size=256, out_h=7, out_w=7):
        super(MixedFeatureNet, self).__init__()
        #224x224
        self.conv0 = Conv_block(3, 3, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
        #112x112
        self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
        #56x56
        self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
        self.conv_23 = Mix_Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, kernel_size=[3,5,7], split_out_channels=[64,32,32] )

        #28x28
        self.conv_3 = Mix_Residual(64, num_block=9, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1), kernel_size=[3,5], split_out_channels=[96,32])
        self.conv_34 = Mix_Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, kernel_size=[3,5,7],split_out_channels=[128,64,64] )

        #14x14
        self.conv_4 = Mix_Residual(128, num_block=16, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1), kernel_size=[3,5], split_out_channels=[192,64])
        self.conv_5 = Mix_Depth_Wise(128, 512, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=512*2, kernel_size=[3,5,7,9],split_out_channels=[128*2,128*2,128*2,128*2] )


    def forward(self, x):
        out = self.conv0(x)
        out = self.conv1(out)
        out = self.conv2_dw(out)
        out = self.conv_23(out)
        out = self.conv_3(out)
        out = self.conv_34(out)
        out = self.conv_4(out)
        out = self.conv_5(out)

        return l2_norm(out)

## VGG19

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.convblock = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

    def forward(self, x):
        return self.convblock(x)


class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        self.block1 = nn.Sequential(
            ConvBlock(3, 64),
            ConvBlock(64, 64),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.block2 = nn.Sequential(
            ConvBlock(64, 128),
            ConvBlock(128, 128),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.block3 = nn.Sequential(
            ConvBlock(128, 256),
            ConvBlock(256, 256),
            ConvBlock(256, 256),
            ConvBlock(256, 256),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.block4 = nn.Sequential(
            ConvBlock(256, 512),
            ConvBlock(512, 512),
            ConvBlock(512, 512),
            ConvBlock(512, 512),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return x

## CCT

In [None]:
class DepthWiseSeperableConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, padding, bias=False):
        super(DepthWiseSeperableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channel,
                                   in_channel,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding,
                                   groups=in_channel,
                                   bias=bias
                                   )
        self.bn1 = nn.BatchNorm2d(in_channel)
        self.relu = nn.ReLU()
        self.pointwise = nn.Conv2d(in_channel,
                                   out_channel,
                                   kernel_size=1,
                                   bias=bias
                                   )
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        out = self.depthwise(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.pointwise(out)
        out = self.bn2(out)
        out = self.relu2(out)
        return out


class PatchExtraction(nn.Module):
    """ Patch extraction block:
            - Depthwise separable convolutional layer
            - Depthwise separable convolutional layer
            - Pointwise convolutional layer

        - MobileNet outputs feature maps from the MobileNetV1 that are padded to
        the dimension of 16x16

        - First depthwise separable convolutional layer splits into 4 patches
        ------------------------------------------
        Input Size:  (N, 512, 16, 16)
    """
    def __init__(self):
        super(PatchExtraction, self).__init__()
        self.conv1 = DepthWiseSeperableConv(in_channel=512,
                                            out_channel=256,
                                            kernel_size=4,
                                            stride=4,
                                            padding=2)
        self.conv2 = DepthWiseSeperableConv(in_channel=256,
                                            out_channel=256,
                                            kernel_size=2,
                                            stride=2,
                                            padding=0)
        self.conv3 = nn.Conv2d(in_channels=256,
                               out_channels=49,
                               kernel_size=1,
                               stride=1,
                               padding=0)
        self.gap = nn.AdaptiveAvgPool2d(1)


    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        return out

In [None]:
class MultiheadedSelfAttention(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads=8,
                 attn_dropout=0.5,
                 proj_dropout=0.5,
                 ):
        super().__init__()
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0, "Embedding dim must be divisible by number of heads."
        head_dim = embed_dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.projection = nn.Linear(embed_dim, embed_dim)
        self.proj_dropout = nn.Dropout(proj_dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x) # B, N, (3*C)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads) # B, N, 3(qkv), H(eads), embed_dim
            .permute(2, 0, 3, 1, 4) # 3, B, H(eads), N, emb_dim
        )
        q, k, v = torch.chunk(qkv, 3) # B, H, N, dim
        # B,H,N,dim x B,H,dim,N -> B,H,N,N
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # <q,k> / sqrt(d)
        attn = attn.softmax(dim=-1) # Softmax over embedding dim
        attn = self.attn_dropout(attn)

        x = ( # B, H, N, N
            torch.matmul(attn, v) # B,H,N,N x B,H,N,dim -> B, H, N, dim
            .transpose(1, 2) # B, N, H, dim
            .reshape(B, N, C) # B, N, (H*dim)
        )
        x = self.projection(x)
        x = self.proj_dropout(x)

        return x

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self,
                 embed_dim=192,
                 num_heads=8,
                 attn_dropout=0.5,
                 proj_dropout=0.5,
                 mlp_dropout=0.1,
                 feedforward_dim=768,
            ):
        super().__init__()
        self.norm_1 = nn.LayerNorm(embed_dim)
        self.norm_2 = nn.LayerNorm(embed_dim)
        self.MHA = MultiheadedSelfAttention(embed_dim,
                                        num_heads,
                                        attn_dropout,
                                        proj_dropout,
                   )
        self.ff = nn.Sequential(nn.Linear(embed_dim, feedforward_dim),
                                nn.GELU(),
                                nn.Dropout(mlp_dropout),
                                nn.Linear(feedforward_dim, embed_dim),
                                nn.Dropout(mlp_dropout),
                 )

    def forward(self, x):
        mha = self.norm_1(x)
        mha = self.MHA(mha)
        x = x + mha # Residual connection (Add)

        x = self.norm_2(x)
        x2 = self.ff(x)
        x = x + x2  # Residual connection (Add)

        return x

In [None]:
class VCCT(nn.Module):
    def __init__(self,
                 num_encoders=1,
                 num_classes=8,
                 embed_dim=192,
                 num_heads=8,
                 attn_dropout=0.5,
                 proj_dropout=0.5,
                 mlp_dropout=0.1,
                 feedforward_dim=768,
            ):
        super(VGGT, self).__init__()
        self.vgg = MixedFeatureNet()
        self.patcher = PatchExtraction()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(1, embed_dim)
        self.transformer = self.create_encoders(embed_dim, num_heads,
                                                attn_dropout, proj_dropout,
                                                mlp_dropout, feedforward_dim,
                                                num_encoders)

        self.attention_pool = nn.Linear(embed_dim, 1)
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, num_classes)

        #for param in self.vgg.parameters():
        #    param.requires_grad = False



    def create_encoders(self, embed_dim=192,
                        num_heads=8,
                        attn_dropout=0.5,
                        proj_dropout=0.5,
                        mlp_dropout=0.1,
                        feedforward_dim=768,
                        num_layers=2,
                       ):
        return nn.Sequential(*[EncoderLayer(embed_dim, num_heads, attn_dropout, proj_dropout, mlp_dropout, feedforward_dim) for _ in range(num_layers)])


    def forward(self, x):
        x = self.vgg(x)
        x = self.patcher(x)
        x = self.gap(x)
        x = self.fc1(x).squeeze(dim=2)
        x = self.transformer(x)
        x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x)
        x = self.fc(x).squeeze(1)
        return x

# Train

In [None]:
# Hyperparameters
learning_rate = 0.0001
batch_size = 128
device = "cuda" if torch.cuda.is_available() else "cpu"
epochs = 10

In [None]:
torch.manual_seed(42)

train_transforms = v2.Compose([
    v2.RandAugment(num_ops=5),
    v2.ToDtype(torch.float32),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transforms = v2.Compose([
    v2.ToDtype(torch.float32),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = AffectNet(train_annotations, train_images, train_transforms)
val_data = AffectNet(val_annotations, val_images, val_transforms)

sample_weights = train_data.sample_weights()

train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          num_workers=2,
                          sampler=WeightedRandomSampler(weights=sample_weights, num_samples=2*len(train_data), replacement=True)
                          )
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
path = "/content/drive/MyDrive/Graduation Project/Logs/MCCT-6_OS/MCCT-6_OS_ckpt_3.pt"
model = VCCT(num_encoders=7)
model.load_state_dict(torch.load(path), strict=False)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params}, Total trainable params: {total_trainable_params}")

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#optimizer.load_state_dict(torch.load("/content/drive/MyDrive/Graduation Project/Logs/VGGT-1/optimizer_state_dict.pt"))

In [None]:
# Set up logging list
logs = []

for epoch in range(0, epochs):
    # Training phase
    total_loss = 0.0
    correct = 0.0
    total = 0
    model.train()
    for (inputs, targets) in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', leave=False, unit="batch"):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)

        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predicted = outputs.argmax(dim=1)
        correct += (predicted == targets).sum().item()
        total += targets.shape[0]

    # Validation phase
    model.eval()
    val_total_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for (inputs, targets) in tqdm(val_loader, desc=f'Epoch {epoch + 1}/{epochs} - Validation', leave=False, unit="batch"):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            val_total_loss += loss.item()
            val_predicted = outputs.argmax(dim=1)
            val_correct += (val_predicted == targets).sum().item()
            val_total += targets.shape[0]


    train_loss = total_loss / total
    train_acc = correct / total

    val_loss = val_total_loss / val_total
    val_acc = val_correct / val_total

    logs.append({'Epoch': epoch+1,
                 'Loss': train_loss,
                 'Accuracy' : train_acc,
                 'VAL_Loss': val_loss,
                 'VAL_Accuracy' : val_acc,
            })

    print(f'Epoch {epoch + 1}/{epochs} - Train loss: {train_loss:.4f} - Train acc: {train_acc:.4f} - Val loss: {val_loss:.4f} - Val acc: {val_acc:.4f}')

    if (epoch + 1) % 3 == 0:
        ckpt_path = f"/content/drive/MyDrive/Graduation Project/Logs/MCCT-6_OS/MCCT-6_OS_ckpt_{epoch+1}.pt"
        torch.save(model.state_dict(), ckpt_path)
        log_df = pd.DataFrame(logs)
        log_df.to_csv(f'/content/drive/MyDrive/Graduation Project/Logs/MCCT-6_OS/MCCT-6_OS_log_{epoch+1}.csv', index=False)
        torch.save(optimizer.state_dict(), "/content/drive/MyDrive/Graduation Project/Logs/MCCT-6_OS/optimizer_state_dict.pt")


print('Finished Training')

# Test

In [None]:
path = "/content/drive/MyDrive/Graduation Project/Logs/VGGT-1/VGGT-1_ckpt_77.pt"

model = VCCT()
model.load_state_dict(torch.load(path), strict=False)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params}, Total trainable params: {total_trainable_params}")

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
torch.manual_seed(42)

test_transforms = v2.Compose([
    v2.ToDtype(torch.float32),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_data = AffectNet(test_annotations, test_images, test_transforms)

test_loader = DataLoader(test_data,
                          batch_size=32,
                          shuffle=True
                          )

In [None]:
# Set up logging list
test_logs = []

# Test phase
test_total_loss = 0.0
test_correct = 0.0
test_total = 0
model.eval()
y_pred = []
y_true = []
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_total_loss += loss.item()
        test_predicted = outputs.argmax(dim=1)
        y_pred.extend(test_predicted.data.cpu().numpy())
        y_true.extend(targets.data.cpu().numpy())
        test_correct += (test_predicted == targets).sum().item()
        test_total += targets.shape[0]


test_loss = test_total_loss / test_total
test_acc = test_correct / test_total

print(f'Test loss: {test_loss:.4f} - Test acc: {test_acc:.4f}')

print('Finished Evaluating')

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sn

labels = ["Neutral", "Happiness", "Sadness", "Surprise", "Fear", "Disgust", "Anger",
"Contempt"]
cm = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cm / np.sum(cm, axis=1)[:, None], index = [i for i in labels],
                     columns = [i for i in labels])
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True)
plt.savefig('output.png')