# Tools

In [1]:
from fastai.vision.all import *
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


# Data

In [2]:
path = untar_data(URLs.MNIST_SAMPLE)

In [3]:
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()
threes

(#6131) [Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10000.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10011.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10031.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10034.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10042.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10052.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/1007.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10074.png'),Path('/home/skiiyuru/.fastai/data/mnist_sample/train/3/10091.png')...]

In [5]:
# convert to tensors
seven_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]
len(three_tensors),len(seven_tensors)

(6131, 6265)

In [6]:
# create stacks
stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255
stacked_threes.shape

torch.Size([6131, 28, 28])

In [11]:
# do the same for validation set
valid_3_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255
valid_3_tens.shape,valid_7_tens.shape

(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))

# Train

In [8]:
# prepare training data
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
train_x.shape,train_y.shape

(torch.Size([12396, 784]), torch.Size([12396, 1]))

In [9]:
# package training data as a dataset
dset = list(zip(train_x,train_y))
x,y = dset[0]
x.shape,y

(torch.Size([784]), tensor([1]))

In [12]:
# prepare validation data
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))

In [13]:
# initialize parameters
def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()

In [14]:
weights = init_params((28*28,1))
bias = init_params(1)

In [15]:
# calculate predictions
def linear1(xb): return xb@weights + bias

In [16]:
preds = linear1(train_x)
preds

tensor([[ 4.3848],
        [ 6.2323],
        [ 2.5124],
        ...,
        [ 2.3593],
        [-1.3629],
        [ 9.1532]], grad_fn=<AddBackward0>)

In [17]:
# loss fn
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

In [18]:
# package data in a dataloader
dl = DataLoader(dset, batch_size=256)
xb,yb = first(dl)
xb.shape,yb.shape

(torch.Size([256, 784]), torch.Size([256, 1]))

In [19]:
valid_dl = DataLoader(valid_dset, batch_size=256)

In [20]:
# calculate gradient
def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()

In [21]:
# put it all togethor in a train epoch
def train_epoch(model, lr, params):
    for xb,yb in dl:
        calc_grad(xb, yb, model)
        for p in params:
            p.data -= p.grad*lr
            p.grad.zero_()

In [22]:
# calculate validation accuracy
def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [23]:
# validation epoch
def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)

In [24]:
# Lets test
lr = 1.
params = weights,bias
train_epoch(linear1, lr, params)
validate_epoch(linear1)

0.6542

In [25]:
for i in range(20):
    train_epoch(linear1, lr, params)
    print(validate_epoch(linear1), end=' ')

0.8397 0.9052 0.9305 0.9452 0.9515 0.955 0.9613 0.9652 0.9677 0.9691 0.9706 0.9706 0.9711 0.9721 0.9721 0.9726 0.9731 0.9731 0.9736 0.974 