Skip to content

Commit

Permalink
Decompose arange.default to arange.start_step (#99739)
Browse files Browse the repository at this point in the history
The aten op arange.default is not in the core aten IR, and should decompose into the arange.start_step op.
Pull Request resolved: #99739
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
angelayi authored and pytorchmergebot committed Apr 27, 2023
1 parent a67fa84 commit d06b93b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
2 changes: 2 additions & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.addcmul_,
aten.addr,
aten.aminmax,
aten.arange.default,
aten.arange.start,
aten.avg_pool2d_backward,
aten.binary_cross_entropy,
aten.binary_cross_entropy_backward,
Expand Down
30 changes: 30 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3356,6 +3356,36 @@ def nansum(self, dim=None, keepdim=False, *, dtype=None):
return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype)


@register_decomposition([aten.arange.default, aten.arange.out])
@out_wrapper()
def arange_default(
end: NumberType,
*,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
pin_memory: bool = False,
):
return aten.arange.start_step(
0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)


@register_decomposition([aten.arange.start])
def arange_start(
start: NumberType,
end: NumberType,
*,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
pin_memory: bool = False,
):
return aten.arange.start_step(
start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
)


def register_inplace(aten_op, outplace_op):
@register_decomposition(aten_op)
def inplace_op(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4341,7 +4341,7 @@ def empty_like(
)


@register_decomposition(aten.arange)
@register_decomposition([aten.arange.start_step, aten.arange.start_out])
@out_wrapper()
def arange(
start: NumberType = 0,
Expand Down

0 comments on commit d06b93b

Please sign in to comment.