Skip to content

Commit

Permalink
Update on "Make CI error on inductor fallback when decomp is available"
Browse files Browse the repository at this point in the history
Fixes #99446 

Remove the warning, as that annoyed end-users who don't know what to do about it.

Instead, try to hold the line by preventing any decomp from being added without making
the corresponding change to inductor's fallbacks.

Note: we probably still need to better document how to update inductor's decomps,
for now it's pretty much "go ask the inductor team for advice"

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
  • Loading branch information
wconstab committed Apr 19, 2023
2 parents 00566c1 + 6914d17 commit c2bb55a
Show file tree
Hide file tree
Showing 20 changed files with 202 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mnasnet1_0,pass,0
mobilenet_v2,pass,0
mobilenet_v3_large,pass,0
nvidia_deeprecommender,pass,0
opacus_cifar10,pass,44
phlippe_densenet,pass,0
phlippe_resnet,pass,0
pyhpc_isoneutral_mixing,pass,0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ maml_omniglot,pass,9
mnasnet1_0,pass,9
mobilenet_v2,pass,9
nvidia_deeprecommender,pass,9
opacus_cifar10,pass,44
phlippe_densenet,pass,9
phlippe_resnet,pass,9
pytorch_CycleGAN_and_pix2pix,pass,9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mnasnet1_0,pass,0
mobilenet_v2,pass,0
mobilenet_v3_large,pass,0
nvidia_deeprecommender,pass,0
opacus_cifar10,pass,44
phlippe_densenet,pass,0
phlippe_resnet,pass,0
pyhpc_isoneutral_mixing,pass,0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ maml_omniglot,pass,9
mnasnet1_0,pass,9
mobilenet_v2,pass,9
nvidia_deeprecommender,pass,9
opacus_cifar10,pass,44
phlippe_densenet,pass,9
phlippe_resnet,pass,9
pytorch_CycleGAN_and_pix2pix,pass,9
Expand Down
1 change: 1 addition & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class CI(NamedTuple):
"cait_m36_384", # Accuracy
"pnasnet5large", # OOM
"xcit_large_24_p8_224", # OOM https://github.com/pytorch/pytorch/issues/95984
"opacus_cifar10", # Fails to run https://github.com/pytorch/pytorch/issues/99201
]

