In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import torchvision # provide access to datasets, models, transforms, utils, etc
import torchvision.transforms as transforms

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv5 = nn.Conv2d(256, 512, 3)
        self.conv6 = nn.Conv2d(512, 1024, 3)
        self.linear = nn.Linear(in_features=3*3*1024, out_features = 100)
        
        # From the CNN pipeline a 100 dimension 1D FC layer is made 
        # this 100 dimension 1D array is combined with 10 dim one hot encoding
        # of integers belonging to 0-9 class.100 + 10 = 110 nodes are then
        # fed to a passed to a MLP network of : 
        # 1st Hidden Layer = 200 nodes
        # 2nd Hidden Layer = 50 nodes 
        # output has 29 dim
        # 29 dim output is split to 10 dim for mnist label 
        # and 19 dim for sum label prediction. the sum is one of 19 class output 
        # and hence has 19 dim

        self.linear_sum1 = nn.Linear(in_features=110, out_features = 200) 
        self.linear_sum2 = nn.Linear(in_features=200, out_features = 50)
        self.linear_sum3 = nn.Linear(in_features=50, out_features = 29)

    def forward(self, x_mnist, x_sum): 
        
        # x_rnd is the one hot encoded random number
        x_mnist = self.pool1(F.relu(self.conv2(F.relu(self.conv1(x_mnist)))))
        x_mnist = self.pool2(F.relu(self.conv4(F.relu(self.conv3(x_mnist)))))
        x_mnist = F.relu(self.conv6(F.relu(self.conv5(x_mnist))))
        
        # flatten the array for FC layer from CNN pipeline for MNIST
        x_mnist = x_mnist.view(x_mnist.size()[0], -1)
        x_mnist = self.linear(x_mnist)
  
        # x_mnist is 100 .....x_sum is 10
        # x_combine concatenates x_mnist and x_sum
        x_combine  = torch.cat((x_mnist, x_sum), dim=1)
        
        # passing x_combine into a MLP
        x_combine = self.linear_sum3(F.relu(self.linear_sum2(F.relu(self.linear_sum1(x_combine)))))
        x_mnist = x_combine[:,0:10]
        x_sum = x_combine[:,10:]
        
        x_mnist = F.log_softmax(x_mnist, dim=1)
        x_sum = F.log_softmax(x_sum, dim=1)
        
        # Network returns the 10 class soft max form of predicted MNIST (x_mnist)
        # and 19 class soft max form of predicted SUM (x_sum)
        return x_mnist, x_sum

In [3]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")



In [4]:
device

device(type='cuda')

In [5]:
from torch.utils.data import Dataset
import torch.nn.functional as F

class MyDataset(Dataset):
    def __init__(self, data_array):
        self.RNdata = data_array
        self.MNISTdata = torchvision.datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
        
    def __getitem__(self, index):
        r = self.RNdata[index]
        RNnumber = F.one_hot(r, num_classes = 10).type(torch.float32).requires_grad_(True)
        
        MNISTimage, MNISTlabel = self.MNISTdata[index] 
        
        # DATA IS THE NUMBER ITSELF AND THE LABEL FOR CORRECT
        # SUM PREDICTION IS THE SUM OF THE MNIST LABEL AND THE 
        # NUMBER ITSELF
        SUMlabel = r.item() + MNISTlabel 
        
        return RNnumber, SUMlabel, MNISTimage, MNISTlabel
    
    def __len__(self):
        return len(self.RNdata)

In [6]:
# RANDOMLY GENERATE AN ARRAY OF 60000 NUMBERS AND MAKE IT THE TRAIN DATA SET
my_tensor = torch.randint(low=0,high=9,size=(60000,))

myRNMNIST_train = MyDataset(my_tensor)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
batch_size = 128

train_RNMNIST_loader = torch.utils.data.DataLoader(myRNMNIST_train, batch_size = batch_size, shuffle=True)

