In [None]:
# load data remember to normalize the data
# Linear class + backward
# relu class + backward
# loss function ( cross_entropy )
# create model has class forward and backward

In [39]:
import torch, torch.nn as nn
from torch.nn import functional as F , init
import gzip, pickle, math

In [12]:
from fastai import datasets

In [40]:
def normalize(x, m, s): return (x-m)/s

def get_data():
    MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    
    # convert to tensor
    x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train,y_train,x_valid,y_valid))
    
    x_train_mean = x_train.mean()
    x_train_std  = x_train.std()
    
    x_train = normalize(x_train, x_train_mean, x_train_std)
    x_valid = normalize(x_valid, x_train_mean, x_train_std) # still use train mean and std
    
    return x_train, y_train, x_valid, y_valid


In [41]:
x_train, y_train, x_valid, y_valid = get_data()

In [78]:
x, y = x_train[:100], y_train[:100]

In [177]:
def init_w_b(in_f, out_f):
    with torch.no_grad():
        weight = init.kaiming_normal_(torch.Tensor(in_f , out_f), a=0) # we use a = 0 instead of math.sqrt(5) as pytorch doc
        bias   = init.uniform_(torch.Tensor(out_f))

    weight.g = None
    bias.g = None
    return weight, bias

In [255]:
class Lin():
    def __init__(self, w, b):
        self.weight, self.bias = w, b
        
    def __call__(self, x):
        return self.forward(x)

    def forward(self, x):
        self.x = x
        self.out = self.x@self.weight + self.bias
        return self.out
    
    def backward(self):
        self.x.g = self.out.g @ self.weight.t()
        self.weight.g = (self.x.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.bias.g = self.out.g.sum(0)
        
        

In [240]:
class MSE():
    def __call__(self, x, target):
        self.inp = x 
        self.target = target
        return (self.inp.squeeze() - target.float()).pow(2).mean()
    
    def backward(self):
        self.inp.g = 2. * (self.inp.squeeze() - self.target.float()).unsqueeze(-1) / self.target.shape[0]

In [241]:
class Model():
    def __init__(self, w1, b1):
        self.layers = [Lin(w1,b1)]
        self.loss = MSE()
    
    def __call__(self, x, target):
        for l in self.layers: x = l(x)
        return self.loss(x, target)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers) : l.backward()

In [242]:
lr = 1e-3

In [271]:
w1, b1 = init_w_b(784,1) 
m = Model(w1, b1)
for i in range(10):
    loss = m(x, y)
    m.backward()
    with torch.no_grad():
        for l in m.layers:
            if hasattr(l, 'weight'):
                l.weight -= l.weight.g * lr
                l.bias   -= l.bias.g   * lr
                l.weight.g.zero_()
                l.bias  .g.zero_()
        
    
    

tensor(1557.1709)
tensor(737.8953)
tensor(529.7811)
tensor(461.7998)
tensor(427.6084)
tensor(402.9778)
tensor(382.1999)
tensor(363.7706)
tensor(347.1635)
tensor(332.0988)
