In [11]:
import torch
torch.manual_seed(42)

class Net(torch.nn.Module):

    def __init__(self, indim=1, outdim=1):
        super().__init__()
        self.actf = torch.tanh
        self.lin1 = torch.nn.Linear(indim, 100)
        self.lin2 = torch.nn.Linear(100, outdim)

    def forward(self, x):
        x = self.lin1(x)
        x = self.lin2(self.actf(x))
        return x.squeeze()

In [12]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda


In [13]:
model_vec = Net(indim=3, outdim=3).to(DEVICE)

In [14]:
x = torch.rand(3)
x = x.requires_grad_().to(DEVICE)
print(x)

tensor([0.0806, 0.6256, 0.0947], device='cuda:0', grad_fn=<ToCopyBackward0>)


In [15]:
grad = torch.autograd.grad(model_vec(x), x, grad_outputs=torch.ones_like(model_vec(x)), create_graph=True, allow_unused=True)[0]
print(grad)

tensor([ 0.1809,  0.2170, -0.1783], device='cuda:0', grad_fn=<ViewBackward0>)


In [16]:
from torch.func import jacrev
jac = jacrev(model_vec)(x)
print(jac)

tensor([[-0.0622,  0.0219,  0.1662],
        [ 0.1229,  0.3468, -0.2792],
        [ 0.1202, -0.1517, -0.0654]], device='cuda:0', grad_fn=<ViewBackward0>)


We can see that $0.1809 = -0.0622 + 0.1229 + 0.1202$

In [19]:
print(grad[0])
print(jac.T[0].sum())

tensor(0.1809, device='cuda:0', grad_fn=<SelectBackward0>)
tensor(0.1809, device='cuda:0', grad_fn=<SumBackward0>)


In [25]:
torch.manual_seed(42)
z = torch.rand(2, 3)
z = z.requires_grad_().to(DEVICE)
print(z)

tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009]], device='cuda:0', grad_fn=<ToCopyBackward0>)


As we expect for each element of batch its Jacobian will be calculated, we expect new tensor whose size is [2, 3, 3] will result when we use `jacrev` over `model_vec` and input `z`. 

However, when tested, resulting tensor has the shape of [2, 3, 2, 3], with unwanted zero-vectors as row vectors.

In [29]:
jac = jacrev(model_vec)(z)
print(jac)
print(jac.shape)

tensor([[[[ 0.0026,  0.0146,  0.2084],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.1259,  0.3053, -0.2887],
          [ 0.0000,  0.0000,  0.0000]],

         [[ 0.0721, -0.1514, -0.1038],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [ 0.0207, -0.0149,  0.2170]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.1575,  0.3220, -0.2883]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0639, -0.1551, -0.1053]]]], device='cuda:0',
       grad_fn=<ViewBackward0>)
torch.Size([2, 3, 2, 3])


To solve such a problem, batched-input processing, we can use `vmap` function from `torch.func` module.

In [30]:
from torch.func import jacrev, vmap

jac_vmap = vmap(jacrev(model_vec))(z)
print(jac_vmap)
print(jac_vmap.shape)

tensor([[[ 0.0026,  0.0146,  0.2084],
         [ 0.1259,  0.3053, -0.2887],
         [ 0.0721, -0.1514, -0.1038]],

        [[ 0.0207, -0.0149,  0.2170],
         [ 0.1575,  0.3220, -0.2883],
         [ 0.0639, -0.1551, -0.1053]]], device='cuda:0',
       grad_fn=<ViewBackward0>)
torch.Size([2, 3, 3])
