In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
#check GPU (optional)
!nvidia-smi

In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as tt
import torch.optim as optim
from torchsummary import summary

ROOT_PATH = '/content/drive/MyDrive/QMUL/NN/code/nn_group5'
sys.path.append(ROOT_PATH)
from my_module import functions as myf
from my_module import make_dataset as mds
model_path = '/content/drive/MyDrive/QMUL/NN/code/model_para/'

myf.torch_fix_seed()

In [None]:
#get dataloader
train_transform = tt.Compose(
    [tt.ToTensor(),
     tt.RandomHorizontalFlip(),
     tt.RandomRotation(10),
     tt.Normalize(0.5, 0.5)])
validation_transform = tt.Compose(
    [tt.ToTensor(),tt.Normalize(0.5, 0.5)])

train_dl = mds.get_dl(
    data='training', 
    bs=64, 
    shuffle=True, 
    transform=train_transform
)
validation_dl = mds.get_dl(
    data='validation', 
    bs=64, 
    shuffle=False, 
    transform=validation_transform
)
test_dl = mds.get_dl(
    data='test', 
    bs=64, 
    shuffle=False, 
    transform=validation_transform
)
dev = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#check some samples
dict_classes={
    0: 'Angry',
    1: 'Disgust',
    2: 'Fear',
    3: 'Happy',
    4: 'Sad',
    5: 'Surprised',
    6: 'Neutral'
  }
fig, axes = plt.subplots(1, 5, tight_layout=True)
for i in range(5):
    img, label = train_dl.dataset[np.random.randint(0,10000)]
    axes[i].imshow(img.squeeze(), cmap='gray')
    axes[i].axis('off')
    axes[i].set_title(dict_classes[label])


In [None]:
class Block(nn.Module):
    def __init__(self, channel):
        super().__init__()
        
        self.conv1 = nn.Conv2d(channel, channel,
                               kernel_size=(3, 3),
                               stride=1,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(channel)
        self.gelu = nn.GELU()
        
        self.conv2 = nn.Conv2d(channel, channel,
                               kernel_size=(3, 3),
                               stride=1,
                               padding=1)
        self.bn2 = nn.BatchNorm2d(channel)
        
        self.shortcut = self._shortcut(channel, channel)
        
    def forward(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        h = self.gelu(h)
        h = self.conv2(h)
        h = self.bn2(h)
        
        shortcut = self.shortcut(x)
        y = self.gelu(h + shortcut)  # skip connection
        return y
    def _shortcut(self, channel_in, channel_out):
        if channel_in != channel_out:
            return self._projection(channel_in, channel_out)
        else:
            return lambda x: x
    def _projection(self, channel_in, channel_out):
        return nn.Conv2d(channel_in, channel_out,
                         kernel_size=(1, 1),
                         padding=0)

In [None]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self,
                 device=dev):
        super().__init__()

    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:]).view(-1, x.size(1))

In [None]:
class ResNet34(nn.Module):
    def __init__(self, output_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 64,
                               kernel_size=3,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(64)

        # Block 1
        self.block1 = nn.ModuleList([
            self._building_block(64) for _ in range(3)
        ])

        self.conv2 = nn.Conv2d(64, 128,
                               kernel_size=(1, 1),
                               stride=(2, 2))

        # Block 2
        self.block2 = nn.ModuleList([
            self._building_block(128) for _ in range(4)
        ])

        self.conv3 = nn.Conv2d(128, 256,
                               kernel_size=(1, 1),
                               stride=(2, 2))

        # Block 3
        self.block3 = nn.ModuleList([
            self._building_block(256) for _ in range(6)
        ])

        self.conv4 = nn.Conv2d(256, 512,
                               kernel_size=(1, 1),
                               stride=(2, 2))

        # Block 4
        self.block4 = nn.ModuleList([
            self._building_block(512) for _ in range(3)
        ])

        self.avg_pool = GlobalAvgPool2d()
        self.fc = nn.Linear(512, 1000)
        self.gelu = nn.GELU()
        self.out = nn.Linear(1000, output_dim)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        for block in self.block1:
            h = block(h)
        h = self.conv2(h)
        for block in self.block2:
            h = block(h)
        h = self.conv3(h)
        for block in self.block3:
            h = block(h)
        h = self.conv4(h)
        for block in self.block4:
            h = block(h)
        h = self.avg_pool(h)
        h = self.fc(h)
        h = self.dropout(h)
        h = self.gelu(h)
        h = self.out(h)

        return h

    def _building_block(self,
                        channel):
        return Block(channel)

In [None]:
model = ResNet34(7).to(dev)
optimizer = optim.AdamW(model.parameters(), weight_decay=1e-4)
train_loss = myf.LabelSmoothingCrossEntropy()
train_func = myf.mixup_train_loop
epochs = 100

train_acc, val_acc = myf.fit(
    model,
    optimizer,
    epochs,
    train_dl,
    validation_dl,
    train_func=train_func,
    train_loss=train_loss,
    print_loss=True
)
torch.save(model.state_dict(), model_path + 'ResNet34.pth')
np.save('/content/drive/MyDrive/QMUL/NN/code/result/ResNet34_train', train_acc)
np.save('/content/drive/MyDrive/QMUL/NN/code/result/ResNet34_val', val_acc)

In [None]:
test_acc = myf.test_loop(
    test_dl,
    model,
    nn.CrossEntropyLoss())
print(test_acc)

In [None]:
myf.plot_cfmat(
    model, 
    validation_dl, 
    dict_classes, 
    savefig=True, 
    name='ResNet34_cfmat')