In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [17]:
from numpy.random import randn
import numpy as np
import os

In [18]:
N, D_in, H, D_out = 16, 6, 5, 2

In [19]:
np.random.seed(1)
x = randn(N, D_in)
np.random.seed(2)
y = randn(N, D_out)

In [20]:
np.random.seed(3)
w1 = randn(D_in, H) 
np.random.seed(23)
b1 = randn(H)
np.random.seed(4)
w2 = randn(H, D_out)
np.random.seed(24)
b2 = randn(D_out)
w1.shape, b1.shape, w2.shape, b2.shape

((6, 5), (5,), (5, 2), (2,))

In [41]:
use_bias = False
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(6, 5, bias=use_bias)
        self.l2 = nn.Linear(5, 2, bias=use_bias)
        
    def forward(self, x):
        x = torch.sigmoid(self.l1(x))
        x= self.l2(x)
        return x

In [80]:
class DAD(nn.Module):
    def __init__(self, model):
        super(DAD, self).__init__()
        self.out_activations = []
        self.model = model
        
    def forward(self, *inputs, **kwargs):
        childrens = dict(self.model.named_children())
        for k, ch in childrens.items():
            ch.register_forward_hook(self.hook_wrapper('forward', k))
        return self.model(*inputs, **kwargs)
    
    def hook_wrapper(self, hook_type, layer):
        def hook_save(a, in_grad, out_grad):
            if hook_type.lower() == 'forward':
                self.out_activations.append(out_grad)
        return hook_save

In [81]:
net = DAD(Net())

In [82]:
for (n, m), w in zip(net.named_parameters(), [w1, w2]):
    if 'bias' in n:
        m.data = torch.FloatTensor(w)
    elif 'weight' in n:
        m.data = torch.FloatTensor(w.T)

In [83]:
optim = torch.optim.Adam(net.parameters(), lr=0.1)
optim.zero_grad()

In [84]:
o = net(torch.FloatTensor(x))

In [85]:
loss = torch.square(o-torch.Tensor(y)).sum()
# loss.backward()

In [87]:
net.out_activations

