Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions aot_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import torch
import torch.fx as fx
import torch.nn as nn
import torch.optim as optim

from typing import List

import torchdynamo
from torchdynamo.optimizations import BACKENDS

from torch.profiler import profile, ProfilerActivity

from functorch.compile import aot_module, clear_compile_cache

class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10000)
self.net2 = nn.Linear(10000, 10000)
self.net3 = nn.Linear(10000, 10000)
self.relu = nn.ReLU()
self.net4 = nn.Linear(10000, 5)

def forward(self, x):
output1 = self.relu(self.net1(x))
output2 = self.relu(self.net2(output1))
output3 = self.relu(self.net3(output2))
return self.net4(output3)

def hook(grad):
print("gradient hook fired")
return grad + 1

def compiler_fn(fx_module: torch.fx.GraphModule, _):
# fx_module.graph.print_tabular()
return fx_module

# A basic AOT example to demonstrate that gradient hooks are all
# fired after the compiled aot module.
def demo_basic():
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
device = "cuda"

# create model and move it to the device with id rank
model = ToyModel().to(device)
for parameter in model.parameters():
parameter.register_hook(hook)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
aot_print_module = aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn)

for i in range(1):
optimizer.zero_grad()
outputs = aot_print_module(torch.randn(20, 10).to(device))
labels = torch.randn(20, 5).to(device)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

print(f"{os.getpid()}: iteration {i}, loss {loss}")

clear_compile_cache()

prof.export_chrome_trace("aot_1.json")

if __name__ == "__main__":
demo_basic()
96 changes: 96 additions & 0 deletions aot_example_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
import torch
import torch.fx as fx
import torch.nn as nn
import torch.optim as optim

from typing import List

import torchdynamo
from torchdynamo.optimizations import BACKENDS

from torch.profiler import profile, ProfilerActivity

from functorch.compile import aot_module, clear_compile_cache

class ToyModel1(nn.Module):
def __init__(self):
super(ToyModel1, self).__init__()
self.net1 = nn.Linear(10, 10000)
self.relu = nn.ReLU()

def forward(self, x):
return self.relu(self.net1(x))

class ToyModel2(nn.Module):
def __init__(self):
super(ToyModel2, self).__init__()
self.net2 = nn.Linear(10000, 10000)
self.relu = nn.ReLU()

def forward(self, x):
return self.relu(self.net2(x))

class ToyModel3(nn.Module):
def __init__(self):
super(ToyModel3, self).__init__()
self.net3 = nn.Linear(10000, 10000)
self.relu = nn.ReLU()

def forward(self, x):
return self.relu(self.net3(x))

class ToyModel4(nn.Module):
def __init__(self):
super(ToyModel4, self).__init__()
self.net4 = nn.Linear(10000, 5)

def forward(self, x):
return self.net4(x)

def hook(grad):
print("gradient hook fired")
return grad + 1

def compiler_fn(fx_module: torch.fx.GraphModule, _):
# fx_module.graph.print_tabular()
return fx_module

# An AOT example to demonstrate that gradient hooks can be
# fired in between the chained compiled aot module.
def demo_basic():
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
device = "cuda"

# create model and move it to the device with id rank
models = []
models.append(ToyModel1().to(device))
models.append(ToyModel2().to(device))
models.append(ToyModel3().to(device))
models.append(ToyModel4().to(device))

for model in models:
for parameter in model.parameters():
parameter.register_hook(hook)

loss_fn = nn.MSELoss()
aot_print_modules = []
for model in models:
aot_print_modules.append(aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn))

for i in range(1):
outputs = torch.randn(20, 10).to(device)
for aot_print_module in aot_print_modules:
outputs = aot_print_module(outputs)
labels = torch.randn(20, 5).to(device)
loss = loss_fn(outputs, labels)
loss.backward()

print(f"{os.getpid()}: iteration {i}, loss {loss}")

clear_compile_cache()

prof.export_chrome_trace("aot_2.json")

if __name__ == "__main__":
demo_basic()
97 changes: 97 additions & 0 deletions aot_example_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import torch
import torch.fx as fx
import torch.nn as nn
import torch.optim as optim

from typing import List

import torchdynamo
from torchdynamo.optimizations import BACKENDS

from torch.profiler import profile, ProfilerActivity

from functorch.compile import aot_module, clear_compile_cache

class ToyModel1(nn.Module):
def __init__(self):
super(ToyModel1, self).__init__()
self.net1 = nn.Linear(10, 10000)
self.relu = nn.ReLU()

def forward(self, x, parameters):
return self.relu(self.net1(x))

class ToyModel2(nn.Module):
def __init__(self):
super(ToyModel2, self).__init__()
self.net2 = nn.Linear(10000, 10000)
self.relu = nn.ReLU()

def forward(self, x, parameters):
return self.relu(self.net2(x))

class ToyModel3(nn.Module):
def __init__(self):
super(ToyModel3, self).__init__()
self.net3 = nn.Linear(10000, 10000)
self.relu = nn.ReLU()

def forward(self, x, parameters):
return self.relu(self.net3(x))

class ToyModel4(nn.Module):
def __init__(self):
super(ToyModel4, self).__init__()
self.net4 = nn.Linear(10000, 5)

def forward(self, x, parameters):
return self.net4(x)

def hook(grad):
print("gradient hook fired")
return grad + 1

def compiler_fn(fx_module: torch.fx.GraphModule, _):
# fx_module.graph.print_tabular()
return fx_module

# An AOT example to demonstrate that gradient hooks can be
# delayed to fire if parameters are passed in as unused parameters
# in the previous chained compiled aot module.
def demo_basic():
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
device = "cuda"

# create model and move it to the device with id rank
models = []
models.append(ToyModel1().to(device))
models.append(ToyModel2().to(device))
models.append(ToyModel3().to(device))
models.append(ToyModel4().to(device))

for model in models:
for parameter in model.parameters():
parameter.register_hook(hook)

loss_fn = nn.MSELoss()
aot_print_modules = []
for model in models:
aot_print_modules.append(aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn))

for i in range(1):
outputs = torch.randn(20, 10).to(device)
for j in range(len(aot_print_modules)):
outputs = aot_print_modules[j](outputs, list(aot_print_modules[j + 1].parameters()) if j + 1 < len(aot_print_modules) else None)
labels = torch.randn(20, 5).to(device)
loss = loss_fn(outputs, labels)
loss.backward()

print(f"{os.getpid()}: iteration {i}, loss {loss}")

clear_compile_cache()

prof.export_chrome_trace("aot_3.json")

if __name__ == "__main__":
demo_basic()
Loading