In [2]:
import torch
import torchvision # provide access to datasets, models, transforms, utils, etc
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import random


torch.set_grad_enabled(True)

In [5]:
test_set = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform = transforms.Compose([
          transforms.ToTensor()
    ])
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
train_set = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform = transforms.Compose([
          transforms.ToTensor()
    ])
)

In [134]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5) 
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5)
    self.fc1 = nn.Linear(in_features=32 * 4 * 4, out_features=120)
    
    self.out1 = nn.Linear(in_features=120, out_features=10)
    self.out2 = nn.Linear(in_features=20, out_features=19)
  
  def forward(self, image, rand_num):
    # input layer
    x = image

    # conv1 layer
    x = self.conv1(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2) 

    # conv2 layer
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2) 


    # reshapre
    x = x.reshape(-1, 32 * 4 * 4)

    # fc1 layer
    x = self.fc1(x)
    x = F.relu(x)


    # output layer
    x = self.out1(x)
    sum_out = torch.cat((x, rand_num), dim=1)
    sum_out = self.out2(sum_out)
    return x, sum_out

In [135]:
network = Network()

In [115]:
print(network)

Network(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=512, out_features=120, bias=True)
  (out1): Linear(in_features=120, out_features=10, bias=True)
  (out2): Linear(in_features=20, out_features=19, bias=True)
)


### Create Dataset

In [139]:
class MNISTRandDataset(Dataset):
    def __init__(self, mnist):
        self.mnist = mnist

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        image, label = self.mnist[idx]

        rand_num = random.randint(0, 9)

        rand_num_tensor = F.one_hot(torch.tensor(rand_num), num_classes=10)
        sum_label = label + rand_num

        return image, rand_num_tensor, label, sum_label

In [None]:
def get_num_correct(img_out, labels, sum_out, sum_label):
  return img_out.argmax(dim=1).eq(labels).sum().item(), sum_out.argmax(dim=1).eq(sum_label).sum().item()

In [140]:
train_ds = MNISTRandDataset(train_set)
test_ds = MNISTRandDataset(test_set)

In [142]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)

for epoch in range(1):

    total_loss = 0
    total_correct_image = 0
    total_correct_sum = 0

    for batch in train_loader: # Get Batch
        images, rand_num_tensor, labels, sum_label = batch 

        img_out, sum_out = network(images, rand_num_tensor) # Pass Batch

        loss1 = F.cross_entropy(img_out, labels) # Calculate Loss
        loss2 = F.cross_entropy(sum_out, sum_label) # Calculate Loss
        loss = loss1 + loss2

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss += loss.item()
        correct_image, correct_sum = get_num_correct(img_out, labels, sum_out, sum_label)
        total_correct_image += correct_image
        total_correct_sum += correct_sum
        
    print(
        "epoch", epoch, 
        "total_correct_image:", total_correct_image, 
        "total_correct_sum:", total_correct_sum, 
        "loss:", total_loss
    )

epoch 0 total_correct_image: 57728 total_correct_sum: 17315 loss: 1228.9090433120728


In [143]:
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=100)

total_correct_image = 0
total_correct_sum = 0


for batch in test_loader:  
    images, rand_num_tensor, labels, sum_label = batch 

    img_out, sum_out = network(images, rand_num_tensor) 

    correct_image, correct_sum = get_num_correct(img_out, labels, sum_out, sum_label)
    total_correct_image += correct_image
    total_correct_sum += correct_sum

print(
    "total_correct_image:", total_correct_image,
    "total_correct_sum:", total_correct_sum
)

total_correct_image: 9844 total_correct_sum: 3531


In [146]:
print(f'Image recognition accuracy: {9844/len(test_set)} , Sum Prediction accuracy: {3531/len(test_set)}')

Image recognition accuracy: 0.9844 , Sum Prediction accuracy: 0.3531
