From d42211c66355b415f6643d97ea450ed6b1075e41 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 4 Sep 2023 10:35:04 -0300 Subject: [PATCH 1/7] Fix inductor `sub` with symbolic integers. Fix: #108159 [ghstack-poisoned] --- .../inductor/test_torchinductor_dynamic_shapes.py | 2 -- torch/_inductor/lowering.py | 15 ++++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 8dcb24428f107..39a3a736c730c 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -257,8 +257,6 @@ def div(x): test(div) @onlyCPU - @unittest.expectedFailure - # Ref: https://github.com/pytorch/pytorch/issues/108159 def test_sub_constant_folding(self, device): def sub(x): return x - torch.zeros(3) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 3caab47287fef..d3b2eb608b0e2 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -377,12 +377,17 @@ def inner(*inputs: List[TensorBox], alpha=None): inputs[-1] = mul(inputs[-1], alpha) else: assert alpha is None + + tensor_inputs = [inp for inp in inputs if isinstance(inp, TensorBox)] + assert len(tensor_inputs) > 0, f"expected at least one tensor. Got: {[type(inp) for inp in inputs]}" + ref = tensor_inputs[0] + loaders = [x.make_loader() for x in inputs] - ranges = inputs[0].get_size() - dtype = override_return_dtype or inputs[0].get_dtype() - is_cuda = decode_device(inputs[0].get_device()).type == "cuda" + ranges = ref.get_size() + dtype = override_return_dtype or ref.get_dtype() + is_cuda = decode_device(ref.get_device()).type == "cuda" - for other in inputs[1:]: + for other in inputs: assert isinstance(other, ir.BaseConstant) or len(ranges) == len( other.get_size() ), f"ndim mismatch {fn} {ranges} {other.get_size()}" @@ -403,7 +408,7 @@ def inner_fn(index): device = i.get_device() break if not device: - device = inputs[0].get_device() + device = ref.get_device() device = override_device or device From cbd570d514b020a3d2a8dcccc9eb9c541e50664a Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 4 Sep 2023 10:50:19 -0300 Subject: [PATCH 2/7] Add comment. on "Fix inductor `sub` with symbolic integers." Fix: #108159 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned] --- torch/_inductor/lowering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d3b2eb608b0e2..ec8fc351182e5 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -378,6 +378,8 @@ def inner(*inputs: List[TensorBox], alpha=None): else: assert alpha is None + # Get the first TensorBox input as the reference tensor for + # sizes, data types, and device. tensor_inputs = [inp for inp in inputs if isinstance(inp, TensorBox)] assert len(tensor_inputs) > 0, f"expected at least one tensor. Got: {[type(inp) for inp in inputs]}" ref = tensor_inputs[0] From fbe751ae9ac27e05e76c2dd5c6677f055aab50a4 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 4 Sep 2023 11:26:52 -0300 Subject: [PATCH 3/7] Add support to ExpandView. on "Fix inductor `sub` with symbolic integers." Fix: #108159 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned] --- torch/_inductor/lowering.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index ec8fc351182e5..546afa1d38483 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -378,10 +378,10 @@ def inner(*inputs: List[TensorBox], alpha=None): else: assert alpha is None - # Get the first TensorBox input as the reference tensor for + # Get the first TensorBox/ExpandView input as the reference tensor for # sizes, data types, and device. - tensor_inputs = [inp for inp in inputs if isinstance(inp, TensorBox)] - assert len(tensor_inputs) > 0, f"expected at least one tensor. Got: {[type(inp) for inp in inputs]}" + tensor_inputs = [inp for inp in inputs if isinstance(inp, (TensorBox, ExpandView))] + assert len(tensor_inputs) > 0, f"expected at least one tensor/expanded view. Got: {[type(inp) for inp in inputs]}" ref = tensor_inputs[0] loaders = [x.make_loader() for x in inputs] From 3b21c3360e29ac633a5db1aa35d677b937f8c1ba Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 4 Sep 2023 11:28:11 -0300 Subject: [PATCH 4/7] Fix lint issues. on "Fix inductor `sub` with symbolic integers." Fix: #108159 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned] --- torch/_inductor/lowering.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 546afa1d38483..156a61d18a907 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -380,8 +380,12 @@ def inner(*inputs: List[TensorBox], alpha=None): # Get the first TensorBox/ExpandView input as the reference tensor for # sizes, data types, and device. - tensor_inputs = [inp for inp in inputs if isinstance(inp, (TensorBox, ExpandView))] - assert len(tensor_inputs) > 0, f"expected at least one tensor/expanded view. Got: {[type(inp) for inp in inputs]}" + tensor_inputs = [ + inp for inp in inputs if isinstance(inp, (TensorBox, ExpandView)) + ] + assert ( + len(tensor_inputs) > 0 + ), f"expected at least one tensor/expanded view. Got: {[type(inp) for inp in inputs]}" ref = tensor_inputs[0] loaders = [x.make_loader() for x in inputs] From f16d94556bbf0fc49554fad423ee1129cda8d554 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 4 Sep 2023 14:28:21 -0300 Subject: [PATCH 5/7] Update on "Fix inductor `sub` with symbolic integers." Fix: #108159 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned] --- torch/_inductor/lowering.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 156a61d18a907..3548fbd43c4ee 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -383,10 +383,8 @@ def inner(*inputs: List[TensorBox], alpha=None): tensor_inputs = [ inp for inp in inputs if isinstance(inp, (TensorBox, ExpandView)) ] - assert ( - len(tensor_inputs) > 0 - ), f"expected at least one tensor/expanded view. Got: {[type(inp) for inp in inputs]}" - ref = tensor_inputs[0] + # Use the first tensor found. Otherwise, fallback to using the first argument. + ref = tensor_inputs[0] if len(tensor_inputs) > 0 else inputs[0] loaders = [x.make_loader() for x in inputs] ranges = ref.get_size() From 9ed4fec58e617da1944c8f6e6c31518e51c5810b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 5 Sep 2023 17:43:22 -0300 Subject: [PATCH 6/7] Update on "Fix inductor `sub` with symbolic integers." Fix: #108159 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned] --- torch/_inductor/index_propagation.py | 4 +++- torch/_inductor/lowering.py | 23 ++++++----------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index bbadb024347bd..4fb2eaad64ea9 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -51,7 +51,9 @@ def identity(value: Any) -> Any: return value @staticmethod - def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: + def constant(value: Union[int, float, bool, sympy.Expr], dtype: torch.dtype) -> TypedExpr: + if isinstance(value, sympy.Expr): + return TypedExpr(value, dtype) if is_boolean_dtype(dtype): expr = sympy.Integer(bool(value)) elif is_integer_dtype(dtype): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 3548fbd43c4ee..506faeed81a11 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -347,14 +347,12 @@ def const_func(x): ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView))) out = [] for x in inputs: - if isinstance(x, (int, float)): + if isinstance(x, (int, float, sympy.Expr)): out.append( ExpandView.create( ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) ) ) - elif isinstance(x, sympy.Expr): - out.append(IndexingConstant(x, ex.get_dtype(), ex.get_device())) else: out.append(x) @@ -377,21 +375,12 @@ def inner(*inputs: List[TensorBox], alpha=None): inputs[-1] = mul(inputs[-1], alpha) else: assert alpha is None - - # Get the first TensorBox/ExpandView input as the reference tensor for - # sizes, data types, and device. - tensor_inputs = [ - inp for inp in inputs if isinstance(inp, (TensorBox, ExpandView)) - ] - # Use the first tensor found. Otherwise, fallback to using the first argument. - ref = tensor_inputs[0] if len(tensor_inputs) > 0 else inputs[0] - loaders = [x.make_loader() for x in inputs] - ranges = ref.get_size() - dtype = override_return_dtype or ref.get_dtype() - is_cuda = decode_device(ref.get_device()).type == "cuda" + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + is_cuda = decode_device(inputs[0].get_device()).type == "cuda" - for other in inputs: + for other in inputs[1:]: assert isinstance(other, ir.BaseConstant) or len(ranges) == len( other.get_size() ), f"ndim mismatch {fn} {ranges} {other.get_size()}" @@ -412,7 +401,7 @@ def inner_fn(index): device = i.get_device() break if not device: - device = ref.get_device() + device = inputs[0].get_device() device = override_device or device From 0dbfb94fa0666d37758fe8ac353c26727922d03b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 6 Sep 2023 09:56:41 -0300 Subject: [PATCH 7/7] Update on "Fix inductor `sub` with symbolic integers." Fix: #108159 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned] --- torch/_inductor/index_propagation.py | 4 +--- torch/_inductor/lowering.py | 9 ++++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/index_propagation.py b/torch/_inductor/index_propagation.py index 4fb2eaad64ea9..bbadb024347bd 100644 --- a/torch/_inductor/index_propagation.py +++ b/torch/_inductor/index_propagation.py @@ -51,9 +51,7 @@ def identity(value: Any) -> Any: return value @staticmethod - def constant(value: Union[int, float, bool, sympy.Expr], dtype: torch.dtype) -> TypedExpr: - if isinstance(value, sympy.Expr): - return TypedExpr(value, dtype) + def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: if is_boolean_dtype(dtype): expr = sympy.Integer(bool(value)) elif is_integer_dtype(dtype): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 506faeed81a11..8b141f0f3c24d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -347,12 +347,19 @@ def const_func(x): ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView))) out = [] for x in inputs: - if isinstance(x, (int, float, sympy.Expr)): + if isinstance(x, (int, float)): out.append( ExpandView.create( ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) ) ) + elif isinstance(x, sympy.Expr): + out.append( + ExpandView.create( + IndexingConstant(x, ex.get_dtype(), ex.get_device()), + list(ex.get_size()), + ) + ) else: out.append(x)