Skip to content

Commit

Permalink
Fix node provenance tracking
Browse files Browse the repository at this point in the history
ghstack-source-id: 5a6f27f09248ec72e913daedbb5b3680d66883b3
Pull Request resolved: #95901
  • Loading branch information
Chillee committed Mar 3, 2023
1 parent 4b5ad34 commit 8c7ca85
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 39 deletions.
77 changes: 40 additions & 37 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,43 +307,42 @@ def placeholder(self, target: str, args, kwargs):
return tensor

def call_function(self, target, args, kwargs):
with ir.IRNode.current_origins(gather_origins(args, kwargs)):
if target is operator.getitem and isinstance(args[0], (list, tuple)):
return super().call_function(target, args, kwargs)

if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)

if target not in lowerings:
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
else MissingOperatorWithoutDecomp
)
log.info(
"Creating implicit fallback for:\n%s",
error.operator_str(target, args, kwargs),
)
make_fallback(target)
elif get_decompositions([target]):
# There isn't a good way to dynamically patch this in
# since AOT Autograd already ran. The error message tells
# the user how to fix it.
raise MissingOperatorWithDecomp(target, args, kwargs)
else:
raise MissingOperatorWithoutDecomp(target, args, kwargs)
if target is operator.getitem and isinstance(args[0], (list, tuple)):
return super().call_function(target, args, kwargs)

if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)

if target not in lowerings:
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
else MissingOperatorWithoutDecomp
)
log.info(
"Creating implicit fallback for:\n%s",
error.operator_str(target, args, kwargs),
)
make_fallback(target)
elif get_decompositions([target]):
# There isn't a good way to dynamically patch this in
# since AOT Autograd already ran. The error message tells
# the user how to fix it.
raise MissingOperatorWithDecomp(target, args, kwargs)
else:
raise MissingOperatorWithoutDecomp(target, args, kwargs)

try:
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
log.exception("Error from lowering")
raise LoweringException(e, target, args, kwargs) from e
try:
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
log.exception("Error from lowering")
raise LoweringException(e, target, args, kwargs) from e

def get_attr(self, target, args, kwargs):
# this is a constant
Expand Down Expand Up @@ -407,7 +406,11 @@ def finalize(self):
buf.decide_layout()

def run_node(self, n: torch.fx.Node):
with ir.IRNode.current_origins({n}):
origins = {n}
if n.op == "call_function":
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins):
if n.op == "call_function" and n.target in layout_constraints:
args, kwargs = self.fetch_args_kwargs_from_env(n)
args, kwargs = layout_constraints[n.target](n, *args, **kwargs)
Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,14 @@ def get_fused_kernel_name(node_schedule):
def gather_origins(args, kwargs):
import itertools

from .ir import ComputedBuffer, IRNode
from . import ir

def is_unrealized_node(n):
return isinstance(n, IRNode) and not isinstance(n, ComputedBuffer)
if isinstance(n, ir.TensorBox):
return is_unrealized_node(n.data)
if isinstance(n, ir.StorageBox):
return is_unrealized_node(n.data)
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)

kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
Expand Down

0 comments on commit 8c7ca85

Please sign in to comment.