<a href="https://colab.research.google.com/github/pandian-raja/EVA8/blob/main/Session_2_5_PyTorch_101_assignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import packages

In [94]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary

randomNumber = torch.tensor([0,1,2,3,4,5,6,7,8,9])

Defining test Dataset: 

In [95]:
class TestDataSet(Dataset):
    def __init__(self):
        #Download all the mnist test images and convert into tensor
        self.data = torchvision.datasets.MNIST('/content/mnist', train=False, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
        #Create random number 
        self.random = torch.randint(0, 10, (len(self.data),))
        
    def __getitem__(self, index):
        image, label = self.data[index]
        # We can't return a int, So to make it flatten
        # With a random number, I used eq() to get 1D True/False values and converted them to long int.
        randomInput = randomNumber.eq(self.random[index]).long()
        return image, label, randomInput

    def __len__(self):
        return len(self.data)

Defining train Dataset: 

In [96]:
class TrainDataSet(Dataset):
    def __init__(self):
        self.data = torchvision.datasets.MNIST('/content/mnist', train=True, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
        self.random = torch.randint(0, 10, (len(self.data),))
        
    def __getitem__(self, index):
        image, label = self.data[index]
        randomInput = randomNumber.eq(self.random[index]).long()
        return image, label, randomInput

    def __len__(self):
        return len(self.data)

Creating model

In [97]:
class Network(nn.Module):

    def __init__(self):
        super().__init__()
        #3x3: SIZE { in = 28 , out= 26}, recptive field = 3, CHANNEL { in = 1, out= 16 }
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3)
        #3x3: SIZE { in = 26 , out= 24}, recptive field = 35, CHANNEL { in = 16, out= 32 }
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3)
        #3x3: SIZE { in = 24 , out= 22}, recptive field = 7, CHANNEL { in = 32, out= 64 }
        self.conv3 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3)

        #MX: SIZE { in = 22, out = 11},  recptive field = 14, CHANNEL { in = 64, out= 64 } Implemented in forward()

        #3x3: SIZE { in = 11 , out= 9}, recptive field = 17, CHANNEL { in = 64, out= 128 }
        self.conv4 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3)

        self.fc1 = nn.Linear(in_features=128*9*9,out_features=64)
        self.fc2 = nn.Linear(in_features=64,out_features=10)

        self.fc3 = nn.Linear(in_features=10,out_features=20)
        self.out = nn.Linear(in_features=20,out_features=19)

    def forward(self, image, random):

        x = image
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) 
        x = self.conv4(x)
        x = F.relu(x)
        x = x.reshape(-1, 128*9*9)
        # x = x.reshape(1, -1)
        x = self.fc1(x)
        x = F.relu(x)
        # Now we have 10 channels and return for mnist and add with random  
        mnist_output = self.fc2(x)
        # Add mnist_output+random
        x = mnist_output+random
        # Expand channel to 20 
        x = self.fc3(x)
        x = F.relu(x)
        # Maximum value we get by adding label and random is 19. So made the output as 19 channel 
        x = self.out(x)
        mnist_output = F.softmax(mnist_output, dim=1)
        x = F.softmax(x, dim=1)
        return mnist_output, x

In [98]:
def get_num_correct(images, labels, random, random_label):
    
    return images.argmax(dim=1).eq(labels).sum().item(), random.argmax(dim=1).eq(random_label).sum().item()


In [99]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
network  = Network()
network.to(device)
testDataset = TestDataSet()
trainDataset = TrainDataSet()
testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=200, shuffle= True)
trainDataLoader = torch.utils.data.DataLoader(trainDataset, batch_size=200, shuffle= True)
optimizer = optim.Adam(network.parameters(), lr=0.001)
totalEpoch = 40
for epoch in range(totalEpoch):
    total_mnist_loss = 0
    total_mnist_correct = 0
    total_random_correct = 0
    total_random_loss = 0
    for batch in trainDataLoader:
        images, labels, randoms = batch
        images, labels, randoms = images.to(device), labels.to(device), randoms.to(device)
        #truth data for random value
        random_labels = labels+randoms.argmax(dim=1)
        mnist_predict, random_predict = network(images, randoms)
        #calculate loss for mnist
        mnist_loss = F.cross_entropy(mnist_predict, labels)
        #calculate loss for random
        random_loss = F.cross_entropy(random_predict, random_labels)
        total_mnist_loss += mnist_loss.item()
        total_random_loss += random_loss.item()
        #Since we can't calculate gradiants twice (mnist_loss,random_loss), So I'm add both loss 
        loss = mnist_loss+random_loss
        optimizer.zero_grad()
        #Computes the gradient 
        loss.backward()
        optimizer.step()
        result = get_num_correct(mnist_predict,labels,random_predict,random_labels)
        total_mnist_correct += result[0]
        total_random_correct += result[1]

    mnist_accuracy = total_mnist_correct/len(trainDataset)
    random_accuracy = total_random_correct/len(trainDataset)
    print(
        "epoch:", epoch+1, 
        "MNIST {Correct:", total_mnist_correct, 
        "Accuracy: %.3f" %mnist_accuracy,
        "Loss: %.2f" % total_mnist_loss,
        "} RANDOM {Correct:", total_random_correct, 
        "Accuracy: %.3f" %random_accuracy,
        "Loss: %.2f" % total_random_loss,
        "} Total {Correct:", total_mnist_correct+total_random_correct, 
        "Accuracy: %.3f" %((random_accuracy+mnist_accuracy)/2),
        "loss: %.2f }" % (total_mnist_loss+total_random_loss)
    )

