In [None]:
"""
Welcome to the PyTorch 2 Tutorial!

In this notebook, we will demonstrate introductory concepts for running TorchDynamo and the tools
to understand the graph capture process.
"""

'\nWelcome to PyTorch 2 Tutorial!\n\nIn this notebook, we will demonstrate introductory concepts for running TorchDynamo and the tools\nto understand the graph capture process.\n'

In [None]:
"""
Install the latest PyTorch Build.
ETA: 1 minute
"""
# resolve dependency conflict on colab and may not be necessary on local environement
!pip uninstall torch -y
!pip uninstall fastai -y
!pip uninstall torchtext -y

!pip install torch==2.3.1+cu121 --index-url https://download.pytorch.org/whl/cu121

Found existing installation: torch 2.2.1+cu121
Uninstalling torch-2.2.1+cu121:
  Successfully uninstalled torch-2.2.1+cu121
Found existing installation: fastai 2.7.14
Uninstalling fastai-2.7.14:
  Successfully uninstalled fastai-2.7.14
Found existing installation: torchtext 0.17.1
Uninstalling torchtext-0.17.1:
  Successfully uninstalled torchtext-0.17.1
Looking in indexes: https://download.pytorch.org/whl/nightly/cpu
Collecting torch
  Downloading https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240422%2Bcpu-cp310-cp310-linux_x86_64.whl (192.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m192.0/192.0 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.2.1+cu121 requires torch==2.2.1, but you have torch 2.4.0.dev20240422+cpu which 

In [None]:
import torch
torch.__version__ # 2.4.0.dev20240422+cpu

'2.4.0.dev20240422+cpu'

In [None]:
"""
Example 1: Compiling a basic python function

In this first example we'll compile a basic python function using torch.compile
and demonstrate how to use logging to view the guards and recompiles if guards fail

"""
import torch

# Disable dynamic shapes for the purposes of this example
@torch.compile(dynamic=False)
def fn(x, y):
  x *= x
  x /= y
  return x + 1

# Before running our function,
# we set logs to observe different artifacts of the compilation process
# in this case setting the guards kwarg to True will print the relevant guards
torch._logging.set_logs(guards=True)
fn(torch.ones(2, 2), torch.ones(2, 2))

# Other useful options:
# guards - display generated guards
# graph - display captured graph
# output_code - display the generated output code from inductor
# bytecode - display the rewritten bytecode for the function you are compiling
# recompiles - print which guard failed if the function has been compiled previously
torch._logging.set_logs(recompiles=True, guards=True)
fn(torch.ones(4, 2), torch.ones(4, 2))



[2024-04-26 07:29:02,700] [1/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-04-26 07:29:02,701] [1/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # x *= x  # <ipython-input-2-2a36c5f136f4>:13 in fn
[2024-04-26 07:29:02,703] [1/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['y'], '_dynamo_dynamic_indices') == False           # x /= y  # <ipython-input-2-2a36c5f136f4>:14 in fn
[2024-04-26 07:29:02,705] [1/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:379 in init_ambient_guards
[2024-04-26 07:29:02,707] [1/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(138809320518080))  # _dynamo/output_graph.py:385 in init_ambient_guards
[2024-04-26 07:29:02,709] [1/0] torch._dynamo.guards.__guards: [DEBUG] ___compile_config_hash() == 'afe34f4bda34a849c5a40bd5f182778f'  # _dynamo/output_

tensor([[2., 2.],
        [2., 2.],
        [2., 2.],
        [2., 2.]])

In [None]:
"""Example 2: Specializaton on non-tensors and displaying the captured graph

In this example, dynamo will specialize on an int because it is a non-tensor value,
so a guard will be generated for it.

There will be two graphs generated because the guard failure will trigger a recompile and
the logs will be used to display these two graphs along with the recompile reason.
"""
import torch


@torch.compile()
def with_int(x, y, z, n):
  if n > 0:
    x += y

  return x * y + z

# View the captured graphs with graph=True, and recompiles to see the guard failure
# due to `flag` changing from True -> False
torch._logging.set_logs(graph=True, recompiles=True, guards=True)
with_int(torch.ones(2, 2), torch.ones(2, 2), torch.zeros(2, 2), 0)
with_int(torch.ones(2, 2), torch.ones(2, 2), torch.zeros(2, 2), 1)



[2024-04-26 07:32:56,338] [2/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-04-26 07:32:56,338] [2/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_4 <eval_with_key>.57 opcode         name    target                   args          kwargs
[2024-04-26 07:32:56,338] [2/0] torch._dynamo.output_graph.__graph: [DEBUG] -------------  ------  -----------------------  ------------  --------
[2024-04-26 07:32:56,338] [2/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder    l_x_    L_x_                     ()            {}
[2024-04-26 07:32:56,338] [2/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder    l_y_    L_y_                     ()            {}
[2024-04-26 07:32:56,338] [2/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder    l_z_    L_z_                     ()            {}
[2024-04-26 07:32:56,338] [2/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function  mul     <built-in function mul>  (l_x_, l_y_)  {}
[2024-04-26 07:32:5

tensor([[2., 2.],
        [2., 2.]])

In [None]:
"""Example 3: Graph break on unsupported behavior

This example demonstrates a graph break - a region of code that dynamo doesn't support.
Dynamo compiles the current subgraph and generates a continuation function to call immediately
after the unsupported code. As a result the rewritten bytecode will consist of 1) a call to the compiled subgraph
2) the unmodified unsupported region and 3) the call to the continuation (these are named "resume_in_<function name>"
in the generated bytecode)
"""
import torch


@torch.compile()
def graph_break_0(x, y):
  z = x + y
  torch._dynamo.graph_break()
  z += x
  return z

# Observe the structure of the modified bytecode and the graph break that caused it
torch._logging.set_logs(graph_breaks=True, bytecode=True)
graph_break_0(torch.ones(4, 4), torch.zeros(4, 4))

[2024-04-26 07:36:59,801] [3/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Graph break: 'skip function graph_break in file /usr/local/lib/python3.10/dist-packages/torch/_dynamo/decorators.py'', skipped according skipfiles.SKIP_DIRS' from user code at:
[2024-04-26 07:36:59,801] [3/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "<ipython-input-4-70e0bc373d08>", line 15, in graph_break_0
[2024-04-26 07:36:59,801] [3/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     torch._dynamo.graph_break()
[2024-04-26 07:36:59,801] [3/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] 
[2024-04-26 07:37:01,407] [3/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG] ORIGINAL BYTECODE graph_break_0 <ipython-input-4-70e0bc373d08> line 12 
[2024-04-26 07:37:01,407] [3/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG]  14           0 LOAD_FAST                0 (x)
[2024-04-26 07:37:01,407] [3/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG]           

tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])

In [None]:
"""Example 4: Graph break on data-dependent control flow

In this example, the program conditions on data within a tensor. This is not traceable into
our graph (because we need to run the computation to determine the runtime value)
so the graph is broken, data extracted from the tensor, and tracing continues afterward.


"""
import torch

@torch.compile()
def data_dep(x, y):
  if x.item() == 1: # <--- graph break here, extracting data from a tensor.
    y += 2
    x += 3
  else:
    y += 3

  return y

torch._logging.set_logs(graph_breaks=True, recompiles=True, bytecode=True)
data_dep(torch.ones(1), torch.ones(5, 5))
data_dep(torch.zeros(1), torch.ones(5, 5))






[2024-04-26 07:40:46,031] [5/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Graph break: Tensor.item from user code at:
[2024-04-26 07:40:46,031] [5/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "<ipython-input-5-1fcb39c11676>", line 13, in data_dep
[2024-04-26 07:40:46,031] [5/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     if x.item() == 1: # <--- graph break here, extracting data from a tensor.
[2024-04-26 07:40:46,031] [5/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] 
[2024-04-26 07:40:46,038] [5/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG] ORIGINAL BYTECODE data_dep <ipython-input-5-1fcb39c11676> line 11 
[2024-04-26 07:40:46,038] [5/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG]  13           0 LOAD_FAST                0 (x)
[2024-04-26 07:40:46,038] [5/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG]               2 LOAD_METHOD              0 (item)
[2024-04-26 07:40:46,038] [5/0_1] torch._dynamo.convert_f

tensor([[4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.],
        [4., 4., 4., 4., 4.]])

In [None]:
"""Example 5: Mutation of python object

In this example, the programs take in a python dicts and lists
and mutate the contents. Internally, dynamo tracks these mutations and constructs
the final state directly after calling the compiled graph.


"""
import torch
import logging

@torch.compile()
def mut_list(x, y):
  y.append(x * 2)
  y.append(x + 3)
  return y

torch._logging.set_logs(graph=True, bytecode=True)
mut_list(torch.ones(2, 2), [0, 0, 0])


@torch.compile()
def mut_dict(x, d):
  d["z"] = x + 2
  d["a"] = x
  return d

mut_dict(torch.ones(2, 2), {})




[2024-04-26 07:44:27,454] [7/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-04-26 07:44:27,454] [7/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_12 <eval_with_key>.99 opcode         name    target                   args           kwargs
[2024-04-26 07:44:27,454] [7/0] torch._dynamo.output_graph.__graph: [DEBUG] -------------  ------  -----------------------  -------------  --------
[2024-04-26 07:44:27,454] [7/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder    l_x_    L_x_                     ()             {}
[2024-04-26 07:44:27,454] [7/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function  mul     <built-in function mul>  (l_x_, 2)      {}
[2024-04-26 07:44:27,454] [7/0] torch._dynamo.output_graph.__graph: [DEBUG] call_function  add     <built-in function add>  (l_x_, 3)      {}
[2024-04-26 07:44:27,454] [7/0] torch._dynamo.output_graph.__graph: [DEBUG] output         output  output                   ((mul, add),)  {}
[2024-04-26 

{'z': tensor([[3., 3.],
         [3., 3.]]),
 'a': tensor([[1., 1.],
         [1., 1.]])}

In [None]:
"""Example 6: Using dynamo disable

In this example, the use of dynamo disable is demonstrated. This function can be used to
tell dynamo to always run this function in the python interpreter. It can be setup to apply recursively on all inner
functions or just on the highest level frame.

"""
import torch
import logging

def inner2(x):
  return x * 2

# By default, the disable
# applies to function and all recursively invoked frames
@torch._dynamo.disable(recursive=False)
def inner(x):
  return inner2(x) + 1

@torch.compile()
def disable_demo(x):
  z = inner(x)
  return z


torch._logging.set_logs(graph_breaks=True, bytecode=True)
disable_demo(torch.ones(5, 5))


[2024-04-26 07:50:56,479] [14/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Graph break: call torch._dynamo.disable() wrapped function <function inner at 0x7e3ef2fe9990> from user code at:
[2024-04-26 07:50:56,479] [14/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "<ipython-input-9-2bde79ae0f3d>", line 22, in disable_demo
[2024-04-26 07:50:56,479] [14/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     z = inner(x)
[2024-04-26 07:50:56,479] [14/0] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] 
[2024-04-26 07:50:56,487] [14/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG] ORIGINAL BYTECODE disable_demo <ipython-input-9-2bde79ae0f3d> line 20 
[2024-04-26 07:50:56,487] [14/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG]  22           0 LOAD_GLOBAL              0 (inner)
[2024-04-26 07:50:56,487] [14/0_1] torch._dynamo.convert_frame.__bytecode: [DEBUG]               2 LOAD_FAST                0 (x)
[2024-04-26 07:50:56,487] [14/0_1

tensor([[3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3.]])

In [None]:
"""Example 7: Functionalization and Decompositions

In this example we'll show how AOTAutograd will remove mutations from the graph
and also decompse complex ops into the core Aten opset.

"""
import torch


# addcdiv is a composed add and divide
# see https://pytorch.org/docs/stable/generated/torch.Tensor.addcdiv_.html
# for more info
@torch.compile()
def func_and_decomps(x, y, z):
  return x.addcdiv_(y, z)


# View the graph after decompositions and functionalization have been applied
# in the previous examples we only viewed the graph that dynamo captured
# AOTAutograd takes this graph and applies transformations to it.
torch._logging.set_logs(graph=True, aot_graphs=True)
func_and_decomps(torch.ones(2, 2), torch.zeros(2, 2), torch.ones(2, 2))




[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_19 <eval_with_key>.127 opcode       name      target    args                kwargs
[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__graph: [DEBUG] -----------  --------  --------  ------------------  --------
[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder  l_x_      L_x_      ()                  {}
[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder  l_y_      L_y_      ()                  {}
[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder  l_z_      L_z_      ()                  {}
[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__graph: [DEBUG] call_method  addcdiv_  addcdiv_  (l_x_, l_y_, l_z_)  {}
[2024-04-26 07:54:12,368] [17/0] torch._dynamo.output_graph.__gr

tensor([[1., 1.],
        [1., 1.]])

In [None]:
"""Example 8: Forward and Backward

In this example the program will generate a forward and backward graph using AOTAutograd.
Up until this point the previous examples had been generating inference graphs (ie only the forward pass)

"""
import torch

input = torch.ones(2, 2)
param = torch.ones(2, 2, requires_grad=True)

@torch.compile()
def fwd_bwd(input):
  return input * param

torch._logging.set_logs(aot_graphs=True)
out = fwd_bwd(input)

  self.pid = os.fork()
[2024-04-27 06:53:16,056] [0/0] torch._functorch._aot_autograd.jit_compile_runtime_wrappers.__aot_graphs: [INFO] TRACED GRAPH
[2024-04-27 06:53:16,056] [0/0] torch._functorch._aot_autograd.jit_compile_runtime_wrappers.__aot_graphs: [INFO]  ===== Forward graph 0 =====
[2024-04-27 06:53:16,056] [0/0] torch._functorch._aot_autograd.jit_compile_runtime_wrappers.__aot_graphs: [INFO]  <eval_with_key>.35 class GraphModule(torch.nn.Module):
[2024-04-27 06:53:16,056] [0/0] torch._functorch._aot_autograd.jit_compile_runtime_wrappers.__aot_graphs: [INFO]     def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
[2024-04-27 06:53:16,056] [0/0] torch._functorch._aot_autograd.jit_compile_runtime_wrappers.__aot_graphs: [INFO]         # File: <ipython-input-1-9069f7cc785b>:14, code: return input * param
[2024-04-27 06:53:16,056] [0/0] torch._functorch._aot_autograd.jit_compile_runtime_wrappers.__aot_graphs: [INFO]         mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor

In [None]:
"""Example 9: Dynamic Shapes

In this example, we will show how dynamic shapes will automatically get enabled
if the shapes of the inputs change during a recompilation. In the output code
the shapes values are passed to the kernel to take into account this dynamism.

"""

import torch

@torch.compile()
def fn(x, y):
  return x * y + 10

torch._logging.set_logs(guards=True)
fn(torch.ones(2, 2), torch.ones(2, 2))
torch._logging.set_logs(guards=True, recompiles=True, output_code=True)
fn(torch.ones(4, 2), torch.ones(4, 2))




[2024-04-26 07:59:44,623] [21/0] torch._dynamo.guards.__guards: [DEBUG] GUARDS:
[2024-04-26 07:59:44,624] [21/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # return x * y + 10  # <ipython-input-14-b1360c215c49>:13 in fn
[2024-04-26 07:59:44,627] [21/0] torch._dynamo.guards.__guards: [DEBUG] hasattr(L['y'], '_dynamo_dynamic_indices') == False           # return x * y + 10  # <ipython-input-14-b1360c215c49>:13 in fn
[2024-04-26 07:59:44,629] [21/0] torch._dynamo.guards.__guards: [DEBUG] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:379 in init_ambient_guards
[2024-04-26 07:59:44,631] [21/0] torch._dynamo.guards.__guards: [DEBUG] (___skip_backend_check() or ___current_backend() == ___lookup_backend(138808821124928))  # _dynamo/output_graph.py:385 in init_ambient_guards
[2024-04-26 07:59:44,633] [21/0] torch._dynamo.guards.__guards: [DEBUG] ___compile_config_hash() == '88a14d47e62622e2d97d70

tensor([[11., 11.],
        [11., 11.],
        [11., 11.],
        [11., 11.]])