`torch.compile` is introduced in PyTorch 2.0 and is intended to replace TorchScript (`torch.jit`).

## Overall
A torch compiled Python code will go through two stages: TorchDynamo + Inductor:

1. TorchDynamo
Parse Python code and get Python byte code, then generate a "FX Graph".

2. TorchInductor
Convert FX Graph into efficient code with potential Operator Fusion.

The higher-level FX Graph operators (e.g., aten.add) will be converted into a loop-level IR in which operation fusion will be performed and are merged into one loop (e.g., Add + ReLU).

Then, it decides how to introduce hardware-dependent code into these loops (e.g., adding Tile size and `tl.program_id`).

Finally, in codegen, these loops are processed with a Triton template engine to generate Python `@triton.jit` code (saved at `/tmp/torchinductor_xxx`).
TorchInductor utilize Triton as a backend to further optimize the generated code.

## Example Dissect
Create a toy example `torch_compile_builtin_fusion.py`:

In [20]:
import torch

a = torch.rand((100, 100), device='cuda')
b = torch.rand((100, 100), device='cuda')

def fn(x, y):
    z = torch.matmul(x, y)
    return torch.nn.functional.softmax(z, dim=1)

compiled_fn = torch.compile(fn)
print(compiled_fn(a, b))

tensor([[0.0042, 0.0006, 0.0176,  ..., 0.0003, 0.0017, 0.0257],
        [0.0402, 0.0009, 0.0153,  ..., 0.0012, 0.0012, 0.0075],
        [0.0427, 0.0011, 0.0052,  ..., 0.0008, 0.0016, 0.0079],
        ...,
        [0.0029, 0.0002, 0.0040,  ..., 0.0004, 0.0002, 0.0206],
        [0.0013, 0.0015, 0.0185,  ..., 0.0024, 0.0001, 0.0043],
        [0.0134, 0.0003, 0.0135,  ..., 0.0007, 0.0002, 0.0140]],
       device='cuda:0')


We can print the two-stage outcome: `graph_code` (FX Graph code presentation), and `output_code` are the output Triton code by Inductor.

In [21]:
!TORCH_LOGS="graph_code,output_code" python torch_compile_builtin_fusion.py

V0113 15:02:14.896000 541798 site-packages/torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code] TRACED GRAPH
V0113 15:02:14.896000 541798 site-packages/torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]  ===== __compiled_fn_1_764aecdc_de0b_44f3_b87f_7dc1542901a0 =====
V0113 15:02:14.896000 541798 site-packages/torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]  /home/tk/Desktop/jupyter/simp-intelligence/.pixi/envs/default/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0113 15:02:14.896000 541798 site-packages/torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]     def forward(self, L_y_: "[31mf32[0m[34m[100, 100][0m[2m[34m[100, 1][0m[2m[32mcuda:0[0m", L_x_: "[31mf32[0m[34m[100, 100][0m[2m[34m[100, 1][0m[2m[32mcuda:0[0m"):
V0113 15:02:14.896000 541798 site-packages/torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]         l_y_ = L_y_
V0113 15:02:14.896000 541798 site-packages/torch/_dynamo/

As shown above, the generated Triton code is using a fused function `triton_per_fused__softmax_0`.

However, the above example calls built-in torch functions without too much room for `torch.compile` optimization (and many stages are not applied at all).
Let's have a more customized toy code `torch_compile_custom_toy.py` so that we can also look more deeper into Inductor stages:

In [22]:
!rm -rf ./torch_compile_debug
!TORCH_COMPILE_DEBUG=1 python torch_compile_custom_toy.py
!ls ./torch_compile_debug/run_*/torchinductor

tensor([ 0.2285,  0.6833, -0.0843,  0.1018, -0.4560,  1.2637,  0.0872,  0.1274,
         0.0481,  0.3608])
tensor([-0.0387, -0.0628, -0.2498,  0.5246, -0.1643,  0.0197,  0.1885, -0.0735,
         0.2820,  1.0181])
tensor([ 0.1763, -0.3590,  0.4897, -0.0080, -0.0753,  0.2756, -0.0122, -0.2288,
         0.5179, -0.0016])
tensor([-1.0154, -0.2763, -0.6568,  0.0159, -0.5978,  0.8283,  0.4505,  0.0409,
        -0.3274, -0.8447])
tensor([ 0.0906, -0.2199,  0.0134, -0.1255,  0.4967,  0.4357, -0.0527, -0.0063,
         0.2549,  0.0277])
tensor([-0.9464, -0.2473,  0.3010,  0.1190, -1.5808, -0.3638, -0.1464, -0.6412,
         0.0051,  0.0452])
tensor([-0.0885,  0.5260,  0.1119, -0.0861, -0.5349, -0.6390,  0.0381, -0.1691,
         0.3603,  0.4051])
tensor([-0.1284,  0.0015,  0.2658,  0.0079,  0.1679, -0.2159, -0.0157, -0.2208,
         0.3690, -0.0501])
tensor([-4.2978e-01, -7.7197e-02,  4.8542e-01,  6.5259e-01, -4.1098e-01,
         2.3311e-04, -5.0730e-01, -2.9696e-01,  4.8463e-02, -3.8781e-01