In [303]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [304]:
import torch
torch.manual_seed(0)

<torch._C.Generator at 0x11caf3b90>

In [305]:
## Forward Functions

In [306]:
m = 5 # numer of examples
n = 3 # feature size
nh = 10 # hidden features



A = torch.ones(n,nh)
b = torch.ones(m,1)*0.5
x = torch.ones(m,n)*0.5

target = torch.ones(m,1)*4

# torch.randn
# torch.zeros

In [307]:
# Linear
def lin(x,A,b): return x@A + b    

In [308]:
lin(x,A,b)

tensor([[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]])

In [309]:
# Lin with Einsum
def lin_es(x,A,b): return torch.einsum('ik,kj->ij',x,A) + b

In [310]:
lin_es(x,A,b)

tensor([[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]])

In [311]:
# Relu
def relu(x): return torch.clamp(x,min=0.)

In [312]:
# Msi
def mse(y, y_hat): return ((y - y_hat)**2).sum() / len(y)

In [313]:
y = torch.ones(5,1)*3
y_hat = torch.ones(5,1)

In [314]:
mse(y,y_hat)

tensor(4.)

In [315]:
## Backward Functions

In [316]:
def mse_grad(inp, out):
    # 2*(y-y_hat).sum() /n
    inp.g = (2. * (inp - out)) / len(inp) 

In [317]:
def relu_grad(inp, out): 
    inp.g = (inp>0).float() * out.g

In [318]:
def lin_grad(inp,out,A,b):
    inp.g = out.g @ A.t()
    A.g = inp.t() @ out.g 
    b.g = out.g.sum(1).unsqueeze(-1)

In [319]:
## Forward Backward Tests

In [320]:
m = 20 # numer of examples
n = 5 # feature size
nh = 10 # hidden features

x = torch.randn(m,n)
target = torch.randn(m,1)

W1 = torch.randn(n, nh)
b1 = torch.randn(m,1)
W2 = torch.randn(nh,1)
b2 = torch.randn(m,1)

In [321]:
x.shape

torch.Size([20, 5])

In [322]:
l1 = lin_es(x,W1,b1) # m x nh

In [323]:
l1.shape

torch.Size([20, 10])

In [324]:
l2 = relu(l1) # m x nh

In [325]:
l2.shape

torch.Size([20, 10])

In [326]:
y = lin_es(l2,W2,b2) # m x 1

In [327]:
y.shape

torch.Size([20, 1])

In [328]:
mse_grad(y, target) # m x 1

In [329]:
print(y.g.shape)

torch.Size([20, 1])


In [330]:
lin_grad(l2,y,W2,b2)

In [331]:
print(l2.g.shape)

torch.Size([20, 10])


In [332]:
relu_grad(l1,l2)

In [333]:
print(l1.g.shape)

torch.Size([20, 10])


In [334]:
lin_grad(x,l1,W1,b1)

In [335]:
print(x.g.shape)

torch.Size([20, 5])


In [336]:
def forward_and_backward(inp, targ):
    # forward pass:
    l1 = lin_es(inp,W1,b1)
    l2 = relu(l1)
    out = lin_es(l2,W2,b2)
    loss = mse(out, targ)
    print(loss)
    
    # backward pass:
    mse_grad(out, targ)
    lin_grad(l2, out, W2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, W1, b1)

In [337]:
forward_and_backward(x, target)

tensor(19.7892)


In [338]:
# Save for testing against later
w1g = W1.g.clone()
w2g = W2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x.g.clone()

In [339]:
# Run with Pytorch to check implementation

In [340]:
xt2 = x.clone().requires_grad_(True)
w12 = W1.clone().requires_grad_(True)
w22 = W2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [341]:
def forward(inp, targ):
    # forward pass:
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, target)

In [342]:
loss = forward(xt2, target)
print(loss)

tensor(19.7892, grad_fn=<DivBackward0>)


In [343]:
loss.backward()

In [344]:
def test_near(a,b,tol=1e-3):
    assert (a-b).abs().max() < tol, f"Not near zero"

In [345]:
test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )

In [346]:
## Refactoring 

In [347]:
class Module():
    """This holds the basic functions of layer classes"""
    
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out

    def forward(self):
        raise Exception('Not implemented')
        
    def backward(self): 
        self._backward(self.out, *self.args)

In [348]:
class Mse(Module):
    """Mse activation layer for Neural Networks"""
          
    def forward(self, inp, target):
        self.loss = ((inp - target)**2).sum() / len(inp)
        return self.loss
    
    def _backward(self, out, inp, target): 
        inp.g = (2. * (inp - target)) / len(inp)   

In [349]:
y = torch.ones(20,1)

In [350]:
loss = Mse()
loss_model = loss(y,target)
print(loss_model)
loss.backward()

tensor(1.3912)


In [351]:
y.g[:5] # internal value for gradient is passed to input

tensor([[ 0.0037],
        [ 0.0651],
        [ 0.1921],
        [ 0.1056],
        [-0.0140]])

In [352]:
loss.args[0].g[:5] # internal value

tensor([[ 0.0037],
        [ 0.0651],
        [ 0.1921],
        [ 0.1056],
        [-0.0140]])

In [353]:
y.g = torch.ones(20,1)
loss.args[0].g[:5] # also changes

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])

In [354]:
class Lin(Module):
    """Linear Layer for Neural Networks"""
    
    def __init__(self, W, b):
        self.W = W
        self.b = b
        
    def forward(self, inp):
        return torch.einsum('ik,kj->ij',inp, self.W) + self.b
    
    def _backward(self, out, inp):
        inp.g = out.g @ self.W.t()
        self.W.g = inp.t() @ out.g 
        self.b.g = out.g.sum(1).unsqueeze(-1)

In [355]:
Lin_1 = Lin(W1,b1)
out = Lin_1(x)
out.shape

torch.Size([20, 10])

In [356]:
class Relu(Module):
    "Relu actication layer for Neural Networks"
    
    def forward(self, inp):
        return torch.clamp(inp,min=0.)
    
    def _backward(self, out, inp):
        inp.g = (inp>0).float() * out.g
        return inp.g

In [357]:
# Lets test our model

l1 = Lin(W1,b1)
l2 = Relu()
l3 = Lin(W2,b2)
loss = Mse()

out = l1(x)
out = l2(out)
out = l3(out)
out = loss(out, target)

loss.backward()
l3.backward()
l2.backward()
l1.backward()

In [358]:
class_w1g = l1.W.g

In [359]:
# Refactoring was successful, if no error is raised
test_near(w12.grad, class_w1g)

In [360]:
class Model():
    def __init__(self):
        self.layers = [Lin(W1,b1), Relu(), Lin(W2,b2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers): l.backward()

In [361]:
model = Model()
loss = model(x, target)
model.backward()

In [362]:
model_w1g = model.layers[0].W.g

In [363]:
# Refactoring was successful, if no error is raised
test_near(w12.grad, model_w1g)

In [364]:
# Fin