In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from copy import deepcopy
from functools import partial
from itertools import combinations, product
from typing import Any, Dict, List, Tuple

import functorch as ft
import torch
from datasets import load_dataset
from torch.autograd import grad
from torch.autograd.functional import hessian, jacobian
from torch.nn.utils.stateless import functional_call
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
)

### Compute second-order derivative

The input space is the following
$$x \in \mathbb{R}^2$$

The output space is
$$y \in \mathbb{R}$$

The function $f$ is a scalar function
$$f: \mathbb{R}^2 \to \mathbb{R}$$

defined as
$$f(x) \coloneqq Wx + b$$

and the loss function is the MSE
$$l(x, y; W) \coloneqq (f_W(x) - y)^2$$

For convenience define
$$D = (x, y)$$

$$\theta = (W, b)$$

In [3]:
inputs = torch.rand((2,), requires_grad=True)
targets = torch.rand((1,), requires_grad=True)

model = torch.nn.Sequential(
    # torch.nn.Linear(2, 2),
    torch.nn.Linear(2, 1),
)

In [4]:
# make model functional manually
params_dict = {k: v.detach().requires_grad_() for k, v in model.named_parameters()}
names = deepcopy(tuple(params_dict.keys()))
params_tuple = deepcopy(tuple(params_dict.values()))

# make model functional using functorch
func_model, params = ft.make_functional(model)

