In [18]:
import torch
from torch.autograd.functional import jacobian

In [128]:
def func(input_data):
    """
    Takes a tensor of shape [batch_size, input_size] and produces
    a tensor of shape [batch_size, output_size].
    """
    assert input_data.ndim in [1, 2]
    a = input_data.sum(dim = -1) + 4
    b = input_data.prod(dim = -1)
    return torch.stack((a, b), dim = -1)

print(func(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])))
print(func(torch.tensor([2.0, 3.0, 4.0])))

tensor([[ 10.,   6.],
        [ 19., 120.],
        [ 28., 504.]])
tensor([13., 24.])


In [31]:
jacobian(func, torch.tensor([2.0, 3.0]))

tensor([[1., 1.],
        [3., 2.]])

In [33]:
jacobian(func, torch.tensor([[2.0, 3.0], [4.0, 5.0]]))

tensor([[[[1., 1.],
          [0., 0.]],

         [[3., 2.],
          [0., 0.]]],


        [[[0., 0.],
          [1., 1.]],

         [[0., 0.],
          [5., 4.]]]])

In [63]:
data = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]])
data.requires_grad = True
result = func(data)
print(result)
print(torch.autograd.grad(result[:, 0], data, grad_outputs = torch.ones_like(result[:, 0]), retain_graph = True))
print(torch.autograd.grad(result[:, 1], data, grad_outputs = torch.ones_like(result[:, 1])))

tensor([[ 5.,  6.],
        [ 9., 20.],
        [13., 42.]], grad_fn=<StackBackward>)
(tensor([[1., 1.],
        [1., 1.],
        [1., 1.]]),)
(tensor([[3., 2.],
        [5., 4.],
        [7., 6.]]),)


In [131]:
def calc_jacobian(func, input_data):
    """
    Calculates the output and the Jacobian of the function at the given 
    input data. The input data is of shape [batch_size, input_size], 
    while the output data is of shape [batch_size, output_size] and the
    Jacobian is of size [batch_size, input_size, output_size].
    """
    assert input_data.ndim in [1, 2]
    reshaped = input_data.ndim == 1
    if reshaped:
        input_data = input_data.reshape((1, -1))
    input_data.requires_grad = True
    output_data = func(input_data)
    jacobian_data = torch.empty(
        (output_data.shape[0], input_data.shape[1], output_data.shape[1]))
    for i in range(output_data.shape[1]):
        jacobian_data[:, :, i] = torch.autograd.grad(
            output_data[:, i], 
            input_data,
            torch.ones(input_data.shape[0]),
            retain_graph = i + 1 < output_data.shape[1])[0]
    if reshaped:
        output_data = output_data.squeeze(dim=0)
        jacobian_data = jacobian_data.squeeze(dim=0)
    print(output_data.shape, jacobian_data.shape)
    return output_data, jacobian_data

calc_jacobian(func, torch.tensor([[1.0, 2.0, 3.0, 4.0], [4.0, 5.0, 6.0, 7.0], [2.0, 3.0, 4.0, 5.0]]))
calc_jacobian(func, torch.tensor([1.0, 2.0, 3.0, 4.0]))

torch.Size([3, 2]) torch.Size([3, 4, 2])
torch.Size([2]) torch.Size([4, 2])


(tensor([14., 24.], grad_fn=<SqueezeBackward1>), tensor([[ 1., 24.],
         [ 1., 12.],
         [ 1.,  8.],
         [ 1.,  6.]]))