In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torch.optim.lr_scheduler import StepLR

# Where the magic happens...
# torch.manual_seed(3456)

# Needed to download MNIST dataset without HTTP Error
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

In [3]:
class MNIST_Net(nn.Module):
    def __init__(self, N=10):
        super(MNIST_Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1,  6, 5),
            nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
            nn.ReLU(True),
            nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
            nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
            nn.ReLU(True)
        )
        self.classifier =  nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, N),
            nn.Softmax(1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x

In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train_data = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True,transform=transform)
mnist_test_data = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True,transform=transform)

kwargs = {'batch_size': 1}

0.1%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw


113.5%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz




Extracting ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw
Processing...
Done!


In [6]:
# ---------- train_data ----------
with open('train_data.txt') as f:
    train_data = f.readlines()
    
# Strip new lines
train_data = [d.strip() for d in train_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
train_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in train_data]

# ---------- test data ----------
with open('test_data.txt') as f:
    test_data = f.readlines()
    
# Strip new lines
test_data = [d.strip() for d in test_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
test_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in test_data]

# ---------- network and optimizer ----------
model = MNIST_Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
# Tensorize
train_data = torch.tensor(train_data)

# Batch
train_data = [train_data[i: i+ kwargs['batch_size']] for i in range(0, len(train_data), kwargs['batch_size'])]
train_data = torch.stack(train_data[:-1])

# Tensorize
test_data = torch.tensor(test_data)

# Batch
test_data = [test_data[i: i+kwargs['batch_size']] for i in range(0, len(test_data), kwargs['batch_size'])]
test_data = torch.stack(test_data[:-1])

In [8]:
def test():
    model.eval()

    total = 0
    correct = 0
    for j, test_batch in enumerate(test_data):
        idx1, idx2, summation = test_batch[0]
        X1 = mnist_test_data[idx1][0].unsqueeze(0)
        X2 = mnist_test_data[idx2][0].unsqueeze(0)

        output1 = model(X1)
        output2 = model(X2)

        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        correct += (summation == (pred1 + pred2)).sum()
        total += len(test_batch)

    print('Test Accuracy: {}/{} ({:.0f}%)\n'.format(correct, total, 100. * correct / total)) 

In [9]:
def brute_force(output1, output2, summation):
    combinations =  torch.cartesian_prod(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
    probs = []
    for combination in combinations:
        probs += [output1[0][combination[0]] * output2[0][combination[1]]]
    probs = torch.stack(probs)
    indices = (combinations[:, 0] + combinations[:, 1] == summation).nonzero(as_tuple=True)[0]
    return -torch.log(probs[indices].sum())

In [11]:
from tqdm import tqdm

NUM_EPOCHS = 1

for epoch in range(NUM_EPOCHS):
      
    # train
    for i, batch in enumerate(tqdm(train_data)):

        model.train()
        optimizer.zero_grad()
        
        idx1, idx2, summation = batch[0]
        X1 = mnist_train_data[idx1][0].unsqueeze(0)
        X2 = mnist_train_data[idx2][0].unsqueeze(0)
        output1 = model(X1)
        output2 = model(X2)
        
        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        
        closs = brute_force(output1, output2, summation)
            
        closs.backward()
        optimizer.step()
        
        if i % 1000 == 0 and i != 0:
            test()

  3%|▎         | 1018/29999 [00:17<1:08:58,  7.00it/s]

Test Accuracy: 3474/4999 (69%)



  7%|▋         | 2017/29999 [00:34<1:06:54,  6.97it/s]

Test Accuracy: 4388/4999 (88%)



 10%|█         | 3014/29999 [00:51<1:04:40,  6.95it/s]

Test Accuracy: 4516/4999 (90%)



 13%|█▎        | 4013/29999 [01:09<1:02:17,  6.95it/s]

Test Accuracy: 4609/4999 (92%)



 17%|█▋        | 5013/29999 [01:26<59:54,  6.95it/s]  

Test Accuracy: 4644/4999 (93%)



 20%|██        | 6012/29999 [01:44<58:44,  6.81it/s]  

Test Accuracy: 4699/4999 (94%)



 23%|██▎       | 7012/29999 [02:02<1:05:10,  5.88it/s]

Test Accuracy: 4675/4999 (94%)



 27%|██▋       | 8009/29999 [02:21<1:13:30,  4.99it/s]

Test Accuracy: 4607/4999 (92%)



 30%|███       | 9016/29999 [02:40<59:58,  5.83it/s]  

Test Accuracy: 4749/4999 (95%)



 33%|███▎      | 10012/29999 [02:57<47:55,  6.95it/s]  

Test Accuracy: 4732/4999 (95%)



 37%|███▋      | 11015/29999 [03:15<47:13,  6.70it/s]  

Test Accuracy: 4589/4999 (92%)



 40%|████      | 12015/29999 [03:34<47:06,  6.36it/s]  

Test Accuracy: 4768/4999 (95%)



 43%|████▎     | 13011/29999 [03:53<41:58,  6.75it/s]

Test Accuracy: 4638/4999 (93%)



 47%|████▋     | 14010/29999 [04:12<43:38,  6.11it/s]  

Test Accuracy: 4747/4999 (95%)



 50%|█████     | 15013/29999 [04:30<38:24,  6.50it/s]

Test Accuracy: 4707/4999 (94%)



 53%|█████▎    | 16010/29999 [04:49<36:12,  6.44it/s]

Test Accuracy: 4768/4999 (95%)



 57%|█████▋    | 17013/29999 [05:07<36:22,  5.95it/s]

Test Accuracy: 4814/4999 (96%)



 60%|██████    | 18012/29999 [05:25<31:27,  6.35it/s]

Test Accuracy: 4780/4999 (96%)



 63%|██████▎   | 19012/29999 [05:44<31:02,  5.90it/s]

Test Accuracy: 4765/4999 (95%)



 67%|██████▋   | 20012/29999 [06:03<27:51,  5.97it/s]

Test Accuracy: 4798/4999 (96%)



 70%|███████   | 21016/29999 [06:22<23:15,  6.44it/s]

Test Accuracy: 4737/4999 (95%)



 73%|███████▎  | 22015/29999 [06:41<21:19,  6.24it/s]

Test Accuracy: 4739/4999 (95%)



 77%|███████▋  | 23016/29999 [07:00<17:53,  6.51it/s]

Test Accuracy: 4517/4999 (90%)



 80%|████████  | 24016/29999 [07:18<15:45,  6.33it/s]

Test Accuracy: 4779/4999 (96%)



 83%|████████▎ | 25012/29999 [07:37<12:59,  6.40it/s]

Test Accuracy: 4808/4999 (96%)



 87%|████████▋ | 26013/29999 [07:55<10:36,  6.27it/s]

Test Accuracy: 4747/4999 (95%)



 90%|█████████ | 27009/29999 [08:14<07:51,  6.34it/s]

Test Accuracy: 4824/4999 (96%)



 93%|█████████▎| 28011/29999 [08:33<05:12,  6.37it/s]

Test Accuracy: 4802/4999 (96%)



 97%|█████████▋| 29014/29999 [08:51<02:36,  6.28it/s]

Test Accuracy: 4815/4999 (96%)



100%|██████████| 29999/29999 [09:04<00:00, 55.07it/s]