epoch: 1 MNIST {Correct: 53744 Accuracy: 0.896 Loss: 470.44 } RANDOM {Correct: 5750 Accuracy: 0.096 Loss: 874.68 } Total {Correct: 59494 Accuracy: 0.496 loss: 1345.12 }
epoch: 2 MNIST {Correct: 58533 Accuracy: 0.976 Loss: 446.07 } RANDOM {Correct: 6471 Accuracy: 0.108 Loss: 871.16 } Total {Correct: 65004 Accuracy: 0.542 loss: 1317.23 }
epoch: 3 MNIST {Correct: 59043 Accuracy: 0.984 Loss: 443.44 } RANDOM {Correct: 6768 Accuracy: 0.113 Loss: 870.09 } Total {Correct: 65811 Accuracy: 0.548 loss: 1313.53 }
epoch: 4 MNIST {Correct: 59258 Accuracy: 0.988 Loss: 442.27 } RANDOM {Correct: 7054 Accuracy: 0.118 Loss: 869.20 } Total {Correct: 66312 Accuracy: 0.553 loss: 1311.47 }
epoch: 5 MNIST {Correct: 59410 Accuracy: 0.990 Loss: 441.55 } RANDOM {Correct: 7381 Accuracy: 0.123 Loss: 868.31 } Total {Correct: 66791 Accuracy: 0.557 loss: 1309.86 }
epoch: 6 MNIST {Correct: 59459 Accuracy: 0.991 Loss: 441.26 } RANDOM {Correct: 7746 Accuracy: 0.129 Loss: 867.17 } Total {Correct: 67205 Accuracy: 0.560 lo

In [100]:
totalEpoch = 20
for epoch in range(totalEpoch):
    total_mnist_loss = 0
    total_mnist_correct = 0
    total_random_correct = 0
    total_random_loss = 0
    for batch in trainDataLoader:
        images, labels, randoms = batch
        images, labels, randoms = images.to(device), labels.to(device), randoms.to(device)
        random_labels = labels+randoms.argmax(dim=1)
        mnist_predict, random_predict = network(images, randoms)
        mnist_loss = F.cross_entropy(mnist_predict, labels)
        random_loss = F.cross_entropy(random_predict, random_labels)
        total_mnist_loss += mnist_loss.item()
        total_random_loss += random_loss.item()
        loss = mnist_loss+random_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        result = get_num_correct(mnist_predict,labels,random_predict,random_labels)
        total_mnist_correct += result[0]
        total_random_correct += result[1]

    mnist_accuracy = total_mnist_correct/len(trainDataset)
    random_accuracy = total_random_correct/len(trainDataset)
    print(
        "epoch:", epoch, 
        "MNIST { Correct:", total_mnist_correct, 
        "Accuracy: %.2f" %mnist_accuracy,
        "Loss: %.2f" % total_mnist_loss,
        "} {RANDOM Correct:", total_random_correct, 
        "Accuracy: %.2f" %random_accuracy,
        "Loss: %.2f" % total_random_loss,
        "} { Total Correct:", total_mnist_correct+total_random_correct, 
        "Accuracy: %.2f" %((random_accuracy+mnist_accuracy)/2),
        "loss: %.2f }" % (total_mnist_loss+total_random_loss)
    )

epoch: 0 MNIST { Correct: 59875 Accuracy: 1.00 Loss: 439.05 } {RANDOM Correct: 23969 Accuracy: 0.40 Loss: 790.81 } { Total Correct: 83844 Accuracy: 0.70 loss: 1229.86 }
epoch: 1 MNIST { Correct: 59893 Accuracy: 1.00 Loss: 439.02 } {RANDOM Correct: 23978 Accuracy: 0.40 Loss: 790.69 } { Total Correct: 83871 Accuracy: 0.70 loss: 1229.71 }
epoch: 2 MNIST { Correct: 59886 Accuracy: 1.00 Loss: 439.02 } {RANDOM Correct: 24162 Accuracy: 0.40 Loss: 789.70 } { Total Correct: 84048 Accuracy: 0.70 loss: 1228.72 }
epoch: 3 MNIST { Correct: 59887 Accuracy: 1.00 Loss: 439.01 } {RANDOM Correct: 24304 Accuracy: 0.41 Loss: 789.06 } { Total Correct: 84191 Accuracy: 0.70 loss: 1228.07 }
epoch: 4 MNIST { Correct: 59873 Accuracy: 1.00 Loss: 439.08 } {RANDOM Correct: 24311 Accuracy: 0.41 Loss: 789.04 } { Total Correct: 84184 Accuracy: 0.70 loss: 1228.12 }
epoch: 5 MNIST { Correct: 59893 Accuracy: 1.00 Loss: 439.01 } {RANDOM Correct: 24552 Accuracy: 0.41 Loss: 787.75 } { Total Correct: 84445 Accuracy: 0.70 lo

In [101]:
for epoch in range(totalEpoch):
    total_mnist_loss = 0
    total_mnist_correct = 0
    total_random_correct = 0
    total_random_loss = 0
    for batch in trainDataLoader:
        images, labels, randoms = batch
        images, labels, randoms = images.to(device), labels.to(device), randoms.to(device)
        random_labels = labels+randoms.argmax(dim=1)
        mnist_predict, random_predict = network(images, randoms)
        mnist_loss = F.cross_entropy(mnist_predict, labels)
        random_loss = F.cross_entropy(random_predict, random_labels)
        total_mnist_loss += mnist_loss.item()
        total_random_loss += random_loss.item()
        loss = mnist_loss+random_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        result = get_num_correct(mnist_predict,labels,random_predict,random_labels)
        total_mnist_correct += result[0]
        total_random_correct += result[1]

    mnist_accuracy = total_mnist_correct/len(trainDataset)
    random_accuracy = total_random_correct/len(trainDataset)
    print(
        "epoch:", epoch+60, 
        "MNIST {Correct:", total_mnist_correct, 
        "Accuracy: %.3f" %mnist_accuracy,
        "Loss: %.2f" % total_mnist_loss,
        "} RANDOM {Correct:", total_random_correct, 
        "Accuracy: %.3f" %random_accuracy,
        "Loss: %.2f" % total_random_loss,
        "} Total {Correct:", total_mnist_correct+total_random_correct, 
        "Accuracy: %.3f" %((random_accuracy+mnist_accuracy)/2),
        "loss: %.2f }" % (total_mnist_loss+total_random_loss)
    )

epoch: 60 MNIST {Correct: 59911 Accuracy: 0.999 Loss: 438.96 } RANDOM {Correct: 26313 Accuracy: 0.439 Loss: 778.34 } Total {Correct: 86224 Accuracy: 0.719 loss: 1217.29 }
epoch: 61 MNIST {Correct: 59908 Accuracy: 0.998 Loss: 438.93 } RANDOM {Correct: 26438 Accuracy: 0.441 Loss: 777.71 } Total {Correct: 86346 Accuracy: 0.720 loss: 1216.65 }
epoch: 62 MNIST {Correct: 59909 Accuracy: 0.998 Loss: 438.93 } RANDOM {Correct: 26511 Accuracy: 0.442 Loss: 777.33 } Total {Correct: 86420 Accuracy: 0.720 loss: 1216.26 }
epoch: 63 MNIST {Correct: 59915 Accuracy: 0.999 Loss: 438.93 } RANDOM {Correct: 26499 Accuracy: 0.442 Loss: 777.36 } Total {Correct: 86414 Accuracy: 0.720 loss: 1216.29 }
epoch: 64 MNIST {Correct: 59897 Accuracy: 0.998 Loss: 439.00 } RANDOM {Correct: 26621 Accuracy: 0.444 Loss: 776.67 } Total {Correct: 86518 Accuracy: 0.721 loss: 1215.67 }
epoch: 65 MNIST {Correct: 59917 Accuracy: 0.999 Loss: 438.90 } RANDOM {Correct: 26734 Accuracy: 0.446 Loss: 776.02 } Total {Correct: 86651 Accura