Skip to content

Commit

Permalink
Deprecate .mT,.T,.mH,.H on 0D tensors
Browse files Browse the repository at this point in the history
As discussed with ngimel, this is not only not documented,
but it's also an unnecessary edge case. See #90463 (comment)

ghstack-source-id: ae0ae52b967b2390abe97725b8f2f4b38fd7c3a4
Pull Request resolved: #92143
  • Loading branch information
lezcano committed Jan 13, 2023
1 parent ec3941a commit e897bad
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 9 deletions.
23 changes: 23 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3543,6 +3543,11 @@ Tensor numpy_T(const Tensor &self) {
"or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor."
);
}
if (n == 0) {
// Added in PyTorch 2.0
TORCH_WARN_ONCE("Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases.");
throw 1;
}
DimVector transpose_dims;
for (int64_t i = n - 1; i >= 0; --i) {
transpose_dims.push_back(i);
Expand All @@ -3552,6 +3557,11 @@ Tensor numpy_T(const Tensor &self) {

Tensor matrix_H(const Tensor &self) {
const auto ndim = self.dim();
if (ndim == 0) {
// Added in PyTorch 2.0
TORCH_WARN_ONCE("Tensor.H is deprecated on 0-D tensors. Consider using x.conj().");
throw 1;
}
TORCH_CHECK(ndim == 2 || ndim == 0,
"tensor.H is only supported on matrices (2-D tensors). Got ", ndim, "-D tensor.",
ndim > 2 ? " For batches of matrices, consider using tensor.mH" : "");
Expand All @@ -3576,14 +3586,27 @@ Tensor _adjoint(const Tensor &self, const bool transpose, const char* const name
} // anonymous namespace

Tensor mT(const Tensor &self) {
if (self.dim() == 0) {
// Added in PyTorch 2.0
TORCH_WARN_ONCE("Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.");
throw 1;
}
return _adjoint(self, /*transpose=*/true, "mT");
}

Tensor mH(const Tensor &self) {
if (self.dim() == 0) {
// Added in PyTorch 2.0
TORCH_WARN_ONCE("Tensor.mH is deprecated on 0-D tensors. Consider using x.conj().");
throw 1;
}
return _adjoint(self, /*transpose=*/false, "mH");
}

Tensor adjoint(const Tensor &self) {
if (self.dim() == 0) {
TORCH_WARN_ONCE("adjoint() is deprecated on 0-D tensors. Consider using x.conj().");
}
return _adjoint(self, /*transpose=*/false, "adjoint()");
}

Expand Down
1 change: 0 additions & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,6 @@ def forward(self, x):
return x.T

self.run_test(NumpyTranspose(), torch.randn(4, 7))
self.run_test(NumpyTranspose(), torch.tensor(-42.0))

# Conversion of Transpose depends on input shape to be known.
# The following test only works when onnx shape inference is enabled.
Expand Down
1 change: 0 additions & 1 deletion test/test_legacy_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,7 +2040,6 @@ def op(t):
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 3, 5),))
test(op, (torch.rand(B0),))
test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
Expand Down
7 changes: 2 additions & 5 deletions test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,21 +1307,18 @@ def test_T(self, device):
self.assertEqual(t2, t1)
b = torch.randn(10, device=device)
self.assertEqual(b, b.T)
scalar = torch.tensor(5, device=device)
self.assertEqual(scalar, scalar.T)

@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
def test_transposes(self, device, dtype):
for op in ("T", "H", "mT", "mH", "adjoint"):
shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
for shape in shapes:
a = make_tensor(shape, device=device, dtype=dtype)
t1 = getattr(a, op)
if op == "adjoint":
t1 = t1()
t2 = a
if a.ndim != 0:
t2 = t2.transpose(-2, -1)
t2 = t2.transpose(-2, -1)
if op[-1] == "H" or op == "adjoint":
t2 = t2.conj()
self.assertEqual(t2, t1)
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,13 +1656,13 @@ def _numpy_ref_transpose(a, dim0, dim1):
def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)

shapes = ((1, 2, 3), (), (M, M), (S, S, S), (S, M, S), (M, S, M, S))
shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S))
return (SampleInput(make_arg(shape)) for shape in shapes)

def sample_inputs_T(self, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)

shapes = ((), (M, M), (M, L))
shapes = ((M, M), (M, L))
return (SampleInput(make_arg(shape)) for shape in shapes)

def error_inputs_T(self, device, has_ndims_error=False):
Expand Down

0 comments on commit e897bad

Please sign in to comment.