From 1d75210e730228740aab5f1bed7d53d489df4a23 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Sep 2025 09:33:50 +0530 Subject: [PATCH 01/12] feat: add a tutorial on regional aot. --- recipes_source/regional_aot.py | 233 +++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 recipes_source/regional_aot.py diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py new file mode 100644 index 0000000000..845518b6f9 --- /dev/null +++ b/recipes_source/regional_aot.py @@ -0,0 +1,233 @@ + +""" +Reducing AoT cold start compilation time with regional compilation +============================================================================ + +**Author:** `Sayak Paul `, `Charles Bensimon `, `Angela Yi ` + +In our [regional compilation recipe](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html), we showed +how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for +just-in-time (JiT) compilation. + +This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you +are not familiar with AoT and `torch.export`, we recommend you to check out [this tutorial](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html). + +Prerequisites +---------------- + +* Pytorch 2.6 or later +* Familiarity with regional compilation +* Familiarity with AoT and `torch.export` + +Setup +----- +Before we begin, we need to install ``torch`` if it is not already +available. + +.. code-block:: sh + + pip install torch + +.. note:: + This feature is available starting with the 2.6 release. +""" + +from time import perf_counter + +###################################################################### +# Steps +# ----- +# +# In this recipe, we will follow pretty much the same steps as the regional compilation recipe mentioned above: +# +# 1. Import all necessary libraries. +# 2. Define and initialize a neural network with repeated regions. +# 3. Measure the compilation time of the full model and the regional compilation with AoT. +# +# First, let's import the necessary libraries for loading our data: +# +# + +import torch +torch.set_grad_enabled(False) + + +########################################################## +# We will use the same neural network structure as the regional compilation recipe. +# +# We will use a network, composed of repeated layers. This mimics a +# large language model, that typically is composed of many Transformer blocks. In this recipe, +# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region. +# We will then create a ``Model`` which is composed of 64 instances of this +# ``Layer`` class. +# +class Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 10) + self.relu1 = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(10, 10) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + a = self.linear1(x) + a = self.relu1(a) + a = torch.sigmoid(a) + b = self.linear2(a) + b = self.relu2(b) + return b + + +class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.layers = torch.nn.ModuleList([Layer() for _ in range(64)]) + + def forward(self, x): + # In regional compilation, the self.linear is outside of the scope of `torch.compile`. + x = self.linear(x) + for layer in self.layers: + x = layer(x) + return x + + +#################################################### +# Since we're compiling the model ahead-of-time, we need to prepare representative +# input examples, that we expect the model to see during actual deployments. +# +# Let's create an instance of `Model` and pass it some sample input data. +# + +model = Model().cuda() +input = torch.randn(10, 10, device="cuda") +output = model(input) +print(f"{output.shape=}") + +#################################################### +# Now, let's compile our model ahead-of-time. We will use `input` created above to pass +# to `torch.export`. This will yield a `torch.export.ExportedProgram` which we can compile. + +path = torch._inductor.aoti_compile_and_package( + torch.export.export( + model, args=input, kwargs={}, + ) +) + +#################################################### +# We can load from this `path` and use it to perform inference. + +compiled_binary = torch._inductor.aoti_load_package(path) +output_compiled = compiled_binary(input) +print(f"{output_compiled.shape=}") + +################################################### +# Compiling model regions ahead-of-time, on the other hand, requires a few key changes. +# +# Since the compute pattern is shared by all the blocks that +# are repeated in a model (``Layer`` instances in this cases), we can just +# compile a single block and let the inductor reuse it. + +model = Model().cuda() +path = torch._inductor.aoti_compile_and_package( + torch.export.export( + model.layers[0], + args=input, + kwargs={}, + inductor_configs={ + # compile artifact w/o saving params in the artifact + "aot_inductor.package_constants_in_so": False, + } + ) +) + +################################################### +# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation, +# a state_dict containing tensor values of all lifted parameters and buffer alongside +# other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to +# not serialize the model parameters in the generated artifact. +# +# Now, when loading the compiled binary, we can reuse the existing parameters of +# each block. This lets us take advantage of the compiled binary obtained above. +# + +for layer in model.layers: + compiled_layer = torch._inductor.aoti_load_package(path) + compiled_layer.load_constants( + layer.state_dict(), check_full_update=True, user_managed=True + ) + layer.forward = compiled_layer + +##################################################### +# Just like JiT regional compilation, compiling regions within a model ahead-of-time +# leads to significantly reduced cold start times. The actual number will vary from +# model to model. +# +# Even though full model compilation offers the fullest scope of optimizations, +# for practical purposes and depending on the type of model, we have seen regional +# compilation (both JiT and AoT) providing similar speed benefits, while drastically +# reducing the cold start times. + +################################################### +# Next, let's measure the compilation time of the full model and the regional compilation. +# +# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation. +# In the code below, we measure the total time spent in the first invocation. While this method is not +# precise, it provides a good estimate since the majority of the time is spent in +# compilation. + + +def measure_latency(fn, input): + # Reset the compiler caches to ensure no reuse between different runs + torch.compiler.reset() + with torch._inductor.utils.fresh_inductor_cache(): + start = perf_counter() + fn(input) + torch.cuda.synchronize() + end = perf_counter() + return end - start + +def aot_compile_model(regional=False): + input = torch.randn(10, 10, device="cuda") + model = Model().cuda() + + inductor_configs = {} + if regional: + inductor_configs = {"aot_inductor.package_constants_in_so": False} + path = torch._inductor.aoti_compile_and_package( + torch.export.export( + model.layers[0] if regional else model, + args=input, + kwargs={}, + inductor_configs=inductor_configs, + ) + ) + + if regional: + for layer in model.layers: + compiled_layer = torch._inductor.aoti_load_package(path) + compiled_layer.load_constants( + layer.state_dict(), check_full_update=True, user_managed=True + ) + layer.forward = compiled_layer + else: + compiled_layer = torch._inductor.aoti_load_package(path) + + return model + +input = torch.randn(10, 10, device="cuda") +full_model_compilation_latency = measure_latency(aot_compile_model(), input) +print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds") + +regional_compilation_latency = measure_latency(aot_compile_model(regional=True), input) +print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds") + +assert regional_compilation_latency < full_model_compilation_latency + +############################################################################ +# Conclusion +# ----------- +# +# This recipe shows how to control the cold start time when compiling your model ahead-of-time. +# This becomes effective when your model has repeated blocks, like typically seen in large generative +# models. From 4a870c862b6c3f01f10b53fd312f25eebee6a9c5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 5 Sep 2025 09:59:00 +0530 Subject: [PATCH 02/12] Apply suggestions from code review Co-authored-by: Angela Yi --- recipes_source/regional_aot.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 845518b6f9..243eae56ac 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -3,21 +3,21 @@ Reducing AoT cold start compilation time with regional compilation ============================================================================ -**Author:** `Sayak Paul `, `Charles Bensimon `, `Angela Yi ` +**Author:** `Sayak Paul `, `Charles Bensimon `, `Angela Yi ` In our [regional compilation recipe](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html), we showed how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for just-in-time (JiT) compilation. This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you -are not familiar with AoT and `torch.export`, we recommend you to check out [this tutorial](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html). +are not familiar with AOTInductor and `torch.export`, we recommend you to check out [this tutorial](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html). Prerequisites ---------------- * Pytorch 2.6 or later * Familiarity with regional compilation -* Familiarity with AoT and `torch.export` +* Familiarity with AOTInductor and `torch.export` Setup ----- @@ -109,9 +109,7 @@ def forward(self, x): # to `torch.export`. This will yield a `torch.export.ExportedProgram` which we can compile. path = torch._inductor.aoti_compile_and_package( - torch.export.export( - model, args=input, kwargs={}, - ) + torch.export.export(model, args=(input,)) ) #################################################### @@ -130,15 +128,11 @@ def forward(self, x): model = Model().cuda() path = torch._inductor.aoti_compile_and_package( - torch.export.export( - model.layers[0], - args=input, - kwargs={}, - inductor_configs={ - # compile artifact w/o saving params in the artifact - "aot_inductor.package_constants_in_so": False, - } - ) + torch.export.export(model.layers[0], args=(input,)), + inductor_configs={ + # compile artifact w/o saving params in the artifact + "aot_inductor.package_constants_in_so": False, + } ) ################################################### @@ -197,10 +191,9 @@ def aot_compile_model(regional=False): path = torch._inductor.aoti_compile_and_package( torch.export.export( model.layers[0] if regional else model, - args=input, - kwargs={}, - inductor_configs=inductor_configs, - ) + args=(input,) + ), + inductor_configs=inductor_configs, ) if regional: From 0c45ef9dde7f6f457f6002fdb709f6e4f2a25971 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 5 Sep 2025 08:11:35 +0200 Subject: [PATCH 03/12] reviewer feedback and fixes. --- recipes_source/regional_aot.py | 76 ++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 243eae56ac..d8d93db55b 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -152,6 +152,9 @@ def forward(self, x): ) layer.forward = compiled_layer +output_regional_compiled = model(input) +print(f"{output_regional_compiled.shape=}") + ##################################################### # Just like JiT regional compilation, compiling regions within a model ahead-of-time # leads to significantly reduced cold start times. The actual number will vary from @@ -171,56 +174,65 @@ def forward(self, x): # compilation. -def measure_latency(fn, input): - # Reset the compiler caches to ensure no reuse between different runs - torch.compiler.reset() - with torch._inductor.utils.fresh_inductor_cache(): - start = perf_counter() - fn(input) - torch.cuda.synchronize() - end = perf_counter() - return end - start +def measure_compile_time(input, regional=False): + start = perf_counter() + model = aot_compile_load_model(regional=regional) + torch.cuda.synchronize() + end = perf_counter() + # make sure the model works. + _ = model(input) + return end - start -def aot_compile_model(regional=False): +def aot_compile_load_model(regional=False) -> torch.nn.Module: input = torch.randn(10, 10, device="cuda") model = Model().cuda() inductor_configs = {} if regional: inductor_configs = {"aot_inductor.package_constants_in_so": False} - path = torch._inductor.aoti_compile_and_package( - torch.export.export( - model.layers[0] if regional else model, - args=(input,) - ), - inductor_configs=inductor_configs, - ) - - if regional: - for layer in model.layers: - compiled_layer = torch._inductor.aoti_load_package(path) - compiled_layer.load_constants( - layer.state_dict(), check_full_update=True, user_managed=True - ) - layer.forward = compiled_layer - else: - compiled_layer = torch._inductor.aoti_load_package(path) + # Reset the compiler caches to ensure no reuse between different runs + torch.compiler.reset() + with torch._inductor.utils.fresh_inductor_cache(): + path = torch._inductor.aoti_compile_and_package( + torch.export.export( + model.layers[0] if regional else model, + args=(input,) + ), + inductor_configs=inductor_configs, + ) + + if regional: + for layer in model.layers: + compiled_layer = torch._inductor.aoti_load_package(path) + compiled_layer.load_constants( + layer.state_dict(), check_full_update=True, user_managed=True + ) + layer.forward = compiled_layer + else: + model = torch._inductor.aoti_load_package(path) return model input = torch.randn(10, 10, device="cuda") -full_model_compilation_latency = measure_latency(aot_compile_model(), input) +full_model_compilation_latency = measure_compile_time(input, regional=False) print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds") -regional_compilation_latency = measure_latency(aot_compile_model(regional=True), input) +regional_compilation_latency = measure_compile_time(input, regional=True) print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds") assert regional_compilation_latency < full_model_compilation_latency +############################################################################ +# There may also be layers in a model incompatible with compilation. So, +# full compilation will result in a fragmented computation graph resulting +# in potential latency degradation. In these case, regional compilation +# can be beneficial. +# + ############################################################################ # Conclusion # ----------- # -# This recipe shows how to control the cold start time when compiling your model ahead-of-time. -# This becomes effective when your model has repeated blocks, like typically seen in large generative -# models. +# This recipe shows how to control the cold start time when compiling your +# model ahead-of-time.This becomes effective when your model has repeated +# blocks, like typically seen in large generative models. From cb620cf34b5e2b568e3aa804106d96861085166a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Sep 2025 21:31:40 +0530 Subject: [PATCH 04/12] up --- recipes_source/regional_aot.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index d8d93db55b..52ae717d1d 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -10,14 +10,14 @@ just-in-time (JiT) compilation. This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you -are not familiar with AOTInductor and `torch.export`, we recommend you to check out [this tutorial](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html). +are not familiar with AOTInductor and ``torch.export``, we recommend you to check out [this tutorial](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html). Prerequisites ---------------- * Pytorch 2.6 or later * Familiarity with regional compilation -* Familiarity with AOTInductor and `torch.export` +* Familiarity with AOTInductor and ``torch.export`` Setup ----- @@ -85,7 +85,7 @@ def __init__(self): self.layers = torch.nn.ModuleList([Layer() for _ in range(64)]) def forward(self, x): - # In regional compilation, the self.linear is outside of the scope of `torch.compile`. + # In regional compilation, the self.linear is outside of the scope of ``torch.compile``. x = self.linear(x) for layer in self.layers: x = layer(x) @@ -96,7 +96,7 @@ def forward(self, x): # Since we're compiling the model ahead-of-time, we need to prepare representative # input examples, that we expect the model to see during actual deployments. # -# Let's create an instance of `Model` and pass it some sample input data. +# Let's create an instance of ``Model`` and pass it some sample input data. # model = Model().cuda() @@ -105,8 +105,8 @@ def forward(self, x): print(f"{output.shape=}") #################################################### -# Now, let's compile our model ahead-of-time. We will use `input` created above to pass -# to `torch.export`. This will yield a `torch.export.ExportedProgram` which we can compile. +# Now, let's compile our model ahead-of-time. We will use ``input`` created above to pass +# to ``torch.export``. This will yield a ``torch.export.ExportedProgram`` which we can compile. path = torch._inductor.aoti_compile_and_package( torch.export.export(model, args=(input,)) @@ -136,7 +136,7 @@ def forward(self, x): ) ################################################### -# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation, +# An exported program (```torch.export.ExportedProgram```) contains the Tensor computation, # a state_dict containing tensor values of all lifted parameters and buffer alongside # other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to # not serialize the model parameters in the generated artifact. @@ -168,7 +168,7 @@ def forward(self, x): ################################################### # Next, let's measure the compilation time of the full model and the regional compilation. # -# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation. +# ```torch.compile``` is a JIT compiler, which means that it compiles on the first invocation. # In the code below, we measure the total time spent in the first invocation. While this method is not # precise, it provides a good estimate since the majority of the time is spent in # compilation. From b1ef178214808c03e9e44b122b4585f68a0ebd02 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Sep 2025 21:32:54 +0530 Subject: [PATCH 05/12] up --- recipes_source/regional_aot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 52ae717d1d..d5e586a561 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -113,7 +113,7 @@ def forward(self, x): ) #################################################### -# We can load from this `path` and use it to perform inference. +# We can load from this ``path`` and use it to perform inference. compiled_binary = torch._inductor.aoti_load_package(path) output_compiled = compiled_binary(input) @@ -136,7 +136,7 @@ def forward(self, x): ) ################################################### -# An exported program (```torch.export.ExportedProgram```) contains the Tensor computation, +# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation, # a state_dict containing tensor values of all lifted parameters and buffer alongside # other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to # not serialize the model parameters in the generated artifact. @@ -168,7 +168,7 @@ def forward(self, x): ################################################### # Next, let's measure the compilation time of the full model and the regional compilation. # -# ```torch.compile``` is a JIT compiler, which means that it compiles on the first invocation. +# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation. # In the code below, we measure the total time spent in the first invocation. While this method is not # precise, it provides a good estimate since the majority of the time is spent in # compilation. From 702b218dc503c447817d222790e888c95c3a0dcf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 5 Sep 2025 21:34:33 +0530 Subject: [PATCH 06/12] Apply suggestions from code review Co-authored-by: Svetlana Karslioglu --- recipes_source/regional_aot.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index d5e586a561..66d3ae365c 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -5,9 +5,9 @@ **Author:** `Sayak Paul `, `Charles Bensimon `, `Angela Yi ` -In our [regional compilation recipe](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html), we showed +In the [regional compilation recipe](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html), we showed how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for -just-in-time (JiT) compilation. +just-in-time (JIT) compilation. This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you are not familiar with AOTInductor and ``torch.export``, we recommend you to check out [this tutorial](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html). @@ -38,7 +38,7 @@ # Steps # ----- # -# In this recipe, we will follow pretty much the same steps as the regional compilation recipe mentioned above: +# In this recipe, we will follow the same steps as the regional compilation recipe mentioned above: # # 1. Import all necessary libraries. # 2. Define and initialize a neural network with repeated regions. @@ -137,7 +137,7 @@ def forward(self, x): ################################################### # An exported program (``torch.export.ExportedProgram``) contains the Tensor computation, -# a state_dict containing tensor values of all lifted parameters and buffer alongside +# a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside # other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to # not serialize the model parameters in the generated artifact. # @@ -156,7 +156,7 @@ def forward(self, x): print(f"{output_regional_compiled.shape=}") ##################################################### -# Just like JiT regional compilation, compiling regions within a model ahead-of-time +# Just like JIT regional compilation, compiling regions within a model ahead-of-time # leads to significantly reduced cold start times. The actual number will vary from # model to model. # @@ -234,5 +234,5 @@ def aot_compile_load_model(regional=False) -> torch.nn.Module: # ----------- # # This recipe shows how to control the cold start time when compiling your -# model ahead-of-time.This becomes effective when your model has repeated -# blocks, like typically seen in large generative models. +# model ahead-of-time. This becomes effective when your model has repeated +# blocks, which is typically seen in large generative models. From 551c6fda4a4c41fbb470c4044c29c13a60306d7c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 6 Sep 2025 07:36:23 +0530 Subject: [PATCH 07/12] headings and subheadings. --- recipes_source/regional_aot.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 66d3ae365c..50d23b27dc 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -32,8 +32,6 @@ This feature is available starting with the 2.6 release. """ -from time import perf_counter - ###################################################################### # Steps # ----- @@ -46,13 +44,16 @@ # # First, let's import the necessary libraries for loading our data: # -# import torch torch.set_grad_enabled(False) +from time import perf_counter -########################################################## +################################################################################### +# Defining the Neural Network +# --------------------------- +# # We will use the same neural network structure as the regional compilation recipe. # # We will use a network, composed of repeated layers. This mimics a @@ -92,7 +93,10 @@ def forward(self, x): return x -#################################################### +################################################################################## +# Compiling the model ahead-of-time +# --------------------------------- +# # Since we're compiling the model ahead-of-time, we need to prepare representative # input examples, that we expect the model to see during actual deployments. # @@ -104,7 +108,7 @@ def forward(self, x): output = model(input) print(f"{output.shape=}") -#################################################### +############################################################################################### # Now, let's compile our model ahead-of-time. We will use ``input`` created above to pass # to ``torch.export``. This will yield a ``torch.export.ExportedProgram`` which we can compile. @@ -112,14 +116,17 @@ def forward(self, x): torch.export.export(model, args=(input,)) ) -#################################################### +################################################################# # We can load from this ``path`` and use it to perform inference. compiled_binary = torch._inductor.aoti_load_package(path) output_compiled = compiled_binary(input) print(f"{output_compiled.shape=}") -################################################### +###################################################################################### +# Compiling _regions_ of the model ahead-of-time +# ---------------------------------------------- +# # Compiling model regions ahead-of-time, on the other hand, requires a few key changes. # # Since the compute pattern is shared by all the blocks that @@ -166,13 +173,10 @@ def forward(self, x): # reducing the cold start times. ################################################### +# Measuring compilation time +# -------------------------- # Next, let's measure the compilation time of the full model and the regional compilation. # -# ``torch.compile`` is a JIT compiler, which means that it compiles on the first invocation. -# In the code below, we measure the total time spent in the first invocation. While this method is not -# precise, it provides a good estimate since the majority of the time is spent in -# compilation. - def measure_compile_time(input, regional=False): start = perf_counter() From 18af854d0f64c4d54f02f2f764a69f499ca8ab98 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 6 Sep 2025 07:38:57 +0530 Subject: [PATCH 08/12] remove pt 2.6 note --- recipes_source/regional_aot.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 50d23b27dc..4f1302871f 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -27,9 +27,6 @@ .. code-block:: sh pip install torch - -.. note:: - This feature is available starting with the 2.6 release. """ ###################################################################### From 6e3c76309ed40ada1b91c71b5ec7c044eedf7b61 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 6 Sep 2025 07:39:21 +0530 Subject: [PATCH 09/12] Update recipes_source/regional_aot.py Co-authored-by: Svetlana Karslioglu --- recipes_source/regional_aot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 4f1302871f..9efcf93b15 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -3,7 +3,7 @@ Reducing AoT cold start compilation time with regional compilation ============================================================================ -**Author:** `Sayak Paul `, `Charles Bensimon `, `Angela Yi ` +**Author:** `Sayak Paul `_, `Charles Bensimon `_, `Angela Yi `_ In the [regional compilation recipe](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html), we showed how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for From 767fd3ff09c595fb7bbc66ec0978b6d7c78c7404 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 6 Sep 2025 07:41:34 +0530 Subject: [PATCH 10/12] add entry to card --- recipes_index.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/recipes_index.rst b/recipes_index.rst index 53239633b6..66db2d7e43 100644 --- a/recipes_index.rst +++ b/recipes_index.rst @@ -333,6 +333,13 @@ from our full-length tutorials. :link: recipes/distributed_comm_debug_mode.html :tags: Distributed-Training +.. customcarditem:: + :header: Reducing AoT cold start compilation time with regional compilation + :card_description: Learn how to use regional compilation to control AoT cold start compile time + :image: _static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: recipes/regional_aot.html + :tags: Model-Optimization + .. End of tutorial card section .. ----------------------------------------- @@ -378,6 +385,7 @@ from our full-length tutorials. recipes/torch_compile_caching_tutorial recipes/torch_compile_caching_configuration_tutorial recipes/regional_compilation + recipes/regional_aot recipes/intel_extension_for_pytorch.html recipes/intel_neural_compressor_for_pytorch recipes/distributed_device_mesh From ad48868c53a6f56056e5694afabbe76df931d351 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 6 Sep 2025 07:46:26 +0530 Subject: [PATCH 11/12] lint --- recipes_source/regional_aot.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 9efcf93b15..3cd9659558 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -50,7 +50,7 @@ ################################################################################### # Defining the Neural Network # --------------------------- -# +# # We will use the same neural network structure as the regional compilation recipe. # # We will use a network, composed of repeated layers. This mimics a @@ -93,12 +93,12 @@ def forward(self, x): ################################################################################## # Compiling the model ahead-of-time # --------------------------------- -# +# # Since we're compiling the model ahead-of-time, we need to prepare representative # input examples, that we expect the model to see during actual deployments. -# +# # Let's create an instance of ``Model`` and pass it some sample input data. -# +# model = Model().cuda() input = torch.randn(10, 10, device="cuda") @@ -123,7 +123,7 @@ def forward(self, x): ###################################################################################### # Compiling _regions_ of the model ahead-of-time # ---------------------------------------------- -# +# # Compiling model regions ahead-of-time, on the other hand, requires a few key changes. # # Since the compute pattern is shared by all the blocks that @@ -141,13 +141,13 @@ def forward(self, x): ################################################### # An exported program (``torch.export.ExportedProgram``) contains the Tensor computation, -# a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside +# a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside # other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to # not serialize the model parameters in the generated artifact. # # Now, when loading the compiled binary, we can reuse the existing parameters of # each block. This lets us take advantage of the compiled binary obtained above. -# +# for layer in model.layers: compiled_layer = torch._inductor.aoti_load_package(path) @@ -187,17 +187,17 @@ def measure_compile_time(input, regional=False): def aot_compile_load_model(regional=False) -> torch.nn.Module: input = torch.randn(10, 10, device="cuda") model = Model().cuda() - + inductor_configs = {} if regional: inductor_configs = {"aot_inductor.package_constants_in_so": False} - + # Reset the compiler caches to ensure no reuse between different runs torch.compiler.reset() with torch._inductor.utils.fresh_inductor_cache(): path = torch._inductor.aoti_compile_and_package( torch.export.export( - model.layers[0] if regional else model, + model.layers[0] if regional else model, args=(input,) ), inductor_configs=inductor_configs, @@ -224,16 +224,16 @@ def aot_compile_load_model(regional=False) -> torch.nn.Module: assert regional_compilation_latency < full_model_compilation_latency ############################################################################ -# There may also be layers in a model incompatible with compilation. So, +# There may also be layers in a model incompatible with compilation. So, # full compilation will result in a fragmented computation graph resulting # in potential latency degradation. In these case, regional compilation # can be beneficial. -# +# ############################################################################ # Conclusion # ----------- # -# This recipe shows how to control the cold start time when compiling your +# This recipe shows how to control the cold start time when compiling your # model ahead-of-time. This becomes effective when your model has repeated # blocks, which is typically seen in large generative models. From 3db167f0e75ec1d90f9e49f93d6154efcd25bcc8 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 9 Sep 2025 07:47:49 +0530 Subject: [PATCH 12/12] Apply suggestions from code review Co-authored-by: Svetlana Karslioglu --- recipes_source/regional_aot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipes_source/regional_aot.py b/recipes_source/regional_aot.py index 3cd9659558..cba082519e 100644 --- a/recipes_source/regional_aot.py +++ b/recipes_source/regional_aot.py @@ -5,12 +5,12 @@ **Author:** `Sayak Paul `_, `Charles Bensimon `_, `Angela Yi `_ -In the [regional compilation recipe](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html), we showed +In the `regional compilation recipe `__, we showed how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for just-in-time (JIT) compilation. This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you -are not familiar with AOTInductor and ``torch.export``, we recommend you to check out [this tutorial](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html). +are not familiar with AOTInductor and ``torch.export``, we recommend you to check out `this tutorial `__. Prerequisites ----------------