In [1]:
import torch
from torch import nn
from torch.autograd.functional import vhp
import copy
import numpy as np


In [2]:
def pow_reducer(x):
  return x.pow(3).sum()


def pow_adder_reducer(x, y):
  return (2 * x.pow(4) + 3 * y.pow(3) + x.pow(3)*y.pow(5)).sum()


The inputs are what the Hessian is computed with respect to. Here it is just x's; for us it will be model parameters. This is fine, because loss functions are set up that way: inputs are fixed and the gradients are with respect to params. I am a little unclear on how to pass params to vhp. 

In [3]:
ishape = (4)
inputs = torch.rand(ishape)
v = torch.ones(ishape)

In [4]:

vhp(pow_reducer, inputs, v)


(tensor(1.7065), tensor([1.6322, 5.6431, 5.6715, 1.2844]))

In [5]:
print(pow_reducer(inputs))
print(6*inputs)

tensor(1.7065)
tensor([1.6322, 5.6431, 5.6715, 1.2844])


In [6]:
vhp(pow_reducer, inputs, v, create_graph=True)


(tensor(1.7065, grad_fn=<SumBackward0>),
 tensor([1.6322, 5.6431, 5.6715, 1.2844], grad_fn=<MulBackward0>))

In [7]:
nv = 3
inputs = (torch.rand(nv), torch.rand(nv)) #  need an x and a y
v = (torch.ones(nv), torch.ones(nv))
lp, vH = vhp(pow_adder_reducer, inputs, v, create_graph=True)
print('pow_adder_reducer:', lp)
print('check pow_adder_reducer:', pow_adder_reducer(*inputs),'\n')
print('vhp:', vH)

x, y = inputs
H = []
vmat = torch.zeros(2, nv)
for i in range(nv):
    vmat[0,i] = v[0][i]
    vmat[1,i] = v[1][i]
    H.append(torch.tensor([[(24*x[i]**2 + 6*x[i]*y[i]**5), (15*x[i]**2*y[i]**4)],\
                           [(15*x[i]**2*y[i]**4), (18*y[i]+20*x[i]**3*y[i]**3)]]))

vH_check = []
for i in range(nv):
    vH_check.append(vmat[:,i].t() @ H[i])

print(vH_check,'\n')

for i in range(nv):
    print(vH[0][i]-vH_check[i][0], vH[1][i]-vH_check[i][1])

pow_adder_reducer: tensor(4.2575, grad_fn=<SumBackward0>)
check pow_adder_reducer: tensor(4.2575) 

vhp: (tensor([32.3443,  0.7036,  4.0947], grad_fn=<AddBackward0>), tensor([35.9528,  5.5127,  0.5465], grad_fn=<AddBackward0>))
[tensor([32.3443, 35.9528]), tensor([0.7036, 5.5127]), tensor([4.0947, 0.5465])] 

tensor(0., grad_fn=<SubBackward0>) tensor(0., grad_fn=<SubBackward0>)
tensor(0., grad_fn=<SubBackward0>) tensor(0., grad_fn=<SubBackward0>)
tensor(0., grad_fn=<SubBackward0>) tensor(0., grad_fn=<SubBackward0>)


In [8]:
print(4*1.23, 6*4.56)
print(inputs)
print(np.array([19.3102, 49.1303, 32.1819])/(6*4.56))

4.92 27.36
(tensor([0.9320, 0.1704, 0.4131]), tensor([0.8937, 0.3059, 0.0304]))
[0.70578216 1.7956981  1.17623904]


# VHP with respect to model parameters

The following is from a stackoverflow question about applying vhp when models with parameters are involved, as opposed to simpler functions where we don't have to iterate over parameters.

https://stackoverflow.com/questions/68492748/trouble-with-minimal-hvp-on-pytorch-model

This in turn linked to

https://discuss.pytorch.org/t/hvp-w-r-t-model-parameters/83520



In [15]:
def capture_gradients(model):
    g = {}
    for k, v in model.named_parameters():
        gnext = v.grad
        if gnext is not None:
            next_entry = {k: gnext.clone().detach()}
        else:
            next_entry = {k: None}
        g.update(next_entry)
    return g
  

In [16]:

# your loss function
def objective(X):
    return torch.sum(0.25 * torch.sum(X)**4)

# Following are utilities to make nn.Module "functional", in the sense of 
#    being from or compatible with the torch.autograd.functional library. 
#
# borrowed from the link I posted in comment
def del_attr(obj, names): # why, why, why? But it definitely breaks without this. 
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])

def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(model):
    orig_params = tuple(model.parameters())
    orig_grad = capture_gradients(model)
    # Remove all the parameters in the model, because reasons. 
    names = []
    for name, p in list(model.named_parameters()):
        del_attr(model, name.split("."))
        names.append(name)
    return orig_params, orig_grad, names

def restore_model(model, names, params, grad):
    load_weights(model, names, params, as_params=True)
    for k, v in model.named_parameters():
        print('grad is type', type(grad))
        print('grad is', grad)
        if grad[k]:
            v.grad = grad[k].clone().detach()
        else:
            v.grad = None
                


def load_weights(model, names, params, as_params=False):
    for name, p in zip(names, params):
        if not as_params:
            set_attr(model, name.split("."), p)
        else:
            set_attr(model, name.split("."), torch.nn.Parameter(p))
 
        