In [8]:
class MyDataset_test(Dataset):
    def __init__(self, data_array):
        self.RNdata = data_array
        self.MNISTdata = torchvision.datasets.MNIST('../data', train=False, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
        
    def __getitem__(self, index):
        r = self.RNdata[index]
        RNnumber = F.one_hot(r, num_classes = 10).type(torch.float32)

        MNISTimage, MNISTlabel = self.MNISTdata[index] 
        
        # DATA IS THE NUMBER ITSELF AND THE LABEL FOR CORRECT
        # SUM PREDICTION IS THE SUM OF THE MNIST LABEL AND THE 
        # NUMBER ITSELF
        SUMlabel = r.item() + MNISTlabel 
        
        return RNnumber, SUMlabel, MNISTimage, MNISTlabel
    
    def __len__(self):
        return len(self.RNdata)

In [9]:
# RANDOMLY GENERATE AN ARRAY OF 10000 NUMBERS AND MAKE IT THE TEST DATA SET
my_tensor_test = torch.randint(low=0,high=9,size=(10000,))

myRNMNIST_test = MyDataset_test(my_tensor_test)

test_RNMNIST_loader = torch.utils.data.DataLoader(myRNMNIST_test, batch_size = batch_size, shuffle=True)

In [10]:
from tqdm import tqdm

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (rn_number, rn_sum_label, mnist_image, mnist_label) in enumerate(pbar):

        # TRAINING ON GPU
        rn_number, rn_sum_label = rn_number.to(device), rn_sum_label.to(device)
        mnist_image, mnist_label = mnist_image.to(device), mnist_label.to(device)
        
        optimizer.zero_grad()

        output_image, output_rnd_sum = model(mnist_image, rn_number)
        
        loss_mnist = F.nll_loss(output_image, mnist_label) # LOSS FOR PREDICTING MNIST LABEL
        loss_sum = F.nll_loss(output_rnd_sum, rn_sum_label)# LOSS FOR PREDICTING SUM LABEL
        
        # ADDING BOTH THE SUM AS THE NETWORK NEEDS TO BE GOOD
        # AT BOTH PREDICITNG THE MNIST LABEL AND THE SUM LABEL
        # CORRECTLY AT THE SAME TIME
        loss = loss_mnist + loss_sum 
        loss.backward()
        optimizer.step()
        
        pbar.set_description(desc= f'MNIST loss={loss_mnist.item()} SUM loss={loss_sum.item()} batch_id={batch_idx}')


def test(model, device, test_loader):
    model.eval()
    test_loss_mnist = 0
    test_loss_sum = 0
    correct_mnist = 0
    correct_sum = 0
    with torch.no_grad():
        for rn_number, rn_sum_label, mnist_image, mnist_label in test_loader:
            
            rn_number, rn_sum_label = rn_number.to(device), rn_sum_label.to(device)
            mnist_image, mnist_label = mnist_image.to(device), mnist_label.to(device)
            
            output_image, output_rnd_sum = model(mnist_image, rn_number)
            
            test_loss_mnist += F.nll_loss(output_image, mnist_label, reduction='sum').item()
            test_loss_sum += F.nll_loss(output_rnd_sum, rn_sum_label, reduction='sum').item()

            pred_mnist = output_image.argmax(dim=1, keepdim=True)  
            pred_sum = output_rnd_sum.argmax(dim=1, keepdim=True)
            
            correct_mnist += pred_mnist.eq(mnist_label.view_as(pred_mnist)).sum().item()
            correct_sum += pred_sum.eq(rn_sum_label.view_as(pred_sum)).sum().item()

    test_loss_mnist /= len(test_loader.dataset)
    test_loss_sum /= len(test_loader.dataset)

    print('\nTest set: Average MNIST loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss_mnist, correct_mnist, len(test_loader.dataset),
        100. * correct_mnist / len(test_loader.dataset)))
    
    print('\nTest set: Average SUM loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss_sum, correct_sum, len(test_loader.dataset),
        100. * correct_sum / len(test_loader.dataset)))

In [11]:
network = Net().to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)

torch.set_printoptions(edgeitems=11)


for epoch in range(1, 11):
    train(network, device, train_RNMNIST_loader, optimizer, epoch)
    test(network, device, test_RNMNIST_loader)

MNIST loss=0.13151748478412628 SUM loss=2.334129571914673 batch_id=468: 100%|█| 



Test set: Average MNIST loss: 0.1669, Accuracy: 9526/10000 (95%)


Test set: Average SUM loss: 2.3467, Accuracy: 1181/10000 (12%)



MNIST loss=0.13226385414600372 SUM loss=1.6017498970031738 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0683, Accuracy: 9819/10000 (98%)


Test set: Average SUM loss: 1.4576, Accuracy: 4284/10000 (43%)



MNIST loss=0.04307926073670387 SUM loss=0.9536581039428711 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0352, Accuracy: 9894/10000 (99%)


Test set: Average SUM loss: 0.8700, Accuracy: 6945/10000 (69%)



MNIST loss=0.04952253773808479 SUM loss=0.37748631834983826 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0285, Accuracy: 9915/10000 (99%)


Test set: Average SUM loss: 0.4081, Accuracy: 9361/10000 (94%)



MNIST loss=0.09310784935951233 SUM loss=0.1534867137670517 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0295, Accuracy: 9910/10000 (99%)


Test set: Average SUM loss: 0.1465, Accuracy: 9789/10000 (98%)



MNIST loss=0.09910460561513901 SUM loss=0.21645613014698029 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0246, Accuracy: 9931/10000 (99%)


Test set: Average SUM loss: 0.0850, Accuracy: 9851/10000 (99%)



MNIST loss=0.010061712004244328 SUM loss=0.06789124757051468 batch_id=468: 100%|



Test set: Average MNIST loss: 0.0251, Accuracy: 9930/10000 (99%)


Test set: Average SUM loss: 0.0795, Accuracy: 9866/10000 (99%)



MNIST loss=0.011025783605873585 SUM loss=0.025114253163337708 batch_id=468: 100%



Test set: Average MNIST loss: 0.0251, Accuracy: 9933/10000 (99%)


Test set: Average SUM loss: 0.0734, Accuracy: 9882/10000 (99%)



MNIST loss=0.001581490971148014 SUM loss=0.009856115095317364 batch_id=468: 100%



Test set: Average MNIST loss: 0.0229, Accuracy: 9930/10000 (99%)


Test set: Average SUM loss: 0.0583, Accuracy: 9893/10000 (99%)



MNIST loss=0.000775829132180661 SUM loss=0.009833535179495811 batch_id=468: 100%



Test set: Average MNIST loss: 0.0228, Accuracy: 9944/10000 (99%)


Test set: Average SUM loss: 0.0577, Accuracy: 9904/10000 (99%)

