<a href="https://colab.research.google.com/github/pandian-raja/EVA8/blob/main/Session_2_5_PyTorch_101_assignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

randomNumber = torch.tensor([0,1,2,3,4,5,6,7,8,9])

In [3]:
class TestDataSet(Dataset):
    def __init__(self):
        self.data = torchvision.datasets.MNIST('/content/mnist', train=False, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
        
        self.random = torch.randint(0, 10, (len(self.data),))
        
    def __getitem__(self, index):
        image, label = self.data[index]
        randomInput = randomNumber.eq(self.random[index]).long()
        return image, label, randomInput

    def __len__(self):
        return len(self.data)

In [4]:
class TrainDataSet(Dataset):
    def __init__(self):
        self.data = torchvision.datasets.MNIST('/content/mnist', train=True, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
        self.random = torch.randint(0, 10, (len(self.data),))
        
    def __getitem__(self, index):
        image, label = self.data[index]
        randomInput = randomNumber.eq(self.random[index]).long()
        return image, label, randomInput

    def __len__(self):
        return len(self.data)

In [5]:
class Network(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3)
        self.conv4 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3)
        
        self.fc1 = nn.Linear(in_features=128*9*9,out_features=64)
        self.fc2 = nn.Linear(in_features=64,out_features=10)
        
        self.fc3 = nn.Linear(in_features=10,out_features=20)

        self.out = nn.Linear(in_features=20,out_features=19)

    def forward(self, image, random):

        x = image
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) 
        x = self.conv4(x)
        x = F.relu(x)
        x = x.reshape(-1, 128*9*9)
        # x = x.reshape(1, -1)
        x = self.fc1(x)
        x = F.relu(x)
        mnist_output = self.fc2(x)
        x = mnist_output+random
        x = self.fc3(x)
        x = F.relu(x)
        x = self.out(x)
        mnist_output = F.softmax(mnist_output, dim=1)
        x = F.softmax(x, dim=1)
        return mnist_output, x

In [6]:
def get_num_correct(images, labels, random, random_label):
    
    return images.argmax(dim=1).eq(labels).sum().item(), random.argmax(dim=1).eq(random_label).sum().item()


In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
network  = Network()
network.to(device)
testDataset = TestDataSet()
trainDataset = TrainDataSet()
testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=200, shuffle= True)
trainDataLoader = torch.utils.data.DataLoader(trainDataset, batch_size=200, shuffle= True)
optimizer = optim.Adam(network.parameters(), lr=0.001)
totalEpoch = 10
for epoch in range(totalEpoch):
    total_mnist_loss = 0
    total_mnist_correct = 0
    total_random_correct = 0
    total_random_loss = 0
    for batch in trainDataLoader:
        images, labels, randoms = batch
        images, labels, randoms = images.to(device), labels.to(device), randoms.to(device)
        random_labels = labels+randoms.argmax(dim=1)
        mnist_predict, random_predict = network(images, randoms)
        mnist_loss = F.cross_entropy(mnist_predict, labels)
        random_loss = F.cross_entropy(random_predict, random_labels)
        total_mnist_loss += mnist_loss.item()
        total_random_loss += random_loss.item()
        loss = mnist_loss+random_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        result = get_num_correct(mnist_predict,labels,random_predict,random_labels)
        total_mnist_correct += result[0]
        total_random_correct += result[1]

    mnist_accuracy = total_mnist_correct/len(trainDataset)
    random_accuracy = total_random_correct/len(trainDataset)
    print(
        "epoch:", epoch, 
        "MNIST { Correct:", total_mnist_correct, 
        "Accuracy: %.2f" %mnist_accuracy,
        "Loss: %.2f" % total_mnist_loss,
        "} {RANDOM Correct:", total_random_correct, 
        "Accuracy: %.2f" %random_accuracy,
        "Loss: %.2f" % total_random_loss,
        "} { Total Correct:", total_mnist_correct+total_random_correct, 
        "Accuracy: %.2f" %((random_accuracy+mnist_accuracy)/2),
        "loss: %.2f }" % (total_mnist_loss+total_random_loss)
    )

