In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import torchvision # provide access to datasets, models, transforms, utils, etc
import torchvision.transforms as transforms

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)         self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv5 = nn.Conv2d(256, 512, 3)
        self.conv6 = nn.Conv2d(512, 1024, 3)
        self.linear = nn.Linear(in_features=3*3*1024, out_features = 10)
        
        # Taking a 10 node one hot encoded integer (0-9) and 
        # pass it to a 100 node FC Layer and further convert it to 19 
        # node output for sum as sum can range from 0 to 18 

        self.linear_sum1 = nn.Linear(in_features=10, out_features = 100) 
        self.linear_sum2 = nn.Linear(in_features=100, out_features = 19)

    def forward(self, x_image, x_rnd): 
        
        # x_rnd is the one hot encoded random number
        x_image = self.pool1(F.relu(self.conv2(F.relu(self.conv1(x_image)))))
        x_image = self.pool2(F.relu(self.conv4(F.relu(self.conv3(x_image)))))
        x_image = F.relu(self.conv6(F.relu(self.conv5(x_image))))
        
        # flatten the array for FC layer
        x_image = x_image.view(x_image.size()[0], -1)
        x_image = self.linear(x_image)
        
        x_one_batch_softmax = F.log_softmax(x_image, dim=1)
        
        x_sum = self.linear_sum2(F.relu(self.linear_sum1(x_rnd)))
        x_sum = F.log_softmax(x_sum, dim=1)
        
        # Network returns the 10 class soft max form of predicted MNIST (x_one_batch_softmax)
        # and 19 class soft max form of predicted SUM (x_sum)
        return x_one_batch_softmax, x_sum

In [3]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")



In [4]:
device

device(type='cuda')

In [5]:
from torch.utils.data import Dataset
import torch.nn.functional as F

class MyDataset(Dataset):
    def __init__(self, data_array):
        self.RNdata = data_array
        self.MNISTdata = torchvision.datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
        
    def __getitem__(self, index):
        r = self.RNdata[index]
        RNnumber = F.one_hot(r, num_classes = 10).type(torch.float32).requires_grad_(True)
        
        MNISTimage, MNISTlabel = self.MNISTdata[index] 
        
        # DATA IS THE NUMBER ITSELF AND THE LABEL FOR CORRECT
        # SUM PREDICTION IS THE SUM OF THE MNIST LABEL AND THE 
        # NUMBER ITSELF
        SUMlabel = r.item() + MNISTlabel 
        
        return RNnumber, SUMlabel, MNISTimage, MNISTlabel
    
    def __len__(self):
        return len(self.RNdata)

In [6]:
# RANDOMLY GENERATE AN ARRAY OF 60000 NUMBERS AND MAKE IT THE TRAIN DATA SET
my_tensor = torch.randint(low=0,high=9,size=(60000,))

myRNMNIST_train = MyDataset(my_tensor)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
batch_size = 128

train_RNMNIST_loader = torch.utils.data.DataLoader(myRNMNIST_train, batch_size = batch_size, shuffle=True)

