In [10]:
import torch
from torch.func import vmap, jacrev

In [5]:
def final_loss(y1, y2):
    return (y1 - y2).pow(2).mean()

def build_full_loss(final_loss, model, x, y):
    def full_loss(*params):
        output = torch.func.functional_call(model, {k: p for (k, v), p in zip(model.named_parameters(), params)}, x)
        return final_loss(output, y)
    return full_loss

In [2]:
dim_in = 4
dim_out = 3
batch_size = 7

x = torch.randn(batch_size, dim_in)
y = torch.randn(batch_size, dim_out)

m = torch.nn.Linear(dim_in, dim_out)

In [9]:
m.zero_grad()

output = m(x)
loss = final_loss(output, y)
loss.backward()

for p in m.parameters():
    print(p.grad)

tensor([[ 0.0036,  0.3164, -0.1004, -0.2653],
        [-0.0085, -0.0053, -0.0745, -0.2524],
        [ 0.5288,  0.2594,  0.4720, -0.5616]])
tensor([ 0.2808,  0.1608, -0.0993])


In [27]:
m.zero_grad()

full_loss = build_full_loss(final_loss, m, x, y)

jacrev(full_loss, argnums = (0, 1))(*m.parameters())

(tensor([[ 0.0036,  0.3164, -0.1004, -0.2653],
         [-0.0085, -0.0053, -0.0745, -0.2524],
         [ 0.5288,  0.2594,  0.4720, -0.5616]], grad_fn=<ViewBackward0>),
 tensor([ 0.2808,  0.1608, -0.0993], grad_fn=<ViewBackward0>))