Skip to content

Commit

Permalink
[export] Fix duplicated params for AOTInductor. (#108354)
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
python benchmarks/dynamo/huggingface.py --bfloat16 --accuracy
--inference --device cuda --export --only  BertForMaskedLM

python benchmarks/dynamo/huggingface.py --bfloat16 --accuracy --inference --device cuda --export-aot-inductor --only  BertForMaskedLM

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: #108354
Approved by: https://github.com/angelayi, https://github.com/desertfire
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Sep 1, 2023
1 parent e18f512 commit d96446b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
16 changes: 16 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ def forward(self, a):
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
self.assertTrue(same(actual, expected))

def test_duplicated_params(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(torch.rand(6))
self.q = self.p

def forward(self, x):
return self.p * x + self.q

model = Model()
example_inputs = (torch.rand(6),)
expected = model(*example_inputs)
actual = torch._export.export(model, example_inputs)(*example_inputs)
self.assertTrue(same(actual, expected))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
24 changes: 14 additions & 10 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ def export(
UserErrorType.ANTI_PATTERN,
f"Consider annotating your code using constrain_as_*(). {str(e)}")

params_buffers: OrderedDict[str, Union[torch.Tensor, torch.nn.Parameter]] = OrderedDict()
for name, param in gm_torch_level.named_parameters():
params_buffers: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {}
for name, param in gm_torch_level.named_parameters(remove_duplicate=False):
params_buffers[name] = param

for name, buffer in gm_torch_level.named_buffers():
for name, buffer in gm_torch_level.named_buffers(remove_duplicate=False):
params_buffers[name] = buffer

fake_args, fake_kwargs, fake_mode = _convert_input_to_fake(gm_torch_level, args, kwargs)
Expand All @@ -343,7 +343,7 @@ def export(
# When aot_export lifts the params, we lose the nn_module_stack
# and source_fn from the param nodes as they are treated as fresh inputs
# Therefore, we manually extract them before calling into aot_export
params_buffers_to_node_meta = OrderedDict()
params_buffers_to_node_meta = {}
for node in gm_torch_level.graph.nodes:
target = node.target
meta = node.meta
Expand Down Expand Up @@ -388,19 +388,20 @@ def export(
)
gm_torch_level.recompile()

param_buffer_table = {}
param_buffer_table: Dict[str, str] = {}
if isinstance(f, torch.nn.Module):
param_lookup = {}
buffer_lookup = {}
for name, param in f.named_parameters():
param_lookup[id(param)] = name
for name, buffer in f.named_buffers():
buffer_lookup[id(buffer)] = name
for dynamo_name, dynamo_param in gm_torch_level.named_parameters():
for dynamo_name, dynamo_param in gm_torch_level.named_parameters(remove_duplicate=False):
assert dynamo_name not in param_buffer_table
if id(dynamo_param) in param_lookup:
param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)]
for dynamo_name, dynamo_buffer in gm_torch_level.named_buffers():

for dynamo_name, dynamo_buffer in gm_torch_level.named_buffers(remove_duplicate=False):
assert dynamo_name not in param_buffer_table
if id(dynamo_buffer) in buffer_lookup:
param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)]
Expand Down Expand Up @@ -491,8 +492,12 @@ def to_str_dict(sig_component: Dict[Any, Any]):
flat_args,
)

if isinstance(f, torch.nn.Module):
# TODO(zhxchen17) Properly handle duplicated buffers.
translated_params_buffers = {param_buffer_table.get(name, name): tensor for name, tensor in params_buffers.items()}
if isinstance(f, torch.nn.Module) and (len(translated_params_buffers) ==
len(export_graph_signature.parameters) + len(export_graph_signature.buffers)):
_replace_param_buffer_names(param_buffer_table, export_graph_signature)
params_buffers = translated_params_buffers

module_call_signatures = {fqn: ModuleCallSignature(inputs=[], outputs=[], **specs) for fqn, specs in module_call_specs.items()}

Expand All @@ -509,8 +514,7 @@ def to_str_dict(sig_component: Dict[Any, Any]):
# TODO(zhxchen17) Remove this field.
CallSpec(in_spec, orig_out_spec),
# TODO(zhxchen17) Return empty state_dict for functions.
{param_buffer_table.get(name, name): tensor for name, tensor in params_buffers.items()}
if isinstance(f, torch.nn.Module) else param_buffer_table,
params_buffers,
range_constraints,
equality_constraints,
[ModuleCallEntry("", ModuleCallSignature(inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=orig_out_spec))] +
Expand Down

0 comments on commit d96446b

Please sign in to comment.