Skip to content

Commit

Permalink
[pt2] add metas for mode ops
Browse files Browse the repository at this point in the history
ghstack-source-id: e6a304794e8d82c0a1305cccdf236432680e3e6c
Pull Request resolved: #106273
  • Loading branch information
nkaretnikov committed Jul 31, 2023
1 parent 1e8dc9f commit d68482e
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 5 deletions.
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2810,7 +2810,6 @@ def forward(self, x):
xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos...
xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
Expand Down
2 changes: 0 additions & 2 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,6 @@ def run_meta_crossref(
torch.histogram : {f64, f32},
torch.histogramdd : {f64, f32},
torch.kthvalue : {f64, i32, i64, u8, i16, bf16, i8, f32},
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
torch.nn.functional.ctc_loss : {f64, f32},
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
torch.nn.functional.one_hot : {i64},
Expand Down Expand Up @@ -828,7 +827,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.histogram.bin_ct : {f32, f64},
aten.histogram.bins_tensor : {f32, f64},
aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.mode.default : {f16, i8, f64, i64, bf16, f32, i32, b8, i16, u8},
aten.nll_loss2d_forward.default : {bf16, f32, f64},
aten.rrelu_with_noise.default : {bf16, f32, f64},
aten.segment_reduce.default : {bf16, f32, f16, f64},
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,6 @@ def f(t):
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition
xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
Expand Down
4 changes: 3 additions & 1 deletion torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3056,10 +3056,12 @@ def meta_median(input):
aten.median.dim_values,
aten.nanmedian.dim,
aten.nanmedian.dim_values,
aten.mode.default,
aten.mode.values,
]
)
@out_wrapper("values", "indices")
def meta_median_dim(input, dim=-1, keepdim=False):
def meta_median_mode_dim(input, dim=-1, keepdim=False):
if device_hint(input) == "cuda":
utils.alert_not_deterministic("median CUDA with indices output")
dim = utils.reduction_dims(input.shape, (dim,))
Expand Down

0 comments on commit d68482e

Please sign in to comment.