## Batch NormalizationとLayer Normarlization

## Hook

### .register_hook

In [45]:
# ライブラリ
from functools import partial


import torch
from torch import nn


In [16]:
a = torch.ones(5, requires_grad=True)
b = 2*a
# 中間gradを保持
b.retain_grad()

# ここに追加

def print_grad(grad):
    print(grad)
b.register_hook(print_grad)

"""
lambda関数を使って省略
b.register_hook(lambda grad: print(grad))
"""

c = b.mean()

# backward
c.backward()


tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


In [9]:
print(a)
print(b)
print(c)

tensor([1., 1., 1., 1., 1.], requires_grad=True)
tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
tensor(2., grad_fn=<MeanBackward0>)


In [13]:
print(a.grad)
print(b.grad)

tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000])
tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


### .register_forward_hook()

In [120]:
# modelを定義する
conv_model = nn.Sequential(
    # 1x28x28
    nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),

    # 4x14x14
    nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1),
    nn.ReLU(), 

    # 8x7x7
    nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),

    # 16x4x4
    nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),

    nn.Flatten(),
    # 128 -> 32
    nn.Linear(128, 10)
    
)


In [114]:
outputs = {}
def save_output(name, module, inp, out):
    module_name = f'{name}_{str(module)}'
    outputs[module_name] = out.shape
    

In [121]:
# 出力のshapeをdictに保存
for name, module in conv_model.named_modules():
    # print(name, module)
    if name: # 自分自身のmoduleにはhookをつけない
        module.register_forward_hook(partial(save_output, name))
        

In [122]:
def print_hooks(model):
    for name , module in model.named_modules():
        if hasattr(module, "_forward_hooks"):
            for hook in module._forward_hooks.values():
                print(f'Module {name} has forward hook : {hook}')
    
    for name , module in model.named_modules():
        if hasattr(module, "_forward_hooks"):
            for hook in module._backward_hooks.values():
                print(f'Module {name} has backward hook : {hook}')

In [102]:
print_hooks(conv_model)

Module 0 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '0')
Module 1 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '1')
Module 2 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '2')
Module 3 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '3')
Module 4 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '4')
Module 5 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '5')
Module 6 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '6')
Module 7 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '7')
Module 8 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '8')
Module 9 has forward hook : functools.partial(<function save_output at 0x7fff138aa8b0>, '9')


In [103]:
from functools import partial

# 元となる関数を定義
def power(base, exponent):
    return base ** exponent

# 部分適用した関数を適用
square = partial(power, exponent=2)

# 部分適用した関数の利用
print(square(5))  # 出力25


25


### forwardでhookを発動

In [104]:
X = torch.randn((1, 1, 28, 28))
output = conv_model(X)

In [105]:
outputs

{'0_Conv2d(1, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))': torch.Size([1, 4, 14, 14]),
 '1_ReLU()': torch.Size([1, 4, 14, 14]),
 '2_Conv2d(4, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))': torch.Size([1, 8, 7, 7]),
 '3_ReLU()': torch.Size([1, 8, 7, 7]),
 '4_Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))': torch.Size([1, 16, 4, 4]),
 '5_ReLU()': torch.Size([1, 16, 4, 4]),
 '6_Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))': torch.Size([1, 32, 2, 2]),
 '7_ReLU()': torch.Size([1, 32, 2, 2]),
 '8_Flatten(start_dim=1, end_dim=-1)': torch.Size([1, 128]),
 '9_Linear(in_features=128, out_features=10, bias=True)': torch.Size([1, 10])}

## .register_full_backward_hook()

In [123]:
grads = {}
def save_grad_in(name, module, grad_in, grad_out):
    module_name = f'{name}_{str(module)}'
    grads[module_name] = grad_in

# 出力のshapeをdictに保存
for name, module in conv_model.named_modules():
    # print(name, module)
    if name: # 自分自身のmoduleにはhookをつけない
        module.register_full_backward_hook(partial(save_grad_in, name))
        

In [124]:
print_hooks(conv_model)

Module 0 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '0')
Module 1 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '1')
Module 2 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '2')
Module 3 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '3')
Module 4 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '4')
Module 5 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '5')
Module 6 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '6')
Module 7 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '7')
Module 8 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '8')
Module 9 has forward hook : functools.partial(<function save_output at 0x7fff13803670>, '9')
Module 0 has backward hook : functools.partial(<function save_grad_in 

In [125]:
# backward 
X = torch.rand((1, 1, 28, 28))
output = conv_model(X)
loss = output.mean()
loss.backward()

In [126]:
grads

{'9_Linear(in_features=128, out_features=10, bias=True)': (tensor([[-0.0210,  0.0354,  0.0042,  0.0016,  0.0038, -0.0273, -0.0010, -0.0174,
            0.0268, -0.0150,  0.0187, -0.0035,  0.0170, -0.0215,  0.0199,  0.0014,
           -0.0111,  0.0262, -0.0164,  0.0233, -0.0164, -0.0035,  0.0147, -0.0020,
           -0.0214, -0.0130, -0.0064,  0.0097, -0.0075, -0.0067, -0.0044,  0.0045,
            0.0053, -0.0059, -0.0050, -0.0138,  0.0003, -0.0249, -0.0132, -0.0007,
           -0.0176,  0.0183, -0.0150,  0.0093, -0.0287, -0.0007, -0.0147, -0.0030,
           -0.0048, -0.0078, -0.0018, -0.0045, -0.0029,  0.0115,  0.0107, -0.0242,
           -0.0164, -0.0035, -0.0091,  0.0232, -0.0076,  0.0029, -0.0004,  0.0220,
            0.0130,  0.0022, -0.0078,  0.0218, -0.0102, -0.0238, -0.0298, -0.0016,
           -0.0092,  0.0134, -0.0125, -0.0140,  0.0018,  0.0047,  0.0087, -0.0159,
           -0.0145, -0.0051,  0.0111,  0.0060, -0.0080,  0.0121,  0.0036, -0.0019,
            0.0108, -0.0007,  