Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give std/var correction overloads proper defaults #56398

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
36c6bd3
Give std/var correction overloads proper defaults
peterbell10 Apr 19, 2021
e4d9c3b
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 19, 2021
529bdde
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 19, 2021
dedabe4
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 20, 2021
799ed78
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 20, 2021
4fcadca
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 21, 2021
979949d
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 21, 2021
1dd29da
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 22, 2021
2269c44
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 23, 2021
6fc51eb
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 27, 2021
6aa9037
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 29, 2021
acb3027
Update on "Give std/var correction overloads proper defaults"
peterbell10 May 4, 2021
6a5f0a4
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jun 14, 2021
2617376
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jun 24, 2021
58b5273
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jul 6, 2021
348d660
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jul 7, 2021
e278586
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jul 9, 2021
c2dc061
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jul 9, 2021
64c4dc1
Update on "Give std/var correction overloads proper defaults"
peterbell10 Aug 9, 2021
41daf37
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 6, 2022
0ffad1e
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 7, 2022
5d57ca3
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 12, 2022
c14b6e6
Update on "Give std/var correction overloads proper defaults"
peterbell10 Apr 15, 2022
87242a4
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jul 30, 2022
44f3fa8
Update on "Give std/var correction overloads proper defaults"
peterbell10 Jul 30, 2022
54524e8
Rebase on "Give std/var correction overloads proper defaults"
peterbell10 Sep 28, 2022
703868b
Update on "Give std/var correction overloads proper defaults"
peterbell10 Sep 28, 2022
95ff3c8
Update on "Give std/var correction overloads proper defaults"
peterbell10 Sep 29, 2022
6a07423
Update on "Give std/var correction overloads proper defaults"
peterbell10 Sep 29, 2022
8126831
Add test for correction=None on "Give std/var correction overloads pr…
peterbell10 Oct 3, 2022
2da5b57
Update on "Give std/var correction overloads proper defaults"
peterbell10 Oct 4, 2022
4754b00
Update on "Give std/var correction overloads proper defaults"
peterbell10 Oct 13, 2022
f4a65ed
Update on "Give std/var correction overloads proper defaults"
peterbell10 Oct 15, 2022
920628c
Rebase and fix merge conflicts on "Give std/var correction overloads …
peterbell10 Oct 16, 2022
094dfd7
Update on "Give std/var correction overloads proper defaults"
peterbell10 Oct 16, 2022
20a04f4
Update on "Give std/var correction overloads proper defaults"
peterbell10 Oct 20, 2022
6bd08d7
Rebase and fix merge conflicts on "Give std/var correction overloads …
peterbell10 Oct 21, 2022
f8c6451
Split std/var opinfos into two on "Give std/var correction overloads …
peterbell10 Nov 1, 2022
356e2d5
Update on "Give std/var correction overloads proper defaults"
peterbell10 Nov 1, 2022
cda59b4
Update on "Give std/var correction overloads proper defaults"
peterbell10 Nov 2, 2022
a9f0084
Update on "Give std/var correction overloads proper defaults"
peterbell10 Nov 2, 2022
66ba1ef
Update on "Give std/var correction overloads proper defaults"
peterbell10 Nov 7, 2022
c8e750a
Rebase on "Give std/var correction overloads proper defaults"
peterbell10 Dec 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,9 @@ static Tensor& std_var_out(
const auto correction = correction_opt.value_or(1);
ScalarType dtype = get_dtype_from_result(result, {});
auto iter = make_reduction(fname, result, self, dim, keepdim, dtype);
TORCH_CHECK(at::canCast(self.scalar_type(), result.scalar_type()),
"result type ", self.scalar_type(), " can't be cast to the "
"desired output type ", result.scalar_type());

if (iter.numel() == 0) {
// Trivial reduction
Expand Down
40 changes: 28 additions & 12 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5176,12 +5176,14 @@
- func: std(Tensor self, bool unbiased=True) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]

- func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]

- func: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
- func: std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
Expand All @@ -5192,12 +5194,14 @@
- func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]

- func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]

- func: std_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
- func: std_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
dispatch:
Expand All @@ -5207,15 +5211,17 @@
- func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]

- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function

- func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]

- func: std.correction_out(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
- func: std.correction_out(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: std_out
Expand All @@ -5224,15 +5230,17 @@
- func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]

- func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]

- func: std.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> Tensor
- func: std.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method

- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function

Expand Down Expand Up @@ -5645,13 +5653,15 @@
- func: var(Tensor self, bool unbiased=True) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]

