In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [76]:
from numpy.random import randn
import numpy as np
import os

In [77]:
N, D_in, H, D_out = 16, 6, 5, 2

In [78]:
np.random.seed(1)
x = randn(N, D_in)
np.random.seed(2)
y = randn(N, D_out)

In [79]:
np.random.seed(3)
w1 = randn(D_in, H) 
np.random.seed(23)
b1 = randn(H)
np.random.seed(4)
w2 = randn(H, D_out)
np.random.seed(24)
b2 = randn(D_out)
w1.shape, b1.shape, w2.shape, b2.shape

((6, 5), (5,), (5, 2), (2,))

In [80]:
use_bias = False
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(6, 5, bias=use_bias)
        self.l2 = nn.Linear(5, 2, bias=use_bias)
        
    def forward(self, x):
        x = torch.sigmoid(self.l1(x))
        x= self.l2(x)
        return x

In [81]:
class DAD(nn.Module):
    def __init__(self, model):
        super(DAD, self).__init__()
        self.grads = []
        self.activations = []
        self.model = model
        
    def forward(self, *inputs, **kwargs):
        childrens = dict(self.model.named_children())
        for k, ch in childrens.items():
            ch.register_forward_hook(self.hook_wrapper('forward', k))
            ch.register_backward_hook(self.hook_wrapper('backward', k))
        return self.model(*inputs, **kwargs)
    
    def hook_wrapper(self, hook_type, layer):
        def hook_save(a, in_grad, out_grad):
#             print(hook_type, layer, a, in_grad, out_grad)
            if hook_type.lower() == 'forward':
                for i, b in enumerate(in_grad):
                    if b is not None:
                        self.activations.append(b)
                    break
            if hook_type.lower()=='backward':
                for i, c in enumerate(out_grad):
                    if c is not None:
                        self.grads.append(c)
                    break
        return hook_save

In [82]:
net = Net()

In [83]:
for (n, m), w in zip(net.named_parameters(), [w1, w2]):
    if 'bias' in n:
        m.data = torch.FloatTensor(w)
    elif 'weight' in n:
        m.data = torch.FloatTensor(w.T)

In [84]:
optim = torch.optim.Adam(net.parameters(), lr=0.1)
optim.zero_grad()

In [85]:
o = net(torch.FloatTensor(x))

In [88]:
o

tensor([[-0.4900, -1.5350],
        [-0.8508, -0.9719],
        [-1.3890, -1.2847],
        [-0.3687,  0.1708],
        [-0.8254, -1.1405],
        [-0.4909, -0.9168],
        [-1.1991, -0.3313],
        [-1.0392, -0.4085],
        [-0.7178, -0.1601],
        [-1.2931,  0.3362],
        [-1.2456, -0.6059],
        [-1.2937, -0.8595],
        [-0.9857, -1.5140],
        [-0.7224, -0.4203],
        [-0.5407, -0.0063],
        [-0.7665, -0.8574]], grad_fn=<MmBackward>)

In [87]:
loss = torch.square(o-torch.Tensor(y)).sum()
loss

tensor(48.6326, grad_fn=<SumBackward0>)

In [35]:
loss.backward()

In [89]:
for t in range(10):
    h = 1/(1+np.exp(-x.dot(w1)))
    print('h:', h.shape)
    
    y_pred = h.dot(w2)
    loss = np.square(y_pred-y).sum()
#     print(t, loss)
    
    grad_y_pred = 2.0 * (y_pred - y)
    print('grad_y_pred:', grad_y_pred.shape)
    
    grad_w2 = h.T.dot(grad_y_pred)
    print('grad_w2: ', grad_w2.shape)
    
    
    grad_h = grad_y_pred.dot(w2.T)
    print('grad_h: ', grad_h.shape)
    grad_w1 = x.T.dot(grad_h * h * (1-h))
    print('grad_w1:', grad_w1.shape)
    
    w1 -= 1e-4 * grad_w1
    w2 -= 1e-4 * grad_w2
    break

h: (16, 5)
grad_y_pred: (16, 2)
grad_w2:  (5, 2)
grad_h:  (16, 5)
grad_w1: (6, 5)
