Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@
# No longer used
optimize_ddp_lazy_compile = False

# lambda guarding on object aliasing to improve opportunity for dict tag
# optimization
use_lamba_guard_for_object_aliasing = True

# Whether to skip guarding on FSDP-managed modules
skip_fsdp_guards = True
# Whether to apply torch._dynamo.disable() to FSDP2 hooks.
Expand Down
39 changes: 34 additions & 5 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def __init__(
str, torch._C._dynamo.guards.GuardManager
] = {}
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
self.object_aliasing_guard_codes: list[tuple[str, str]] = []
self.serialization_mode = serialization_mode
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
"pytorch/compiler:guard_nn_modules"
Expand Down Expand Up @@ -1992,11 +1993,19 @@ def DUPLICATE_INPUT(self, guard, source_b):
code = [f"{ref_b} is {ref_a}"]
self._set_guard_export_info(guard, code)

install_object_aliasing_guard(
self.get_guard_manager(guard),
self.get_guard_manager_from_source(source_b),
get_verbose_code_parts(code, guard),
)
if config.use_lamba_guard_for_object_aliasing:
# Save the code part so that we can install a lambda guard at the
# end. Read the Note - On Lambda guarding of object aliasing - to
# get more information.
code_part = code[0]
verbose_code_part = get_verbose_code_parts(code_part, guard)[0]
self.object_aliasing_guard_codes.append((code_part, verbose_code_part))
else:
install_object_aliasing_guard(
self.get_guard_manager(guard),
self.get_guard_manager_from_source(source_b),
get_verbose_code_parts(code, guard),
)

def WEAKREF_ALIVE(self, guard):
if self.serialization_mode == "save":
Expand Down Expand Up @@ -3261,6 +3270,26 @@ def add_code_part(code_part, guard, log_only=False):
["check_no_aliasing(" + ", ".join(no_tensor_aliasing_names) + ")"],
)

# Note - On Lambda guarding of object aliasing
# We previously installed object‑aliasing guards as relational guards,
# but that undermined the recursive‑dict guard optimization: placing the
# aliasing guard at a leaf prevented the parent dict node from
# qualifying as a recursive‑dict guard root. Because aliasing guards are
# rare, we now emit them as epilogue guards via a small Python lambda.
# This repeats the access in Python—adding a bit of work—but the
# overhead is outweighed by the gains from enabling recursive‑dict guard
# optimization.
if (
config.use_lamba_guard_for_object_aliasing
and builder.object_aliasing_guard_codes
):
aliasing_code_parts, aliasing_verbose_code_parts = map(
list, zip(*builder.object_aliasing_guard_codes)
)
builder.add_python_lambda_leaf_guard_to_root(
aliasing_code_parts, aliasing_verbose_code_parts
)

aotautograd_guards: list[GuardEnvExpr] = (
self.output_graph.aotautograd_guards if self.output_graph else []
)
Expand Down
Loading