In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [447]:
from numpy.random import randn
import numpy as np
import os
import datetime
import time

In [446]:
def duration(begin, t_del=None):
    if t_del is None:
        t_del = datetime.datetime.fromtimestamp(time.time()) - datetime.datetime.fromtimestamp(begin)
    return t_del

In [262]:
N, D_in, H, D_out = 16, 6, 5, 2

In [263]:
np.random.seed(1)
x = randn(N, D_in)
np.random.seed(2)
y = randn(N, D_out)

In [264]:
np.random.seed(3)
w1 = randn(D_in, H) 
np.random.seed(23)
b1 = randn(H)
np.random.seed(4)
w2 = randn(H, D_out)
np.random.seed(24)
b2 = randn(D_out)
w1.shape, b1.shape, w2.shape, b2.shape

((6, 5), (5,), (5, 2), (2,))

In [265]:
use_bias = False
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(6, 5, bias=use_bias)
        self.l2 = nn.Linear(5, 2, bias=use_bias)
        
    def forward(self, x):
        x = torch.sigmoid(self.l1(x))
        x= self.l2(x)
        return x

In [266]:
class DAD(nn.Module):
    def __init__(self, model):
        super(DAD, self).__init__()
        self.in_activations = []
        self.out_activations = []
        self.model = model
        
    def forward(self, *inputs, **kwargs):
        childrens = dict(self.model.named_children())
        for k, ch in childrens.items():
            ch.register_forward_hook(self.hook_wrapper('forward', k))
        return self.model(*inputs, **kwargs)
    
    def hook_wrapper(self, hook_type, layer):
        def fw_hook(a, in_act, out_act):
#             print(f'----IN----------------------')
#             print(in_act)
#             print(f'{out_act.shape}----OUT----------------------')
#             print(out_act)
            self.out_activations.append(out_act)
            self.in_activations.append(in_act[0])
            
        return fw_hook

In [532]:
net = DAD(Net())

In [533]:
for (n, m), w in zip(net.named_parameters(), [w1, w2]):
    if 'bias' in n:
        m.data = torch.FloatTensor(w)
    elif 'weight' in n:
        m.data = torch.FloatTensor(w.T)

In [534]:
optim = torch.optim.Adam(net.parameters(), lr=0.1)
net.zero_grad()

In [535]:
o = net(torch.FloatTensor(x))

In [590]:
loss = torch.square(o-torch.Tensor(y)).sum()

In [597]:
def back(_optim, x, y, iters=3000):
    tot = datetime.timedelta(0)
    for i in range(iters):
        net = DAD(Net())
        _loss = torch.square(net(x)-y).sum()
        _st = time.time()
        _loss.backward()
        _optim.step()
        tot += duration(_st)
    return tot

In [598]:
bkt = back(optim, torch.FloatTensor(x), torch.FloatTensor(y))

In [599]:
f"{bkt}"

'0:00:00.528007'

In [540]:
# grads = []
# for op in net.out_activations:
#     grads.append(torch.autograd.grad(loss, op, retain_graph=True)[0])
#     print('------------')
# grads

In [600]:
def dadback(optim, x, y, iters=100):
    tot = datetime.timedelta(0)
    for i in range(iters):
        net = DAD(Net())
        _loss = torch.square(net(x)-y).sum()
        
        _st = time.time()
        grads = []
        for op in net.out_activations:
            grads.append(torch.autograd.grad(_loss, op, retain_graph=True)[0])
            
        for p, a, g in zip(net.parameters(), net.in_activations, grads):
            p.grad = torch.FloatTensor(a.T.mm(g)).T
            
        optim.step()
        tot += duration(_st)
    return tot

In [601]:
dadt = dadback(optim, torch.FloatTensor(x), torch.FloatTensor(y), 3000)

In [602]:
f"{dadt}"

'0:00:00.709344'

In [252]:
for t in range(10):
    h = 1/(1+np.exp(-x.dot(w1)))
    print('h:', h.shape)
    y_pred = h.dot(w2)
    loss1 = np.square(y_pred-y).sum()