In [8]:
class MyDataset_test(Dataset):
    def __init__(self, data_array):
        self.RNdata = data_array
        self.MNISTdata = torchvision.datasets.MNIST('../data', train=False, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
        
    def __getitem__(self, index):
        r = self.RNdata[index]
        RNnumber = F.one_hot(r, num_classes = 10).type(torch.float32)

        MNISTimage, MNISTlabel = self.MNISTdata[index] 
        
        # DATA IS THE NUMBER ITSELF AND THE LABEL FOR CORRECT
        # SUM PREDICTION IS THE SUM OF THE MNIST LABEL AND THE 
        # NUMBER ITSELF
        SUMlabel = r.item() + MNISTlabel 
        
        return RNnumber, SUMlabel, MNISTimage, MNISTlabel
    
    def __len__(self):
        return len(self.RNdata)

In [9]:
# RANDOMLY GENERATE AN ARRAY OF 10000 NUMBERS AND MAKE IT THE TEST DATA SET
my_tensor_test = torch.randint(low=0,high=9,size=(10000,))

myRNMNIST_test = MyDataset_test(my_tensor_test)

test_RNMNIST_loader = torch.utils.data.DataLoader(myRNMNIST_test, batch_size = batch_size, shuffle=True)

In [10]:
from tqdm import tqdm

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (rn_number, rn_sum_label, mnist_image, mnist_label) in enumerate(pbar):

        # TRAINING ON GPU
        rn_number, rn_sum_label = rn_number.to(device), rn_sum_label.to(device)
        mnist_image, mnist_label = mnist_image.to(device), mnist_label.to(device)
        
        optimizer.zero_grad()

        output_image, output_rnd_sum = model(mnist_image, rn_number)
        
        loss_mnist = F.nll_loss(output_image, mnist_label) # LOSS FOR PREDICTING MNIST LABEL
        loss_sum = F.nll_loss(output_rnd_sum, rn_sum_label)# LOSS FOR PREDICTING SUM LABEL
        
        # ADDING BOTH THE SUM AS THE NETWORK NEEDS TO BE GOOD
        # AT BOTH PREDICITNG THE MNIST LABEL AND THE SUM LABEL
        # CORRECTLY AT THE SAME TIME
        loss = loss_mnist + loss_sum 
        loss.backward()
        optimizer.step()
        
        pbar.set_description(desc= f'MNIST loss={loss_mnist.item()} SUM loss={loss_sum.item()} batch_id={batch_idx}')


def test(model, device, test_loader):
    model.eval()
    test_loss_mnist = 0
    test_loss_sum = 0
    correct_mnist = 0
    correct_sum = 0
    with torch.no_grad():
        for rn_number, rn_sum_label, mnist_image, mnist_label in test_loader:
            
            rn_number, rn_sum_label = rn_number.to(device), rn_sum_label.to(device)
            mnist_image, mnist_label = mnist_image.to(device), mnist_label.to(device)
            
            output_image, output_rnd_sum = model(mnist_image, rn_number)
            
            test_loss_mnist += F.nll_loss(output_image, mnist_label, reduction='sum').item()
            test_loss_sum += F.nll_loss(output_rnd_sum, rn_sum_label, reduction='sum').item()

            pred_mnist = output_image.argmax(dim=1, keepdim=True)  
            pred_sum = output_rnd_sum.argmax(dim=1, keepdim=True)
            
            correct_mnist += pred_mnist.eq(mnist_label.view_as(pred_mnist)).sum().item()
            correct_sum += pred_sum.eq(rn_sum_label.view_as(pred_sum)).sum().item()

    test_loss_mnist /= len(test_loader.dataset)
    test_loss_sum /= len(test_loader.dataset)

    print('\nTest set: Average MNIST loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss_mnist, correct_mnist, len(test_loader.dataset),
        100. * correct_mnist / len(test_loader.dataset)))
    
    print('\nTest set: Average SUM loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss_sum, correct_sum, len(test_loader.dataset),
        100. * correct_sum / len(test_loader.dataset)))

In [11]:
network = Net().to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)

torch.set_printoptions(edgeitems=11)


for epoch in range(1, 11):
    train(network, device, train_RNMNIST_loader, optimizer, epoch)
    test(network, device, test_RNMNIST_loader)

MNIST loss=0.08297522366046906 SUM loss=2.466980457305908 batch_id=468: 100%|█| 



Test set: Average MNIST loss: 0.0602, Accuracy: 9806/10000 (98%)


Test set: Average SUM loss: 2.4803, Accuracy: 985/10000 (10%)



MNIST loss=0.0625964105129242 SUM loss=2.3915774822235107 batch_id=468: 100%|█| 



Test set: Average MNIST loss: 0.0319, Accuracy: 9894/10000 (99%)


Test set: Average SUM loss: 2.3591, Accuracy: 1040/10000 (10%)



MNIST loss=0.006198072340339422 SUM loss=2.343538761138916 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0330, Accuracy: 9889/10000 (99%)


Test set: Average SUM loss: 2.3308, Accuracy: 1047/10000 (10%)



MNIST loss=0.010398979298770428 SUM loss=2.3249571323394775 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0352, Accuracy: 9881/10000 (99%)


Test set: Average SUM loss: 2.3202, Accuracy: 1024/10000 (10%)



MNIST loss=0.005175141151994467 SUM loss=2.300950050354004 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0332, Accuracy: 9900/10000 (99%)


Test set: Average SUM loss: 2.3167, Accuracy: 1076/10000 (11%)



MNIST loss=0.025918133556842804 SUM loss=2.3093085289001465 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0282, Accuracy: 9911/10000 (99%)


Test set: Average SUM loss: 2.3151, Accuracy: 1054/10000 (11%)



MNIST loss=0.018872041255235672 SUM loss=2.3080291748046875 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0309, Accuracy: 9894/10000 (99%)


Test set: Average SUM loss: 2.3126, Accuracy: 1091/10000 (11%)



MNIST loss=0.017793577164411545 SUM loss=2.3040051460266113 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0293, Accuracy: 9924/10000 (99%)


Test set: Average SUM loss: 2.3102, Accuracy: 1116/10000 (11%)



MNIST loss=0.002159718656912446 SUM loss=2.2874672412872314 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0359, Accuracy: 9892/10000 (99%)


Test set: Average SUM loss: 2.3097, Accuracy: 1065/10000 (11%)



MNIST loss=0.005997654050588608 SUM loss=2.3240864276885986 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0289, Accuracy: 9909/10000 (99%)


Test set: Average SUM loss: 2.3079, Accuracy: 1079/10000 (11%)

