In [None]:
import torch
from torch.autograd.functional import jacobian
from torch.func import jacrev, vmap, jacfwd
from lib.trainer.timer import TimeTracker
from lib.model.layers import MLP
from tqdm import tqdm

device = "cuda"
steps = 10
n_unknowns = 159
n_residuals = 1135
hidden_dim = 1000
num_layers = 8


tracker = TimeTracker()

params = torch.randn((n_unknowns), requires_grad=True, device=device)
params_no_grad = torch.randn((n_unknowns), requires_grad=False, device=device)
mlp = MLP(
    in_dim=n_unknowns,
    out_dim=n_residuals,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
).to(device)
print("MLP sum params:", sum(p.numel() for p in mlp.parameters()))


def f(x):
    residuals = mlp(x)
    return residuals

def f_no_grad(x):
    with torch.no_grad():
        residuals = mlp(x)
    return residuals

for _ in tqdm(range(steps)):
    tracker.start("reverse_jacobian_vectorize_graph")
    jacobian(f, params, strategy="reverse-mode", vectorize=True, create_graph=True)
    tracker.start("reverse_jacobian_vectorize", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=True, create_graph=False)
    tracker.start("reverse_jacobian_graph", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=False, create_graph=True)
    tracker.start("reverse_jacobian", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=False, create_graph=False)
    tracker.start("forward_jacobian_vectorize", stop=True)
    jacobian(f, params, strategy="forward-mode", vectorize=True, create_graph=False)
    tracker.start("reverse_vmapjac", stop=True)
    vmap(jacrev(f))(params.unsqueeze(0))[0]
    tracker.start("reverse_jac", stop=True)
    jacrev(f)(params)
    tracker.start("reverse_jac_chunk10", stop=True)
    jacrev(f, chunk_size=10)(params)
    tracker.start("reverse_jac_chunk1000", stop=True)
    jacrev(f, chunk_size=1000)(params)
    tracker.start("reverse_jac_chunk10000", stop=True)
    jacrev(f, chunk_size=10000)(params)
    tracker.start("reverse_jac_chunk100000", stop=True)
    jacrev(f, chunk_size=100000)(params)
    tracker.start("reverse_jac_no_grad", stop=True)
    jacrev(f_no_grad)(params)
    tracker.start("forward_jac", stop=True)
    jacfwd(f)(params)
    tracker.start("forward_jac_no_grad", stop=True)
    jacfwd(f_no_grad)(params)
    tracker.start("forward_jac_no_grad", stop=True)
    jacfwd(f_no_grad)(params)
    tracker.stop()

print(tracker.print_summary())

In [None]:
print(tracker.print_summary())

In [None]:
jacrev(f)(params)

In [None]:
jacfwd(f)(params)