Skip to content

Commit

Permalink
Revert "Make torch_geometric models compatible with export (#123403)"
Browse files Browse the repository at this point in the history
This reverts commit 2ffab6e.

Reverted #123403 on behalf of https://github.com/atalman due to Related issue basic_gnn_gin ([comment](#123403 (comment)))
  • Loading branch information
pytorchmergebot committed Apr 5, 2024
1 parent 5b0ce8f commit 8c7d8f0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 20 deletions.
6 changes: 3 additions & 3 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,12 +1135,12 @@ def load(cls, model, example_inputs, device):
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)

gm = torch.export._trace._export(
# TODO(angelayi): change this to predispatch
gm = torch.export._trace._export_to_torch_ir(
model,
example_args,
example_kwargs,
pre_dispatch=True,
).module()
)
with torch.no_grad():
so_path = torch._inductor.aot_compile(
gm, example_args, example_kwargs
Expand Down
17 changes: 0 additions & 17 deletions benchmarks/dynamo/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,6 @@
torch.backends.cuda.matmul.allow_tf32 = True


def _reassign_parameters(model):
# torch_geometric models register parameter as tensors due to
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
# Since it is unusual thing to do, we just reassign them to parameters
def state_dict_hook(module, destination, prefix, local_metadata):
for name, param in module.named_parameters():
if isinstance(destination[name], torch.Tensor) and not isinstance(
destination[name], torch.nn.Parameter
):
destination[name] = torch.nn.Parameter(destination[name])

model._register_state_dict_hook(state_dict_hook)


def setup_torchbench_cwd():
original_dir = abspath(os.getcwd())

Expand Down Expand Up @@ -279,9 +265,6 @@ def load_model(
extra_args=extra_args,
)
model, example_inputs = benchmark.get_module()
if model_name in ["basic_gnn_edgecnn", "basic_gnn_gcn", "basic_gnn_sage"]:
_reassign_parameters(model)

# Models that must be in train mode while training
if is_training and (
not use_eval_mode or model_name in self._config["only_training"]
Expand Down

0 comments on commit 8c7d8f0

Please sign in to comment.