Skip to content

Commit

Permalink
[reland][aotinductor] Add example_value metadata to nodes (#113986)
Browse files Browse the repository at this point in the history
Test Plan:
`TORCH_LOGS=dynamo,inductor,aot  CUDA_VISIBLE_DEVICES=7 TORCH_COMPILE_DEBUG=0 TORCHINDUCTOR_MAX_AUTOTUNE=1 buck2 run mode/opt-split-dwarf mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010  caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark -- --local-model /tmp/409501788/66/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend="AOT_INDUCTOR"`

Without passes:
`BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.32, Time per iter: 2.22ms, Threads: 1, QPS: 921146.83, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 66.15s`

With passes:
`BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.49, Time per iter: 2.21ms, Threads: 1, QPS: 925450.82, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 261.11s`

Differential Revision: D51436878

Pull Request resolved: #113986
Approved by: https://github.com/zhxchen17
  • Loading branch information
angelayi authored and pytorchmergebot committed Nov 19, 2023
1 parent 33c6cae commit 72a8329
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch._inductor
import torch.fx._pytree as fx_pytree
from torch._dynamo.testing import same
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.exc import CppWrapperCodeGenError
from torch._inductor.utils import aot_inductor_launcher, cache_dir
Expand Down Expand Up @@ -318,6 +319,21 @@ def forward(self, x, y):
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

def test_simple_split(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)

example_inputs = (torch.randn(2, 8, device=self.device),)
counters.clear()
self.check_model(Model(), example_inputs)
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)

def test_missing_output(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down
6 changes: 6 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,8 @@ def placeholder(self, target, args, kwargs):
arg.node.meta["val"] = self.current_node.meta["val"]
if "tensor_dict" in self.current_node.meta:
arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
if "example_value" in self.current_node.meta:
arg.node.meta["example_value"] = self.current_node.meta["example_value"]
return arg

def output(self, target, args, kwargs):
Expand All @@ -925,6 +927,10 @@ def run_node(self, n):
result_proxy = super().run_node(n)
if "val" in self.current_node.meta:
result_proxy.node.meta["val"] = self.current_node.meta["val"]
if "example_value" in self.current_node.meta:
result_proxy.node.meta["example_value"] = self.current_node.meta[
"example_value"
]
if self.current_node.op != "output":
result_proxy.node._rename(
getattr(self.current_node, "name", result_proxy.node.name)
Expand Down

0 comments on commit 72a8329

Please sign in to comment.