Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/export/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,7 @@ def func3(x): # noqa: F841
)
# qnnpack not supported on s390x
@xfailIfS390X
def test_ts2ep_convert_quantized_model(self):
def test_ts2ep_convert_quantized_model1(self):
class Standalone(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
15 changes: 6 additions & 9 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,16 +640,13 @@ def forward(self, x):
self.assertExpectedInline(
without_token_ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None
getitem = with_effects[0]
getitem_1 = with_effects[1]
getitem_2 = with_effects[2]; with_effects = None
def forward(self, obj_attr, x):
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None
getitem_1 = takes_foo_tuple_return_default[0]
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None
getitem_3 = with_effects_1[0]
getitem_4 = with_effects_1[1]; with_effects_1 = None
return (getitem_3, getitem_4)""", # noqa: B950
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None
return (takes_foo_default,)""", # noqa: B950
)

def test_fakify_script_objects(self):
Expand Down
12 changes: 7 additions & 5 deletions test/export/test_torchbind.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,9 @@ def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
)
self.assertExpectedInline(
Expand Down Expand Up @@ -1087,10 +1087,12 @@ def forward(self, token, tq, x):
str(ep.graph_module.graph).strip(),
"""\
graph():
%token : [num_users=1] = placeholder[target=token]
%tq : [num_users=2] = placeholder[target=tq]
%x : [num_users=1] = placeholder[target=x]
%queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {})
return (tq,)""", # noqa: B950
%with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {})
return (getitem, tq)""", # noqa: B950
)

def test_deepcopy(self):
Expand Down
98 changes: 98 additions & 0 deletions test/higher_order_ops/test_with_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,104 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
finally:
handle.destroy()

@unittest.skipIf(not TEST_CUDA, "triton")
def test_export_invoke_subgraph(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
recorded_list = []

@torch.library.custom_op("mylib::record_memory", mutates_args=())
def record_memory(prefix: str, module_name: str) -> None:
torch.cuda.synchronize()
mem_alloc = torch.cuda.memory_allocated() / 1024**2
mem_reserved = torch.cuda.memory_reserved() / 1024**2
memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB"
recorded_list.append(memory_str)

@record_memory.register_fake
def record_memory_fake(prefix, module_name):
return

record_memory.register_effect(_EffectType.ORDERED)

class N(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(1024, 1024)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(1024, 1024)

@torch.compiler.nested_compile_region
def forward(self, x):
torch.ops.mylib.record_memory("forward", "N")
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod_list = torch.nn.ModuleList(N() for _ in range(3))

def forward(self, x):
for m in self.mod_list:
x = m(x)
torch.ops.mylib.record_memory("forward", "N")
return (x,)

model = M().to("cuda")
torch.cuda.reset_peak_memory_stats()

x = torch.randn(32, 1024, requires_grad=True, device="cuda")

ep = torch.export.export(model, (x,))
ep = ep.run_decompositions()
self.assertEqual(len(list(ep.graph_module.named_modules())), 2)

self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None
getitem = invoke_subgraph[0]
getitem_1 = invoke_subgraph[1]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None
getitem_2 = invoke_subgraph_1[0]
getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None
repeated_subgraph0_2 = self.repeated_subgraph0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None
getitem_4 = invoke_subgraph_2[0]
getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None
with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None
getitem_6 = with_effects[0]; with_effects = None
return (getitem_6, getitem_5)""",
)

