In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from numpy.random import randn
import numpy as np
import os

In [3]:
N, D_in, H, D_out = 16, 6, 5, 2

In [4]:
mid = N//2

In [5]:
np.random.seed(1)
x = randn(N, D_in)
np.random.seed(2)
y = randn(N, D_out)

x1, x2 = x[:mid], x[mid:]
y1, y2 = y[:mid], y[mid:]

In [6]:
np.random.seed(3)
w1 = randn(D_in, H) 
np.random.seed(23)
b1 = randn(H)
np.random.seed(4)
w2 = randn(H, D_out)
np.random.seed(24)
b2 = randn(D_out)
w1.shape, b1.shape, w2.shape, b2.shape

((6, 5), (5,), (5, 2), (2,))

In [7]:
use_bias = True
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(6, 5, bias=use_bias)
        self.l2 = nn.Linear(5, 2, bias=use_bias)
        
    def forward(self, x):
        x = torch.sigmoid(self.l1(x))
        x= self.l2(x)
        return x

In [23]:
class DAD(nn.Module):
    def __init__(self, model):
        super(DAD, self).__init__()
        self.grads = []
        self.activations = []
        self.model = model
        
    def forward(self, *inputs, **kwargs):
        childrens = dict(self.model.named_children())
        for k, ch in childrens.items():
            ch.register_forward_hook(self.hook_wrapper('forward', k))
            ch.register_backward_hook(self.hook_wrapper('backward', k))
        return self.model(*inputs, **kwargs)
    
    def hook_wrapper(self, hook_type, layer):
        def hook_save(a, in_grad, out_grad):
#             print(hook_type, layer, a, in_grad, out_grad)
            if hook_type.lower() == 'forward':
                for i, b in enumerate(in_grad):
                    if b is not None:
                        self.activations.append(b)
                    break
            if hook_type.lower()=='backward':
                for i, c in enumerate(out_grad):
                    if c is not None:
                        self.grads.append(c)
                    break
        return hook_save

In [25]:
net = DAD(Net())

In [26]:
ks = list(dict(net.named_parameters()).keys())

In [27]:
ks[::2]

['model.l1.weight', 'model.l2.weight']

In [28]:
ks[::-2]

['model.l2.bias', 'model.l1.bias']

In [29]:
ks

['model.l1.weight', 'model.l1.bias', 'model.l2.weight', 'model.l2.bias']

In [30]:
# list(net.model.named_children())

In [31]:
# for (n, m), w in zip(net.named_parameters(), [w1, w2]):
#     if 'bias' in n:
#         m.data = torch.FloatTensor(w)
#     elif 'weight' in n:
#         m.data = torch.FloatTensor(w.T)

In [32]:
optim = torch.optim.Adam(net.parameters(), lr=0.1)
optim.zero_grad()

In [33]:
o = net(torch.FloatTensor(x))



In [34]:
loss = torch.square(o-torch.Tensor(y)).sum()

In [35]:
loss.backward()

In [20]:
for p1, p2 in zip(net.parameters(), net.parameters()):
    print(p1.grad.shape, p2.grad.shape)
    p = torch.stack([p1, p2]).sum(0)

torch.Size([5, 6]) torch.Size([5, 6])
torch.Size([5]) torch.Size([5])
torch.Size([2, 5]) torch.Size([2, 5])
torch.Size([2]) torch.Size([2])


tensor(44.2717, grad_fn=<SumBackward0>)

In [53]:
p1+p2

tensor([-0.4013, -0.6554], grad_fn=<AddBackward0>)

In [934]:
# [p.grad for p in net.parameters()]

In [908]:
# g = grad_h * h * (1-h)
# g.sum(0)

In [909]:
# b1, b2

In [37]:
for t in range(10):
    h = 1/(1+np.exp(-x.dot(w1)))
    print('h:', h.shape)
    
    y_pred = h.dot(w2)
    loss = np.square(y_pred-y).sum()
#     print(t, loss)
    
    grad_y_pred = 2.0 * (y_pred - y)
    print('grad_y_pred:', grad_y_pred.shape)
    
    grad_w2 = h.T.dot(grad_y_pred)
    print('grad_w2: ', grad_w2.shape)
    
    
    grad_h = grad_y_pred.dot(w2.T)
    print('grad_h: ', grad_h.shape)
    grad_w1 = x.T.dot(grad_h * h * (1-h))
    print('grad_w1:', grad_w1.shape)
    
    w1 -= 1e-4 * grad_w1
    w2 -= 1e-4 * grad_w2
    break

h: (16, 5)
grad_y_pred: (16, 2)
grad_w2:  (5, 2)
grad_h:  (16, 5)
grad_w1: (6, 5)
