In [3]:
import numpy as np
import torch
import torch.nn as nn
from training import train
from NP import NP
import matplotlib.pyplot as plt
import random
from PIL import Image
from training import train
from NP import NP
import torchvision
import torchvision.datasets as datasets
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

In [4]:
def img_to_vector(imgs):
    data = [torch.from_numpy(np.asarray(img, dtype="int32")).unsqueeze(0) for img in imgs]
    return torch.cat(data, dim=0).float()

def vector_to_image(vec):
    array = vec.reshape(28, 28).detach().numpy()
    img = Image.fromarray(array)
    return img.convert('RGB')

mnist_trainset_images = [mnist_trainset[i][0] for i in range(len(mnist_trainset))]
mnist_trainset_digits = [mnist_trainset[i][1] for i in range(len(mnist_trainset))]

mnist_trainset_arrays = img_to_vector(mnist_trainset_images)
mnist_trainset_digits = torch.LongTensor(mnist_trainset_digits).view(-1, 1)

In [5]:
def batch(data, batched_size):
    batched = [data[k*batched_size:(k+1)*batched_size] for k in range(len(data)//batched_size)]
    return batched

batched_arrays = batch(mnist_trainset_arrays, 32)
batched_digits = batch(mnist_trainset_digits, 32)
onehot = torch.zeros(len(mnist_trainset_arrays), 10)
onehot = list(onehot.scatter(-1, mnist_trainset_digits, 1))
batched_onehot = batch(onehot, 32)

data = [[batched_arrays[i], batched_digits[i]] for i in range(len(batched_onehot))]

In [6]:
class CNN(torch.nn.Module):
    #input size ([batch_size, 1, 28, 28])
    
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = torch.nn.Linear(196, 64)
        self.fc2 = torch.nn.Linear(64, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        l1 = self.relu(self.conv1(x))
        l2 = self.pool(l1)
        
        l2 = l2.view(-1, 196)
        l3 = self.relu(self.fc1(l2))
        l4 = self.fc2(l3)
        return l4

In [7]:
cnn = CNN()
lossf = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=3e-3)

In [14]:
def train(cnn, data, batch_size, epochs):
    for epoch in range(epochs):
        batch = 0
        epoch_loss = 0
        for point, onehot in data:
            
            optimizer.zero_grad()
            outputs = cnn(point.reshape(batch_size, 1, 28, 28))
            
            loss = lossf(outputs, onehot.reshape(batch_size))
            loss.backward()
            optimizer.step()
            batch += 1
            epoch_loss += loss
                
        print('EPOCH {} LOSS {}'.format(epoch, epoch_loss))
        
        print(mnist_testset_digits[0])
        print(cnn(mnist_testset_arrays[0].reshape(1, 1, 28, 28)))

In [15]:
train(cnn, data, 32, 100)

EPOCH 0 LOSS 367.8582763671875
tensor([7])
tensor([[ -4.4486,  -6.3734,   1.2488,  -1.9938, -16.3647,  -1.3526, -16.1080,
           8.6392,  -0.8749,  -5.7601]], grad_fn=<AddmmBackward>)
EPOCH 1 LOSS 324.76666259765625
tensor([7])
tensor([[ -5.3012,  -4.4660,  -2.8189,  -2.9155, -10.1711,  -2.9353, -11.5502,
           7.5768,  -0.4602,  -2.4791]], grad_fn=<AddmmBackward>)
EPOCH 2 LOSS 307.7736511230469
tensor([7])
tensor([[-15.5866,  -5.5307,  -3.5803,  -6.3941, -14.4833, -12.0358, -24.1186,
          15.1500,  -4.7284,  -3.3258]], grad_fn=<AddmmBackward>)
EPOCH 3 LOSS 289.4705810546875
tensor([7])
tensor([[-13.1868,  -3.0702,  -0.0507,  -3.8630,  -9.7421, -11.5946, -18.9273,
           9.8030,  -3.7424,  -3.3035]], grad_fn=<AddmmBackward>)
EPOCH 4 LOSS 266.4327087402344
tensor([7])
tensor([[-18.8301,  -7.1300,  -5.8119,  -4.5020, -14.1582, -11.4972, -21.8731,
          13.3599,  -6.5411,  -2.4796]], grad_fn=<AddmmBackward>)
EPOCH 5 LOSS 255.9978485107422
tensor([7])
tensor([[-12.089