CI_SKIP[CI("inductor", training=True)] = [
Expand Down
4 changes: 0 additions & 4 deletions benchmarks/dynamo/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def setup_torchbench_cwd():
"detectron2_maskrcnn",
# https://github.com/pytorch/torchdynamo/issues/145
"fambench_xlmr",
# https://github.com/pytorch/pytorch/issues/99201
"opacus_cifar10",
# TIMEOUT, https://github.com/pytorch/pytorch/issues/98467
"tacotron2",
# https://github.com/pytorch/pytorch/issues/99438
Expand All @@ -93,8 +91,6 @@ def setup_torchbench_cwd():
"pyhpc_equation_of_state",
"pyhpc_isoneutral_mixing",
"pyhpc_turbulent_kinetic_energy",
# Unusual training setup
"opacus_cifar10",
"maml",
# segfault: Internal Triton PTX codegen error
"timm_efficientdet",
Expand Down
2 changes: 1 addition & 1 deletion c10/cuda/CUDACachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ class CachingAllocatorConfig {
}

static bool expandable_segments() {
#ifndef EXPANDABLE_SEGMENTS_SUPPORTED
#ifndef PYTORCH_EXPANDABLE_SEGMENTS_SUPPORTED
if (instance().m_expandable_segments) {
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
}
Expand Down
43 changes: 43 additions & 0 deletions test/inductor/indirect_assert_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import sys

import torch


def first_arg(x, y):
return x[y]


def second_arg(x, y):
return x[:, y]


def same_pm_one(x, y):
return x[y + 1, y - 1]


def same_pp_one(x, y):
return x[y + 1, y + 1]


def store(x, y, z):
x[y + 1, y + 1] = z


if __name__ == "__main__":
_, fn_name, dims, dyn_shape = sys.argv
assert fn_name in ("first_arg", "second_arg", "same_pm_one", "same_pp_one", "store")
assert dims in ("2", "3")
shape_x = (3, 2, 4) if dims == "3" else (3, 2)
assert dyn_shape in ("True", "False")
dynamic_shapes = dyn_shape == "True"

x = torch.randn(shape_x, device="cuda")
y = torch.arange(4, device="cuda")
fn = vars()[fn_name]
fn = torch.compile(dynamic=dynamic_shapes)(fn)
if fn_name == "store":
shape = (y.numel(),) + x.shape[2:]
z = torch.randn(shape, device="cuda")
fn(x, y, z)
else:
fn(x, y)
21 changes: 21 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math
import os
import random
import subprocess
import sys
import time
import typing
Expand Down Expand Up @@ -6123,6 +6124,26 @@ def fn(x: torch.Tensor) -> torch.Tensor:
inps = torch.randn([5, 5])
fn_opt(inps)

def test_indirect_device_assert(self):
dir_path = os.path.dirname(os.path.realpath(__file__))
test_path = os.path.join(dir_path, "indirect_assert_helper.py")
fns = ("first_arg", "store", "second_arg", "same_pm_one", "same_pp_one")

for fn, ndims, dyn_shape in itertools.product(fns, (2, 3), (True, False)):
proc = subprocess.Popen(
[sys.executable, test_path, fn, str(ndims), str(dyn_shape)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stderr = proc.communicate()[1]
self.assertTrue(
any(
"index out of bounds" in err.decode("utf-8")
for err in stderr.splitlines()
),
f"{fn}, {ndims}, {dyn_shape}",
)


if HAS_CUDA and not TEST_WITH_ASAN:

Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.elu_backward,
aten._embedding_bag,
aten.embedding_dense_backward,
aten._euclidean_dist.default,
aten.expand_as,
aten.eye,
aten.fill,
Expand Down
12 changes: 9 additions & 3 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,20 +601,26 @@ def generate(
buffer: IndentedBuffer,
expr: typing.Union[str, CSEVariable],
write=True,
assignment=True,
) -> CSEVariable:
assert isinstance(expr, (str, CSEVariable)), type(expr)
assert write or assignment
if isinstance(expr, CSEVariable):
return expr
cache_key = expr
if cache_key not in self.cache:
var = self.newvar()
var = self.newvar() if assignment else None
self.cache[cache_key] = var
if write:
if V.kernel.current_node:
V.kernel.current_node.codegen_originating_info(
buffer, only_once=True
)
buffer.writeline(f"{self.prefix}{var} = {expr}{self.suffix}")
if assignment:
line = f"{self.prefix}{var} = {expr}{self.suffix}"
else:
line = f"{expr}{self.suffix}"
buffer.writeline(line)

return self.cache[cache_key]

Expand Down Expand Up @@ -721,7 +727,7 @@ def inner(*args, **kwargs):
return inner

@staticmethod
def indirect_indexing(index_var):
def indirect_indexing(index_var, size):
return sympy_symbol(str(index_var))

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,7 +1917,7 @@ def indexing_div_rep(x, y):
return tmp_var

@staticmethod
def indirect_indexing(index_var):
def indirect_indexing(index_var, size):
return sympy.Symbol(str(index_var))

@staticmethod
Expand Down
50 changes: 50 additions & 0 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,46 @@ def mask_loads(self, mask):
finally:
self._load_mask = prior

def gen_assert_indirect_indexing(self, buffer, original_index, mask):
if mask == "None":
return
body = self.current_node._body
indirect_size = dict(zip(body.indirect_vars, body.indirect_max_sizes))
indirect_name = body.indirect_new
# Many indirect variables may be mapped to the same CSE'd variable
# For example when you do x[y, y] for x = randn(3, 8)
var_size = collections.defaultdict(set)
for ind, size in indirect_size.items():
var_size[indirect_name[ind]].add(V.kernel.rename_indexing(size))

indirect_vars = [
s for s in original_index.free_symbols if s.name.startswith("tmp")
]
for var in indirect_vars:
sizes = list(var_size[var])
if all(isinstance(s, sympy.Integer) for s in sizes):
size = min(sizes)
else:
# Should this go here or in TritonPrinter?
def print_min(expr):
if len(expr) == 1:
return texpr(expr[0])
else:
return f"min({texpr(expr[0])}, {print_min(expr[1:])})"

size = print_min(sizes)
# The conditions need to be in parens because of Python's operator precedence.
# It'd be less # error-prone to use and/or/not, which is suported by triton
cond = f"((0 <= {var}) & ({var} < {size}))"
cond_print = f"0 <= {var} < {size}"
if not isinstance(original_index, sympy.Integer):
var_mask = f"({mask})" if "&" in mask else mask
var_mask = f" | ~{var_mask}"
else:
var_mask = ""
line = f'tl.device_assert(({cond}){var_mask}, "index out of bounds: {cond_print}")'
self.cse.generate(buffer, line, assignment=False)

def load(self, name: str, index: sympy.Expr):
var = self.args.input(name)
indirect_indexing = self.is_indirect_indexing(index)
Expand Down Expand Up @@ -1065,6 +1105,10 @@ def load(self, name: str, index: sympy.Expr):
else:
load_buffer = self.loads

# Assert that the loaded indices will not read garbage
if indirect_indexing and config.triton.assert_indirect_indexing:
self.gen_assert_indirect_indexing(load_buffer, original_index, mask)

result_var = self.cse.generate(load_buffer, line)
result_var.mask_vars = mask_vars

Expand All @@ -1079,7 +1123,13 @@ def load(self, name: str, index: sympy.Expr):

def store(self, name, index, value, mode=None):
var = self.args.output(name)
indirect_indexing = self.is_indirect_indexing(index)
original_index = index
index, mask_vars, mask, expand_str = self.indexing(index, dense_indexing=True)

if indirect_indexing and config.triton.assert_indirect_indexing:
self.gen_assert_indirect_indexing(self.stores, original_index, mask)

if mode is None:
line = f"tl.store({var} + ({index}), {value}, {mask})"
elif mode == "atomic_add":
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ class triton:
tiling_prevents_pointwise_fusion = True
tiling_prevents_reduction_fusion = True

# assert that indirect indexing does not read / write out of bounds
assert_indirect_indexing = True

# should we give different names to kernels
# Note: This is orthogonal to descriptive_names - this is deciding whether
# our triton kernel names should all be `triton_` (to maximize caching) or
Expand Down
14 changes: 10 additions & 4 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3858,6 +3858,8 @@ def __init__(self, fn, args, var_ranges):
self.submodules = {"get_index": self.get_index}
self.subblocks = {}
self.indirect_vars = []
self.indirect_max_sizes = []
self.indirect_new = {}
self.root_block = LoopBodyBlock(self, fn, args)
self.indexing = None

Expand Down Expand Up @@ -3893,16 +3895,18 @@ def add_submodule(self, block, prefix):
self.submodules[name] = block
return name

def add_indirect(self):
def add_indirect(self, size):
name = f"indirect{len(self.indirect_vars)}"
var = sympy_symbol(name)
self.indirect_vars.append(var)
self.indirect_max_sizes.append(size)
return var

def replace_indirect(self, old, new):
"""Swap in a variable used in indirect indexing"""
if str(old) == str(new):
return
self.indirect_new[old] = new
self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}

def get_index(self, name):
Expand Down Expand Up @@ -3981,16 +3985,18 @@ def shim(mask, other):
)

@staticmethod
def indirect_indexing(index_proxy):
def indirect_indexing(index_proxy, size):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
"""

def set_indirect(new_var):
self.body.replace_indirect(var, V.ops.indirect_indexing(new_var))
self.body.replace_indirect(
var, V.ops.indirect_indexing(new_var, size)
)

var = self.body.add_indirect()
var = self.body.add_indirect(size)
tracer.create_proxy(
"call_module",
self.body.add_submodule(set_indirect, f"set_{var}"),
Expand Down

0 comments on commit c2bb55a

Please sign in to comment.