In [2]:
import torch
from typing import List
import torch.nn as nn
import torch.fx as fx

In [3]:
class Model(nn.Module):
    def __init__(self, in_dim, h_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, h_dim, bias=False),
            nn.ReLU(),
            nn.Linear(h_dim, h_dim, bias=False),
            nn.ReLU(),
            nn.Linear(h_dim, out_dim, bias=False)
        )

    def forward(self, x):
        return (self.layers(x),)


In [4]:
m = Model(128, 32, 16)
x = m(torch.rand((128,)))
graph = fx.symbolic_trace(m)
graph.print_readable()
# for key, value in vars(graph).items(): print(key, value)
# for node in graph.graph.nodes:
    # for key, value in vars(node).items(): print(key, value)
    # break
# graph.print_readable()


class Model(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        layers_0 = getattr(self.layers, "0")(x);  x = None
        layers_1 = getattr(self.layers, "1")(layers_0);  layers_0 = None
        layers_2 = getattr(self.layers, "2")(layers_1);  layers_1 = None
        layers_3 = getattr(self.layers, "3")(layers_2);  layers_2 = None
        layers_4 = getattr(self.layers, "4")(layers_3);  layers_3 = None
        return (layers_4,)
        


'class Model(torch.nn.Module):\n    def forward(self, x):\n        # No stacktrace found for following nodes\n        layers_0 = getattr(self.layers, "0")(x);  x = None\n        layers_1 = getattr(self.layers, "1")(layers_0);  layers_0 = None\n        layers_2 = getattr(self.layers, "2")(layers_1);  layers_1 = None\n        layers_3 = getattr(self.layers, "3")(layers_2);  layers_2 = None\n        layers_4 = getattr(self.layers, "4")(layers_3);  layers_3 = None\n        return (layers_4,)\n        '

In [7]:
from torch._decomp import core_aten_decompositions
import torch._dynamo
from torch._functorch.aot_autograd import aot_module_simplified, aot_export_module
import pprint

def toy_backend(gm: fx.GraphModule, sample_inputs):
    def my_compiler(gm, sample_inputs):
        # <implement your compiler here>
        print("Decomposed fx Graph in Aten IR:")
        print(gm.graph)
        for node in gm.graph.nodes:
            # print(node.target)
            # print(type(node.target))
            # if node.target == torch.ops.aten.mm.default:
            if node.op == 'output':
                pprint.pprint(vars(node), indent=4)

        return gm

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=my_compiler
    )

m = torch.compile(m, backend=toy_backend)
x = m(torch.rand((1, 128)))
# x

Decomposed fx Graph in Aten IR:
graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_3 : [num_users=1] = placeholder[target=primals_3]
    %primals_4 : [num_users=2] = placeholder[target=primals_4]
    %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%primals_1,), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%primals_4, %t), kwargs = {})
    %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%mm,), kwargs = {})
    %t_1 : [num_users=2] = call_function[target=torch.ops.aten.t.default](args = (%primals_2,), kwargs = {})
    %mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu, %t_1), kwargs = {})
    %relu_1 : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%mm_1,), kwargs = {})
    %t_2 : [num_users=2] = call_function[target=torc

In [7]:
from torch.export import export 
ex_in = torch.rand((128,))
m = Model(128, 32, 16)
prog = export(m, args=(ex_in,))
print(prog.graph)

# torch.ops.aten.

graph():
    %p_layers_0_weight : [num_users=1] = placeholder[target=p_layers_0_weight]
    %p_layers_1_weight : [num_users=1] = placeholder[target=p_layers_1_weight]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_layers_0_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_layers_1_weight), kwargs = {})
    return (linear_1,)
