In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from random import randrange

In [3]:
# Neural Network Class
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        # Layers for Image Number Prediction
        # input 28 | output 24 | Because of Stride 5
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)

        # input 24 | output 20 | Because of Stride 5
        self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(in_features=20 * 4 * 4, out_features=50)
        self.out = nn.Linear(in_features=50, out_features=10)

        # Layers for Sum Prediction
        self.sum_fc1 = nn.Linear(in_features=20, out_features=50)
        self.sum_fc2 = nn.Linear(in_features=50, out_features=70)
        self.sum_fc3 = nn.Linear(in_features=70, out_features=40)
        self.sum_out = nn.Linear(in_features=40, out_features=19)


    def forward(self, x, x_int):
        # Processing Image
        x = self.conv1(x) # input image size 28 | output image size 24 | Because of Stride 5
        x = F.max_pool2d(x, kernel_size=2, stride=2)  # input image size 24 | output image size 12 | Because of Maxpooling
        x = F.relu(x) # Output Channels 10


        x = self.conv2(x) # input image size 12 | output image size 8 | Because of Stride 5
        x = self.conv2_drop(x)
        x = F.max_pool2d(x, kernel_size=2,stride=2)   # input image size 8 | output image size 4 | Because of Maxpooling
        x = F.relu(x) # Output Channels 20

        x = x.reshape(-1, 20 * 4 * 4)

        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = F.log_softmax(F.relu(self.out(x)))

        # Processing Number
        # Getting number predicted by the image network
        num_x = x.data.max(1, keepdim=True)[1]

        # One hot encoding of the number predicted by the image network - 10 classes - 0 to 9
        num_x_one_hot = F.one_hot(num_x.squeeze(), num_classes=10)
        
        # One Hot encoding of input number - 10 classes - 0 to 9
        x_int_one_hot = F.one_hot(x_int, num_classes=10)

        # This condition is required for testing single sample
        if len(num_x_one_hot.shape) == 1:
            num_x_one_hot = num_x_one_hot.unsqueeze(dim=0)

        # Concatnet the one hot encoded image number and input number columnwise - dim=1
        # For batch size of 1000
        # One hot encoded image number shape 1000 x 10
        # One hot encoded input number shape 1000 x 10

        sum_t = torch.cat((num_x_one_hot,x_int_one_hot), dim=1).float()
        # After concatenation - 1000 x 20

        # Now take channels from 20 -> 50 -> 70 -> 40 -> 19
        # Final channels are 19 - as the maximum sum can be 9 + 9 = 18
        sum_t = F.relu(self.sum_fc1(sum_t))
        sum_t = F.dropout(F.relu(self.sum_fc2(sum_t)), training=self.training)
        sum_t = F.dropout(F.relu(self.sum_fc3(sum_t)), training=self.training)
        sum_t = F.log_softmax(F.relu(self.sum_out(sum_t)))

        # Returning processed image and number 
        return x,sum_t

In [None]:
# Custom dataloader combining image and a random number
# Inherits from MNIST
# We need to override only __getitem__ to return a random number along with an image
class MNISTPlusDataset(MNIST):
    def __getitem__(self, index):
        # User parent class __getitem__ method to get an image
        item = super(MNISTPlusDataset, self).__getitem__(index)

        # Generate random int 0 to 9
        random_int = torch.tensor(randrange(10))

        # return tuple with following 
        # x :  a tuple, image and a random number
        # y : a tuple, true image lable and random number + true image number 
        return (
            (item[0], random_int),
            (item[1],random_int + item[1])
        )