- func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
tags: canonical
cpp_no_default_args: ["unbiased"]

- func: var.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
- func: var.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
Expand All @@ -5660,36 +5670,41 @@

- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]

- func: var.correction_out(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
- func: var.correction_out(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: var_out

- func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
cpp_no_default_args: ["unbiased"]

- func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
cpp_no_default_args: ["unbiased"]

- func: var.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> Tensor
- func: var.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method

- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function

- func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]

- func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]

- func: var_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
- func: var_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
dispatch:
Expand All @@ -5699,8 +5714,9 @@
- func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function
cpp_no_default_args: ["unbiased"]

- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor)
device_check: NoCheck # TensorIterator
variants: function

Expand Down
4 changes: 4 additions & 0 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1979,7 +1979,9 @@ def forward(self, x):
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('std', 'unbiased'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('std_mean', 'unbiased'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('sum_to_size', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
Expand All @@ -1994,7 +1996,9 @@ def forward(self, x):
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
xfail('unflatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('var', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('var', 'unbiased'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('var_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('var_mean', 'unbiased'), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('view_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
}
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def process(device_type):
"scatter_reduce.prod": {f16, f32, f64},
"segment_reduce.lengths": {f16, f32, f64},
"sparse.sampled_addmm": {f32, f64},
"std_mean.unbiased": {f16},
"stft": {f32, f64},
"svd_lowrank": {f32, f64},
"tensor_split": {b8, f16, f32, f64, i32, i64},
Expand Down
8 changes: 4 additions & 4 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1514,12 +1514,12 @@
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear

- name: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
- name: std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor
self: std_backward(result, grad, self, dim, correction, keepdim)
# pointwise (variance) + sum + sqrt
result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0)

- name: std_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
- name: std_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor)
self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim)
result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0)
# linear
Expand Down Expand Up @@ -1723,12 +1723,12 @@
self: grad.squeeze(dim)
result: auto_linear

- name: var.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
- name: var.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor
self: var_backward(grad, self, dim, correction, keepdim)
# pointwise + sum
result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim))

- name: var_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor)
- name: var_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor)
self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim)
result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim))
# linear
Expand Down
2 changes: 1 addition & 1 deletion torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ def xlogy(self: Tensor, other: Tensor) -> Tensor:
@reduction_complex_to_real
def std_decomposition(
x: Tensor,
dim: Optional[List[int]],
dim: Optional[List[int]] = None,
correction: Optional[int] = None,
keepdim: bool = False,
):
Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3292,7 +3292,7 @@ def mean(x, axis=None, keepdim=False, *, dtype=None):


@register_lowering([aten.var, prims.var])
def var_(x, axis, correction=1, keepdim=False):
def var_(x, axis=None, correction=1, keepdim=False):
size = x.get_size()
axis = _validate_reduction_axis(x, axis)
diffs = square(sub(x, mean(x, axis, keepdim=True)))
Expand All @@ -3307,7 +3307,7 @@ def var_(x, axis, correction=1, keepdim=False):


@register_lowering(aten.var_mean)
def var_mean(x, dim, unbiased=True, keepdim=False, correction=None):
def var_mean(x, dim=None, unbiased=True, keepdim=False, correction=None):
if correction is None:
correction = int(unbiased)
return [
Expand All @@ -3317,7 +3317,7 @@ def var_mean(x, dim, unbiased=True, keepdim=False, correction=None):


@register_lowering(aten.std)
def std(x, axis, correction=1, keepdim=False):
def std(x, axis=None, correction=1, keepdim=False):
return sqrt(var_(x, axis, correction, keepdim=keepdim))


Expand Down
14 changes: 2 additions & 12 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4785,12 +4785,7 @@ def callable(a, b) -> number
add_docstr_all(
"std",
r"""
std(dim, unbiased=True, keepdim=False) -> Tensor

See :func:`torch.std`

.. function:: std(unbiased=True) -> Tensor
:noindex:
std(dim=None, *, correction=1, keepdim=False) -> Tensor

See :func:`torch.std`
""",
Expand Down Expand Up @@ -5738,12 +5733,7 @@ def callable(a, b) -> number
add_docstr_all(
"var",
r"""
var(dim, unbiased=True, keepdim=False) -> Tensor

See :func:`torch.var`

.. function:: var(unbiased=True) -> Tensor
:noindex:
var(dim=None, *, correction=1, keepdim=False) -> Tensor

See :func:`torch.var`
""",
Expand Down
Loading