In [2]:
# script for https://github.com/pytorch/functorch/issues/801

import time
from functorch import vmap, jacrev, jacfwd, vjp
import torch
import torch.nn as nn

_ = torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D1 = 2  # x, y
D2 = 3  # u, v, p
B = 100

model = nn.Sequential(
    nn.Linear(D1, 512), nn.Tanh(),
    nn.Linear(512, D2),
).to(device)


def predict(x):
    return model(x)

In [3]:
# Documentation example
# works when batch size is 1
x = torch.randn(D1).to(device)
x_ = x.clone().requires_grad_()
y = predict(x_)
I_N = torch.eye(D2).to(device)

def get_vjp(v):
    return torch.autograd.grad(y, x_, v)

jacobian, = vmap(get_vjp)(I_N)
print(jacobian.shape)  # torch.Size([3, 2])

torch.Size([3, 2])


In [4]:
# Not work for batched input
x = torch.randn(B, D1).to(device)
x_ = x.clone().requires_grad_()
y = predict(x_)
I_N = torch.eye(D2).to(device)

def get_vjp(y, x, v):
    print(x.requires_grad)  # print False
    print(y.requires_grad)  # print False
    return torch.autograd.grad(y, x, v)

# jacobian = vmap(vmap(get_vjp, in_dims=(None, None, 0)), in_dims=(0, 0, None))(y, x_, I_N)
# # Got error element 0 of tensors does not require grad and does not have a grad_fn


In [5]:
3, 4
4, 6

3, 6

(3, 6)

In [6]:
x = torch.randn(B, D1).to(device)
I_N = torch.eye(D2).to(device)

def get_vjp(x, v):
    (_, vjpfunc) = vjp(predict, x)
    return vjpfunc(v)[0]

jacobian = vmap(vmap(get_vjp, in_dims=(None, 0)), in_dims=(0, None))(x, I_N[[0, 2]])
print(jacobian.shape)  # [100, 3, 2]

torch.Size([100, 2, 2])


jacobian

In [35]:
x = torch.randn(B, D1).to(device)
I_N = torch.eye(D2).to(device)

def get_vjp(x, v):
    print(x.shape)
    (pred, vjpfunc) = vjp(predict, x)
    print(pred.shape)
    return vjpfunc(v)[0], pred

jacobian, pred = vmap(vmap(get_vjp, in_dims=(None, 0)), in_dims=(0, None))(x, I_N[[0, 2]])
print(jacobian.shape)
print(pred.shape)

torch.Size([2])
torch.Size([3])
torch.Size([100, 2, 2])
torch.Size([100, 2, 3])


In [46]:
x = torch.randn(B, D1).to(device)
I_N = torch.eye(D2).to(device)

def get_vjp(x, v):
    print(x.shape)
    (pred, vjpfunc) = vjp(predict, x)
    print(pred.shape)
    print(pred.stride())
    return vjpfunc(v)[0], pred

jacobian, pred = vmap(vmap(get_vjp, in_dims=(None, 0)), in_dims=(0, None))(x, I_N[[0, 2]])
print(jacobian.shape)
print(pred.shape)
print(pred.stride())

torch.Size([2])
torch.Size([3])
(1,)
torch.Size([100, 2, 2])
torch.Size([100, 2, 3])
(3, 0, 1)


hessian

In [36]:
# x = torch.randn(B, D1).to(device)
# I_N = torch.eye(D2).to(device)

# def get_vjp(x, v):
#     print(x.shape)
#     (pred, vjpfunc) = vjp(predict, x)
#     print(pred.shape)
#     jacobian, hessianfunc = vjp(vjpfunc, x)
#     print(jacobian.shape)
#     hess = hessianfunc(v)[0]

#     return hess, jacobian, pred

# hess, jacobian, pred = vmap(vmap(get_vjp, in_dims=(None, 0)), in_dims=(0, None))(x, I_N)
# print(jacobian.shape)
# print(pred.shape)

get pred, jacobian and hessian together

In [50]:
x = torch.randn(B, D1).to(device)
I_N1 = torch.eye(D2).to(device)
I_N2 = torch.eye(D1).to(device)

def get_vjp(x, v1, v2):
    print(x.shape)
    def jacofunc(x):
        (pred, vjpfunc) = vjp(predict, x)
        return vjpfunc(v1)[0], pred
    (jacobian, hessianfunc, pred) = vjp(jacofunc, x, has_aux=True)
    print(pred.shape)
    print(jacobian.shape)

    hess = hessianfunc(v2)[0]

    return hess, jacobian, pred

hess, jacobian, pred= vmap(vmap(vmap(get_vjp, in_dims=(None, None, 0)), in_dims=(None, 0, None)), in_dims=(0, None, None))(x, I_N1, I_N2)
print(hess.shape)
print(hess.stride())
print(jacobian.shape)
print(jacobian.stride())
print(pred.shape)
print(pred.stride())

torch.Size([2])
torch.Size([3])
torch.Size([2])
torch.Size([100, 3, 2, 2])
(12, 4, 2, 1)
torch.Size([100, 3, 2, 2])
(6, 2, 0, 1)
torch.Size([100, 3, 2, 3])
(3, 0, 0, 1)


In [51]:
x = torch.randn(B, D1).to(device)
I_N1 = torch.eye(D2).to(device)
I_N2 = torch.eye(D1).to(device)

def get_vjp(x, v1, v2):
    print(x.shape)
    def jacofunc(x):
        (pred, vjpfunc) = vjp(predict, x)
        return vjpfunc(v1)[0], pred
    def selected_jac(x):
        jac, pred = jacofunc(x)
        return jac[:2,:], pred
    (jacobian, hessianfunc, pred) = vjp(selected_jac, x, has_aux=True)
    print(pred.shape)
    print(jacobian.shape)

    hess = hessianfunc(v2)[0]

    return hess, jacobian, pred

hess, jacobian, pred= vmap(vmap(vmap(get_vjp, in_dims=(None, None, 0)), in_dims=(None, 0, None)), in_dims=(0, None, None))(x, I_N1, I_N2)
print(hess.shape)
print(hess.stride())
print(jacobian.shape)
print(jacobian.stride())
print(pred.shape)
print(pred.stride())

torch.Size([2])
torch.Size([3])
torch.Size([2])
torch.Size([100, 3, 2, 2])
(12, 4, 2, 1)
torch.Size([100, 3, 2, 2])
(6, 2, 0, 1)
torch.Size([100, 3, 2, 3])
(3, 0, 0, 1)
