In [1]:
!pip install torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio==0.11.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
!pip install functorch

Looking in links: https://download.pytorch.org/whl/cpu/torch_stable.html
Collecting torch==1.11.0+cpu
  Downloading https://download.pytorch.org/whl/cpu/torch-1.11.0%2Bcpu-cp37-cp37m-linux_x86_64.whl (169.1 MB)
[K     |████████████████████████████████| 169.1 MB 36 kB/s 
[?25hCollecting torchvision==0.12.0+cpu
  Downloading https://download.pytorch.org/whl/cpu/torchvision-0.12.0%2Bcpu-cp37-cp37m-linux_x86_64.whl (14.7 MB)
[K     |████████████████████████████████| 14.7 MB 50.2 MB/s 
[?25hCollecting torchaudio==0.11.0+cpu
  Downloading https://download.pytorch.org/whl/cpu/torchaudio-0.11.0%2Bcpu-cp37-cp37m-linux_x86_64.whl (2.7 MB)
[K     |████████████████████████████████| 2.7 MB 33.9 MB/s 
Installing collected packages: torch, torchvision, torchaudio
  Attempting uninstall: torch
    Found existing installation: torch 1.10.0+cu111
    Uninstalling torch-1.10.0+cu111:
      Successfully uninstalled torch-1.10.0+cu111
  Attempting uninstall: torchvision
    Found existing installation

In [2]:
print("--> Restarting colab instance") 
get_ipython().kernel.do_shutdown(True)

--> Restarting colab instance


{'restart': True, 'status': 'ok'}

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

from functorch import make_functional, vmap

import numpy as np
import matplotlib.pyplot as plt

In [9]:
BATCH_SIZE = 16
FEATURES = 64
NUM_LAYERS = 8
LEARNING_RATE = 0.1
TRAIN_STEPS = 100

In [10]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        model = []
        for _ in range(NUM_LAYERS):
            model.append(nn.Linear(FEATURES, FEATURES, bias=False))
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return (self.model(x)**2).sum()

In [11]:
input_data = torch.randn((BATCH_SIZE, FEATURES))

model = SimpleModel()
jit_model = torch.jit.script(model)

functional_model, parameters = make_functional(model)

In [12]:
optimizer = torch.optim.SGD(
    parameters, lr=LEARNING_RATE,
    momentum=0, dampening=0, weight_decay=0
)

In [13]:
def train_step_functional(data, params):
    params = [param.detach().requires_grad_() for param in params]
    out = functional_model(params, data)
    out.backward()
    updated_params = [param - LEARNING_RATE * param.grad for param in params]
    return out, updated_params

In [14]:
def train(train_step_fn, params):
    torch.manual_seed(16)
    train_step_fn(input_data, params)
    for step in range(TRAIN_STEPS):
        loss, params = train_step_fn(
            torch.randn(BATCH_SIZE, FEATURES), params
        )
        if step % 10 == 0:
            print(f"Loss at Step {step}: {loss}")

In [15]:
train(train_step_functional, parameters)

Loss at Step 0: 0.13981828093528748
Loss at Step 10: 0.043717481195926666
Loss at Step 20: 0.01685665361583233
Loss at Step 30: 0.015003936365246773
Loss at Step 40: 0.014027919620275497
Loss at Step 50: 0.009594528004527092
Loss at Step 60: 0.009202130138874054
Loss at Step 70: 0.008044816553592682
Loss at Step 80: 0.007874097675085068
Loss at Step 90: 0.006949734874069691
