In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from numpy.random import randn
import numpy as np
import os

In [3]:
N, D_in, H, D_out = 16, 6, 5, 2

In [4]:
mid = N//2

In [5]:
np.random.seed(1)
x = randn(N, D_in)
np.random.seed(2)
y = randn(N, D_out)

x1, x2 = x[:mid], x[mid:]
y1, y2 = y[:mid], y[mid:]

In [6]:
np.random.seed(3)
w1 = randn(D_in, H) 
np.random.seed(4)
w2 = randn(H, D_out)
w1.shape, w2.shape

((6, 5), (5, 2))

In [7]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(6, 5, bias=False)
        self.l2 = nn.Linear(5, 2, bias=False)
        
    def forward(self, x):
        x = torch.sigmoid(self.l1(x))
        x= self.l2(x)
        return x

In [8]:
net = Net()
net1 = Net()
net2 = Net()
for p, p1, p2, a in zip(net.parameters(), net1.parameters(), net2.parameters(), [w1, w2]):
    p.data *= 0
    p.data += torch.FloatTensor(a.T)
    
    p1.data *= 0
    p1.data += torch.FloatTensor(a.T)
    
    p2.data *= 0
    p2.data += torch.FloatTensor(a.T)

In [9]:
def hook_wrapper(site, hook_type, layer):
    print(f"**** {site}, {hook_type}, {layer}  ****")
    
    name = f"Site:{site}-Type:{hook_type}-Layer:{layer}"
    def hook_save(a, in_grad, out_grad):
        print('------IN---------')
        if hook_type.lower() == 'forward':
            for i, b in enumerate(in_grad):
                if b is not None:
                    print(name)
                    print(b)
                break
        print('\n------------OUT-----------')
        if hook_type.lower()=='backward':
            for i, c in enumerate(out_grad):
                if c is not None:
                    print(name)
                    print(c)
                break
        print('****************************')
    return hook_save

In [10]:
# def hook_wrapper(site, hook_type, layer, save_to='', debug=False):
#     if debug:
#         print(f"**** {site}, {hook_type}, {layer} ****")

#     name = os.path.join(save_to, f"Site:{site}-Type:{hook_type}-Layer:{layer}")

#     def hook_save(a, in_grad, out_grad):
#         if hook_type.lower() == 'forward':
#             for i, b in enumerate(in_grad):
#                 if b is not None:
#                     np.save(name + f"-IO:in-Index:{i}.npy", b.clone().detach().numpy())
#                 break
#         if hook_type.lower() == 'backward':
#             for i, c in enumerate(out_grad):
#                 if c is not None:
#                     np.save(name + f"-IO:out-index:{i}.npy", c.clone().detach().numpy())
#                 break

#     return hook_save

In [11]:
childrens = dict(net.named_children())
for k, ch in childrens.items():
    ch.register_forward_hook(hook_wrapper(0, 'forward', k))
    ch.register_backward_hook(hook_wrapper(0, 'backward', k))

**** 0, forward, l1  ****
**** 0, backward, l1  ****
**** 0, forward, l2  ****
**** 0, backward, l2  ****


In [12]:
[k for k, V in net.named_children()]

['l1', 'l2']

In [13]:
for t in range(10):
    h = 1/(1+np.exp(-x.dot(w1)))
    print('h:', h.shape)
    y_pred = h.dot(w2)
    loss = np.square(y_pred-y).sum()
#     print(t, loss)
    
    grad_y_pred = 2.0 * (y_pred - y)
    print('grad_y_pred:', grad_y_pred.shape)
    
    grad_w2 = h.T.dot(grad_y_pred)
    print('grad_w2: ', grad_w2.shape)
    
    
    grad_h = grad_y_pred.dot(w2.T)
    print('grad_h: ', grad_h.shape)
    grad_w1 = x.T.dot(grad_h * h * (1-h))
    print('grad_w1:', grad_w1.shape)
    
    w1 -= 1e-4 * grad_w1
    w2 -= 1e-4 * grad_w2
    break

h: (16, 5)
grad_y_pred: (16, 2)
grad_w2:  (5, 2)
grad_h:  (16, 5)
grad_w1: (6, 5)


In [14]:
p1 = net(torch.Tensor(x))

