In [1]:
import torch

In [2]:
torch.__version__
torch.manual_seed(0)

<torch._C.Generator at 0x7ac2f7f2b250>

## vmap

In [3]:
x: torch.Tensor = torch.randn(100)
y: torch.Tensor = torch.randn(100)
x_dot_y = torch.dot(x, y)
print(f"{x_dot_y=}")

x_dot_y=tensor(-11.9096)


In [4]:
x: torch.Tensor = torch.randn(10, 100)
y: torch.Tensor = torch.randn(10, 100)
x_dot_y = torch.dot(x, y)

RuntimeError: 1D tensors expected, but got 2D and 2D tensors

In [None]:
x: torch.Tensor = torch.randn(10, 100)
y: torch.Tensor = torch.randn(10, 100)
x_dot_y = torch.ones(10)
for i in range(x.shape[0]):
    x_dot_y[i] = torch.dot(x[i], y[i])
print(f"{x_dot_y=}")

In [None]:
batched_dot_product = torch.func.vmap(torch.dot)
x_dot_y_using_vmap = batched_dot_product(x, y)
print(f"{x_dot_y_using_vmap=}")

In [None]:
assert torch.allclose(x_dot_y, x_dot_y_using_vmap)

## grad

In [None]:
sin_x = lambda x: torch.sin(x)
grad_sin_x = torch.func.grad(sin_x)
x = torch.randn([])
assert torch.allclose(grad_sin_x(x), x.cos())

In [None]:
grad_grad_sin_x = torch.func.grad(grad_sin_x)
assert torch.allclose(grad_grad_sin_x(x), -x.sin())

# grad + vmap

In [None]:
from torch.func import grad, vmap

batch_size, feature_size = 3, 5

def model(weights: torch.Tensor, feature_vec: torch.Tensor) -> torch.Tensor:
    return feature_vec.dot(weights).relu()

def compute_loss(weights: torch.Tensor, example: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    y = model(weights, example)
    return ((y - target) ** 2).mean()  # MSELoss'

weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)

In [None]:
inputs = (weights, examples, targets)
grad_of_loss = grad(compute_loss)
grad_of_loss_per_sample = vmap(grad_of_loss, in_dims=(None, 0, 0))

grad_weight_per_example = grad_of_loss_per_sample(*inputs)
print(grad_weight_per_example)

# functional_call

In [None]:
x = torch.randn(4, 3)
t = torch.randn(4, 3)

model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
y = torch.func.functional_call(model, params, x)

assert torch.allclose(y, model(x))

In [None]:
def compute_loss(
    params: dict[str, torch.Tensor], x: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
    y = torch.func.functional_call(model, params, x)
    return torch.nn.functional.mse_loss(y, t)


grad_of_loss = grad(compute_loss)
grad_weights = grad_of_loss(dict(model.named_parameters()), x, t)

In [None]:
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# stack_module_state

In [None]:
def forward_call(params, buffers, data):
    return torch.func.functional_call(models[0], (params, buffers), data)

vmap_forward_call = vmap(forward_call, (0, 0, None))

params, buffers = torch.func.stack_module_state(models)

output = vmap_forward_call(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)