In [233]:
#Do all Installations
!pip install torch
!pip install torchsummary



In [234]:
#Do all Imports
import torch
import torchvision # provide access to datasets, models, transforms, utils, etc
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision import datasets
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from torchsummary import summary

In [235]:
#Do all Downloads
mnist = datasets.MNIST(root = "./data", train = True, download = True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))

In [236]:
#Create combined dataset
import torch.nn.functional as F
Combined_Dataset = []
for sample_data in iter(mnist):
  sample_image, sample_label = sample_data
  first_input = sample_image
  first_label = sample_label
  for i in range(10):
    second_label = first_label + i
    second_input = i
    first_label = torch.tensor(first_label)
    second_label_onehot = torch.tensor(second_label)
    second_label_onehot = F.one_hot(second_label_onehot, num_classes=19).float()
    second_input_onehot = torch.tensor(second_input)
    second_input_onehot = F.one_hot(second_input_onehot, num_classes=10)
    first_label_onehot = F.one_hot(first_label, num_classes=10)
    combined_sample = (first_input,second_input_onehot,first_label_onehot,second_label_onehot)
    Combined_Dataset.append(combined_sample)

print("Dataset size = ",len(Combined_Dataset))

  # This is added back by InteractiveShellApp.init_path()
  if sys.path[0] == '':


Dataset size =  600000


In [237]:
#Create class and iterator
class MnistNumberCombinedDataset(Dataset):
  def __init__(self):
    self.data = Combined_Dataset

  def __getitem__(self, index):
    return self.data[index]

  def __len__(self):
    return len(self.data)

ComData = MnistNumberCombinedDataset()

In [238]:
#Create Network class

class Network(nn.Module):
  def __init__(self):
    super().__init__()

    # input size = 28 | output size = 24 | numb_channels = 10
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5) 
    
    # input size = 24 | output size = 12 | numb_channels = 10
    self.maxpool1 = nn.MaxPool2d(2,2)
    
    # input size = 12 | output size = 8 | numb_channels = 16
    self.conv2 = nn.Conv2d(in_channels=10, out_channels=16, kernel_size=3)

    self.conv2drop = nn.Dropout2d()

    # input size = 8 | output size = 4 | numb_channels = 16
    self.maxpool2 = nn.MaxPool2d(2,2)

    
    self.fc1 = nn.Linear(in_features = 400,out_features = 50)

    self.out1  = nn.Linear(in_features = 50,out_features = 10)

    # input size = 10 + 10 | output size = 50
    self.fc2 = nn.Linear(in_features=20, out_features=50)

    # input size = 50 | output size = 70
    self.fc3 = nn.Linear(in_features=50, out_features=70)

    # input size = 70 | output size = 19
    self.out2 = nn.Linear(in_features=70, out_features=19)
    

  
  def forward(self, image, data):
    #print("Input Image size = ", image.shape)
    x = F.relu(self.maxpool1(self.conv1(image)))
    #print("After 1st conv size and maxpool1 = ", x.shape)

    x = F.relu(self.maxpool2(self.conv2drop(self.conv2(x))))
    #print("After 2nd conv and Maxpool2 size = ", x.shape)

    x = x.reshape(x.shape[0], -1)
    #print(type(x))
    #print("After reshape size = ",x.shape)

    x = F.relu(self.fc1(x))
    #print("After fc1 size = ",x.shape)

    x = F.relu(self.out1(x))
    #print("After second fc(out1), out =", x.shape)

    self.out1_pred = F.log_softmax(x)

    #print("Input data size = ", data.shape)
    concat = torch.cat((self.out1_pred, data),dim=1)
    #print("After Concat size = ", concat.shape)

    concat = concat.reshape(concat.shape[0],-1)
    #print("After reshape concat size = ",concat.shape)

    x = F.relu(self.fc2(concat))
    #print("After 2nd FC  size = ", x.shape)
    
    x = F.relu(self.fc3(x))
    #print("After 3rd FC  size = ", x.shape)
    
    self.out_pred2 = F.relu(self.out2(x))
    #print("After out2 FC, 2nd Output size = ", self.out_pred2.shape)

    return (self.out1_pred, self.out_pred2)

In [239]:
#initialize Network object
network = Network()

for name, param in network.named_parameters():
  print(name, '\t\t', param.shape)

conv1.weight 		 torch.Size([10, 1, 5, 5])
conv1.bias 		 torch.Size([10])
conv2.weight 		 torch.Size([16, 10, 3, 3])
conv2.bias 		 torch.Size([16])
fc1.weight 		 torch.Size([50, 400])
fc1.bias 		 torch.Size([50])
out1.weight 		 torch.Size([10, 50])
out1.bias 		 torch.Size([10])
fc2.weight 		 torch.Size([50, 20])
fc2.bias 		 torch.Size([50])
fc3.weight 		 torch.Size([70, 50])
fc3.bias 		 torch.Size([70])
out2.weight 		 torch.Size([19, 70])
out2.bias 		 torch.Size([19])


In [240]:
def combined_loss_function(out1, labels1, out2, labels2):
  #out1_argmax = torch.tensor(out1.argmax(dim = 1)).float()
  #print(out1_argmax)
  #print(out1_argmax.shape)
  labels1_argmax = torch.tensor(labels1.argmax(dim=1)).long()
  crossentropyloss = nn.CrossEntropyLoss()
  loss1 = crossentropyloss(out1, labels1_argmax) #Cross entropy for image
  mse_loss = nn.MSELoss() #MSE for sum prediction
  loss2 = mse_loss(out2, labels2)
  #print(loss1)
  #print(loss2) 
  loss = 0.8*loss1 + 0.2*loss2
  #print(loss)
  return loss #Give 80% importance to classification than sum prediction

In [241]:
def get_num_correct(preds, labels):
  labels = torch.tensor(labels.argmax(dim=1)).long()
  #print(preds)
  #print(labels)
  return preds.argmax(dim=1).eq(labels).sum().item()

In [None]:
#Train
torch.set_grad_enabled(True)
train_loader = torch.utils.data.DataLoader(ComData, batch_size=64)
optimizer = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(10):

    total_loss = 0
    total_correct = 0

    for batch in train_loader: # Get Batch
        images =  batch[0]
        #print(images.shape)

        datas = batch[1]
        #print(datas.shape)

        labels1 = batch[2]
        #print("Label 1 shape= ",labels1.shape)

        labels2 = batch[3]
        #print("Label 2 shape= ",labels2.shape)

        out1, out2 = network(images,datas) # Pass Batch
        #print("Out1 shape=",out1.shape)
        #print("Out2 shape =",out2.shape)
        loss = combined_loss_function(out1, labels1, out2, labels2) # Calculate combined Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        total_correct += get_num_correct(out1, labels1)
        

    print(
        "epoch", epoch, 
        "total_correct:", total_correct, 
        "loss:", total_loss
    )


  """
  
