Skip to content

Commit

Permalink
Avoid accessing imag of non-complex tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu committed Oct 24, 2023
1 parent 63bcb11 commit e8be014
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,25 +271,27 @@ def angle(x):

@register_decomposition([aten.add])
def add(x, y, **kwargs):
x_is_complex = torch.is_tensor(x) and x.is_complex()
y_is_complex = torch.is_tensor(y) and y.is_complex()
if not x_is_complex and not y_is_complex:
x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
if not x_is_complex_tensor and not y_is_complex_tensor:
return NotImplemented
r = y.real
i = y.imag
i = y.imag if y_is_complex_tensor else 0.0
alpha = kwargs.get("alpha")
if alpha is not None:
r *= alpha
i *= alpha
r = x.real + r
i = x.imag + i
if x_is_complex_tensor:
i = x.imag + i
complex_type = x.dtype if x_is_complex_tensor else y.dtype
return (
torch.where(
torch.arange(2, device=x.device, dtype=torch.uint8) == 0,
r.unsqueeze(-1),
i.unsqueeze(-1),
)
.view(x.dtype)
.view(complex_type)
.squeeze(-1)
)

Expand Down

0 comments on commit e8be014

Please sign in to comment.