# Train CNN

In [2]:
import cv2
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import seaborn as sns
import pickle as pkl
# from utils import load_data

import torch as T
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import Sequential 
from torch.utils.data import DataLoader, TensorDataset, random_split
# import torchmetrics accuracy
import torchmetrics

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers import TensorBoardLogger

from torchvision.datasets import MNIST
from torchvision import transforms

from sklearn.metrics import classification_report, confusion_matrix

### 1) Load Data

In [None]:
PATH_TRAIN_IMG = '../CW_Dataset/train/'
PATH_TEST_IMG = '../CW_Dataset/test/'
PATH_TRAIN_LABEL = '../CW_Dataset/labels/list_label_train.txt'
PATH_TEST_LABEL = '../CW_Dataset/labels/list_label_test.txt'

# load data
X_train_list, y_train_list = load_data(PATH_TRAIN_IMG, PATH_TRAIN_LABEL, gray=True)
X_test_list, y_test_list = load_data(PATH_TEST_IMG, PATH_TEST_LABEL, gray=True)

In [None]:
# convert to tensors
X_train = T.tensor(np.array(X_train_list), dtype=T.float)  # image should be float
y_train = T.tensor(np.array(y_train_list), dtype=T.long) - 1  # target should be long | -1 to make 0-6 range
X_test = T.tensor(np.array(X_test_list), dtype=T.float)
y_test = T.tensor(np.array(y_test_list), dtype=T.long) - 1

# Add channel dimension | image size: (channel, height, width)
X_train = X_train.unsqueeze(1)
X_test = X_test.unsqueeze(1)

# convert to tensor dataset (train, val, test)
train_dataset = TensorDataset(X_train, y_train)
val_split = int(0.25 * len(train_dataset))
train_dataset, val_dataset = random_split(train_dataset, [len(train_dataset) - val_split, val_split])
test_dataset = TensorDataset(X_test, y_test)

# create data loaders (train, val, test) | image size: (batch, channel, height, width)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)

In [None]:
print('Train obs: {} - Val obs: {} - Test obs: {}'.format(len(train_dataset), len(val_dataset), len(test_dataset)))
print('\nLabel:', train_dataset[0][1])

print('\nImage shape:', train_dataset[0][0].shape)  # dimensions first dataset item

batch_x, batch_y = next(iter(train_loader))  # dimensions first data loader batch
print('\nBatch shape:', batch_x.shape)

### 2) Train CNN

In [None]:
class LitCNN(LightningModule):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.accuracy = torchmetrics.Accuracy()
        self.input_size = input_size
        
        # CNN block - image size 100*100
        self.CNN_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # FC block - image batch size: torch.Size([32, 128, 12, 12])
        self.FC_layers = nn.Sequential(
            nn.Linear(in_features=128 * 12 * 12, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=num_classes),
        )

    def forward(self, x):
        x = self.CNN_layers(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.FC_layers(x)  # no activation and no softmax at the end
        return x  

    def configure_optimizers(self):
        optimizer = T.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        self.log('train_loss', loss, prog_bar=True, on_step=True)
        self.log('train_acc', self.accuracy(y_hat, y), prog_bar=True, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        self.log('val_loss', loss, prog_bar=True, on_step=True)
        self.log('val_acc', self.accuracy(y_hat, y), prog_bar=True, on_step=True)
        return loss

    # def validation_epoch_end(self, outputs):
    #     avg_loss = T.stack([x['val_loss'] for x in outputs]).mean()
    #     tensorboard_logs = {'val_loss': avg_loss}
    #     return {'val_loss': avg_loss, 'log': tensorboard_logs}

In [None]:
# model
model = LitCNN(input_size=100*100, num_classes=7)

In [None]:
# training
logger = TensorBoardLogger('../tensorboard_logs', 'my_model')
trainer = pl.Trainer(logger=logger, max_epochs=2, fast_dev_run=False)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ../tensorboard_logs

In [None]:
# test
y_hat = model(X_test)
y_hat = y_hat.argmax(dim=1)
y_hat = y_hat.detach().numpy()
y_test = y_test.detach().numpy()

In [None]:
# classification report
print(classification_report(y_test, y_hat))
sns.heatmap(confusion_matrix(y_test, y_hat))