# Importing Libraries

In [2]:
import sys
import pyrootutils

root = pyrootutils.setup_root(sys.path[0], pythonpath=True, cwd=True)

import timm
import torch
import shutil
import numpy as np
import torchvision
import seaborn as sns
import torch.nn as nn
import albumentations as A
import torch.optim as optim
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms


from PIL import Image
from omegaconf import OmegaConf
from torchview import draw_graph
from torchvision import datasets
from torchmetrics import Accuracy
from hydra import compose, initialize
from torch import nn, optim, utils, Tensor
from albumentations.pytorch import ToTensorV2
from torch.utils.data import random_split, DataLoader, TensorDataset


# shutil.copy("configs/config.yaml", "notebooks/config.yaml")
# with initialize(version_base=None, config_path=""):
#     config = compose(config_name="config.yaml")


In [14]:
# Creating Data-Loader for MNIST

mnist_mean = (0.1307,)
mnist_std = (0.3081,)


class custom_mnist_dataset(torchvision.datasets.MNIST):
    def __init__(self, root="data/", train=True, transform=None, target_transform=None, download=True):
        super().__init__(root, train, transform, target_transform, download)
        self.transform = transform

    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])
        # print("Type of img: ", type(img))
        img = np.array(img)
        # img = img.reshape(1, 28, 28)
        # convert image to PIL image
        img = Image.fromarray(img)
        # Add a channel dimension to the image using torch
        # img = torch.unsqueeze(img, 0)
        # # convert to float
        # img = img.float()

        if self.transform is not None:
            
            transformed = self.transform(img)
            img = transformed


        # img = torch.from_numpy(img)
        ano = np.random.randint(10)
        return (img, ano), target