[tensor([[ 4.8668e+00, -7.3101e-01,  3.3871e+00, -2.9787e+00,  6.3070e-01],
         [ 2.8090e+00, -3.3843e-01,  3.8158e+00, -9.8379e-01, -8.6114e-01],
         [-6.5381e-01,  9.8198e-01,  2.7928e+00,  2.1106e+00,  1.9173e+00],
         [-5.7622e-01, -1.5003e+00, -1.8400e+00, -1.0349e+00, -2.6426e+00],
         [ 2.3340e+00,  1.2404e+00,  1.3691e+00, -3.2487e+00,  1.2505e+00],
         [ 9.5366e-01, -2.8573e-01, -1.1422e-01, -8.4773e-01,  1.9675e+00],
         [-3.7027e+00,  3.8523e-02, -8.0047e-02,  6.2751e+00, -1.3253e-01],
         [-3.4696e+00,  1.1946e+00, -1.5707e-01, -7.9400e-01, -5.7030e-01],
         [ 1.3477e+00,  4.7508e-01, -1.2094e+00, -8.6205e-02,  1.3253e+00],
         [-3.0312e+00,  5.8591e-01, -8.0356e-01,  2.0979e+00, -1.6977e+00],
         [-2.8721e+00,  3.3434e-02,  7.9310e-01,  2.1075e+00, -5.1581e-01],
         [ 1.2508e+00,  3.0449e+00,  1.1739e+00,  2.3414e-01,  2.2010e+00],
         [ 1.5272e-01,  1.9597e+00,  2.2800e+00, -2.4939e+00,  1.9022e+00],
         [ 1

In [88]:
for t in range(10):
    h = 1/(1+np.exp(-x.dot(w1)))
    print('h:', h.shape)
    y_pred = h.dot(w2)
    loss1 = np.square(y_pred-y).sum()
#     print(t, loss)
    
    grad_y_pred = 2.0 * (y_pred - y)
    print('grad_y_pred:', grad_y_pred.shape)
    
    grad_w2 = h.T.dot(grad_y_pred)
    print('grad_w2: ', grad_w2.shape)
    
    
    grad_h = grad_y_pred.dot(w2.T)
    print('grad_h: ', grad_h.shape)
    grad_w1 = x.T.dot(grad_h * h * (1-h))
    print('grad_w1:', grad_w1.shape)
    
    w1 -= 1e-4 * grad_w1
    w2 -= 1e-4 * grad_w2
    break

h: (16, 5)
grad_y_pred: (16, 2)
grad_w2:  (5, 2)
grad_h:  (16, 5)
grad_w1: (6, 5)


In [89]:
grads = []
for op in net.out_activations[::-1]:
    grads.append(torch.autograd.grad(loss, op, retain_graph=True))
    print('------------')
grads

------------
------------


[(tensor([[-0.1413, -2.9189],
          [ 2.5752, -5.1889],
          [ 0.8237, -0.8552],
          [-1.7391,  2.8386],
          [ 0.4770, -0.4280],
          [-2.0736, -6.3905],
          [-2.4692,  1.5844],
          [-3.1440,  0.3854],
          [-1.3851, -2.6450],
          [-1.0799,  0.6617],
          [-0.7229, -0.8840],
          [-3.0839,  0.2942],
          [-1.2801, -2.5237],
          [-0.1607,  1.5531],
          [ 1.7682,  0.3077],
          [-0.9864, -6.1461]]),),
 (tensor([[-1.1251e-02, -4.1559e-01,  1.4767e-01, -7.6393e-02,  7.4549e-01],
          [-1.3361e-01, -1.5004e+00,  1.5012e-01, -9.4552e-01,  1.4183e+00],
          [-8.7819e-02, -2.8038e-01,  5.4787e-02, -1.0068e-01,  1.4001e-01],
          [ 3.1013e-01,  5.5268e-01, -4.4444e-01,  5.4632e-01, -2.3732e-01],
          [-1.5456e-02, -1.3437e-01,  7.7126e-02, -2.0308e-02,  1.1226e-01],
          [-6.7072e-01, -5.8676e-01,  2.7300e+00, -5.2336e-01,  7.1121e-01],
          [ 1.5831e-02,  8.8895e-01, -3.6771e-01,  4.7

In [90]:
grads[0][0]

tensor([[-0.1413, -2.9189],
        [ 2.5752, -5.1889],
        [ 0.8237, -0.8552],
        [-1.7391,  2.8386],
        [ 0.4770, -0.4280],
        [-2.0736, -6.3905],
        [-2.4692,  1.5844],
        [-3.1440,  0.3854],
        [-1.3851, -2.6450],
        [-1.0799,  0.6617],
        [-0.7229, -0.8840],
        [-3.0839,  0.2942],
        [-1.2801, -2.5237],
        [-0.1607,  1.5531],
        [ 1.7682,  0.3077],
        [-0.9864, -6.1461]])

In [91]:
grads[1][0]

tensor([[-1.1251e-02, -4.1559e-01,  1.4767e-01, -7.6393e-02,  7.4549e-01],
        [-1.3361e-01, -1.5004e+00,  1.5012e-01, -9.4552e-01,  1.4183e+00],
        [-8.7819e-02, -2.8038e-01,  5.4787e-02, -1.0068e-01,  1.4001e-01],
        [ 3.1013e-01,  5.5268e-01, -4.4444e-01,  5.4632e-01, -2.3732e-01],
        [-1.5456e-02, -1.3437e-01,  7.7126e-02, -2.0308e-02,  1.1226e-01],
        [-6.7072e-01, -5.8676e-01,  2.7300e+00, -5.2336e-01,  7.1121e-01],
        [ 1.5831e-02,  8.8895e-01, -3.6771e-01,  4.7716e-03, -6.5634e-01],
        [ 9.5444e-04,  6.0509e-01,  1.7426e-01,  4.8471e-01, -3.4437e-01],
        [-2.3081e-01, -1.1065e-01,  8.4144e-01, -1.7256e-01,  4.2448e-01],
        [ 1.2254e-02,  3.5236e-01, -1.2711e-01,  1.0650e-01, -1.4620e-01],
        [-2.4561e-02,  2.5442e-02,  3.6400e-01, -6.1153e-03,  1.7985e-01],
        [-1.8056e-03,  1.4175e-01,  1.4796e-01,  5.3447e-01, -1.2277e-01],
        [-3.3380e-01, -5.2801e-02,  3.8041e-01, -4.8383e-02,  2.7747e-01],
        [ 1.9438e-01,  3.

In [92]:
grad_y_pred

array([[-0.14128343, -2.91890857],
       [ 2.57515734, -5.18890603],
       [ 0.82365203, -0.85523361],
       [-1.73909922,  2.8385638 ],
       [ 0.47695332, -0.42796761],
       [-2.07359372, -6.39049049],
       [-2.46923772,  1.58436757],
       [-3.14404445,  0.38541964],
       [-1.38512863, -2.64496736],
       [-1.07988552,  0.66171606],
       [-0.72291186, -0.88400441],
       [-3.08388612,  0.29422503],
       [-1.28009607, -2.52365231],
       [-0.16070638,  1.55309877],
       [ 1.76819895,  0.30772287],
       [-0.9863842 , -6.1461239 ]])