#     print(t, loss)
    
    grad_y_pred = 2.0 * (y_pred - y)
    print('grad_y_pred:', grad_y_pred.shape)
    
    grad_w2 = h.T.dot(grad_y_pred)
    print('grad_w2: ', grad_w2.shape)
    
    
    grad_h = grad_y_pred.dot(w2.T)
    print('grad_h: ', grad_h.shape)
    grad_w1 = x.T.dot(grad_h * h * (1-h))
    print('grad_w1:', grad_w1.shape)
    
    w1 -= 1e-4 * grad_w1
    w2 -= 1e-4 * grad_w2
    break

h: (16, 5)
grad_y_pred: (16, 2)
grad_w2:  (5, 2)
grad_h:  (16, 5)
grad_w1: (6, 5)


In [254]:
grads = []
for op in net.out_activations[::-1]:
    grads.append(torch.autograd.grad(loss, op, retain_graph=True)[0])
    print('------------')
grads

------------
------------


[tensor([[-0.1430, -2.9317],
         [ 2.5737, -5.2007],
         [ 0.8187, -0.8654],
         [-1.7404,  2.8364],
         [ 0.4730, -0.4396],
         [-2.0773, -6.3996],
         [-2.4732,  1.5807],
         [-3.1482,  0.3821],
         [-1.3892, -2.6533],
         [-1.0834,  0.6593],
         [-0.7269, -0.8889],
         [-3.0894,  0.2824],
         [-1.2846, -2.5343],
         [-0.1636,  1.5470],
         [ 1.7658,  0.3033],
         [-0.9892, -6.1566]]),
 tensor([[-1.1260e-02, -4.1599e-01,  1.4838e-01, -7.6555e-02,  7.4904e-01],
         [-1.3344e-01, -1.5005e+00,  1.5068e-01, -9.4621e-01,  1.4238e+00],
         [-8.8707e-02, -2.8083e-01,  5.5806e-02, -1.0102e-01,  1.4115e-01],
         [ 3.0874e-01,  5.5254e-01, -4.4421e-01,  5.4637e-01, -2.3730e-01],
         [-1.5892e-02, -1.3514e-01,  8.0408e-02, -2.0471e-02,  1.1434e-01],
         [-6.6903e-01, -5.8514e-01,  2.7372e+00, -5.2326e-01,  7.1344e-01],
         [ 1.5730e-02,  8.8936e-01, -3.6619e-01,  4.7740e-03, -6.5570e-01],
  

In [260]:
torch.autograd.grad(loss, net.out_activations, retain_graph=True)[0]

tensor([[-1.1260e-02, -4.1599e-01,  1.4838e-01, -7.6555e-02,  7.4904e-01],
        [-1.3344e-01, -1.5005e+00,  1.5068e-01, -9.4621e-01,  1.4238e+00],
        [-8.8707e-02, -2.8083e-01,  5.5806e-02, -1.0102e-01,  1.4115e-01],
        [ 3.0874e-01,  5.5254e-01, -4.4421e-01,  5.4637e-01, -2.3730e-01],
        [-1.5892e-02, -1.3514e-01,  8.0408e-02, -2.0471e-02,  1.1434e-01],
        [-6.6903e-01, -5.8514e-01,  2.7372e+00, -5.2326e-01,  7.1344e-01],
        [ 1.5730e-02,  8.8936e-01, -3.6619e-01,  4.7740e-03, -6.5570e-01],
        [ 9.0980e-04,  6.0522e-01,  1.7625e-01,  4.8553e-01, -3.4322e-01],
        [-2.3060e-01, -1.1005e-01,  8.4526e-01, -1.7271e-01,  4.2672e-01],
        [ 1.2154e-02,  3.5271e-01, -1.2605e-01,  1.0658e-01, -1.4589e-01],
        [-2.4594e-02,  2.5999e-02,  3.6646e-01, -6.0715e-03,  1.8129e-01],
        [-2.7960e-03,  1.4174e-01,  1.5188e-01,  5.3409e-01, -1.2150e-01],
        [-3.3383e-01, -5.2667e-02,  3.8204e-01, -4.8522e-02,  2.7918e-01],
        [ 1.9281e-01,  3.

In [257]:
net.out_activations

[tensor([[ 4.8665e+00, -7.3381e-01,  3.3880e+00, -2.9801e+00,  6.3372e-01],
         [ 2.8089e+00, -3.4093e-01,  3.8159e+00, -9.8507e-01, -8.5849e-01],
         [-6.5411e-01,  9.8146e-01,  2.7933e+00,  2.1102e+00,  1.9178e+00],
         [-5.7586e-01, -1.4998e+00, -1.8408e+00, -1.0347e+00, -2.6430e+00],
         [ 2.3337e+00,  1.2396e+00,  1.3698e+00, -3.2489e+00,  1.2513e+00],
         [ 9.5344e-01, -2.8579e-01, -1.1350e-01, -8.4780e-01,  1.9675e+00],
         [-3.7026e+00,  3.9338e-02, -8.0749e-02,  6.2752e+00, -1.3347e-01],
         [-3.4696e+00,  1.1964e+00, -1.5698e-01, -7.9277e-01, -5.7189e-01],
         [ 1.3475e+00,  4.7524e-01, -1.2089e+00, -8.6037e-02,  1.3252e+00],
         [-3.0309e+00,  5.8706e-01, -8.0442e-01,  2.0985e+00, -1.6988e+00],
         [-2.8721e+00,  3.4168e-02,  7.9301e-01,  2.1079e+00, -5.1637e-01],
         [ 1.2504e+00,  3.0449e+00,  1.1749e+00,  2.3448e-01,  2.2014e+00],
         [ 1.5234e-01,  1.9598e+00,  2.2812e+00, -2.4935e+00,  1.9025e+00],
         [ 1

In [197]:
def dadback(net, grads, optim):
    for p, (g, act) in zip(net.parameters(), grads):
        p.grad = torch.FloatTensor(act.T.mm(g)).T
    optim.step()

In [198]:
%time dadback(net, grads, optim)

CPU times: user 4.28 ms, sys: 408 µs, total: 4.69 ms
Wall time: 2.58 ms


In [207]:
net.in_activations

[tensor([[ 1.6243, -0.6118, -0.5282, -1.0730,  0.8654, -2.3015],
         [ 1.7448, -0.7612,  0.3190, -0.2494,  1.4621, -2.0601],
         [-0.3224, -0.3841,  1.1338, -1.0999, -0.1724, -0.8779],
         [ 0.0422,  0.5828, -1.1006,  1.1447,  0.9016,  0.5025],
         [ 0.9009, -0.6837, -0.1229, -0.9358, -0.2679,  0.5304],
         [-0.6917, -0.3968, -0.6872, -0.8452, -0.6712, -0.0127],
         [-1.1173,  0.2344,  1.6598,  0.7420, -0.1918, -0.8876],
         [-0.7472,  1.6925,  0.0508, -0.6370,  0.1909,  2.1003],
         [ 0.1202,  0.6172,  0.3002, -0.3522, -1.1425, -0.3493],
         [-0.2089,  0.5866,  0.8390,  0.9311,  0.2856,  0.8851],
         [-0.7544,  1.2529,  0.5129, -0.2981,  0.4885, -0.0756],
         [ 1.1316,  1.5198,  2.1856, -1.3965, -1.4441, -0.5045],
         [ 0.1600,  0.8762,  0.3156, -2.0222, -0.3062,  0.8280],
         [ 0.2301,  0.7620, -0.2223, -0.2008,  0.1866,  0.4101],
         [ 0.1983,  0.1190, -0.6707,  0.3776,  0.1218,  1.1295],
         [ 1.1989,  0.185

In [157]:
def a(_):
    print(_)

In [159]:
a(10)

10


In [165]:
list(net.named_children())

[('model',
  Net(
    (l1): Linear(in_features=6, out_features=5, bias=True)
    (l2): Linear(in_features=5, out_features=2, bias=True)
  ))]

In [185]:
m = list(net.model.children())