epoch: 0 MNIST { Correct: 47015 Accuracy: 0.78 Loss: 503.20 } {RANDOM Correct: 5809 Accuracy: 0.10 Loss: 874.50 } { Total Correct: 52824 Accuracy: 0.44 loss: 1377.71 }
epoch: 1 MNIST { Correct: 53535 Accuracy: 0.89 Loss: 470.78 } {RANDOM Correct: 6510 Accuracy: 0.11 Loss: 870.80 } { Total Correct: 60045 Accuracy: 0.50 loss: 1341.58 }
epoch: 2 MNIST { Correct: 53823 Accuracy: 0.90 Loss: 469.07 } {RANDOM Correct: 6872 Accuracy: 0.11 Loss: 869.49 } { Total Correct: 60695 Accuracy: 0.51 loss: 1338.56 }
epoch: 3 MNIST { Correct: 54059 Accuracy: 0.90 Loss: 467.70 } {RANDOM Correct: 7451 Accuracy: 0.12 Loss: 868.15 } { Total Correct: 61510 Accuracy: 0.51 loss: 1335.84 }
epoch: 4 MNIST { Correct: 54148 Accuracy: 0.90 Loss: 467.02 } {RANDOM Correct: 7918 Accuracy: 0.13 Loss: 866.83 } { Total Correct: 62066 Accuracy: 0.52 loss: 1333.85 }
epoch: 5 MNIST { Correct: 54213 Accuracy: 0.90 Loss: 466.63 } {RANDOM Correct: 8609 Accuracy: 0.14 Loss: 864.69 } { Total Correct: 62822 Accuracy: 0.52 loss: 13

In [16]:
totalEpoch = 20
for epoch in range(totalEpoch):
    total_mnist_loss = 0
    total_mnist_correct = 0
    total_random_correct = 0
    total_random_loss = 0
    for batch in trainDataLoader:
        images, labels, randoms = batch
        images, labels, randoms = images.to(device), labels.to(device), randoms.to(device)
        random_labels = labels+randoms.argmax(dim=1)
        mnist_predict, random_predict = network(images, randoms)
        mnist_loss = F.cross_entropy(mnist_predict, labels)
        random_loss = F.cross_entropy(random_predict, random_labels)
        total_mnist_loss += mnist_loss.item()
        total_random_loss += random_loss.item()
        loss = mnist_loss+random_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        result = get_num_correct(mnist_predict,labels,random_predict,random_labels)
        total_mnist_correct += result[0]
        total_random_correct += result[1]

    mnist_accuracy = total_mnist_correct/len(trainDataset)
    random_accuracy = total_random_correct/len(trainDataset)
    print(
        "epoch:", epoch, 
        "MNIST { Correct:", total_mnist_correct, 
        "Accuracy: %.2f" %mnist_accuracy,
        "Loss: %.2f" % total_mnist_loss,
        "} {RANDOM Correct:", total_random_correct, 
        "Accuracy: %.2f" %random_accuracy,
        "Loss: %.2f" % total_random_loss,
        "} { Total Correct:", total_mnist_correct+total_random_correct, 
        "Accuracy: %.2f" %((random_accuracy+mnist_accuracy)/2),
        "loss: %.2f }" % (total_mnist_loss+total_random_loss)
    )

epoch: 0 MNIST { Correct: 59746 Accuracy: 1.00 Loss: 440.16 } {RANDOM Correct: 13034 Accuracy: 0.22 Loss: 847.30 } { Total Correct: 72780 Accuracy: 0.61 loss: 1287.46 }
epoch: 1 MNIST { Correct: 59791 Accuracy: 1.00 Loss: 439.99 } {RANDOM Correct: 14445 Accuracy: 0.24 Loss: 841.36 } { Total Correct: 74236 Accuracy: 0.62 loss: 1281.34 }
epoch: 2 MNIST { Correct: 59815 Accuracy: 1.00 Loss: 439.82 } {RANDOM Correct: 15937 Accuracy: 0.27 Loss: 834.59 } { Total Correct: 75752 Accuracy: 0.63 loss: 1274.41 }
epoch: 3 MNIST { Correct: 59832 Accuracy: 1.00 Loss: 439.76 } {RANDOM Correct: 17406 Accuracy: 0.29 Loss: 827.43 } { Total Correct: 77238 Accuracy: 0.64 loss: 1267.19 }
epoch: 4 MNIST { Correct: 59854 Accuracy: 1.00 Loss: 439.62 } {RANDOM Correct: 18932 Accuracy: 0.32 Loss: 820.26 } { Total Correct: 78786 Accuracy: 0.66 loss: 1259.88 }
epoch: 5 MNIST { Correct: 59863 Accuracy: 1.00 Loss: 439.58 } {RANDOM Correct: 20141 Accuracy: 0.34 Loss: 813.84 } { Total Correct: 80004 Accuracy: 0.67 lo