In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import torchvision # provide access to datasets, models, transforms, utils, etc
import torchvision.transforms as transforms

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1) #input -? OUtput? RF
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv5 = nn.Conv2d(256, 512, 3)
        self.conv6 = nn.Conv2d(512, 1024, 3)
        self.linear = nn.Linear(in_features=3*3*1024, out_features = 10)
        
        # Taking a 10 class node one hot encoded integer (0-9) and convert it to 19 
        # class node output for sum because sum can range from 0 to 18

        self.linear_sum1 = nn.Linear(in_features=10, out_features = 100) 
        self.linear_sum2 = nn.Linear(in_features=100, out_features = 19)
        #self.linear_sum3 = nn.Linear(in_features=50, out_features = 19)

    def forward(self, x_image, x_rnd):
        x_image = self.pool1(F.relu(self.conv2(F.relu(self.conv1(x_image)))))
        x_image = self.pool2(F.relu(self.conv4(F.relu(self.conv3(x_image)))))
        x_image = F.relu(self.conv6(F.relu(self.conv5(x_image))))
        
        # flatten the array for FC layer
        x_image = x_image.view(x_image.size()[0], -1)
        x_image = self.linear(x_image)
        
        x_one_batch_softmax = F.log_softmax(x_image, dim=1)
        
        x_sum = self.linear_sum2(F.relu(self.linear_sum1(x_rnd)))
        x_sum = F.log_softmax(x_sum, dim=1)
        
        return x_one_batch_softmax, x_sum

In [3]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
network = Net().to(device)

#summary(model, input_size=(1, 28, 28))



In [4]:
device

device(type='cuda')

In [5]:
from torch.utils.data import Dataset
import torch.nn.functional as F

class MyDataset(Dataset):
    def __init__(self, data_array):
        self.RNdata = data_array
        self.MNISTdata = torchvision.datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
        
    def __getitem__(self, index):
        r = self.RNdata[index]
        RNnumber = F.one_hot(r, num_classes = 10).type(torch.float32).requires_grad_(True)
        
        MNISTimage, MNISTlabel = self.MNISTdata[index] 
        
        SUMlabel = r.item() + MNISTlabel #SUM OF MNIST AND RND NUMBER
        
        return RNnumber, SUMlabel, MNISTimage, MNISTlabel
    
    def __len__(self):
        return len(self.RNdata)

In [6]:
my_tensor = torch.randint(low=0,high=9,size=(60000,))

myRNMNIST = MyDataset(my_tensor)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
torch.manual_seed(1)
batch_size = 128
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 

train_RNMNIST_loader = torch.utils.data.DataLoader(myRNMNIST, batch_size = batch_size, shuffle=True)

In [8]:
class MyDataset_test(Dataset):
    def __init__(self, data_array):
        self.RNdata = data_array
        self.MNISTdata = torchvision.datasets.MNIST('../data', train=False, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ]))
        
    def __getitem__(self, index):
        r = self.RNdata[index]
        RNnumber = F.one_hot(r, num_classes = 10).type(torch.float32)
        #RNnumber = RNnumber.clone().detach().type(torch.float32).requires_grad_(True)

        MNISTimage, MNISTlabel = self.MNISTdata[index] 
        
        SUMlabel = r.item() + MNISTlabel #SUM OF MNIST AND RND NUMBER
        
        return RNnumber, SUMlabel, MNISTimage, MNISTlabel
    
    def __len__(self):
        return len(self.RNdata)

In [9]:
my_tensor_test = torch.randint(low=0,high=9,size=(10000,))

myRNMNIST_test = MyDataset_test(my_tensor_test)

test_RNMNIST_loader = torch.utils.data.DataLoader(myRNMNIST_test, batch_size = batch_size, shuffle=True)

In [10]:
from tqdm import tqdm
torch.set_printoptions(edgeitems=128)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (rn_number, rn_sum_label, mnist_image, mnist_label) in enumerate(pbar):

        rn_number, rn_sum_label = rn_number.to(device), rn_sum_label.to(device)
        mnist_image, mnist_label = mnist_image.to(device), mnist_label.to(device)
        
        optimizer.zero_grad()

        output_image, output_rnd_sum = model(mnist_image, rn_number)
        
        loss_mnist = F.nll_loss(output_image, mnist_label)
        loss_sum = F.nll_loss(output_rnd_sum, rn_sum_label)
        loss = loss_mnist + loss_sum
        loss.backward()
        optimizer.step()
        
        pbar.set_description(desc= f'MNIST loss={loss_mnist.item()} SUM loss={loss_sum.item()} batch_id={batch_idx}')


