In [163]:
from fastai.vision.all import *

In [164]:
#Loading Data

In [165]:
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path

In [166]:
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()

In [167]:
seven_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]

In [168]:
stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255

In [169]:
valid_3_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255

In [170]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)

In [171]:
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)

In [172]:
dset = list(zip(train_x,train_y))

In [173]:
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))

In [174]:
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)

In [175]:
#Initialization

In [176]:
def init_params(size, variance=1.0): return (torch.randn(size)*variance).requires_grad_()

In [177]:
#Loss

In [178]:
def sigmoid(x): return 1/(1+torch.exp(-x))

In [179]:
def mnist_loss(predictions, targets): 
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

In [180]:
#Linear Layer

In [181]:
def linear1(xb): return xb@weights + bias

In [182]:
#Backpropogation

In [183]:
def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward() #calculates gradient and adds to existing gradient

In [184]:
#Train epoch function

In [185]:
def train_epoch(model, lr, params):
    for xb, yb in dl: 
        calc_grad(xb, yb, model)
        for p in params:
            p.data -= p.grad*lr #tells Pytorch not to update the gradients using that calculation
            p.grad.zero_() #reset the gradients, _ is an inplace modifier

In [186]:
#Check Accuracy

In [187]:
def batch_accuracy(preds,yb):
    preds = preds.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [188]:
def validate_epoch(model):
    accs = [batch_accuracy(model(xb),yb) for xb, yb in valid_dl] #note: model(xb) = preds
    return round(torch.stack(accs).mean().item(), 4)

In [189]:
#Train model and check validation accuracy

In [190]:
weights = init_params((28*28,1))
bias = init_params(1)

In [191]:
params = weights, bias
epoch_num = 20
lr = 1.
for i in range(epoch_num):
    train_epoch(linear1, lr, params)
    print(validate_epoch(linear1), end=' ')

0.6581 0.85 0.9106 0.9316 0.9438 0.9526 0.9564 0.9593 0.9603 0.9623 0.9637 0.9652 0.9657 0.9672 0.9687 0.9706 0.9706 0.9716 0.9716 0.9726 