In [6]:
import torch                     
import torch.nn as nn            
import torch.nn.functional as F
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [12]:
!rm -rf torch_compile_debug

In [13]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv = nn.Conv2d(1, 6, 5)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x,1)
        return x

In [14]:
# class SimpleModel(nn.Module):
#     def __init__(self):
#         super(SimpleModel, self).__init__()
#         self.conv = nn.Conv2d(1, 6, 5)
#         self.fc = nn.Linear(1176, 10)

#     def forward(self, x):
#         x = F.relu(self.conv(x))
#         x = torch.flatten(x,1)
#         x = F.relu(self.fc(x))
#         return x

In [15]:
model = SimpleModel().to(device)
cmodel = torch.compile(model, backend="inductor", 
                       options={'trace.graph_diagram':True,
                                'trace.enabled':True})

input_tensor = torch.randn(1, 1, 32, 32).to(device)
out = cmodel(input_tensor)
print(out)

from torch.fx import passes, symbolic_trace
model_trace = symbolic_trace(model)

g = passes.graph_drawer.FxGraphDrawer(model_trace, 'fn')
with open("model_graph.svg", "wb") as f:
    f.write(g.get_dot_graph().create_svg())



Writing FX graph to file: /pytorch-examples/pytorch-torchcompile/torch_compile_debug/run_2023_03_15_20_01_10_252039-pid_373/aot_torchinductor/model__2_forward_3.2/graph_diagram.svg
tensor([[ 0.7080, -0.4106, -0.7436,  ...,  0.4363,  0.2918, -0.6914]],
       device='cuda:0', grad_fn=<AsStridedBackward0>)


In [None]:
import torch._dynamo
from torch._functorch.aot_autograd import aot_module_simplified

def toy_backend(gm, sample_inputs): 
    def my_compiler(gm, sample_inputs):
        from torch.fx import passes, symbolic_trace
        g = passes.graph_drawer.FxGraphDrawer(gm, 'fn')
        with open("biscuit.svg", "wb") as f:
            f.write(g.get_dot_graph().create_svg())

        return gm.forward

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=my_compiler
    )

torch._dynamo.reset()
model = SimpleModel().to(device)
input_tensor = torch.randn(1, 1, 32, 32).to(device)
cmodel = torch.compile(model, backend=toy_backend)
out = cmodel(input_tensor)

In [15]:
e = torch._dynamo.explain(model, input_tensor)
print(e[1])

torch.Size([1, 1176])
[{Guard(name='torch', source=<GuardSource.GLOBAL: 1>, create_fn=<function GuardBuilder.FUNCTION_MATCH at 0x7f8ef5efbeb0>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None), Guard(name='self.conv', source=<GuardSource.LOCAL_NN_MODULE: 2>, create_fn=<function GuardBuilder.NN_MODULE at 0x7f8ef5efbe20>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None), Guard(name='self', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.NN_MODULE at 0x7f8ef5efbe20>, is_volatile=False, guard_types=['ID_MATCH'], code_list=['___check_obj_id(self, 140251669772896)'], obj_weakref=<weakref at 0x7f8ee2cf3f60; to 'SimpleModel' at 0x7f8ee2f4ba60>, guarded_class_weakref=<weakref at 0x7f8ee12a2340; to 'type' at 0x2f97330 (SimpleModel)>), Guard(name='print', source=<GuardSource.GLOBAL: 1>, create_fn=<function GuardBuilder.BUILTIN_MATCH at 0x7f8ef5efbf40>, is_volatile=False, guard_types

In [14]:
from torch.fx import passes, symbolic_trace
model_trace = symbolic_trace(model)

g = passes.graph_drawer.FxGraphDrawer(model_trace, 'fn')
with open("asdf_graph.svg", "wb") as f:
    f.write(g.get_dot_graph().create_svg())

In [9]:
from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


triton__0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor

@pointwise(size_hints=[8192], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4704
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = (xindex // 784)
    tmp0 = tl.load(in_out_ptr0 + (x2), xmask)
    tmp1 = tl.load(in_ptr0 + (x1), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(in_out_ptr0 + (x2 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
''')


async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2, primals_3 = args
    args.clear()
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        buf0 = aten.convolution(primals_3, primals_1, None, (1, 1), (0, 0), (1, 1), False, (0, 0), 1)
        assert_size_stride(buf0, (1, 6, 28, 28), (4704, 784, 28, 1))
        buf1 = buf0; del buf0  # reuse
        stream0 = get_cuda_stream(0)
        triton__0.run(buf1, primals_2, 4704, grid=grid(4704), stream=stream0)
        del primals_2
        return (as_strided(buf1, (1, 4704), (4704, 1)), buf1, primals_1, primals_3, )


if __name__ == "__main__":
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((6, 1, 5, 5), (25, 25, 5, 1), device='cuda:0', dtype=torch.float32)
    primals_2 = rand_strided((6, ), (1, ), device='cuda:0', dtype=torch.float32)
    primals_3 = rand_strided((1, 1, 32, 32), (1024, 1024, 32, 1), device='cuda:0', dtype=torch.float32)
    print_performance(lambda: call([primals_1, primals_2, primals_3]))


0.000704


In [11]:
async_compile = AsyncCompile()
async_compile.triton.

<torch._inductor.codecache.AsyncCompile at 0x7f2e8e8f12a0>