In [5]:
def loss(inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Use `functional_call` to make model functional."""
    preds = model(inputs)
    return torch.mean((preds - targets) ** 2)


def functional_call_loss(
    param_dict: Dict[str, torch.Tensor], inputs: torch.Tensor, targets: torch.Tensor
) -> torch.Tensor:
    """Use `functional_call` to make model functional."""
    preds = functional_call(model, param_dict, inputs)
    return torch.mean((preds - targets) ** 2)


def functional_call_loss_tuples(
    *args, names: List[str], inputs: torch.Tensor, targets: torch.Tensor
) -> torch.Tensor:
    """Use `functional_call` but use *args to parametrize loss."""
    param_dict = dict(zip(names, args))
    preds = functional_call(model, param_dict, inputs)
    return torch.mean((preds - targets) ** 2)


def functorch_loss(
    params: Tuple[torch.nn.parameter.Parameter],
    inputs: torch.Tensor,
    targets: torch.Tensor,
) -> torch.Tensor:
    """Use functorch to make model functional."""
    preds = func_model(params, inputs)
    return torch.mean((preds - targets) ** 2)

In [6]:
(
    loss(inputs, targets),
    functional_call_loss(params_dict, inputs, targets),
    functional_call_loss_tuples(
        *params_tuple, names=names, inputs=inputs, targets=targets
    ),
    functorch_loss(params, inputs=inputs, targets=targets),
)

(tensor(0.7264, grad_fn=<MeanBackward0>),
 tensor(0.7264, grad_fn=<MeanBackward0>),
 tensor(0.7264, grad_fn=<MeanBackward0>),
 tensor(0.7264, grad_fn=<MeanBackward0>))

In [7]:
%timeit loss(inputs, targets)

35.6 µs ± 552 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
%timeit functional_call_loss(params_dict, inputs, targets)

91 µs ± 3.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
%timeit functional_call_loss_tuples(*params_tuple, names=names, inputs=inputs, targets=targets)

95.1 µs ± 2.66 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
%timeit functorch_loss(params, inputs=inputs, targets=targets)

80.2 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


---

### Derivatives wrt parameters

Compute 
$\nabla_W L(D; \theta)$
and
$\nabla_b L(D; \theta)$

For the definition of scalar-by-vector derivative look [here](https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-vector)


Analytical solutions

$$\frac{d L}{d W} = \left[ \frac{d L}{d w_1}, \; \frac{d L}{d w_2}\right] = \left[2 x_1(Wx + b - y) , \; 2 x_2 (Wx + b - y)\right]$$

$$\frac{d L}{d b} = 2 (Wx + b - y) (+1)$$

In [11]:
with torch.no_grad():
    dw = 2 * (model(inputs) - targets) * inputs
    db = 2 * (model(inputs) - targets)

dw, db

(tensor([-0.4424, -0.5870]), tensor([-1.7045]))

In [12]:
%timeit grad(loss(inputs, targets), model.parameters())

grad(loss(inputs, targets), model.parameters())

161 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


(tensor([[-0.4424, -0.5870]]), tensor([-1.7045]))

In [13]:
%timeit grad(functional_call_loss(params_dict, inputs, targets), params_dict.values())

grad(functional_call_loss(params_dict, inputs, targets), params_dict.values())

241 µs ± 7.03 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


(tensor([[-0.4424, -0.5870]]), tensor([-1.7045]))

In [14]:
%timeit grad(functional_call_loss_tuples(*params_tuple, names=names, inputs=inputs, targets=targets), params_tuple)

grad(
    functional_call_loss_tuples(
        *params_tuple, names=names, inputs=inputs, targets=targets
    ),
    params_tuple,
)

245 µs ± 1.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


(tensor([[-0.4424, -0.5870]]), tensor([-1.7045]))

In [15]:
%timeit ft.grad(functorch_loss)(params, inputs, targets)

ft.grad(functorch_loss)(params, inputs, targets)

403 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


(tensor([[-0.4424, -0.5870]], grad_fn=<TBackward0>),
 tensor([-1.7045], grad_fn=<MulBackward0>))

---

### Second-order derivatives wrt parameters

Compute $\nabla_W^2 L(D; \theta)$ and $\nabla_b^2 L(D; \theta)$

For the definition of scalar-by-vector derivative look [here](https://en.wikipedia.org/wiki/Matrix_calculus#Scalar-by-vector)


Analytical solutions

$$
\frac{\partial^2 L}{\partial^2 w_1} = 2x_1^2
$$

$$
\frac{\partial^2 L}{\partial^2 w_2} = 2x_2^2
$$

$$
\frac{\partial^2 L}{\partial^2 b} = 2
$$

$$
\frac{\partial^2 L}{\partial w_1 \; \partial w_2} = \frac{\partial^2 L}{\partial w_2 \; \partial w_1} 2x_1 x_2
$$

$$
\frac{\partial^2 L}{\partial w_1 \; \partial b} = \frac{\partial^2 L}{\partial b \; \partial w_1} = 2x_1
$$

$$
\frac{\partial^2 L}{\partial w_2 \; \partial b} = \frac{\partial^2 L}{\partial b \; \partial w_2} = 2x_2
$$


In [16]:
with torch.no_grad():
    print(
        2 * (inputs**2),
        torch.tensor(2),
        2 * inputs.prod(),
        2 * inputs,
    )

tensor([0.1347, 0.2372]) tensor(2) tensor(0.1788) tensor([0.5191, 0.6888])


In [17]:
def my_hessian():
    first_order_grads = grad(
        loss(inputs, targets), model.parameters(), create_graph=True
    )

    # with torch.no_grad():
    return [
        [grad(elem, param, create_graph=True)[0] for elem in g.flatten()]
        for param, g in product(model.parameters(), first_order_grads)
    ]

In [18]:
%timeit my_hessian()

my_hessian()

975 µs ± 7.69 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


[[tensor([[0.1347, 0.1788]], grad_fn=<TBackward0>),
  tensor([[0.1788, 0.2372]], grad_fn=<TBackward0>)],
 [tensor([[0.5191, 0.6888]], grad_fn=<TBackward0>)],
 [tensor([0.5191], grad_fn=<MulBackward0>),
  tensor([0.6888], grad_fn=<MulBackward0>)],
 [tensor([2.], grad_fn=<MulBackward0>)]]

In [19]:
def my_hessian_functional_call():
    first_order_grads = grad(
        functional_call_loss(params_dict, inputs, targets),
        params_dict.values(),
        create_graph=True,
    )

    # with torch.no_grad():
    return [
        [grad(elem, param, create_graph=True)[0] for elem in g.flatten()]
        for param, g in product(params_dict.values(), first_order_grads)
    ]

In [20]:
%timeit my_hessian_functional_call()

my_hessian_functional_call()

1.12 ms ± 37.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


[[tensor([[0.1347, 0.1788]], grad_fn=<TBackward0>),
  tensor([[0.1788, 0.2372]], grad_fn=<TBackward0>)],
 [tensor([[0.5191, 0.6888]], grad_fn=<TBackward0>)],
 [tensor([0.5191], grad_fn=<MulBackward0>),
  tensor([0.6888], grad_fn=<MulBackward0>)],
 [tensor([2.], grad_fn=<MulBackward0>)]]

In [21]:
def l(*params):
    return functional_call_loss_tuples(
        *params, names=names, inputs=inputs, targets=targets
    )


def torch_hessian(loss, params):
    return hessian(loss, params)

In [22]:
%timeit torch_hessian(l, params_tuple)

torch_hessian(l, params_tuple)

950 µs ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


((tensor([[[[0.1347, 0.1788]],
  
           [[0.1788, 0.2372]]]]),
  tensor([[[0.5191],
           [0.6888]]])),
 (tensor([[[0.5191, 0.6888]]]), tensor([[2.]])))

In [23]:
def functorch_hessian():
    return ft.hessian(functorch_loss)(params, inputs, targets)

In [24]:
%timeit functorch_hessian()

functorch_hessian()

1.57 ms ± 47.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


((tensor([[[[0.1347, 0.1788]],
  
           [[0.1788, 0.2372]]]], grad_fn=<ViewBackward0>),
  tensor([[[0.5191],
           [0.6888]]], grad_fn=<ViewBackward0>)),
 (tensor([[[0.5191, 0.6888]]], grad_fn=<ViewBackward0>),
  tensor([[2.]], grad_fn=<ViewBackward0>)))

---
### Hessian-vector product

In [25]:
first_order_grads = grad(
    loss(inputs, targets), model.parameters(), create_graph=True
)
hvp = grad(first_order_grads, model.parameters(), grad_outputs=first_order_grads)
hvp

(tensor([[-1.0494, -1.3924]]), tensor([-4.0431]))

Check analytically if this is correct

In [26]:
x_1, x_2 = inputs

# manually construct the Hessian
H = torch.tensor(
    [
        [2 * x_1**2, 2 * x_1 * x_2, 2 * x_1],
        [2 * x_1 * x_2, 2 * x_2**2, 2 * x_2],
        [2 * x_1, 2 * x_2, 2],
    ]
)

v = grad(loss(inputs, targets), model.parameters())
v = torch.cat(tuple(i.flatten() for i in v))

# multiply by gradients
torch.matmul(H, v)

tensor([-1.0494, -1.3924, -4.0431])

---
### Check how `grad` works

In [29]:
first_order_grads = grad(
    functional_call_loss(params_dict, inputs, targets),
    params_dict.values(),
    create_graph=True,
)

print("correct")
for param, grads in zip(params_dict.values(), first_order_grads):
    flat_grad = grads.flatten()
    # this gives the right answers
    for idx in range(len(flat_grad)):
        print(grad(flat_grad[idx], param, create_graph=True))

print("\nwrong: it sums the gradients")
for param, grads in zip(params_dict.values(), first_order_grads):
    flat_grad = grads.flatten()
    # this does not work since it sums the gradients in the tuple
    print(grad(tuple(flat_grad), param, create_graph=True))

correct
(tensor([[0.1347, 0.1788]], grad_fn=<TBackward0>),)
(tensor([[0.1788, 0.2372]], grad_fn=<TBackward0>),)
(tensor([2.], grad_fn=<MulBackward0>),)

wrong: it sums the gradients
(tensor([[0.3135, 0.4160]], grad_fn=<TBackward0>),)
(tensor([2.], grad_fn=<MulBackward0>),)


In [30]:
inputs = torch.rand((1, 2), requires_grad=True)
targets = torch.rand((1, 1), requires_grad=True)


first_order_grads = grad(
    loss(inputs, targets), model.parameters(), create_graph=True
)
hvp = grad(first_order_grads, model.parameters(), grad_outputs=first_order_grads)
hvp

(tensor([[-1.0215, -2.4176]]), tensor([-3.9874]))