Skip to content

Commit

Permalink
[inductor] decomposition for complex addition (pytorch#110740)
Browse files Browse the repository at this point in the history
Tracks pytorch#98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
```

Pull Request resolved: pytorch#110740
Approved by: https://github.com/jansel
  • Loading branch information
htyu authored and xuhancn committed Nov 8, 2023
1 parent 0e8fec6 commit 9877b2b
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 10 deletions.
9 changes: 9 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,15 @@ def fn(x, y):

self.common(fn, (x, y))

def test_add_complex(self):
def fn(a, b, alpha):
return torch.add(a, b, alpha=alpha)

x = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])

self.common(fn, (x, y, 2))

def test_concat_add_inplace(self):
def fn(x, y, z):
return torch.cat([x, y], dim=1).add_(z)
Expand Down
13 changes: 13 additions & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,19 @@ def angle(x):
return ret + nan


@register_decomposition([aten.add])
def add(x, y, *, alpha=None):
x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
if not x_is_complex_tensor or not y_is_complex_tensor:
return NotImplemented
z = y
if alpha is not None:
z = alpha * y
complex_type = torch.promote_types(x.dtype, y.dtype)
return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type)


@register_decomposition([aten.conj_physical])
def conj_physical(self):
assert not self.is_complex(), "TODO: implement this"
Expand Down
66 changes: 64 additions & 2 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def _check_tensorbox(nodes):
sympy.Symbol,
sympy.logic.boolalg.Boolean,
Expr,
torch._inductor.ir.ExpandView,
),
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"

Expand Down Expand Up @@ -4390,6 +4389,69 @@ def apply_constraint(self):
return super().apply_constraint()


@dataclasses.dataclass
class ComplexView(ExternKernelAlloc):
"""View a complex number as two dtyped numbers or vice versa"""

def should_allocate(self):
return False

def has_aliasing(self):
return True

def get_alias_names(self):
# Signal to codegen that our output buffer isn't safe to reuse
return [self.inputs[0].get_name()]

def __init__(
self,
layout,
kernel,
tensor_args,
nontensor_args,
):
super().__init__(
layout,
tuple(tensor_args),
tuple(nontensor_args),
)
# We need output buffers for generating kernel arguments in the
# abi-compatible mode, where we retrieve outputs by pass each individual
# output through the abi-compatible interface.
self.outputs: Sequence[Any] = []
self.kernel = kernel

@classmethod
def create(cls, kernel, *args, **kwargs):
context = V.graph.fake_mode
with context:
(
example_output,
tensor_args,
non_tensor_args,
unflatten_args,
schema,
) = cls.process_kernel(kernel, *args, **kwargs)

device = FallbackKernel.find_device(tensor_args, example_output)
assert device, "Not sure where to find device info"

packed = ComplexView(
MultiOutputLayout(device), kernel, tensor_args, non_tensor_args
)

layout = FixedLayout(
example_output.device,
example_output.dtype,
convert_shape_to_inductor(example_output.size()),
convert_shape_to_inductor(example_output.stride()),
)
outputs = MultiOutput(layout, packed, [])

packed.outputs = [outputs]
return outputs


@dataclasses.dataclass
class MultiOutputLayout(IRNode):
device: torch.device
Expand Down Expand Up @@ -4437,7 +4499,7 @@ def should_allocate(self):

def has_aliasing(self):
return any(
isinstance(inp, FallbackKernel) and inp.has_aliasing()
isinstance(inp, (FallbackKernel, ComplexView)) and inp.has_aliasing()
for inp in self.inputs
)

Expand Down
23 changes: 15 additions & 8 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,10 @@ def _to_dtype_bitcast(x):

@register_lowering(aten.view.dtype, type_promotion_kind=None)
def _view_dtype(x: TensorBox, dtype: torch.dtype):
if dtype.is_complex or x.get_dtype().is_complex:
return TensorBox.create(
ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype)
)
return to_dtype_bitcast(x, dtype, copy=True)


Expand Down Expand Up @@ -1602,17 +1606,20 @@ def _warn_complex_not_supported():

# There are some types (CPU) which we accept as input but not as
# output.
def unsupported_input_tensor(t: torch._subclasses.FakeTensor):
def unsupported_input_tensor(t: torch._subclasses.FakeTensor, parent=None):
"Do not support reading or writing to this tensor"
if t.is_complex():
# Complex views are supported with IR ComplexView
if parent and parent.target == torch.ops.aten.view.dtype:
return False
_warn_complex_not_supported()
return True
return False


def unsupported_output_tensor(t: torch._subclasses.FakeTensor):
def unsupported_output_tensor(t: torch._subclasses.FakeTensor, parent=None):
"Do not support writing tensor but can read from it"
if unsupported_input_tensor(t):
if unsupported_input_tensor(t, parent):
return True
return t.is_cpu and config.disable_cpp_codegen

Expand All @@ -1626,7 +1633,7 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
if node.target is aten.lift_fresh_copy.default:
return False

def check_skip_condition(node, is_output):
def check_skip_condition(node, parent, is_output):
if not isinstance(node, torch.fx.Node):
return False

Expand All @@ -1638,20 +1645,20 @@ def check_skip_condition(node, is_output):
continue

if is_output:
if unsupported_output_tensor(meta):
if unsupported_output_tensor(meta, parent):
return True
else:
if unsupported_input_tensor(meta):
if unsupported_input_tensor(meta, parent):
return True

return False

# only skip codegen if there is a cpu output, not input
for arg in tree_flatten((node.args, node.kwargs))[0]:
if check_skip_condition(arg, is_output=False):
if check_skip_condition(arg, node, is_output=False):
return True

return check_skip_condition(node, is_output=True)
return check_skip_condition(node, node, is_output=True)


def make_fallback(op, layout_constraint=None, warn=True):
Expand Down

0 comments on commit 9877b2b

Please sign in to comment.