In [1]:
import torch
import torch._dynamo
from torchvision import models
from torch.profiler import profile, record_function, ProfilerActivity

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [2]:
def model(x):
    return x**4 + x**3 + x**2 + x

x = torch.rand(10000,10000, requires_grad=True).to(device)

In [3]:
from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CUDA],
            ) as prof:
    out = model(x).sum().backward()

print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))

prof.export_chrome_trace("no_compile_trace.json")

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us     144.535ms        80.60%     144.535ms     144.535ms             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      12.434ms         6.93%      12.434ms       2.487ms             5  
void at::

STAGE:2023-03-24 23:05:33 2552:2552 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-03-24 23:05:33 2552:2552 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-03-24 23:05:33 2552:2552 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [4]:
torch._dynamo.reset()
compiled_model = torch.compile(model,options={'trace.graph_diagram':True,
                                'trace.enabled':True})
# out = compiled_model(x)

In [5]:
from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CUDA],
            ) as prof:
    out = compiled_model(x).sum().backward()

print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=10))

prof.export_chrome_trace("compiled_trace.json")

STAGE:2023-03-24 23:05:33 2552:2552 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


Writing FX graph to file: /pytorch-examples/friday24/torch_compile_debug/run_2023_03_24_23_05_35_131295-pid_2552/aot_torchinductor/model__0_forward_1.0/graph_diagram.svg
Writing FX graph to file: /pytorch-examples/friday24/torch_compile_debug/run_2023_03_24_23_05_35_131295-pid_2552/aot_torchinductor/model__0_backward_2.1/graph_diagram.svg




-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us     144.811ms        96.16%     144.811ms     144.811ms             1  
                                       triton__0d1d2d3d         0.00%       0.000us         0.00%       0.000us       0.000us       2.512ms         1.67%       2.512ms       2.512ms             1  
         

STAGE:2023-03-24 23:05:36 2552:2552 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-03-24 23:05:36 2552:2552 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [6]:
from torch.fx import passes, symbolic_trace
model_sym = symbolic_trace(model)

g = passes.graph_drawer.FxGraphDrawer(model_sym, 'fn')
with open("forward.svg", "wb") as f:
    f.write(g.get_dot_graph().create_svg())

In [7]:
import torch._dynamo
from torch._functorch.aot_autograd import aot_module_simplified

def toy_backend(gm, sample_inputs): 
    def fw(gm, sample_inputs):
        return gm.forward
    
    def bw(gm, sample_inputs):
        g = passes.graph_drawer.FxGraphDrawer(gm, 'fn')
        with open("backward.svg", "wb") as f:
            f.write(g.get_dot_graph().create_svg())
        return gm.forward

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=fw,
        bw_compiler=bw
    )

torch._dynamo.reset()
compiled_model = torch.compile(model, backend=toy_backend)

out = compiled_model(x).sum().backward()