class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_dir:str = 'data/', batch_size:int = 32):
        super().__init__()

        self.data_dir = data_dir
        self.batch_size = batch_size

        self.transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor(),  transforms.Normalize(mnist_mean ,mnist_std)])

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    def prepare_data(self):
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    
    def setup(self, stage=None):

        if stage == 'fit' or stage is None:
            mnist_full = custom_mnist_dataset(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        if stage == 'test' or stage is None:
            self.mnist_test = custom_mnist_dataset(self.data_dir, train=False, transform=self.transform)

    
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

class custom_view(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class MNISTTrainingModule(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # Model 1 takes image as the input and outputs 10 classes
        # Input size of image is 28x28x1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) # after conv1 output size is 28x28x32
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) # after conv2 output size is 28x28x64
        self.max_pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # after max_pool1 output size is 14x14x64

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3) # after conv3 output size is 12x12x128
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3) # after conv4 output size is 10x10x256
        self.max_pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # after max_pool2 output size is 5x5x256
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3) # after conv5 output size is 3x3x512
        self.conv6 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3) # after conv6 output size is 1x1x1024

        self.view = custom_view() # after view output size is 1024

        self.fc1 = nn.Linear(in_features=1024, out_features=256) # after fc1 output size is 10
        self.fc2 = nn.Linear(in_features=256, out_features=128) # after fc2 output size is 1
        self.fc3 = nn.Linear(in_features=128, out_features=10) # after fc3 output size is 10
        self.cv_model = nn.Sequential(self.conv1, self.conv2, self.max_pool1, self.conv3, self.conv4, self.max_pool2, self.conv5, self.conv6, self.view, self.fc1, self.fc2, self.fc3)

        self.cv_pass = nn.Sequential(self.conv1, self.conv2, self.max_pool1, self.conv3, self.conv4, self.max_pool2, self.conv5, self.conv6, self.view, self.fc1, self.fc2)

        # Model 2 takes one-hot encoded number (in the range (0 to 9)) as the input and outputs the sum of the label from the image and one-hot encoded number

        self.model2_fc1 = nn.Linear(in_features=1, out_features=128)
        # self.model2_fc2 = nn.Linear(in_features=128, out_features=256)
        self.final_layer = nn.Linear(in_features=256, out_features=1)
        self.accuracy = Accuracy()

        # self.num_model = nn.Sequential(self.model2_fc1, self.model2_fc2, self.model2_fc3)

        # stack the layers of self.fc2 and self.model2_fc4

        # self.final_layer = nn.Linear(in_features=256, out_features=2)

   # Calcualate custom loss, sparse categorical cross entropy for sth1 and mean squared error for sth4

    def custom_loss(self, sth1, sth4, x, y):
        loss1 = F.cross_entropy(sth1, y)
        loss2 = F.mse_loss(sth4, y.unsqueeze(1) + x)

        return loss1 + loss2
 
    def training_step(self, batch, batch_idx):
        x, y = batch
        inp1 = x[0].to(self.device)
        inp2 = x[1].unsqueeze(1).type(torch.FloatTensor).to(self.device)

        y = y.to(self.device)

        sth1 = self.cv_model(inp1)
        sth2 = self.model2_fc1(inp2)

        tmp = self.cv_pass(inp1)
        sth3 = torch.cat((tmp, sth2), 1)
        sth4 = self.final_layer(sth3)

        # calculate loss
        loss = self.custom_loss(sth1, sth4, inp2, y)



        # print("Shape of inp2 + y: ", (inp2 + y.unsqueeze(1)).shape, " shape of inp2: ", inp2.shape, " shape of y: ", y.unsqueeze(1).shape)
        # sys.exit()

        # log loss
        self.log('train_loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        inp1 = x[0].to(self.device)
        inp2 = x[1].unsqueeze(1).type(torch.FloatTensor).to(self.device)

        y = y.to(self.device)

        sth1 = self.cv_model(inp1)
        sth2 = self.model2_fc1(inp2)

        tmp = self.cv_pass(inp1)
        sth3 = torch.cat((tmp, sth2), 1)
        sth4 = self.final_layer(sth3)

        # calculate loss
        loss = self.custom_loss(sth1, sth4, inp2, y)

        # log loss
        self.log('val_loss', loss)

        return loss

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    
    def forward(self, x):
        inp1 = x[0].to(self.device)
        inp2 = x[1].unsqueeze(1).type(torch.FloatTensor).to(self.device)


        sth1 = self.cv_model(inp1)
        sth1 = F.log_softmax(sth1, dim=1)
        sth2 = self.model2_fc1(inp2)

        tmp = self.cv_pass(inp1)
        sth3 = torch.cat((tmp, sth2), 1)
        sth4 = self.final_layer(sth3)

        return sth1, sth4


    def test_step(self, batch, batch_idx):
        x, y = batch
        inp1 = x[0].to(self.device)
        inp2 = x[1].unsqueeze(1).type(torch.FloatTensor).to(self.device)

        y = y.to(self.device)

        sth1 = self.cv_model(inp1)
        sth1 = F.log_softmax(sth1, dim=1)
        sth2 = self.model2_fc1(inp2)

        tmp = self.cv_pass(inp1)
        sth3 = torch.cat((tmp, sth2), 1)
        sth4 = self.final_layer(sth3)

        # calculate acc
        acc1 = self.accuracy(sth1, y)

        # Cast sth4 and y.unsqueeze(1) + inp2 to int
        sth4 = sth4.squeeze().type(torch.IntTensor)
        ano = y.unsqueeze(1) + inp2
        ano = ano.squeeze().type(torch.IntTensor)

        # convert all negative values to 0
        sth4[sth4 < 0] = 0
        ano[ano < 0] = 0

        print("Shape of sth4: ", sth4.shape, " shape of ano: ", ano.shape)

        acc2 = self.accuracy(sth4, ano)


        # log acc
        self.log('test_acc1', acc1)
        self.log('test_acc2', acc2)

data_module = MnistDataModule(batch_size=32)
model = MNISTTrainingModule()

trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(model, data_module)


trainer.test(model, datamodule=data_module)






GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name        | Type        | Params
---------------------------------------------
0  | conv1       | Conv2d      | 320   
1  | conv2       | Conv2d      | 18.5 K
2  | max_pool1   | MaxPool2d   | 0     
3  | conv3       | Conv2d      | 73.9 K
4  | conv4       | Conv2d      | 295 K 
5  | max_pool2   | MaxPool2d   | 0     
6  | conv5       | Conv2d      | 1.2 M 
7  | conv6       | Conv2d      | 4.7 M 
8  | view        | custom_view | 0     
9  | fc1         | Linear      | 262 K 
10 | fc2         | Linear      | 32.9 K
11 | fc3         | Linear      | 1.3 K 
12 | cv_model    | Sequential  | 6.6 M 
13 | cv_pass     | Sequential  | 6.6 M 
14 | model2_fc1  | Linear      | 256   
15 | final_layer | Linear      | 257   
16 | accuracy    | Accuracy    | 0     
---------------------------------------

Epoch 0: 100%|██████████| 1876/1876 [00:36<00:00, 50.93it/s, loss=1.4, v_num=14] 

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 1876/1876 [00:37<00:00, 50.64it/s, loss=1.4, v_num=14]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0:   0%|          | 0/313 [00:00<?, ?it/s]Shape of sth4:  torch.Size([32])  shape of ano:  torch.Size([32])
Testing DataLoader 0:   0%|          | 1/313 [00:00<00:03, 89.25it/s]Shape of sth4:  torch.Size([32])  shape of ano:  torch.Size([32])
Testing DataLoader 0:   1%|          | 2/313 [00:00<00:03, 83.25it/s]Shape of sth4:  torch.Size([32])  shape of ano:  torch.Size([32])
Testing DataLoader 0:   1%|          | 3/313 [00:00<00:03, 80.58it/s]Shape of sth4:  torch.Size([32])  shape of ano:  torch.Size([32])
Testing DataLoader 0:   1%|▏         | 4/313 [00:00<00:03, 79.31it/s]Shape of sth4:  torch.Size([32])  shape of ano:  torch.Size([32])
Testing DataLoader 0:   2%|▏         | 5/313 [00:00<00:03, 78.59it/s]Shape of sth4:  torch.Size([32])  shape of ano:  torch.Size([32])
Testing DataLoader 0:   2%|▏         | 6/313 [00:00<00:03, 78.63it/s]Shape of sth4:  torch.Size([32])  shape of ano:  torch.Size([32])
Testing DataLoader 0:   2%|▏         | 7/313 [00:00<00:03, 78.0

[{'test_acc1': 0.9646000266075134, 'test_acc2': 0.35670000314712524}]

In [15]:
model(sample_input)

(tensor([[-2.0322e+01, -1.8461e+01, -1.4367e+01, -1.3818e+01, -2.1222e+01,
          -2.1028e+01, -3.0182e+01, -3.9391e-04, -1.3978e+01, -7.8458e+00],
         [-7.0925e+00, -6.4781e+00, -2.8281e-03, -8.9611e+00, -1.5458e+01,
          -1.5939e+01, -8.0538e+00, -1.5765e+01, -1.1576e+01, -1.9908e+01],
         [-1.4896e+01, -2.6902e-04, -1.1967e+01, -1.2769e+01, -1.0093e+01,
          -1.6410e+01, -1.7167e+01, -8.4313e+00, -1.7632e+01, -1.7814e+01],
         [-1.7689e-04, -1.2115e+01, -1.0364e+01, -1.2090e+01, -1.4672e+01,
          -1.0391e+01, -9.1887e+00, -1.4387e+01, -1.6410e+01, -1.6033e+01],
         [-1.5253e+01, -1.2870e+01, -1.4658e+01, -1.5654e+01, -3.5818e-03,
          -1.8401e+01, -1.5776e+01, -8.8162e+00, -1.5213e+01, -5.6772e+00],
         [-1.5912e+01, -9.6197e-05, -1.2836e+01, -1.3269e+01, -1.0801e+01,
          -1.8322e+01, -1.8830e+01, -9.5480e+00, -1.9266e+01, -1.8337e+01],
         [-1.5841e+01, -1.0577e+01, -1.2496e+01, -1.3498e+01, -2.0000e-02,
          -1.2645e+