# This is how we trick vhp into doing the Hessian with respect to params and not other inputs. 
def loss_wrt_params(*new_params):
    load_weights(mlp, names, new_params) # Weird! We removed the params before. 

    x = torch.ones((Arows,))
    out = mlp(x)
    loss = objective(out)
    loss.backward(retain_graph=True)
    return loss

# your simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, out_dim),
        )

    def forward(self, x):
        '''Forward pass'''
        return self.layers(x)

    
    
    

In [17]:
# your model instantiation
Arows = 2
Acols = 2
mlp = SimpleMLP(Arows, Acols)

v_to_dot = tuple([p.clone().detach() for p in mlp.parameters()])



In [18]:
for n, p in mlp.named_parameters():
    print(n,':',p)

layers.0.weight : Parameter containing:
tensor([[ 0.0313, -0.1776],
        [ 0.3404,  0.3687]], requires_grad=True)
layers.0.bias : Parameter containing:
tensor([ 0.2553, -0.6846], requires_grad=True)


In [19]:
#make model's parameters functional
orig_params, orig_grad, names = make_functional(mlp)
params2pass = tuple(p.detach().requires_grad_() for p in orig_params)


In [20]:

loss_value, hessian = torch.autograd.functional.vhp(loss_wrt_params, params2pass, \
                                                    v_to_dot, strict=True)

restore_model(mlp, names, orig_params, orig_grad)
loss_wrt_params(*orig_params) # this calls backward on loss = objective(out)
grad = capture_gradients(mlp)

print('loss:', loss_value)
print('H(params):', hessian)

#L=torch.sum(0.25 * torch.sum(Y)**4)
#Y = W * x + B
# dLdY = sum(Y)
# dYdWij = sum(x)

grad is type <class 'dict'>
grad is {'layers.0.weight': None, 'layers.0.bias': None}
grad is type <class 'dict'>
grad is {'layers.0.weight': None, 'layers.0.bias': None}
loss: tensor(7.9300e-05)
H(params): (tensor([[0.0071, 0.0071],
        [0.0071, 0.0071]]), tensor([0.0071, 0.0071]))


In [21]:
capture_gradients(mlp)

{'layers.0.weight': tensor([[0.0024, 0.0024],
         [0.0024, 0.0024]]),
 'layers.0.bias': tensor([0.0024, 0.0024])}

In [22]:
grad

{'layers.0.weight': tensor([[0.0024, 0.0024],
         [0.0024, 0.0024]]),
 'layers.0.bias': tensor([0.0024, 0.0024])}

In [23]:

# objective(x)

# Argh, what is a leaf node? 

The following is meant to be an example of what leaf nodes are. I think it might be helpful, even though the blog post it came from is super hard to understand and explains itself poorly. http://www.bnikolic.co.uk/blog/pytorch/python/2021/03/15/pytorch-leaf.html

In [18]:

import torch

x=torch.ones(10, requires_grad=True)
y=torch.ones(10, requires_grad=True)

# The remaining nodes are not leaves:
def H(z1, z2):
    ret = torch.sin(z1**3)*torch.cos(z2**2)
    return ret
# dHdz1 = cos(z1**3)*cos(z2**2)*3*z1**2
# dHdz2 = -sin(z1**3)*sin(z2**2)*2*z2

def G(z1, z2):
    return torch.exp(z1)+torch.log(z2)

def F(z1, z2):
    return z1**3*z2**0.5

h=H(x,y)

g=G(x,y)

f=F(h,g)
f.retain_grad()
f.sum().backward() # must sum to get a scalar, otherwise backward() will barf. 
                    #  No gradients are loaded until backawrd() is called. 
   

In [19]:
hchk = np.sin(1)*np.cos(1)
gchk = np.exp(1) + np.log(1)

In [20]:
dFdh = 3*hchk**2 *np.sqrt(gchk)
dFdg = 0.5* hchk**3 / np.sqrt(gchk)

In [21]:
dgdx = np.exp(1)
dgdy = 1.0
dhdx = np.cos(1)*np.cos(1)*3
dhdy = -np.sin(1)*np.sin(1)*2
#-sin(z1**3)*sin(z2**2)*2*z2


In [22]:
dFdx = dFdh*dhdx + dFdg*dgdx
dFdy = dFdh*dhdy + dFdg*dgdy
print('dFdx =', dFdx)
print('dFdy =', dFdy)


dFdx = 0.9728684287087229
dFdy = -1.419366770452894


In [23]:
print(dFdx - x.grad)
print(dFdy - y.grad)

tensor([-1.1921e-07, -1.1921e-07, -1.1921e-07, -1.1921e-07, -1.1921e-07,
        -1.1921e-07, -1.1921e-07, -1.1921e-07, -1.1921e-07, -1.1921e-07])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [24]:
f.grad

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [25]:
a = torch.tensor([2., 3.], requires_grad=True)
b = torch.tensor([6., 4.], requires_grad=True)

In [26]:
Q = 3*a**3 - b**2

In [27]:
Q.retain_grad()

In [28]:
Q.sum().backward()

In [29]:
a.grad


tensor([36., 81.])

In [30]:
b.grad

tensor([-12.,  -8.])