def test(model, device, test_loader):
    model.eval()
    test_loss_mnist = 0
    test_loss_sum = 0
    correct_mnist = 0
    correct_sum = 0
    with torch.no_grad():
        for rn_number, rn_sum_label, mnist_image, mnist_label in test_loader:
            
            rn_number, rn_sum_label = rn_number.to(device), rn_sum_label.to(device)
            mnist_image, mnist_label = mnist_image.to(device), mnist_label.to(device)
            
            output_image, output_rnd_sum = model(mnist_image, rn_number)
            #test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            
            test_loss_mnist += F.nll_loss(output_image, mnist_label, reduction='sum').item()
            test_loss_sum += F.nll_loss(output_rnd_sum, rn_sum_label, reduction='sum').item()

            pred_mnist = output_image.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            pred_sum = output_rnd_sum.argmax(dim=1, keepdim=True)
            
            correct_mnist += pred_mnist.eq(mnist_label.view_as(pred_mnist)).sum().item()
            correct_sum += pred_sum.eq(rn_sum_label.view_as(pred_sum)).sum().item()
            #correct = correct_mnist + correct_sum

    test_loss_mnist /= len(test_loader.dataset)
    test_loss_sum /= len(test_loader.dataset)

    print('\nTest set: Average MNIST loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss_mnist, correct_mnist, len(test_loader.dataset),
        100. * correct_mnist / len(test_loader.dataset)))
    
    print('\nTest set: Average SUM loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss_sum, correct_sum, len(test_loader.dataset),
        100. * correct_sum / len(test_loader.dataset)))

In [11]:
network = Net().to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)

for epoch in range(1, 10):
    train(network, device, train_RNMNIST_loader, optimizer, epoch)
    test(network, device, test_RNMNIST_loader)

MNIST loss=0.10195460170507431 SUM loss=2.4684128761291504 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0526, Accuracy: 9834/10000 (98%)


Test set: Average SUM loss: 2.4593, Accuracy: 1026/10000 (10%)



MNIST loss=0.01747068576514721 SUM loss=2.348778009414673 batch_id=468: 100%|█| 



Test set: Average MNIST loss: 0.0425, Accuracy: 9859/10000 (99%)


Test set: Average SUM loss: 2.3549, Accuracy: 1039/10000 (10%)



MNIST loss=0.01674645021557808 SUM loss=2.325608015060425 batch_id=468: 100%|█| 



Test set: Average MNIST loss: 0.0329, Accuracy: 9908/10000 (99%)


Test set: Average SUM loss: 2.3281, Accuracy: 1098/10000 (11%)



MNIST loss=0.026637546718120575 SUM loss=2.3302671909332275 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0287, Accuracy: 9908/10000 (99%)


Test set: Average SUM loss: 2.3206, Accuracy: 983/10000 (10%)



MNIST loss=0.02953304350376129 SUM loss=2.3173446655273438 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0315, Accuracy: 9901/10000 (99%)


Test set: Average SUM loss: 2.3149, Accuracy: 1051/10000 (11%)



MNIST loss=0.006319492589682341 SUM loss=2.316699504852295 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0348, Accuracy: 9893/10000 (99%)


Test set: Average SUM loss: 2.3112, Accuracy: 1085/10000 (11%)



MNIST loss=0.009280189871788025 SUM loss=2.309617519378662 batch_id=468: 100%|█|



Test set: Average MNIST loss: 0.0302, Accuracy: 9913/10000 (99%)


Test set: Average SUM loss: 2.3112, Accuracy: 1113/10000 (11%)



MNIST loss=0.00025560209178365767 SUM loss=2.326871633529663 batch_id=468: 100%|



Test set: Average MNIST loss: 0.0300, Accuracy: 9918/10000 (99%)


Test set: Average SUM loss: 2.3120, Accuracy: 1034/10000 (10%)



MNIST loss=0.0012243717210367322 SUM loss=2.307300329208374 batch_id=468: 100%|█



Test set: Average MNIST loss: 0.0337, Accuracy: 9905/10000 (99%)


Test set: Average SUM loss: 2.3100, Accuracy: 1062/10000 (11%)

