diff --git a/aot_example.py b/aot_example.py new file mode 100644 index 0000000000..792ea57f90 --- /dev/null +++ b/aot_example.py @@ -0,0 +1,69 @@ +import os +import torch +import torch.fx as fx +import torch.nn as nn +import torch.optim as optim + +from typing import List + +import torchdynamo +from torchdynamo.optimizations import BACKENDS + +from torch.profiler import profile, ProfilerActivity + +from functorch.compile import aot_module, clear_compile_cache + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10000) + self.net2 = nn.Linear(10000, 10000) + self.net3 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + self.net4 = nn.Linear(10000, 5) + + def forward(self, x): + output1 = self.relu(self.net1(x)) + output2 = self.relu(self.net2(output1)) + output3 = self.relu(self.net3(output2)) + return self.net4(output3) + +def hook(grad): + print("gradient hook fired") + return grad + 1 + +def compiler_fn(fx_module: torch.fx.GraphModule, _): + # fx_module.graph.print_tabular() + return fx_module + +# A basic AOT example to demonstrate that gradient hooks are all +# fired after the compiled aot module. +def demo_basic(): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + device = "cuda" + + # create model and move it to the device with id rank + model = ToyModel().to(device) + for parameter in model.parameters(): + parameter.register_hook(hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001) + aot_print_module = aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn) + + for i in range(1): + optimizer.zero_grad() + outputs = aot_print_module(torch.randn(20, 10).to(device)) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + print(f"{os.getpid()}: iteration {i}, loss {loss}") + + clear_compile_cache() + + prof.export_chrome_trace("aot_1.json") + +if __name__ == "__main__": + demo_basic() diff --git a/aot_example_2.py b/aot_example_2.py new file mode 100644 index 0000000000..8b4ace7344 --- /dev/null +++ b/aot_example_2.py @@ -0,0 +1,96 @@ +import os +import torch +import torch.fx as fx +import torch.nn as nn +import torch.optim as optim + +from typing import List + +import torchdynamo +from torchdynamo.optimizations import BACKENDS + +from torch.profiler import profile, ProfilerActivity + +from functorch.compile import aot_module, clear_compile_cache + +class ToyModel1(nn.Module): + def __init__(self): + super(ToyModel1, self).__init__() + self.net1 = nn.Linear(10, 10000) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.net1(x)) + +class ToyModel2(nn.Module): + def __init__(self): + super(ToyModel2, self).__init__() + self.net2 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.net2(x)) + +class ToyModel3(nn.Module): + def __init__(self): + super(ToyModel3, self).__init__() + self.net3 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.net3(x)) + +class ToyModel4(nn.Module): + def __init__(self): + super(ToyModel4, self).__init__() + self.net4 = nn.Linear(10000, 5) + + def forward(self, x): + return self.net4(x) + +def hook(grad): + print("gradient hook fired") + return grad + 1 + +def compiler_fn(fx_module: torch.fx.GraphModule, _): + # fx_module.graph.print_tabular() + return fx_module + +# An AOT example to demonstrate that gradient hooks can be +# fired in between the chained compiled aot module. +def demo_basic(): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + device = "cuda" + + # create model and move it to the device with id rank + models = [] + models.append(ToyModel1().to(device)) + models.append(ToyModel2().to(device)) + models.append(ToyModel3().to(device)) + models.append(ToyModel4().to(device)) + + for model in models: + for parameter in model.parameters(): + parameter.register_hook(hook) + + loss_fn = nn.MSELoss() + aot_print_modules = [] + for model in models: + aot_print_modules.append(aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn)) + + for i in range(1): + outputs = torch.randn(20, 10).to(device) + for aot_print_module in aot_print_modules: + outputs = aot_print_module(outputs) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + + print(f"{os.getpid()}: iteration {i}, loss {loss}") + + clear_compile_cache() + + prof.export_chrome_trace("aot_2.json") + +if __name__ == "__main__": + demo_basic() diff --git a/aot_example_3.py b/aot_example_3.py new file mode 100644 index 0000000000..188e735b54 --- /dev/null +++ b/aot_example_3.py @@ -0,0 +1,97 @@ +import os +import torch +import torch.fx as fx +import torch.nn as nn +import torch.optim as optim + +from typing import List + +import torchdynamo +from torchdynamo.optimizations import BACKENDS + +from torch.profiler import profile, ProfilerActivity + +from functorch.compile import aot_module, clear_compile_cache + +class ToyModel1(nn.Module): + def __init__(self): + super(ToyModel1, self).__init__() + self.net1 = nn.Linear(10, 10000) + self.relu = nn.ReLU() + + def forward(self, x, parameters): + return self.relu(self.net1(x)) + +class ToyModel2(nn.Module): + def __init__(self): + super(ToyModel2, self).__init__() + self.net2 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + + def forward(self, x, parameters): + return self.relu(self.net2(x)) + +class ToyModel3(nn.Module): + def __init__(self): + super(ToyModel3, self).__init__() + self.net3 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + + def forward(self, x, parameters): + return self.relu(self.net3(x)) + +class ToyModel4(nn.Module): + def __init__(self): + super(ToyModel4, self).__init__() + self.net4 = nn.Linear(10000, 5) + + def forward(self, x, parameters): + return self.net4(x) + +def hook(grad): + print("gradient hook fired") + return grad + 1 + +def compiler_fn(fx_module: torch.fx.GraphModule, _): + # fx_module.graph.print_tabular() + return fx_module + +# An AOT example to demonstrate that gradient hooks can be +# delayed to fire if parameters are passed in as unused parameters +# in the previous chained compiled aot module. +def demo_basic(): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + device = "cuda" + + # create model and move it to the device with id rank + models = [] + models.append(ToyModel1().to(device)) + models.append(ToyModel2().to(device)) + models.append(ToyModel3().to(device)) + models.append(ToyModel4().to(device)) + + for model in models: + for parameter in model.parameters(): + parameter.register_hook(hook) + + loss_fn = nn.MSELoss() + aot_print_modules = [] + for model in models: + aot_print_modules.append(aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn)) + + for i in range(1): + outputs = torch.randn(20, 10).to(device) + for j in range(len(aot_print_modules)): + outputs = aot_print_modules[j](outputs, list(aot_print_modules[j + 1].parameters()) if j + 1 < len(aot_print_modules) else None) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + + print(f"{os.getpid()}: iteration {i}, loss {loss}") + + clear_compile_cache() + + prof.export_chrome_trace("aot_3.json") + +if __name__ == "__main__": + demo_basic() diff --git a/dynamo_example.py b/dynamo_example.py new file mode 100644 index 0000000000..8ef00757bc --- /dev/null +++ b/dynamo_example.py @@ -0,0 +1,154 @@ +import os +import torch +import torch.fx as fx +import torch.nn as nn +import torch.optim as optim + +from typing import List + +import torchdynamo +from torchdynamo.optimizations import BACKENDS + +from torch.profiler import profile, record_function, ProfilerActivity + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10000) + self.net2 = nn.Linear(10000, 10000) + self.net3 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + self.net4 = nn.Linear(10000, 5) + + def forward(self, x): + return self.net4(self.relu(self.net3(self.relu(self.net2(self.relu(self.net1(x))))))) + +# A very rough PoC that split the input fx gragh into small graphs, and then +# stitch the compiled aot graphs back to one. +def graph_break_compiler(gm: fx.GraphModule, example_inputs: List[torch.Tensor]): + print("graph_break_compiler() called with FX graph:") + gm.graph.print_tabular() + print() + + # 1. Splitting the gm into small graphs. Currently, we just naively split + # the graph every 3 ops. + def insert_node(graph: fx.Graph, node: fx.Node): + if len(graph.nodes) == 0: + graph.create_node(node.op, node.target, node.args, node.kwargs, node.name, node.type) + else: + with graph.inserting_before(list(graph.nodes)[0]): + graph.node_copy(node) + + graphs = list() + magic_number = 3 # magic number to break the graph + count = magic_number + outputs = {} + for node in reversed(gm.graph.nodes): + # To be noted, we don't clear the outputs from the last graph + # given that's the arguments for the output node too. + if count == magic_number: + count = 0 + graphs.insert(0, fx.Graph()) + if len(graphs) > 1: + # set the output of the new graph + graphs[0].output(tuple(value for _, value in outputs.items())) + # set the input of the previous graph + # reverse the order to make the order match the above output + for key, _ in reversed(outputs.items()): + with graphs[1].inserting_before(list(graphs[1].nodes)[0]): + graphs[1].placeholder(key) + + if node.op != "output" and node.op != "placeholder": + count = count + 1 + + insert_node(graphs[0], node) + if node.name in outputs: + outputs.pop(node.name) + # TODO: Deal with default arguments? + for arg in node.args: + # Somehow the output node's args is a tuple of tuple. + if type(arg) is tuple: + for a in arg: + outputs[str(a)] = a + continue + outputs[str(arg)] = arg + # TODO: Do we care the kwargs? + # for key, value in node.kwargs.items(): + # outputs.add(value) + + print("graph_break_compiler() called with splitted graphs:") + for graph in graphs: + graph.print_tabular() + print() + + # 2. Compiling the splitted graphs using AOT. + gms = [fx.GraphModule(gm, graph) for graph in graphs] + aot_compileds = [] + for g in gms: + aot_compiled = BACKENDS["aot_autograd"](g, None) + assert aot_compiled is not None, "aot compilation failed" + aot_compileds.append(aot_compiled) + + print(f"AOT compiled all {len(aot_compileds)} modules\n") + + # 3. Stitching the compiled graphs back to a fx gm to return. + assert len(aot_compileds) == len(graphs) + final_graph = fx.Graph() + last_aot = None + for i in range(len(aot_compileds)): + arguments = list() + j = 0 + for node in graphs[i].nodes: + if node.op != "placeholder": + break + if i == 0: + last_node = final_graph.node_copy(node) + else: + assert last_aot is not None + last_node = final_graph.call_method("__getitem__", (last_aot, j)) + j = j + 1 + arguments.append(last_node) + + last_aot = final_graph.call_function(aot_compileds[i].forward, tuple(arguments)) + final_graph.output(last_aot) + final_graph_module = fx.GraphModule(gm, final_graph) + + print("graph_break_compiler() called with stitched graph:") + final_graph_module.graph.print_tabular() + print() + + return final_graph_module.forward # return a python callable + +def hook(grad): + print("gradient hook fired") + grad + 1 + return grad + +def demo_basic(): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + with torchdynamo.optimize(graph_break_compiler): + device = "cuda" + # device = "cpu" + + # create model and move it to the device with id rank + model = ToyModel().to(device) + for parameter in model.parameters(): + parameter.register_hook(hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001) + + for i in range(1): + optimizer.zero_grad() + outputs = model(torch.randn(20, 10).to(device)) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + print(f"{os.getpid()}: iteration {i}, loss {loss}") + + prof.export_chrome_trace("trace.json") + +if __name__ == "__main__": + demo_basic() diff --git a/dynamo_example2.py b/dynamo_example2.py new file mode 100644 index 0000000000..aedb0942f0 --- /dev/null +++ b/dynamo_example2.py @@ -0,0 +1,175 @@ +import os +import torch +import torch.fx as fx +import torch.nn as nn +import torch.optim as optim + +from typing import List + +import torchdynamo +from torchdynamo.optimizations import BACKENDS + +from torch.profiler import profile, record_function, ProfilerActivity + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10000) + self.net2 = nn.Linear(10000, 10000) + self.net3 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + self.net4 = nn.Linear(10000, 5) + + def forward(self, x): + return self.net4(self.relu(self.net3(self.relu(self.net2(self.relu(self.net1(x))))))) + +class GraphProducer(object): + def __init__(self): + self.possible_input_names: set[str] = set() + self.inputs: list[str] = [] # this is used to iterate the above set given it's iteration is abitrary. + self.code: list[fx.Node] = [] + self.possible_output_names: set[str] = set() + self.outputs: list[fx.Node] = [] + self._graph: fx.Graph = None + + def graph(self): + if self._graph is not None: + return self._graph + + self._graph = fx.Graph() + for input in self.inputs: + self._graph.placeholder(input) + for code in self.code: + self._graph.node_copy(code) + self._graph.output(tuple(self.outputs)) + return self._graph + +# A very rough PoC that split the input fx gragh into small graphs, and then +# stitch the compiled aot graphs back to one. +def graph_break_compiler(gm: fx.GraphModule, example_inputs: List[torch.Tensor]): + print("graph_break_compiler() called with FX graph:") + gm.graph.print_tabular() + print() + + # 1. Splitting the gm into small graphs. Currently, we just naively split + # the graph every 3 ops. + magic_number = 3 # magic number to break the graph + count = 0 + graphs = [GraphProducer()] + # Init the first node to use the output of the full graph if appropriate. + # Somehow the args for the output node is a tupel of args. + graphs[0].possible_output_names = {str(arg) for arg in list(gm.graph.nodes)[-1].args[0]} + for node in reversed(gm.graph.nodes): + # To be noted, we don't clear the outputs from the last graph + # given that's the arguments for the output node too. + if count == magic_number: + assert len(graphs) > 0 + graphs[0].inputs = list(graphs[0].possible_input_names) + count = 0 + graphs.insert(0, GraphProducer()) + + # Set the possible output of the new graph + # For any remaining possible outputs in the next graph, it could + # either be produced by the current graph or be the input for the + # full graph. Therefore, carry them on. + graphs[0].possible_output_names = graphs[1].possible_output_names + for input in graphs[1].inputs: + graphs[0].possible_output_names.add(input) + + if node.op == "output" or node.op == "placeholder": + continue + + count = count + 1 + graphs[0].code.insert(0, node) + if node.name in graphs[0].possible_output_names: + graphs[0].outputs.append(node) + graphs[0].possible_output_names.remove(node.name) + + if node.name in graphs[0].possible_input_names: + graphs[0].possible_input_names.remove(node.name) + + # TODO: Deal with default arguments? + for arg in node.args: + graphs[0].possible_input_names.add(str(arg)) + # TODO: Do we care the kwargs? + # for key, value in node.kwargs.items(): + # outputs.add(value) + graphs[0].inputs = list(graphs[0].possible_input_names) + + print("graph_break_compiler() called with splitted graphs:") + for graph in graphs: + graph.graph().print_tabular() + print() + + # 2. Compiling the splitted graphs using AOT. + gms = [fx.GraphModule(gm, graph.graph()) for graph in graphs] + aot_compileds = [] + for g in gms: + aot_compiled = BACKENDS["aot_autograd"](g, None) + assert aot_compiled is not None, "aot compilation failed" + aot_compileds.append(aot_compiled) + + print(f"AOT compiled all {len(aot_compileds)} modules\n") + + # 3. Stitching the compiled graphs back to a fx gm to return. + assert len(aot_compileds) == len(graphs) + final_graph = fx.Graph() + arg_map = {} # To keep track of all possible inputs to sub-graphs. + + for node in gm.graph.nodes: + if node.op != "placeholder": + break + last_node = final_graph.node_copy(node) + arg_map[last_node.name] = last_node + + # We call the AOT compiled sub module with arg in arg_map that match recorded arg name in the corresponding graph. + for i in range(len(graphs)): + output = final_graph.call_function(aot_compileds[i].forward, tuple([arg_map[input] for input in graphs[i].inputs])) + # Unpack the output and name them properly so that they can be fetched correctly when needed in consecutive graphs. + for j in range(len(graphs[i].outputs)): + getitem = final_graph.call_method("__getitem__", (output, j)) + getitem.name = graphs[i].outputs[j].name + arg_map[getitem.name] = getitem + + final_graph.node_copy(list(gm.graph.nodes)[-1]) + final_graph_module = fx.GraphModule(gm, final_graph) + + print("graph_break_compiler() called with stitched graph:") + final_graph_module.graph.print_tabular() + print() + + return final_graph_module.forward # return a python callable + +def hook(grad): + print("gradient hook fired") + grad + 1 + return grad + +def demo_basic(): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + with torchdynamo.optimize(graph_break_compiler): + device = "cuda" + # device = "cpu" + + # create model and move it to the device with id rank + model = ToyModel().to(device) + for parameter in model.parameters(): + parameter.register_hook(hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001) + + for i in range(1): + optimizer.zero_grad() + outputs = model(torch.randn(20, 10).to(device)) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + print(f"{os.getpid()}: iteration {i}, loss {loss}") + + prof.export_chrome_trace("new_trace.json") + +if __name__ == "__main__": + demo_basic() diff --git a/dynamo_example_eager_backend.py b/dynamo_example_eager_backend.py new file mode 100644 index 0000000000..07993b6289 --- /dev/null +++ b/dynamo_example_eager_backend.py @@ -0,0 +1,68 @@ +import os +import torch +import torch.fx as fx +import torch.nn as nn +import torch.optim as optim + +from typing import List + +import torchdynamo +from torchdynamo.optimizations import BACKENDS + +from torch.profiler import profile, record_function, ProfilerActivity + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10000) + self.net2 = nn.Linear(10000, 10000) + self.net3 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + self.net4 = nn.Linear(10000, 5) + + def forward(self, x): + return self.net4(self.relu(self.net3(self.relu(self.net2(self.relu(self.net1(x))))))) + + +def graph_break_compiler(gm: fx.GraphModule, example_inputs: List[torch.Tensor]): + print("graph_break_compiler() called with FX graph:") + gm.graph.print_tabular() + print() + + return gm.forward # return a python callable + +def hook(grad): + print("gradient hook fired") + grad + 1 + return grad + +# An example to demonstrate that gradient hooks are fired correctly +# for dynamo eager backend. +def demo_basic(): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + with torchdynamo.optimize(graph_break_compiler): + device = "cuda" + # device = "cpu" + + # create model and move it to the device with id rank + model = ToyModel().to(device) + for parameter in model.parameters(): + parameter.register_hook(hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001) + + for i in range(1): + optimizer.zero_grad() + outputs = model(torch.randn(20, 10).to(device)) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + print(f"{os.getpid()}: iteration {i}, loss {loss}") + + prof.export_chrome_trace("eager_backend.json") + +if __name__ == "__main__": + demo_basic() diff --git a/dynamo_example_print.py b/dynamo_example_print.py new file mode 100644 index 0000000000..b394815148 --- /dev/null +++ b/dynamo_example_print.py @@ -0,0 +1,75 @@ +import os +import torch +import torch.fx as fx +import torch.nn as nn +import torch.optim as optim + +from typing import List + +import torchdynamo +from torchdynamo.optimizations import BACKENDS + +from torch.profiler import profile, record_function, ProfilerActivity + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10000) + self.net2 = nn.Linear(10000, 10000) + self.net3 = nn.Linear(10000, 10000) + self.relu = nn.ReLU() + self.net4 = nn.Linear(10000, 5) + + def forward(self, x): + output1 = self.relu(self.net1(x)) + print() + output2 = self.relu(self.net2(output1)) + print() + output3 = self.relu(self.net3(output2)) + print() + return self.net4(output3) + +def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + + compiled = BACKENDS["aot_autograd"](gm, example_inputs) + if compiled is not None: + print("aot compiled") + return compiled + + return gm.forward # return a python callable + +def hook(grad): + print("gradient hook fired") + grad + 1 + return grad + +# An example to demonstrate manual graph break works: print() in the above forward function. +def demo_basic(): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + with torchdynamo.optimize(my_compiler): + device = "cuda" + + # create model and move it to the device with id rank + model = ToyModel().to(device) + for parameter in model.parameters(): + parameter.register_hook(hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.001) + + for i in range(1): + optimizer.zero_grad() + outputs = model(torch.randn(20, 10).to(device)) + labels = torch.randn(20, 5).to(device) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + print(f"{os.getpid()}: iteration {i}, loss {loss}") + + prof.export_chrome_trace("manual.json") + +if __name__ == "__main__": + demo_basic()