In [4]:
from fastai import *
from fastai.vision.all import *

In [5]:
#downloading the MNIST dataset
path = untar_data(URLs.MNIST)

In [6]:
path, path.ls()

(Path('/home/sam/.fastai/data/mnist_png'),
 (#2) [Path('/home/sam/.fastai/data/mnist_png/testing'),Path('/home/sam/.fastai/data/mnist_png/training')])

In [7]:
#Getting the training datasets ready
'''
Important info:
#training images = 60,000
Each of the categories having roughly equal distribution: train_y.unique(return_counts=True)
'''
train_images_list = get_image_files(path/'training')
train_x_list = [tensor(Image.open(img_path)) for img_path in train_images_list]
train_y_list = [int(img_path.parent.name) for img_path in train_images_list]
train_x = (torch.stack(train_x_list).float()/255).view(-1,28*28)
train_y = tensor(train_y_list).view(-1,1)

train_x.shape, train_y.shape

(torch.Size([60000, 784]), torch.Size([60000, 1]))

In [8]:
train_dset = list(zip(train_x, train_y))

In [9]:
#Getting the validation datasets ready
'''
Important info:
#validation images = 10,000
Each of the categories having roughly equal distribution: valid_y.unique(return_counts=True)
'''
valid_images_list = get_image_files(path/'testing')
valid_x_list = [tensor(Image.open(img_path)) for img_path in valid_images_list]
valid_y_list = [int(img_path.parent.name) for img_path in valid_images_list]
valid_x = (torch.stack(valid_x_list).float()/255).view(-1,28*28)
valid_y = tensor(valid_y_list).view(-1,1)

valid_x.shape, valid_y.shape

(torch.Size([10000, 784]), torch.Size([10000, 1]))

In [10]:
valid_dset = list(zip(valid_x, valid_y))

#### Using fastai packages

In [11]:
#this is just to get a sense of the accuracy using resnet18
dls = ImageDataLoaders.from_folder(path, train='training',valid='testing')
learn = cnn_learner(dls, resnet18, pretrained=False,
                    loss_func=F.cross_entropy, metrics=accuracy, n_out=10)
learn.fit_one_cycle(1, 0.1)

  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")


epoch,train_loss,valid_loss,accuracy,time
0,0.106589,0.043339,0.9874,00:41


  return F.conv2d(input, weight, bias, self.stride,


#### Manual SGD & Model training 

In [12]:
train_dl = DataLoader(train_dset, batch_size=256)
#valid_dl = DataLoader(valid_dset, batch_size=256)

In [13]:
# function to calculate loss
def mnist_loss(pred, actual):
    l = nn.CrossEntropyLoss()
    return l(pred, actual.squeeze())

# function to calculate gradient
def calc_grad(xb, yb, model):
    pred = model(xb)
    loss = mnist_loss(pred, yb)
    loss.backward()    
    return loss

# function to define accuracy
def batch_accuracy(pred, actual):
    digit_pred = pred.max(dim=1)[1]
    return (digit_pred==actual.squeeze()).float().mean()

#function to train 1 epoch and print average batch loss
def train_epoch(model):
    batch_loss = []
    for xb,yb in train_dl:
        batch_loss.append(calc_grad(xb, yb, model))
        opt.step()
        opt.zero_grad()
    return tensor(batch_loss).mean()

In [14]:
#Optimizer
class BasicOptim:
    def __init__(self,params,lr): self.params,self.lr = list(params),lr

    def step(self, *args, **kwargs):
        for p in self.params: p.data -= p.grad.data * self.lr

    def zero_grad(self, *args, **kwargs):
        for p in self.params: p.grad = None

In [15]:
#Simple 2 activations function NN
simple_net = nn.Sequential(
    nn.Linear(28*28,100),
    nn.ReLU(),
    nn.Linear(100,30),
    nn.ReLU(),
    nn.Linear(30,10)
)

In [16]:
#random accuracy
batch_accuracy(simple_net(valid_x),valid_y)

tensor(0.1310)

In [17]:
opt = BasicOptim(simple_net.parameters(), lr=0.001)

In [18]:
#function to train model for multiple epochs
def train_model(model,epochs):
    print('{:<10}{:<15}{:<15}'.format('Epoch','Training Loss','Validation Accuracy'))
    for i in range(epochs):
        avg_bl = train_epoch(model)
        print('{:<10}{:<15,.2f}{:<15,.2f}'.format(i,avg_bl.item(),batch_accuracy(model(valid_x),valid_y).item()))

In [19]:
#model training call
train_model(simple_net, 100)

Epoch     Training Loss  Validation Accuracy
0         2.30           0.21           
1         2.30           0.26           
2         2.29           0.29           
3         2.28           0.33           
4         2.27           0.37           
5         2.27           0.39           
6         2.26           0.41           
7         2.25           0.41           
8         2.24           0.42           
9         2.22           0.42           
10        2.21           0.42           
11        2.20           0.42           
12        2.18           0.43           
13        2.16           0.43           
14        2.14           0.44           
15        2.12           0.45           
16        2.09           0.46           
17        2.06           0.48           
18        2.03           0.51           
19        1.99           0.53           
20        1.95           0.55           
21        1.91           0.56           
22        1.86           0.57           
23        1.