In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
from exp.nb_02 import *
from torch import nn
import torch.nn.functional as F

In [None]:
#export

from collections import defaultdict

def assert_shape(x, shape:list):
    """ assert_shape(conv_input_array, [8, 3, None, None]) """
    assert len(x.shape) == len(shape), (x.shape, shape)
    for _a, _b in zip(x.shape, shape):
        if isinstance(_b, int):
            assert _a == _b, (x.shape, shape)


def assert_shapes(x, x_shape, y, y_shape):
    assert_shape(x, x_shape)
    assert_shape(y, y_shape)

    shapes = defaultdict(set)
    for arr, shape in [(x, x_shape), (y, y_shape)]:
        for i, char in enumerate(shape):
            if isinstance(char, str):
                shapes[char].add(arr.shape[i])


    for _, _set in shapes.items():
        assert len(_set) == 1, (x, x_shape, y, y_shape)

In [None]:
class Model(nn.Module):
    def __init__(self, n_in, n_hid, n_out):
        super().__init__()
        self.layers = [
            nn.Linear(n_in, n_hid),
            nn.ReLU(),
            nn.Linear(n_hid, n_out),
        ]
    
    def __call__(self, x): # overwriting __call__, not using fwd!
        for l in self.layers: x = l(x)
        return x

In [None]:
x_train, y_train, x_valid, y_valid = get_data()

In [None]:
n,m = x_train.shape
c = y_train.max()+1

n,m,c

In [None]:
model = Model(n_in=784, n_hid=50, n_out=10)
loss_func = F.cross_entropy

In [None]:
#export
def accuracy(pred, gt):
    assert_shape(pred, [None, 10])
    assert_shape(gt, [None])
    assert gt.shape[0] == pred.shape[0]
    return (torch.argmax(pred, dim=1)==gt).float().mean()

In [None]:
bs=64

x_batch = x_train[:64]
preds = model(x_batch)

preds[0], preds.shape

In [None]:
y_batch = y_train[:bs]
loss_func(preds, y_batch)

In [None]:
accuracy(preds, y_batch)

In [None]:
lr = 0.5
epochs = 1

In [None]:
for epoch in range(epochs):
    for i in range((n-1)//bs + 1):
        start_i = i*bs
        end_i = start_i+bs
        x_b = x_train[start_i:end_i]
        y_b = y_train[start_i:end_i]
        loss = loss_func(model(x_b), y_b)
        
        loss.backward()
        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, 'weight'):
                    l.weight -= l.weight.grad * lr
                    l.bias -= l.bias.grad * lr
                    l.weight.grad.zero_()
                    l.bias.grad.zero_()

In [None]:
x_train.shape, y_train.shape

In [None]:
preds = model(x_train)
accuracy(preds, y_train)

In [None]:
class Optim:
    def __init__(self, params, lr=0.5):
        self.params = list(params)
        self.lr = lr
        
    def step(self):
        with torch.no_grad():
            for p in self.params:
                p -= p.grad * lr
                
    def zero_grad(self):
        for p in self.params:
            p.grad.data.zero_()

In [None]:
model = nn.Sequential(
    nn.Linear(784, 50),
    nn.ReLU(),
    nn.Linear(50, 10),
)

In [None]:
opt = Optim(model.parameters())

In [None]:
for epoch in range(epochs):
    for i in range((n-1)//bs+1):
        start_i = i*bs
        end_i = start_i+bs
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = model(xb)
        loss = loss_func(pred, yb)
        
        loss.backward()
        opt.step()
        opt.zero_grad()

In [None]:
bigpreds = model(x_train)
accuracy(bigpreds, y_train), loss_func(bigpreds, y_train)

In [None]:
#export
class Dataset():
    def __init__(self, x, y):
        assert len(x) == len(y)
        self.x = x
        self.y = y
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]

In [None]:
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)

In [None]:
assert len(train_ds) == len(x_train)

In [None]:
xb,yb=train_ds[0:5]
assert xb.shape==(5,28*28)
assert yb.shape==(5,)
xb,yb

In [None]:
class DataLoader:
    def __init__(self, ds, bs):
        self.ds = ds
        self.bs = bs
        
    def __iter__(self):
        for i in range(0, len(self.ds), self.bs):
            yield self.ds[i:i+self.bs]

In [None]:
train_dl = DataLoader(train_ds, bs)

In [None]:
valid_dl = DataLoader(valid_ds, bs)

In [None]:
from torch.utils.data import DataLoader

In [None]:
#export
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, shuffle=False, **kwargs))

In [None]:
!python notebook2script.py 03_training.ipynb