In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from functools import partial
from functorch import (
    make_functional_with_buffers, vmap, grad,
)

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
    )
])

batch_size = 1

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
)

classes = (
    'plane', 'car', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
)

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
model = Classifier().to(device="cuda")
functional_model, params, buffers = make_functional_with_buffers(model)

In [None]:
def sgd_optimizer(weights, gradients, learning_rate):
    return [weight - learning_rate * gradient for weight, gradient in zip(weights, gradients)]


def compute_loss_stateless(params, buffers, sample, target):
    sample = sample.unsqueeze(0)
    target = target.unsqueeze(0)
    output = functional_model(params, buffers, sample) 
    return F.cross_entropy(output, target)


compute_gradient = vmap(grad(
    compute_loss_stateless),
    in_dims=(None, None, 0, 0)
    # (params, buffers, sample dim 0, target dim 0)
)

In [None]:
def functional_step(input, target, weights, buffers):
    weights = [weight.detach().requires_grad_() for weight in weights]
    gradients = compute_gradient(params, buffers, input, target)
    new_weights = sgd_optimizer(weights, gradients, 1e-3)
    return new_weights

In [None]:
def train(train_step_fn, weights, buffers):
    for i, data in tqdm(enumerate(train_loader, 0)):
        input, target = data
        input = input.to("cuda")
        target = target.to("cuda")
        weights = train_step_fn(input, target, weights, buffers)
    return weights

In [None]:
for epoch in range(2):
    weights = train(functional_step, params, buffers)