Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 79 additions & 13 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ def to_real_dtype(dtype: torch.dtype):
elif dtype == torch.complex128:
return torch.float64

# TODO: None of these loss castings are quite correct, see
# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
# perform the pointwise portion in opmath, but don't maintain it between the
# pointwise portion and the reduction

@register_decomposition(aten.l1_loss)
def l1_loss(
Expand Down Expand Up @@ -371,6 +375,7 @@ def l1_loss_backward(


@register_decomposition(aten.mse_loss)
@pw_cast_for_opmath
def mse_loss(
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
) -> Tensor:
Expand Down Expand Up @@ -618,11 +623,6 @@ def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
return grad_output * (mask.type_as(grad_output) * scale)


@register_decomposition(aten.reciprocal)
def reciprocal(self: Tensor) -> Tensor:
return 1 / self


@register_decomposition(aten.logit)
@pw_cast_for_int_to_real
def logit(self: Tensor, eps: Optional[float] = None) -> Tensor:
Expand Down Expand Up @@ -1034,14 +1034,14 @@ def logical_not(self: Tensor) -> Tensor:
return ~self.to(dtype=torch.bool)


# Actually, I'm just not sure how to implement this correctly (maybe you need a special case for floating point?)
# @register_decomposition(aten.xlogy)
# def xlogy(self: Tensor, other: Tensor) -> Tensor:
# return aten.where(aten.isnan(self),
# self,
# aten.where(self == aten.new_zeros(self, ()),
# aten.new_zeros(self, ()),
# self * aten.log(other)))
@register_decomposition(aten.xlogy.Tensor)
@pw_cast_for_int_to_real
def xlogy(self: Tensor, other: Tensor) -> Tensor:
return aten.where(aten.isnan(self),
self,
aten.where(self == aten.new_zeros(self, ()),
aten.new_zeros(self, ()),
self * aten.log(other)))


@register_decomposition(aten.var.correction)
Expand Down Expand Up @@ -1152,3 +1152,69 @@ def cudnn_batch_norm_backward(
epsilon,
[True, True, True],
)


@register_decomposition(aten.rot90.default)
def rot90(self: Tensor, k: int = 1, dims: List[int] = [0, 1]) -> Tensor: # noqa: B006
total_dims = self.dim()
total_rot_dims = len(dims)
assert total_rot_dims == 2, f"expected total rotation dims == 2, but got dims = {total_rot_dims}"
assert total_dims >= 2, f"expected total dims >= 2, but got total dims = {total_dims}"
assert dims[0] != dims[1] and abs(dims[0] - dims[1]) != total_dims,\
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
assert dims[0] < total_dims and dims[0] >= -total_dims, f"Rotation dim0 out of range, dim0 = {dims[0]}"
assert dims[1] < total_dims and dims[1] >= -total_dims, f"Rotation dim1 out of range, dim1 = {dims[1]}"
k = k % 4
if k == 1:
return self.flip(dims[1]).transpose(dims[0], dims[1])
elif k == 2:
return self.flip(dims)
elif k == 3:
return self.flip(dims[0]).transpose(dims[0], dims[1])
else:
return self.clone(memory_format=torch.contiguous_format)


@register_decomposition(aten.transpose.int)
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor:
dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc]

if self.dim() <= 1:
return self

if dim0 == dim1:
return self
perm = list(range(self.dim()))
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
return torch.permute(self, perm)


@register_decomposition(aten.t.default)
def t(self: Tensor) -> Tensor:
return self.transpose(0, 0 if self.dim() < 2 else 1)


def check_stack_inputs(tensors: List[Tensor]):
entry_shape = tensors[0].shape
for i in range(1, len(tensors)):
assert tensors[i].shape == entry_shape, (f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
f"and {tensors[i].shape} at entry {i}")


def get_stack_inputs(tensors: List[Tensor], dim: int):
check_stack_inputs(tensors)
return [t.unsqueeze(dim) for t in tensors]


@register_decomposition(aten.stack.default)
def stack(tensors: List[Tensor], dim: int = 0) -> Tensor:
assert len(tensors) > 0, "stack expects a non-empty TensorList"
wrapped_dim = utils.canonicalize_dim(tensors[0].dim() + 1, dim)
if wrapped_dim < tensors[0].dim() and not tensors[0].is_sparse:
check_stack_inputs(tensors)
result_sizes = list(tensors[0].shape)
result_sizes.insert(wrapped_dim, len(tensors))
out = torch.cat(tensors, wrapped_dim)
return out.view(result_sizes)
else:
return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)
10 changes: 6 additions & 4 deletions torch/_prims/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def validate_exclusive_idx(shape: Sequence, ex_idx: int):

# "Wraps" a dim (up to one time) for the given rank, allowing
# dims to be specified using negative indices
def canonicalize_idx(rank: int, idx: int) -> int:
def canonicalize_dim(rank: int, idx: int) -> int:
# TODO: add a comment for why this is
_rank = rank if rank != 0 else 1

Expand All @@ -237,6 +237,8 @@ def canonicalize_idx(rank: int, idx: int) -> int:

if idx < 0:
_idx = idx + _rank
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

_idx = idx

if _idx < 0 or _idx > _rank:
msg = "Received out of bounds index {0} for tensor of rank {1}!".format(
Expand All @@ -251,9 +253,9 @@ def canonicalize_idx(rank: int, idx: int) -> int:
# mapping negative offsets to positive ones
def canonicalize_dims(rank: int, indices: DimsType) -> DimsType:
if isinstance(indices, int):
return canonicalize_idx(rank, indices)
return canonicalize_dim(rank, indices)

return tuple(canonicalize_idx(rank, x) for x in indices)
return tuple(canonicalize_dim(rank, x) for x in indices)


def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
Expand Down Expand Up @@ -676,7 +678,7 @@ def compute_reduction_output_shape(
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
if dims is None:
return tuple(range(len(shape)))
dims = tuple(canonicalize_idx(len(shape), idx) for idx in dims)
dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
if len(dims) != len(set(dims)):
raise RuntimeError("duplicate value in the list of dims")
return dims
2 changes: 1 addition & 1 deletion torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def tensor_split(
indices_or_sections: Union[Tensor, DimsType],
dim: int = 0,
) -> Tuple[TensorLikeType, ...]:
_dim = utils.canonicalize_idx(a.ndim, dim)
_dim = utils.canonicalize_dim(a.ndim, dim)
if a.ndim == 0:
msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
raise ValueError(msg)
Expand Down