------IN---------
Site:0-Type:forward-Layer:l1
tensor([[ 1.6243, -0.6118, -0.5282, -1.0730,  0.8654, -2.3015],
        [ 1.7448, -0.7612,  0.3190, -0.2494,  1.4621, -2.0601],
        [-0.3224, -0.3841,  1.1338, -1.0999, -0.1724, -0.8779],
        [ 0.0422,  0.5828, -1.1006,  1.1447,  0.9016,  0.5025],
        [ 0.9009, -0.6837, -0.1229, -0.9358, -0.2679,  0.5304],
        [-0.6917, -0.3968, -0.6872, -0.8452, -0.6712, -0.0127],
        [-1.1173,  0.2344,  1.6598,  0.7420, -0.1918, -0.8876],
        [-0.7472,  1.6925,  0.0508, -0.6370,  0.1909,  2.1003],
        [ 0.1202,  0.6172,  0.3002, -0.3522, -1.1425, -0.3493],
        [-0.2089,  0.5866,  0.8390,  0.9311,  0.2856,  0.8851],
        [-0.7544,  1.2529,  0.5129, -0.2981,  0.4885, -0.0756],
        [ 1.1316,  1.5198,  2.1856, -1.3965, -1.4441, -0.5045],
        [ 0.1600,  0.8762,  0.3156, -2.0222, -0.3062,  0.8280],
        [ 0.2301,  0.7620, -0.2223, -0.2008,  0.1866,  0.4101],
        [ 0.1983,  0.1190, -0.6707,  0.3776,  0.1218,  1.

In [15]:
loss = torch.square(p1-torch.Tensor(y)).sum()

In [16]:
loss.backward()

------IN---------

------------OUT-----------
Site:0-Type:backward-Layer:l2
tensor([[-0.1465, -2.9575],
        [ 2.5708, -5.2243],
        [ 0.8088, -0.8860],
        [-1.7431,  2.8321],
        [ 0.4652, -0.4630],
        [-2.0848, -6.4180],
        [-2.4813,  1.5732],
        [-3.1565,  0.3753],
        [-1.3973, -2.6701],
        [-1.0906,  0.6543],
        [-0.7350, -0.8989],
        [-3.1005,  0.2585],
        [-1.2938, -2.5556],
        [-0.1695,  1.5346],
        [ 1.7610,  0.2945],
        [-0.9949, -6.1776]])
****************************
------IN---------

------------OUT-----------
Site:0-Type:backward-Layer:l1
tensor([[-1.1275e-02, -4.1676e-01,  1.4981e-01, -7.6879e-02,  7.5620e-01],
        [-1.3308e-01, -1.5006e+00,  1.5181e-01, -9.4760e-01,  1.4349e+00],
        [-9.0469e-02, -2.8171e-01,  5.7864e-02, -1.0170e-01,  1.4345e-01],
        [ 3.0594e-01,  5.5226e-01, -4.4375e-01,  5.4646e-01, -2.3724e-01],
        [-1.6758e-02, -1.3668e-01,  8.7033e-02, -2.0798e-02,  1.1853e-

In [27]:
for k, p in net.named_parameters():
    print(k,p)

l1.weight Parameter containing:
tensor([[ 1.7886, -0.3548, -1.3139, -0.4047, -1.1850, -0.7130],
        [ 0.4365, -0.0827,  0.8846, -0.5454, -0.2056,  0.6252],
        [ 0.0965, -0.6270,  0.8813, -1.5465,  1.4861, -0.1605],
        [-1.8635, -0.0438,  1.7096,  0.9824,  0.2367, -0.7688],
        [-0.2774, -0.4772,  0.0500, -1.1011, -1.0238, -0.2300]],
       requires_grad=True)
l2.weight Parameter containing:
tensor([[ 0.0506, -0.9959, -0.4183, -0.6477,  0.3323],
        [ 0.5000,  0.6936, -1.5846,  0.5986, -1.1475]], requires_grad=True)


In [23]:
n = Net()

In [26]:
for p in net.parameters():
    p.grad.detach_()

In [32]:
n.load_state_dict(torch.load('chk.tar'))

<All keys matched successfully>

In [37]:
"Name: {}".format('AK')

'Name: AK'

In [44]:
t = torch.FloatTensor([[1, 2], [3, 4]], device='cpu')

In [45]:
t.T.mm()

tensor([[1., 3.],
        [2., 4.]])

In [46]:
t

tensor([[1., 2.],
        [3., 4.]])

In [50]:
h.T.dot(grad_y_pred)

array([[ -3.65050639, -20.05829342],
       [ -8.77522329, -11.72429864],
       [ -5.49211195, -17.99616749],
       [ -7.0682281 ,  -3.19438619],
       [ -7.95513326, -15.59246294]])

In [52]:
torch.Tensor(h).T.mm(torch.Tensor(grad_y_pred))

tensor([[ -3.6505, -20.0583],
        [ -8.7752, -11.7243],
        [ -5.4921, -17.9962],
        [ -7.0682,  -3.1944],
        [ -7.9551, -15.5925]])

In [55]:
[p.grad.shape for p in net.parameters()]

[torch.Size([5, 6]), torch.Size([2, 5])]

In [64]:
for l, ch in net.named_parameters():
    print(l, ch.grad)

l1.weight tensor([[-0.1147,  0.2385, -0.1600,  2.0195,  0.8291,  0.4634],
        [-5.2419,  3.7720,  1.7932,  3.2138, -1.9054,  4.8540],
        [ 0.9208, -0.0698, -1.5408, -6.1658, -2.3991, -0.7603],
        [-1.2793,  2.9918,  0.6647,  0.7110, -1.0412,  3.3018],
        [ 6.0451, -2.3750, -1.7649, -4.2975,  2.3372, -5.0113]])
l2.weight tensor([[ -3.6777,  -8.8142,  -5.5268,  -7.0952,  -7.9853],
        [-20.1403, -11.7940, -18.0874,  -3.2346, -15.6785]])