self.assertExpectedInline(
ep.graph_module.repeated_subgraph0.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None
getitem = with_effects[0]; with_effects = None
permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None
addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None
relu = torch.ops.aten.relu.default(addmm); addmm = None
permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None
addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None
return (getitem, addmm_1)""",
)

recorded_list.clear()
# TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph
# eagerly twice. Once for get_output_metadata and then once for
# InvokeSubgraphAutogradOp. This causes record_memory to be called twice.
with torch.no_grad():
out2 = ep.module()(x)
self.assertEqual(len(recorded_list), 4)
self.assertTrue(torch.allclose(model(x)[0], out2[0]))


if __name__ == "__main__":
run_tests()
18 changes: 18 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,9 @@ def __init__(self) -> None:
self.lazy_bwd_cache: dict[
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
] = defaultdict(dict)
self.effects_cache: dict[
str, set
] = {} # Maps identifier -> set of effect types

def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
self.dynamo_installed_submodules[fn_id].append(identifier)
Expand Down Expand Up @@ -751,6 +754,21 @@ def get_lazy_bwd_entry(

return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None))

def add_effects(self, identifier: str, effects: set) -> None:
"""Store the effect types for a given invoke_subgraph identifier."""
if prev_effects := self.effects_cache.get(identifier, None):
assert effects == prev_effects, (
"Different number of effects were found for invoke_subgraph "
f"call with identifier {identifier}. \n"
f"Previously we had the following effects: {prev_effects}.\n"
f"But now we have: {effects}."
)
self.effects_cache[identifier] = effects

def get_effects(self, identifier: str) -> Optional[set]:
"""Retrieve the effect types for a given invoke_subgraph identifier."""
return self.effects_cache.get(identifier, None)


class HopDispatchSetCache:
def __init__(self) -> None:
Expand Down
50 changes: 50 additions & 0 deletions torch/_higher_order_ops/invoke_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __call__(
assert all(
isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator))
for o in operands
if o is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do you see None as input?

Copy link
Contributor Author

@angelayi angelayi Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The effect tokens are passed in as None here since we will eventually just discard these inputs.

), (
f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}"
)
Expand Down Expand Up @@ -562,7 +563,34 @@ def _(ctx, subgraph, identifier, *operands):
do_auto_functionalize_v2,
)

# (in the functionalization metadata phase) Capture tokens before
tokens_before = dict(ctx.mode._tokens)

# Check if this subgraph has effects stored in the cache
invoke_subgraph_cache = get_invoke_subgraph_cache()
effects = None
if invoke_subgraph_cache:
effects = invoke_subgraph_cache.get_effects(identifier)

if effects:
assert len(effects) == 1, "Multiple effects within a subgraph NYI"
tokens = ctx.mode._tokens
effects = next(iter(effects))
token_input = tokens[effects]

operands = (token_input, *operands)

def wrap_subgraph(subgraph):
def wrapped_subgraph(token, *args):
res = subgraph(*args)
return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res

return wrapped_subgraph

subgraph = wrap_subgraph(subgraph)

unwrapped_operands = ctx.unwrap_tensors(operands)

hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands)
if can_auto_functionalize(hop_instance):
# NOTE: [auto_functionalize x invoke_subgraph caching]
Expand All @@ -587,6 +615,28 @@ def _(ctx, subgraph, identifier, *operands):
# of invoke_subgraph ops if input aliasing/mutation is detected.
functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)

if effects:
(new_token, *out) = out
ctx.mode._tokens[effects] = new_token

# (in the functionalization metadata phase) Capture tokens after and see if
# there are any differences (there are new effects or the token value for an
# effect type has changed)
tokens_after = dict(ctx.mode._tokens)
discovered_effects = set()
for effect_type, token in tokens_after.items():
if effect_type not in tokens_before or tokens_before[effect_type] is not token:
discovered_effects.add(effect_type)

if discovered_effects:
assert ctx.mode._allow_token_discovery, (
f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}."
)
# Store discovered effects in the cache by identifier
if invoke_subgraph_cache:
invoke_subgraph_cache.add_effects(identifier, discovered_effects)

return ctx.wrap_tensors(out)


Expand Down
15 changes: 15 additions & 0 deletions torch/_library/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,28 @@ def _set_default_effect(self) -> None:
if namespace == "higher_order":
return

# These classes do not have side effects as they just store quantization
# params, so we dont need to mark them as ordered
skip_classes = (
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase",
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase",
"__torch__.torch.classes.quantized.EmbeddingPackedParamsBase",
"__torch__.torch.classes.quantized.LinearPackedParamsBase",
"__torch__.torch.classes.xnnpack.Conv2dOpContext",
"__torch__.torch.classes.xnnpack.LinearOpContext",
"__torch__.torch.classes.xnnpack.TransposeConv2dOpContext",
)

opname = f"{namespace}::{opname}"
if torch._C._get_operation_overload(opname, overload) is not None:
# Since we call this when destroying the library, sometimes the
# schema will be gone already at that time.
schema = torch._C._get_schema(opname, overload)
for arg in schema.arguments:
if isinstance(arg.type, torch.ClassType):
type_str = arg.type.str() # pyrefly: ignore[missing-attribute]
if type_str in skip_classes:
continue
self._effect = EffectType.ORDERED
return

Expand Down
Loading
Loading