In [141]:
import ai
import numpy as np

In [142]:
def load_data(file):
    dict = np.load(file, allow_pickle=True)
    return dict

In [143]:
train_file = 'MNIST/train.npy'
test_file = 'MNIST/test.npy'

In [144]:
class MNIST(ai.Module):
    def __init__(self):
        self.conv1 = ai.Conv2d(1, 8, kernel_size=3, stride=1)
        self.conv2 = ai.Conv2d(8, 16, kernel_size=3, stride=1)
        self.fc1 = ai.Linear(2304, 128)
        self.fc2 = ai.Linear(128, 10)
        
    def forward(self, x):
        o1 = ai.G.relu(self.conv1.forward(x))
        o2 = ai.G.relu(self.conv2.forward(o1))
        o3 = ai.G.dropout(ai.G.maxpool2d(o2), p=0.75)
        o4 = ai.G.dropout(ai.G.relu(self.fc1.forward(o3)), p=0.5)
        o5 = self.fc2.forward(o4)
        o6 = ai.G.softmax(o5)
        return o6

In [145]:
mnist = MNIST()
print(mnist)

MNIST(
  conv1: Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), bias=True)
  conv2: Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), bias=True)
  fc1: Linear(input_features=2304, output_features=128, bias=True)
  fc2: Linear(input_features=128, output_features=10, bias=True)
)


In [146]:
L = ai.Loss(loss_fn='CrossEntropyLoss')
optim = ai.Optimizer(mnist.parameters(), optim_fn='Adadelta', lr=1e-3)

In [147]:
train_dict = load_data(train_file)
inputs = train_dict.item()['data']
outputs = train_dict.item()['labels']

In [148]:
del train_dict

In [149]:
it, epoch = 0, 0
loss = np.inf
m = 8

In [150]:
def evaluate():
    ai.G.grad_mode = False
    file = test_file
    dict = load_data(file)
    inputs = dict.item()['data']
    outputs = dict.item()['labels']
    correct, total = 0, 0
    test_m = m
    for batch in range(int(len(outputs) / m)):
        input = inputs[batch * test_m : (batch + 1) * test_m].reshape(test_m, 1, 28, 28) / 255
        input =  np.stack([_ for _ in input], axis = -1)
        output = np.array(outputs[batch * test_m : (batch + 1) * test_m])
        scores = mnist.forward(input)
        preds = np.argmax(scores.data, axis=0)
        correct += np.sum(np.equal(output, preds))
        total += test_m
    accuracy = float(correct / total)
    ai.G.grad_mode = True
    return accuracy

In [None]:
while epoch < 10:
    epoch += 1
    it = 0
    for batch in range(int(len(outputs) / m)):
    # for batch in range(1):
        input = inputs[batch * m : (batch + 1) * m].reshape(m, 1, 28, 28) / 255
        input =  np.stack([_ for _ in input], axis = -1)
        output = outputs[batch * m : (batch + 1) * m]
        onehot = np.zeros((10, m))
        for _ in range(m):
            onehot[output[_], _] = 1.0
        scores = mnist.forward(input)
        loss = L.loss(scores, onehot)
        loss.backward()
        optim.step()        # update parameters with optimization functions
        optim.zero_grad()   # clearing the backprop list and resetting the gradients to zero
        if it%100 == 0:
            print('epoch: {}, iter: {}, loss: {}'.format(epoch, it, loss.data[0, 0]))
        it += 1
    print('\n\n', 'Epoch {} completed. Accuracy: {}'.format(epoch, evaluate()))

using Adadelta
epoch: 1, iter: 0, loss: 2.3025889220413527
epoch: 1, iter: 100, loss: 2.299660308329791
epoch: 1, iter: 200, loss: 2.306180821471351
epoch: 1, iter: 300, loss: 2.3213434113052767
epoch: 1, iter: 400, loss: 2.3129695848127594
epoch: 1, iter: 500, loss: 2.3319523321922553
epoch: 1, iter: 600, loss: 2.301929027806004
epoch: 1, iter: 700, loss: 2.33746649371023
epoch: 1, iter: 800, loss: 2.3355563101670116
epoch: 1, iter: 900, loss: 2.293141123555913
epoch: 1, iter: 1000, loss: 2.3110276834634016
epoch: 1, iter: 1100, loss: 2.2541160312853354
epoch: 1, iter: 1200, loss: 2.3212566739882536
epoch: 1, iter: 1300, loss: 2.306361155346287
epoch: 1, iter: 1400, loss: 2.2610592577530726
epoch: 1, iter: 1500, loss: 2.2992121039228035
epoch: 1, iter: 1600, loss: 2.3184077971560346
epoch: 1, iter: 1700, loss: 2.3007350787573664
epoch: 1, iter: 1800, loss: 2.3135613867358797
epoch: 1, iter: 1900, loss: 2.3124007018661135
epoch: 1, iter: 2000, loss: 2.311620675471689
epoch: 1, iter: 21

In [None]:
# mnist.save()