diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv index 9ec7553e55504..8e8c8902f92ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_inference.csv @@ -28,6 +28,7 @@ mnasnet1_0,pass,0 mobilenet_v2,pass,0 mobilenet_v3_large,pass,0 nvidia_deeprecommender,pass,0 +opacus_cifar10,pass,44 phlippe_densenet,pass,0 phlippe_resnet,pass,0 pyhpc_isoneutral_mixing,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv index fadc2c3597d6b..c94f6bb3621fe 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv @@ -23,6 +23,7 @@ maml_omniglot,pass,9 mnasnet1_0,pass,9 mobilenet_v2,pass,9 nvidia_deeprecommender,pass,9 +opacus_cifar10,pass,44 phlippe_densenet,pass,9 phlippe_resnet,pass,9 pytorch_CycleGAN_and_pix2pix,pass,9 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index 9ec7553e55504..8e8c8902f92ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -28,6 +28,7 @@ mnasnet1_0,pass,0 mobilenet_v2,pass,0 mobilenet_v3_large,pass,0 nvidia_deeprecommender,pass,0 +opacus_cifar10,pass,44 phlippe_densenet,pass,0 phlippe_resnet,pass,0 pyhpc_isoneutral_mixing,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index c11bdee229aa7..8b9453c92bb08 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -23,6 +23,7 @@ maml_omniglot,pass,9 mnasnet1_0,pass,9 mobilenet_v2,pass,9 nvidia_deeprecommender,pass,9 +opacus_cifar10,pass,44 phlippe_densenet,pass,9 phlippe_resnet,pass,9 pytorch_CycleGAN_and_pix2pix,pass,9 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index c7eaa77c50122..52bcb72701f0d 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -211,6 +211,7 @@ class CI(NamedTuple): "cait_m36_384", # Accuracy "pnasnet5large", # OOM "xcit_large_24_p8_224", # OOM https://github.com/pytorch/pytorch/issues/95984 + "opacus_cifar10", # Fails to run https://github.com/pytorch/pytorch/issues/99201 ] CI_SKIP[CI("inductor", training=True)] = [ diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index ebe5736ea3244..475581e9bc3b7 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -73,8 +73,6 @@ def setup_torchbench_cwd(): "detectron2_maskrcnn", # https://github.com/pytorch/torchdynamo/issues/145 "fambench_xlmr", - # https://github.com/pytorch/pytorch/issues/99201 - "opacus_cifar10", # TIMEOUT, https://github.com/pytorch/pytorch/issues/98467 "tacotron2", # https://github.com/pytorch/pytorch/issues/99438 @@ -93,8 +91,6 @@ def setup_torchbench_cwd(): "pyhpc_equation_of_state", "pyhpc_isoneutral_mixing", "pyhpc_turbulent_kinetic_energy", - # Unusual training setup - "opacus_cifar10", "maml", # segfault: Internal Triton PTX codegen error "timm_efficientdet", diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 838e7a1402330..3958e1dfc503e 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -803,7 +803,7 @@ class CachingAllocatorConfig { } static bool expandable_segments() { -#ifndef EXPANDABLE_SEGMENTS_SUPPORTED +#ifndef PYTORCH_EXPANDABLE_SEGMENTS_SUPPORTED if (instance().m_expandable_segments) { TORCH_WARN_ONCE("expandable_segments not supported on this platform") } diff --git a/test/inductor/indirect_assert_helper.py b/test/inductor/indirect_assert_helper.py new file mode 100644 index 0000000000000..272f90f855bbc --- /dev/null +++ b/test/inductor/indirect_assert_helper.py @@ -0,0 +1,43 @@ +import sys + +import torch + + +def first_arg(x, y): + return x[y] + + +def second_arg(x, y): + return x[:, y] + + +def same_pm_one(x, y): + return x[y + 1, y - 1] + + +def same_pp_one(x, y): + return x[y + 1, y + 1] + + +def store(x, y, z): + x[y + 1, y + 1] = z + + +if __name__ == "__main__": + _, fn_name, dims, dyn_shape = sys.argv + assert fn_name in ("first_arg", "second_arg", "same_pm_one", "same_pp_one", "store") + assert dims in ("2", "3") + shape_x = (3, 2, 4) if dims == "3" else (3, 2) + assert dyn_shape in ("True", "False") + dynamic_shapes = dyn_shape == "True" + + x = torch.randn(shape_x, device="cuda") + y = torch.arange(4, device="cuda") + fn = vars()[fn_name] + fn = torch.compile(dynamic=dynamic_shapes)(fn) + if fn_name == "store": + shape = (y.numel(),) + x.shape[2:] + z = torch.randn(shape, device="cuda") + fn(x, y, z) + else: + fn(x, y) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 2eec07a247b91..318653ef5553d 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8,6 +8,7 @@ import math import os import random +import subprocess import sys import time import typing @@ -6123,6 +6124,26 @@ def fn(x: torch.Tensor) -> torch.Tensor: inps = torch.randn([5, 5]) fn_opt(inps) + def test_indirect_device_assert(self): + dir_path = os.path.dirname(os.path.realpath(__file__)) + test_path = os.path.join(dir_path, "indirect_assert_helper.py") + fns = ("first_arg", "store", "second_arg", "same_pm_one", "same_pp_one") + + for fn, ndims, dyn_shape in itertools.product(fns, (2, 3), (True, False)): + proc = subprocess.Popen( + [sys.executable, test_path, fn, str(ndims), str(dyn_shape)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stderr = proc.communicate()[1] + self.assertTrue( + any( + "index out of bounds" in err.decode("utf-8") + for err in stderr.splitlines() + ), + f"{fn}, {ndims}, {dyn_shape}", + ) + if HAS_CUDA and not TEST_WITH_ASAN: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 60c613f685d87..8cef5c4abd396 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -601,20 +601,26 @@ def generate( buffer: IndentedBuffer, expr: typing.Union[str, CSEVariable], write=True, + assignment=True, ) -> CSEVariable: assert isinstance(expr, (str, CSEVariable)), type(expr) + assert write or assignment if isinstance(expr, CSEVariable): return expr cache_key = expr if cache_key not in self.cache: - var = self.newvar() + var = self.newvar() if assignment else None self.cache[cache_key] = var if write: if V.kernel.current_node: V.kernel.current_node.codegen_originating_info( buffer, only_once=True ) - buffer.writeline(f"{self.prefix}{var} = {expr}{self.suffix}") + if assignment: + line = f"{self.prefix}{var} = {expr}{self.suffix}" + else: + line = f"{expr}{self.suffix}" + buffer.writeline(line) return self.cache[cache_key] @@ -721,7 +727,7 @@ def inner(*args, **kwargs): return inner @staticmethod - def indirect_indexing(index_var): + def indirect_indexing(index_var, size): return sympy_symbol(str(index_var)) @staticmethod diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index d8f16a6dc1ab7..0e36f3b5e46df 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1917,7 +1917,7 @@ def indexing_div_rep(x, y): return tmp_var @staticmethod - def indirect_indexing(index_var): + def indirect_indexing(index_var, size): return sympy.Symbol(str(index_var)) @staticmethod diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index db9acebb09c44..01492c207411a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1016,6 +1016,46 @@ def mask_loads(self, mask): finally: self._load_mask = prior + def gen_assert_indirect_indexing(self, buffer, original_index, mask): + if mask == "None": + return + body = self.current_node._body + indirect_size = dict(zip(body.indirect_vars, body.indirect_max_sizes)) + indirect_name = body.indirect_new + # Many indirect variables may be mapped to the same CSE'd variable + # For example when you do x[y, y] for x = randn(3, 8) + var_size = collections.defaultdict(set) + for ind, size in indirect_size.items(): + var_size[indirect_name[ind]].add(V.kernel.rename_indexing(size)) + + indirect_vars = [ + s for s in original_index.free_symbols if s.name.startswith("tmp") + ] + for var in indirect_vars: + sizes = list(var_size[var]) + if all(isinstance(s, sympy.Integer) for s in sizes): + size = min(sizes) + else: + # Should this go here or in TritonPrinter? + def print_min(expr): + if len(expr) == 1: + return texpr(expr[0]) + else: + return f"min({texpr(expr[0])}, {print_min(expr[1:])})" + + size = print_min(sizes) + # The conditions need to be in parens because of Python's operator precedence. + # It'd be less # error-prone to use and/or/not, which is suported by triton + cond = f"((0 <= {var}) & ({var} < {size}))" + cond_print = f"0 <= {var} < {size}" + if not isinstance(original_index, sympy.Integer): + var_mask = f"({mask})" if "&" in mask else mask + var_mask = f" | ~{var_mask}" + else: + var_mask = "" + line = f'tl.device_assert(({cond}){var_mask}, "index out of bounds: {cond_print}")' + self.cse.generate(buffer, line, assignment=False) + def load(self, name: str, index: sympy.Expr): var = self.args.input(name) indirect_indexing = self.is_indirect_indexing(index) @@ -1065,6 +1105,10 @@ def load(self, name: str, index: sympy.Expr): else: load_buffer = self.loads + # Assert that the loaded indices will not read garbage + if indirect_indexing and config.triton.assert_indirect_indexing: + self.gen_assert_indirect_indexing(load_buffer, original_index, mask) + result_var = self.cse.generate(load_buffer, line) result_var.mask_vars = mask_vars @@ -1079,7 +1123,13 @@ def load(self, name: str, index: sympy.Expr): def store(self, name, index, value, mode=None): var = self.args.output(name) + indirect_indexing = self.is_indirect_indexing(index) + original_index = index index, mask_vars, mask, expand_str = self.indexing(index, dense_indexing=True) + + if indirect_indexing and config.triton.assert_indirect_indexing: + self.gen_assert_indirect_indexing(self.stores, original_index, mask) + if mode is None: line = f"tl.store({var} + ({index}), {value}, {mask})" elif mode == "atomic_add": diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 7299decfbc2fc..c37bdc836d416 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -231,6 +231,9 @@ class triton: tiling_prevents_pointwise_fusion = True tiling_prevents_reduction_fusion = True + # assert that indirect indexing does not read / write out of bounds + assert_indirect_indexing = True + # should we give different names to kernels # Note: This is orthogonal to descriptive_names - this is deciding whether # our triton kernel names should all be `triton_` (to maximize caching) or diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 74a532e1f71ce..b6a26359d2624 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3858,6 +3858,8 @@ def __init__(self, fn, args, var_ranges): self.submodules = {"get_index": self.get_index} self.subblocks = {} self.indirect_vars = [] + self.indirect_max_sizes = [] + self.indirect_new = {} self.root_block = LoopBodyBlock(self, fn, args) self.indexing = None @@ -3893,16 +3895,18 @@ def add_submodule(self, block, prefix): self.submodules[name] = block return name - def add_indirect(self): + def add_indirect(self, size): name = f"indirect{len(self.indirect_vars)}" var = sympy_symbol(name) self.indirect_vars.append(var) + self.indirect_max_sizes.append(size) return var def replace_indirect(self, old, new): """Swap in a variable used in indirect indexing""" if str(old) == str(new): return + self.indirect_new[old] = new self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} def get_index(self, name): @@ -3981,16 +3985,18 @@ def shim(mask, other): ) @staticmethod - def indirect_indexing(index_proxy): + def indirect_indexing(index_proxy, size): """ Flow data from tensors into indexing formulas. Introduce a call_module to update the indexing. """ def set_indirect(new_var): - self.body.replace_indirect(var, V.ops.indirect_indexing(new_var)) + self.body.replace_indirect( + var, V.ops.indirect_indexing(new_var, size) + ) - var = self.body.add_indirect() + var = self.body.add_indirect(size) tracer.create_proxy( "call_module", self.body.add_submodule(set_indirect, f"set_{var}"), diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 854dfb021a181..0c0df7da11fcb 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1882,7 +1882,8 @@ def gather(x, dim, index, sparse_grad=False): # and backward tracing is taken care of by AOT Autograd assert isinstance(x, TensorBox) assert index.get_dtype() == torch.int64 - offset = len(x.get_size()) == 0 + size = x.get_size() + offset = len(size) == 0 dim = _validate_dim(x, dim, offset) x_loader = x.make_loader() @@ -1891,7 +1892,7 @@ def gather(x, dim, index, sparse_grad=False): def fn(idx): idx = list(idx) if len(idx) != 0: - idx[dim] = ops.indirect_indexing(index_loader(idx)) + idx[dim] = ops.indirect_indexing(index_loader(idx), size[dim]) return x_loader(idx) return Pointwise.create( @@ -1912,12 +1913,15 @@ def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse= weight_loader = weight.make_loader() indices_loader = indices.make_loader() indices_ndim = len(indices.get_size()) - new_size = [*indices.get_size(), *weight.get_size()[1:]] + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] def fn(idx): assert len(idx) == len(new_size), f"{idx} != {new_size}" var_index = indices_loader(idx[:indices_ndim]) - weight_idx = [ops.indirect_indexing(var_index)] + [*idx[indices_ndim:]] + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] return weight_loader(weight_idx) return Pointwise.create( @@ -1997,9 +2001,10 @@ def index(x, indices): def fn(idx): assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) new_index = [ - ops.indirect_indexing(loader(idx[start_offset:end_offset])) - for loader in indices_loaders + ops.indirect_indexing(loader(idx[start_offset:end_offset]), size) + for loader, size in zip(indices_loaders, indexed_size) ] new_index = [*idx[:start_offset], *new_index, *idx[end_offset:]] return x_loader(new_index) @@ -2097,6 +2102,7 @@ def index_put_(self, indices, values, accumulate=False): *output_size, *x_size[start_offset + len(indices_sizes) :], ] + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] values = expand(values, expected_vals_size) # all guards are set above during broadcast_tensors and expand @@ -2104,8 +2110,8 @@ def index_put_(self, indices, values, accumulate=False): def output_indexer(index): assert len(index) == len(expected_vals_size) new_index = [ - ops.indirect_indexing(loader(index[start_offset:end_offset])) - for loader in indices_loaders + ops.indirect_indexing(loader(index[start_offset:end_offset]), size) + for loader, size in zip(indices_loaders, indexed_size) ] new_index = [*index[:start_offset], *new_index, *index[end_offset:]] return new_index @@ -2224,7 +2230,7 @@ def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = if isinstance(index, TensorBox) and len(index.get_size()) == 0: index = view(index, [1]) - assert -len(self.get_size()) <= dim < len(self.get_size()) + dim = _validate_dim(self, dim) self.realize() V.graph.realize_users_of(self.get_name()) @@ -2232,8 +2238,13 @@ def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = src_loader = src.make_loader() if isinstance(src, TensorBox) else None def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) indirect_idx = list(idx) - indirect_idx[dim] = ops.indirect_indexing(index_loader(idx)) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim] + ) return indirect_idx def fn(idx): @@ -2306,16 +2317,18 @@ def upsample_nearestnd(x, output_size, scales_x: Tuple[float] = None, n: int = 2 if scale: scales[i] = scale - def scale(x, scale): + def scale(x, scale, size): x = ops.index_expr(x, torch.float32) x = ops.mul(x, ops.constant(scale, torch.float32)) x = ops.to_dtype(x, torch.int32) - return ops.indirect_indexing(x) + return ops.indirect_indexing(x, size) def fn(idx): x = idx[-n:] b = idx[:-n] - return x_loader([*b, *[scale(i, s) for i, s in zip(x, scales)]]) + return x_loader( + [*b, *[scale(i, s, size) for i, s, size in zip(x, scales, i_sizes)]] + ) return Pointwise.create( device=x.get_device(), @@ -2439,8 +2452,9 @@ def fn(idx): t_y = ops.sub(real_y, in_y) def load_bounded(fy, fx): - iy = ops.indirect_indexing(clamp(fy, 0, iH - 1)) - ix = ops.indirect_indexing(clamp(fx, 0, iW - 1)) + # TODO(Lezcano) Here we may not need to set-up a device_size + iy = ops.indirect_indexing(clamp(fy, 0, iH - 1), iH) + ix = ops.indirect_indexing(clamp(fx, 0, iW - 1), iW) return x_loader([n, c, iy, ix]) iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) @@ -2474,11 +2488,12 @@ def reflection_pad2d(x, padding): w = V.graph.sizevars.guard_static_shape(w) def reflect(x, size, offset): + size_num = size size = ops.constant(size - 1, torch.int32) x = ops.index_expr(x, torch.int32) x = ops.sub(x, ops.constant(offset, torch.int32)) x = ops.sub(size, ops.abs(ops.sub(size, ops.abs(x)))) - return ops.indirect_indexing(x) + return ops.indirect_indexing(x, size_num) def fn(idx): *b, x, y = idx @@ -2503,13 +2518,14 @@ def reflection_pad2d_backward(grad_output, x, padding): h = V.graph.sizevars.guard_static_shape(h) - 1 w = V.graph.sizevars.guard_static_shape(w) - 1 grad_loader = grad_output.make_loader() + *_, h_grad, w_grad = grad_output.get_size() def fn(idx): *b, x, y = idx def load_from_output(x, y): - x = ops.indirect_indexing(ops.index_expr(x, torch.int32)) - y = ops.indirect_indexing(ops.index_expr(y, torch.int32)) + x = ops.indirect_indexing(ops.index_expr(x, torch.int32), h_grad) + y = ops.indirect_indexing(ops.index_expr(y, torch.int32), w_grad) return grad_loader([*b, x, y]) def index_range_condition(index_range): @@ -2875,6 +2891,8 @@ def max_pool2d_with_indices_backward( grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices ) + indices_size = indices.get_size() + def fn(idx): *prefix, h, w = idx index_test = ops.index_expr(h * width + w, torch.int32) @@ -2904,12 +2922,14 @@ def fn(idx): ops.indirect_indexing( ops.int_minimum( ph, ops.sub(phend, ops.constant(1, torch.int32)) - ) + ), + indices_size[-2], ), ops.indirect_indexing( ops.int_minimum( pw, ops.sub(pwend, ops.constant(1, torch.int32)) - ) + ), + indices_size[-1], ), ] @@ -3346,12 +3366,14 @@ def fn(idx): ops.indirect_indexing( ops.int_minimum( ph, ops.sub(phend, ops.constant(1, torch.int32)) - ) + ), + pooled_height, ), ops.indirect_indexing( ops.int_minimum( pw, ops.sub(pwend, ops.constant(1, torch.int32)) - ) + ), + pooled_width, ), ] ), diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index 15e3e07cefd6d..4c401a6d40fc4 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -84,6 +84,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only_with_cc: int): compile_meta["constants"][self.fn.arg_names.index(k)] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages + compile_meta["debug"] = config.triton.assert_indirect_indexing if warm_cache_only_with_cc: triton.compile( self.fn, diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 006d6c85fa470..d614b43c6a954 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -76,7 +76,7 @@ def masked(mask, body, other): return f"ops.masked({mask}, {body()}, {other})" @staticmethod - def indirect_indexing(index_var): + def indirect_indexing(index_var, size): return sympy_symbol(f"({str(index_var)})") @classmethod diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py index 702ebe1857c36..3bfae39e49eb0 100644 --- a/torch/distributed/_spmd/distribute.py +++ b/torch/distributed/_spmd/distribute.py @@ -46,7 +46,7 @@ maybe_disable_fake_tensor_mode, proxy_slot, ) -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import tree_flatten, tree_map, tree_map_only, tree_unflatten # patch aot_function so that we can pass the full (non-sharded) input to capture the graph # pyre-fixme @@ -86,11 +86,12 @@ def is_shard(self) -> bool: @classmethod def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt": + dim: int = 0 if node.target == aten.sym_size: - dim: int = cast(int, node.args[1]) + dim = cast(int, node.args[1]) return cls( - global_value=dtensor.size(cast(int, node.args[1])), - local_value=dtensor.to_local().size(cast(int, node.args[1])), + global_value=dtensor.size(dim), + local_value=dtensor.to_local().size(dim), mesh=dtensor.device_mesh, ) elif node.target == aten.sym_numel: @@ -100,7 +101,7 @@ def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt": mesh=dtensor.device_mesh, ) elif node.target == aten.sym_stride: - dim: int = cast(int, node.args[1]) # type: ignore[no-redef] + dim = cast(int, node.args[1]) return cls( global_value=dtensor.stride(dim), local_value=dtensor.to_local().stride(dim), @@ -714,9 +715,7 @@ def _convert_to_distributed( """ global logger logger = get_logger("spmd_exp") - operators = { - getattr(operator, name) for name in dir(operator) if not name.startswith("_") - } + operators = {getattr(operator, name) for name in operator.__all__} node_to_obj: Dict[fx.Node, Any] = {} # map local op node in traced_f to its corresponding subgraph of # DTensor ops. @@ -781,19 +780,11 @@ def _convert_to_distributed( dsymints[0].mesh == d.mesh for d in dsymints ), "all DSymInts must have the same mesh. " - local_args = tree_map( - lambda a: a.local_value if isinstance(a, DSymInt) else a, args - ) - local_kwargs = tree_map( - lambda a: a.local_value if isinstance(a, DSymInt) else a, kwargs - ) + local_args = tree_map_only(DSymInt, lambda a: a.local_value, args) + local_kwargs = tree_map_only(DSymInt, lambda a: a.local_value, kwargs) - global_args = tree_map( - lambda a: a.global_value if isinstance(a, DSymInt) else a, args - ) - global_kwargs = tree_map( - lambda a: a.global_value if isinstance(a, DSymInt) else a, kwargs - ) + global_args = tree_map_only(DSymInt, lambda a: a.global_value, args) + global_kwargs = tree_map_only(DSymInt, lambda a: a.global_value, kwargs) node.args = local_args node.kwargs = local_kwargs diff --git a/torch/distributed/_spmd/experimental_ops.py b/torch/distributed/_spmd/experimental_ops.py index 0e39bece45d92..66b9803427e80 100644 --- a/torch/distributed/_spmd/experimental_ops.py +++ b/torch/distributed/_spmd/experimental_ops.py @@ -258,7 +258,7 @@ def _prop_select(op_schema: OpSchema) -> OutputSharding: # if they are larger than dim. new_placements: List[Placement] = [] for p in placements: - # Using isinstance instead of is_shard to so that mypy won't complain + # Using isinstance instead of is_shard so that mypy won't complain # about accessing dim attribute. if isinstance(p, Shard) and p.dim > dim: new_placements.append(Shard(p.dim - 1))