Skip to content

Commit

Permalink
Fix node provenance tracking (pytorch#95901)
Browse files Browse the repository at this point in the history
Before:
```
triton_fused_add_83_add_84_convolution_15_relu_12_relu_13_squeeze_46_var_mean_15_14
```

After:
```
triton_fused_add_83_add_84_relu_13_squeeze_46_var_mean_15_14
```

For this kernel
```
@persistent_reduction(
    size_hints=[512, 64],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: '*fp32', 9: '*fp32', 10: 'i32', 11: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['in_out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), equal_to_1=())]}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr2, out_ptr3, out_ptr4, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 49
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (49*x0)), rmask & xmask, other=0)
    tmp8 = tl.load(in_ptr1 + (x0), xmask)
    tmp22 = tl.load(in_ptr2 + (x0), xmask)
    tmp24 = tl.load(in_ptr3 + (x0), xmask)
    tmp30 = tl.load(in_ptr4 + (x0), xmask)
    tmp2 = tl.where(rmask & xmask, tmp0, 0)
    tmp3 = tl.sum(tmp2, 1)[:, None]
    tmp4 = 49.0
    tmp5 = tmp3 / tmp4
    tmp6 = 0.1
    tmp7 = tmp5 * tmp6
    tmp9 = 0.9
    tmp10 = tmp8 * tmp9
    tmp11 = tmp7 + tmp10
    tmp12 = tmp0 - tmp5
    tmp13 = tmp12 * tmp12
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None]
    tmp17 = tmp16 / tmp4
    tmp18 = 1e-05
    tmp19 = tmp17 + tmp18
    tmp20 = tl.libdevice.rsqrt(tmp19)
    tmp21 = tmp12 * tmp20
    tmp23 = tmp21 * tmp22
    tmp25 = tmp23 + tmp24
    tmp26 = tl.where(0 != 0, 0, tl.where(0 > tmp25, 0, tmp25))
    tmp27 = 1.0208333333333333
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28 * tmp6
    tmp31 = tmp30 * tmp9
    tmp32 = tmp29 + tmp31
    tl.store(in_out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp5, xmask)
    tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp11, xmask)
    tl.store(out_ptr2 + (r1 + (49*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp26, rmask & xmask)
    tl.store(out_ptr3 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp20, xmask)
    tl.store(out_ptr4 + (x0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp32, xmask)
```

Tbh this still isn't super great provenance tracking, since ops like layernorms are decomposed. I might add some extra provenance tracking during decompositions.

Pull Request resolved: pytorch#95901
Approved by: https://github.com/jansel, https://github.com/mlazos
  • Loading branch information
Chillee authored and ydwu4 committed Mar 10, 2023
1 parent fd4468b commit ae6e2ee
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 39 deletions.
77 changes: 40 additions & 37 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,43 +319,42 @@ def placeholder(self, target: str, args, kwargs):
return tensor

def call_function(self, target, args, kwargs):
with ir.IRNode.current_origins(gather_origins(args, kwargs)):
if target is operator.getitem and isinstance(args[0], (list, tuple)):
return super().call_function(target, args, kwargs)

if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)

if target not in lowerings:
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
else MissingOperatorWithoutDecomp
)
log.info(
"Creating implicit fallback for:\n%s",
error.operator_str(target, args, kwargs),
)
make_fallback(target)
elif get_decompositions([target]):
# There isn't a good way to dynamically patch this in
# since AOT Autograd already ran. The error message tells
# the user how to fix it.
raise MissingOperatorWithDecomp(target, args, kwargs)
else:
raise MissingOperatorWithoutDecomp(target, args, kwargs)
if target is operator.getitem and isinstance(args[0], (list, tuple)):
return super().call_function(target, args, kwargs)

if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)

if target not in lowerings:
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
else MissingOperatorWithoutDecomp
)
log.info(
"Creating implicit fallback for:\n%s",
error.operator_str(target, args, kwargs),
)
make_fallback(target)
elif get_decompositions([target]):
# There isn't a good way to dynamically patch this in
# since AOT Autograd already ran. The error message tells
# the user how to fix it.
raise MissingOperatorWithDecomp(target, args, kwargs)
else:
raise MissingOperatorWithoutDecomp(target, args, kwargs)

try:
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
log.exception("Error from lowering")
raise LoweringException(e, target, args, kwargs) from e
try:
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
log.exception("Error from lowering")
raise LoweringException(e, target, args, kwargs) from e

def get_attr(self, target, args, kwargs):
# this is a constant
Expand Down Expand Up @@ -422,7 +421,11 @@ def finalize(self):
buf.decide_layout()

def run_node(self, n: torch.fx.Node):
with ir.IRNode.current_origins({n}):
origins = {n}
if n.op == "call_function":
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins):
if n.op == "call_function" and n.target in layout_constraints:
args, kwargs = self.fetch_args_kwargs_from_env(n)
args, kwargs = layout_constraints[n.target](n, *args, **kwargs)
Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,14 @@ def get_fused_kernel_name(node_schedule):
def gather_origins(args, kwargs):
import itertools

from .ir import ComputedBuffer, IRNode
from . import ir

def is_unrealized_node(n):
return isinstance(n, IRNode) and not isinstance(n, ComputedBuffer)
if isinstance(n, ir.TensorBox):
return is_unrealized_node(n.data)
if isinstance(n, ir.StorageBox):
return is_unrealized_node(n.data)
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)

kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
Expand Down

0 comments on commit ae6e2ee

Please sign in to comment.