In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torch.optim.lr_scheduler import StepLR

# Where the magic happens...
# torch.manual_seed(3456)

# Needed to download MNIST dataset without HTTP Error
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

In [2]:
class MNIST_Net(nn.Module):
    def __init__(self, N=10):
        super(MNIST_Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1,  6, 5),
            nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
            nn.ReLU(True),
            nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
            nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
            nn.ReLU(True)
        )
        self.classifier =  nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, N),
            nn.Softmax(1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train_data = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True,transform=transform)
mnist_test_data = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True,transform=transform)

kwargs = {'batch_size': 1}

In [4]:
# ---------- train_data ----------
with open('train_data.txt') as f:
    train_data = f.readlines()
    
# Strip new lines
train_data = [d.strip() for d in train_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
train_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in train_data]

# ---------- test data ----------
with open('test_data.txt') as f:
    test_data = f.readlines()
    
# Strip new lines
test_data = [d.strip() for d in test_data]

# Convert strings (e.g. "(datum_i, datum_j, sum)") to tuples of ints
test_data = [tuple(int(e) for e in d.strip("()").split(",")) for d in test_data]

# ---------- network and optimizer ----------
model = MNIST_Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [5]:
# Tensorize
train_data = torch.tensor(train_data)

# Batch
train_data = [train_data[i: i+ kwargs['batch_size']] for i in range(0, len(train_data), kwargs['batch_size'])]
train_data = torch.stack(train_data[:-1])

# Tensorize
test_data = torch.tensor(test_data)

# Batch
test_data = [test_data[i: i+kwargs['batch_size']] for i in range(0, len(test_data), kwargs['batch_size'])]
test_data = torch.stack(test_data[:-1])

In [6]:
def test():
    model.eval()

    total = 0
    correct = 0
    for j, test_batch in enumerate(test_data):
        idx1, idx2, summation = test_batch[0]
        X1 = mnist_test_data[idx1][0].unsqueeze(0)
        X2 = mnist_test_data[idx2][0].unsqueeze(0)

        output1 = model(X1)
        output2 = model(X2)

        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        correct += (summation == (pred1 + pred2)).sum()
        total += len(test_batch)

    print('Test Accuracy: {}/{} ({:.0f}%)\n'.format(correct, total, 100. * correct / total)) 

In [7]:
def brute_force(output1, output2, summation):
    combinations =  torch.cartesian_prod(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
    probs = []
    for combination in combinations:
        probs += [output1[0][combination[0]] * output2[0][combination[1]]]
    probs = torch.stack(probs)
    indices = (combinations[:, 0] + combinations[:, 1] == summation).nonzero(as_tuple=True)[0]
    return -torch.log(probs[indices].sum())

In [8]:
from tqdm import tqdm

NUM_EPOCHS = 1

for epoch in range(NUM_EPOCHS):
      
    # train
    for i, batch in enumerate(tqdm(train_data)):

        model.train()
        optimizer.zero_grad()
        
        idx1, idx2, summation = batch[0]
        X1 = mnist_train_data[idx1][0].unsqueeze(0)
        X2 = mnist_train_data[idx2][0].unsqueeze(0)
        output1 = model(X1)
        output2 = model(X2)
        
        pred1 = output1.argmax(dim=1, keepdim=False)
        pred2 = output2.argmax(dim=1, keepdim=False)
        
        closs = brute_force(output1, output2, summation)
            
        closs.backward()
        optimizer.step()
        
        if i % 1000 == 0 and i != 0:
            test()

  3%|▎         | 1009/29999 [00:47<6:14:02,  1.29it/s]

Test Accuracy: 2000/4999 (40%)



  7%|▋         | 2007/29999 [01:33<4:24:14,  1.77it/s]

Test Accuracy: 4446/4999 (89%)



 10%|█         | 3005/29999 [02:21<8:21:51,  1.12s/it] 

Test Accuracy: 4538/4999 (91%)



 13%|█▎        | 4003/29999 [03:06<5:06:43,  1.41it/s]

Test Accuracy: 4624/4999 (92%)



 17%|█▋        | 5005/29999 [03:53<21:07:07,  3.04s/it]

Test Accuracy: 4587/4999 (92%)



 20%|██        | 6006/29999 [04:38<5:07:32,  1.30it/s] 

Test Accuracy: 4633/4999 (93%)



 23%|██▎       | 7002/29999 [05:22<5:53:05,  1.09it/s]

Test Accuracy: 4653/4999 (93%)



 27%|██▋       | 8004/29999 [06:07<4:28:37,  1.36it/s]

Test Accuracy: 4722/4999 (94%)



 30%|███       | 9005/29999 [06:52<16:16:50,  2.79s/it]

Test Accuracy: 4489/4999 (90%)



 33%|███▎      | 10003/29999 [07:40<4:32:06,  1.22it/s]

Test Accuracy: 4659/4999 (93%)



 37%|███▋      | 11009/29999 [08:24<3:46:53,  1.39it/s]

Test Accuracy: 4731/4999 (95%)



 40%|████      | 12004/29999 [09:11<15:38:18,  3.13s/it]

Test Accuracy: 4781/4999 (96%)



 43%|████▎     | 13009/29999 [09:56<2:53:25,  1.63it/s] 

Test Accuracy: 4697/4999 (94%)



 47%|████▋     | 14006/29999 [10:40<6:35:24,  1.48s/it]

Test Accuracy: 4654/4999 (93%)



 50%|█████     | 15005/29999 [11:24<2:48:29,  1.48it/s]

Test Accuracy: 4721/4999 (94%)



 53%|█████▎    | 16006/29999 [12:09<2:10:21,  1.79it/s]

Test Accuracy: 4743/4999 (95%)



 57%|█████▋    | 17006/29999 [12:54<3:35:15,  1.01it/s]

Test Accuracy: 4774/4999 (95%)



 60%|██████    | 18002/29999 [13:40<3:37:06,  1.09s/it]

Test Accuracy: 4800/4999 (96%)



 63%|██████▎   | 19009/29999 [14:25<1:55:43,  1.58it/s]

Test Accuracy: 4739/4999 (95%)



 67%|██████▋   | 20004/29999 [15:10<8:01:05,  2.89s/it] 

Test Accuracy: 4824/4999 (96%)



 70%|███████   | 21002/29999 [15:55<1:58:37,  1.26it/s]

Test Accuracy: 4737/4999 (95%)



 73%|███████▎  | 22007/29999 [16:43<3:05:25,  1.39s/it]

Test Accuracy: 4759/4999 (95%)



 77%|███████▋  | 23008/29999 [17:29<1:25:32,  1.36it/s]

Test Accuracy: 4799/4999 (96%)



 80%|████████  | 24007/29999 [18:17<1:12:53,  1.37it/s]

Test Accuracy: 4759/4999 (95%)



 83%|████████▎ | 25005/29999 [19:03<1:22:01,  1.01it/s]

Test Accuracy: 4781/4999 (96%)



 87%|████████▋ | 26004/29999 [19:48<51:21,  1.30it/s]  

Test Accuracy: 4726/4999 (95%)



 90%|█████████ | 27005/29999 [20:32<2:23:53,  2.88s/it]

Test Accuracy: 4797/4999 (96%)



 93%|█████████▎| 28007/29999 [21:17<25:02,  1.33it/s]  

Test Accuracy: 4778/4999 (96%)



 97%|█████████▋| 29003/29999 [22:03<16:42,  1.01s/it]

Test Accuracy: 4804/4999 (96%)



100%|██████████| 29999/29999 [22:33<00:00, 22.17it/s]
