diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3849c3843e7c0..cc99b3dde0065 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -611,6 +611,15 @@ def fn(x, y): self.common(fn, (x, y)) + def test_add_complex(self): + def fn(a, b, alpha): + return torch.add(a, b, alpha=alpha) + + x = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1]) + y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1]) + + self.common(fn, (x, y, 2)) + def test_concat_add_inplace(self): def fn(x, y, z): return torch.cat([x, y], dim=1).add_(z) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 875c39efcbffa..eb0d394f204fc 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -269,6 +269,19 @@ def angle(x): return ret + nan +@register_decomposition([aten.add]) +def add(x, y, *, alpha=None): + x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() + y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() + if not x_is_complex_tensor or not y_is_complex_tensor: + return NotImplemented + z = y + if alpha is not None: + z = alpha * y + complex_type = torch.promote_types(x.dtype, y.dtype) + return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type) + + @register_decomposition([aten.conj_physical]) def conj_physical(self): assert not self.is_complex(), "TODO: implement this" diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 9401bd6ea6fe5..7972c5bc9a89f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -132,7 +132,6 @@ def _check_tensorbox(nodes): sympy.Symbol, sympy.logic.boolalg.Boolean, Expr, - torch._inductor.ir.ExpandView, ), ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" @@ -4390,6 +4389,69 @@ def apply_constraint(self): return super().apply_constraint() +@dataclasses.dataclass +class ComplexView(ExternKernelAlloc): + """View a complex number as two dtyped numbers or vice versa""" + + def should_allocate(self): + return False + + def has_aliasing(self): + return True + + def get_alias_names(self): + # Signal to codegen that our output buffer isn't safe to reuse + return [self.inputs[0].get_name()] + + def __init__( + self, + layout, + kernel, + tensor_args, + nontensor_args, + ): + super().__init__( + layout, + tuple(tensor_args), + tuple(nontensor_args), + ) + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs: Sequence[Any] = [] + self.kernel = kernel + + @classmethod + def create(cls, kernel, *args, **kwargs): + context = V.graph.fake_mode + with context: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + schema, + ) = cls.process_kernel(kernel, *args, **kwargs) + + device = FallbackKernel.find_device(tensor_args, example_output) + assert device, "Not sure where to find device info" + + packed = ComplexView( + MultiOutputLayout(device), kernel, tensor_args, non_tensor_args + ) + + layout = FixedLayout( + example_output.device, + example_output.dtype, + convert_shape_to_inductor(example_output.size()), + convert_shape_to_inductor(example_output.stride()), + ) + outputs = MultiOutput(layout, packed, []) + + packed.outputs = [outputs] + return outputs + + @dataclasses.dataclass class MultiOutputLayout(IRNode): device: torch.device @@ -4437,7 +4499,7 @@ def should_allocate(self): def has_aliasing(self): return any( - isinstance(inp, FallbackKernel) and inp.has_aliasing() + isinstance(inp, (FallbackKernel, ComplexView)) and inp.has_aliasing() for inp in self.inputs ) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 972bfeb255cde..621710df1b7e1 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -543,6 +543,10 @@ def _to_dtype_bitcast(x): @register_lowering(aten.view.dtype, type_promotion_kind=None) def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) return to_dtype_bitcast(x, dtype, copy=True) @@ -1602,17 +1606,20 @@ def _warn_complex_not_supported(): # There are some types (CPU) which we accept as input but not as # output. -def unsupported_input_tensor(t: torch._subclasses.FakeTensor): +def unsupported_input_tensor(t: torch._subclasses.FakeTensor, parent=None): "Do not support reading or writing to this tensor" if t.is_complex(): + # Complex views are supported with IR ComplexView + if parent and parent.target == torch.ops.aten.view.dtype: + return False _warn_complex_not_supported() return True return False -def unsupported_output_tensor(t: torch._subclasses.FakeTensor): +def unsupported_output_tensor(t: torch._subclasses.FakeTensor, parent=None): "Do not support writing tensor but can read from it" - if unsupported_input_tensor(t): + if unsupported_input_tensor(t, parent): return True return t.is_cpu and config.disable_cpp_codegen @@ -1626,7 +1633,7 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs= if node.target is aten.lift_fresh_copy.default: return False - def check_skip_condition(node, is_output): + def check_skip_condition(node, parent, is_output): if not isinstance(node, torch.fx.Node): return False @@ -1638,20 +1645,20 @@ def check_skip_condition(node, is_output): continue if is_output: - if unsupported_output_tensor(meta): + if unsupported_output_tensor(meta, parent): return True else: - if unsupported_input_tensor(meta): + if unsupported_input_tensor(meta, parent): return True return False # only skip codegen if there is a cpu output, not input for arg in tree_flatten((node.args, node.kwargs))[0]: - if check_skip_condition(arg, is_output=False): + if check_skip_condition(arg, node, is_output=False): return True - return check_skip_condition(node, is_output=True) + return check_skip_condition(node, node, is_output=True) def make_fallback(op, layout_constraint=None, warn=True):