In [6]:
# All imports
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastai.vision.all import *
from fastbook import *
matplotlib.rc('image', cmap='Greys')

In [7]:
# All preparation work

# Read data from the path
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
threes = (path/'train/3').ls().sorted() #6131 images of 3
sevens = (path/'train/7').ls().sorted() # 6265 images of 7, total = 12396 of training data
valid = (path/'valid').ls().sorted()
tensor_threes = [tensor(Image.open(i)) for i in threes]
tensor_sevens = [tensor(Image.open(i)) for i in sevens]

# For vision related work, convert all pixel data in range 0 -1
stacked_tensors_threes = torch.stack(tensor_threes).float() / 255.0    #[6131, 28, 28]
stacked_tensors_sevens = torch.stack(tensor_sevens).float() / 255.0    #[6131, 28, 28]

# Do same for validation data, read data, convert pixel data 0-1, 
valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255.0 
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255.0

# Prepare x,y for tensor usage
train_x = torch.cat([stacked_tensors_threes, stacked_tensors_sevens]).view(-1, 28*28)  # [12396, 784]
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)                       # [12396, 1]

valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))

dset = list(zip(train_x,train_y))
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)

In [9]:
# Using Pytorch SGD
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()
def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()
def train_epoch(model):
    for xb,yb in dl:
        calc_grad(xb, yb, model)
        opt.step()
        opt.zero_grad()
def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()
def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)        
def train_model(model, epochs):
    for i in range(epochs):
        train_epoch(model)
        print(validate_epoch(model), end=' ')
        
def linear1(xb): 
    return xb@weights + bias

class BasicOptim:
    def __init__(self,params,lr): self.params,self.lr = list(params),lr

    def step(self, *args, **kwargs):
        for p in self.params: p.data -= p.grad.data * self.lr

    def zero_grad(self, *args, **kwargs):
        for p in self.params: p.grad = None
            
lr = 1.0
linear_model = nn.Linear(28*28,1)
#opt = BasicOptim(linear_model.parameters(), lr)
opt = SGD(linear_model.parameters(), lr)
train_model(linear_model, 20)
dls = DataLoaders(dl, valid_dl)
learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,
                loss_func=mnist_loss, metrics=batch_accuracy)
# learn = Learner(dls, nn.Linear(28*28,1), opt_func=SGD,
#                 loss_func=mnist_loss, metrics=batch_accuracy)
learn.fit(10, lr=lr)

0.4932 0.8447 0.8354 0.9116 0.9321 0.9472 0.9555 0.9618 0.9653 0.9667 0.9696 0.9721 0.9731 0.975 0.9755 0.976 0.9775 0.978 0.978 0.9785 

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.636804,0.502146,0.495584,00:00
1,0.32902,0.289147,0.720805,00:00
2,0.128082,0.153073,0.864573,00:00
3,0.060638,0.097113,0.917566,00:00
4,0.03545,0.073222,0.937193,00:00
5,0.025341,0.059699,0.948479,00:00
6,0.02101,0.051057,0.957311,00:00
7,0.018966,0.045205,0.964181,00:00
8,0.017849,0.041041,0.966143,00:00
9,0.017129,0.037945,0.967615,00:00


In [33]:
# Train by Fast AI Optimised Function - One liner - better performance
dls = ImageDataLoaders.from_folder(path)
learn = cnn_learner(dls, resnet18, pretrained=False,
                    loss_func=F.cross_entropy, metrics=accuracy)
learn.fit_one_cycle(1, 0.1)

epoch,train_loss,valid_loss,accuracy,time
0,0.143675,0.019299,0.997056,07:09
