Skip to content

Commit

Permalink
Decompose real,imag and complex to allow for a single kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu committed Oct 24, 2023
1 parent c91f9b2 commit 63bcb11
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 20 deletions.
24 changes: 10 additions & 14 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,15 @@ def fn(x, y):

self.common(fn, (x, y))

def test_add_complex(self):
def fn(a, b):
return a + b

x = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])

self.common(fn, (x, y))

def test_concat_add_inplace(self):
def fn(x, y, z):
return torch.cat([x, y], dim=1).add_(z)
Expand Down Expand Up @@ -638,19 +647,6 @@ def fn(a, b, c):
interger_real_input = torch.tensor([-1, 0, 1])
self.common(fn, (complex_input, real_input, interger_real_input))

def test_add(self):
def fn(a, b):
return a + b

x = torch.tensor(
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1]
)
y = torch.tensor(
[1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1]
)

self.common(fn, (x, y))

def test_sgn(self):
def fn(a):
return torch.sgn(a), torch.sgn(a + 1) - 1
Expand Down Expand Up @@ -3060,7 +3056,7 @@ def fn(x):
fn,
(torch.randn([1, 2, 4, 8]).to(dtype=torch.complex64),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)

class ToComplex(nn.Module):
def forward(self, x):
Expand Down
1 change: 0 additions & 1 deletion test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def run(*ex, **kwargs):
#
# Failed to find for loop/triton kernel:
#
"test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")),
Expand Down
56 changes: 51 additions & 5 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,60 @@ def angle(x):
nan = torch.where(torch.isnan(x), float("nan"), 0.0)
return ret + nan


@register_decomposition([aten.add])
def add(x, y):
if x.is_complex():
z1 = x.real + y.real
z2 = x.imag + y.imag;
return torch.complex(z1, z2);
def add(x, y, **kwargs):
x_is_complex = torch.is_tensor(x) and x.is_complex()
y_is_complex = torch.is_tensor(y) and y.is_complex()
if not x_is_complex and not y_is_complex:
return NotImplemented
r = y.real
i = y.imag
alpha = kwargs.get("alpha")
if alpha is not None:
r *= alpha
i *= alpha
r = x.real + r
i = x.imag + i
return (
torch.where(
torch.arange(2, device=x.device, dtype=torch.uint8) == 0,
r.unsqueeze(-1),
i.unsqueeze(-1),
)
.view(x.dtype)
.squeeze(-1)
)


@register_decomposition([aten.real])
def real(self):
if not torch.is_tensor(self) or not self.is_complex():
return NotImplemented
assert self.is_complex(), "real should only be called on a complex tensor"
if self.dtype == torch.complex32:
return self.view(torch.float16)[..., ::2]
elif self.dtype == torch.complex64:
return self.view(torch.float32)[..., ::2]
elif self.dtype == torch.complex128:
return self.view(torch.float64)[..., ::2]
else:
raise AssertionError("unsupported complex type")


@register_decomposition([aten.imag])
def imag(self):
if not torch.is_tensor(self) or not self.is_complex():
return NotImplemented
if self.dtype == torch.complex32:
return self.view(torch.float16)[..., 1::2]
elif self.dtype == torch.complex64:
return self.view(torch.float32)[..., 1::2]
elif self.dtype == torch.complex128:
return self.view(torch.float64)[..., 1::2]
else:
raise AssertionError("unsupported complex type")


@register_decomposition([aten.conj_physical])
def conj_physical(self):
Expand Down

0 comments on commit 63bcb11

Please sign in to comment.