# About Hook

1. In PyTorch, there are a lot of hook functon.
2. For example, after using .backward(), PyTorch will autoderivative and backpropagation to find gradient of leaf nodes.
3. after backpropagation, the computation graph will be destroied
4. However, we can register(註冊) a hook to retain the gradient of middle nodes

# Autoderivative and backpropagation

In [7]:
import torch

x = torch.tensor(2.0, requires_grad=True)
y = x **2 # y = 4
z = y **3 # z = 64

print(f"x join gradient descent: {x.requires_grad}")
print(f"y join gradient descent: {y.requires_grad}")
print(f"x = {x}, y = {y}, z = {z}")


x join gradient descent: True
y join gradient descent: True
x = 2.0, y = 4.0, z = 64.0


In [4]:
z.backward() # backpropagation
print(f"gradient of x = {x.grad}") # 192
print(f"gradient of y = {y.grad}") # None, since y is not a leaf nodes

gradient of x = 192.0
gradient of y = None


  return self._grad


# Define a Hook Function

In [11]:
x = torch.tensor(2.0, requires_grad=True)
y = x **2 # y = 4
z = y **3 # z = 64

print(f"x join gradient descent: {x.requires_grad}")
print(f"y join gradient descent: {y.requires_grad}")
print(f"x = {x}, y = {y}, z = {z}")

x join gradient descent: True
y join gradient descent: True
x = 2.0, y = 4.0, z = 64.0


In [12]:
##########################################
# 1. define (hook) function
# 2. use this function to retain gradient
# 3. tensor.register: register function as a hook of some tensor
##########################################
def save_gradient(grad):
    global gradient # global variable
    gradient = grad

hook = y.register_hook(save_gradient) # 將函數註冊為張量y的鉤子


##########################################
z.backward()

print(f"gradient of x = {x.grad}")   # 192
print(f"gradient of y = {y.grad}")   # None, still be destroied
print(f"gradient of y = {gradient}") # but we already save gradient of y in "gradient"


gradient of x = 192.0
gradient of y = None
gradient of y = 48.0


# backpropagation
1. reference: https://blog.csdn.net/comli_cn/article/details/104664494

In [14]:
import torch

x = torch.tensor([[2, 4]]
                 , requires_grad=True, dtype=torch.float)
y = torch.zeros(1, 2)
y[0, 0] = x[0, 0]**2 + x[0, 1]
y[0, 1] = x[0, 0]    + x[0, 1]**3
output = 2 * y

gradient = torch.tensor([[1, 2]], requires_grad=True, dtype=torch.float)
output.backward(gradient) # gradient is weight
print(x.grad)

############### visualize ###############
#x = /   \
#    | 2 |
#    | 4 |
#    \   /
#
#output = 2 * y = /             \
#                 |  2*x0^2+x1  |
#                 | 2*(x0+a1^3) |
#                 \             /
#        
#output = 2 * y = [[2*(x0^2+x1),
#                   2*(x0+a1^3)]]
#d(output) / d(x) = /                                 \
#                   |  d(out_0)/d(x0)  d(out_0)/d(x1) |
#                   |  d(out_1)/d(x0)  d(out_1)/d(x1) |
#                   \                                 /
#x.grad = 
#            /                                 \
#            |  d(out_0)/d(x0)  d(out_0)/d(x1) |
#  [2 4] *   |  d(out_1)/d(x0)  d(out_1)/d(x1) |
#            \                                 /
###########################################            



tensor([[ 12., 194.]])