Skip to content

Commit

Permalink
[pt2] add meta for ormqr
Browse files Browse the repository at this point in the history
ghstack-source-id: fbff1e0b23bbb25e31b2339c1c67c2fa0b29205c
Pull Request resolved: #106278
  • Loading branch information
nkaretnikov committed Jul 30, 2023
1 parent eab3b26 commit add5157
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 8 deletions.
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,7 +2830,6 @@ def forward(self, x):
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/de...
Expand Down
6 changes: 0 additions & 6 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,6 @@ def run_meta_crossref(
torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.ormqr : {f64, c64, c128, f32},
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
torch.bincount : {i32, i64, u8, i16, i8},
torch.frexp : {f64, f16, bf16, f32},
Expand Down Expand Up @@ -715,7 +714,6 @@ def run_meta_crossref(
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
torch.kthvalue: {f16}, # aten::kthvalue.values
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out
}

meta_function_device_expected_failures_only_outplace['cuda'] = {
Expand Down Expand Up @@ -809,8 +807,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
aten.ormqr.default : {c64, c128, f64, f32},
aten.ormqr.out : {c64, c128, f64, f32},
aten.tensordot.out : {c64, i8, f64, c128, i64, bf16, f32, i32, i16, u8},
aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
Expand Down Expand Up @@ -890,8 +886,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.median.default: {f16}, # aten::median
aten.median.dim: {f16}, # aten::median.dim_values
aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward
aten.ormqr.default: {f32, f64}, # aten::ormqr
aten.ormqr.out: {f32, f64}, # aten::ormqr.out
aten.rrelu_with_noise.default: {f16}, # aten::rrelu_with_noise
aten.tensordot.out: {f16}, # aten::tensordot.out
aten.unique_consecutive.default: {f16}, # aten::unique_consecutive
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,6 @@ def f(t):
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
Expand Down
92 changes: 92 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,98 @@ def _linalg_det_meta(A):
return det, LU, pivots


@register_meta(aten.ormqr)
@out_wrapper()
def ormqr(
input: Tensor,
tau: Tensor,
other: Tensor,
left: bool = True,
transpose: bool = False,
) -> Tensor:
torch._check(
input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
)
torch._check(
other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
)

left_size_condition = -2 if left else -1
torch._check(
other.shape[left_size_condition] >= tau.shape[-1],
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
)
torch._check(
other.shape[left_size_condition] == input.shape[-2],
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
)

torch._check(
tau.shape[-1] <= input.shape[-1],
lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
)

torch._check(
input.ndim - tau.ndim == 1,
lambda: (
f"torch.ormqr: Expected tau to have one dimension less than input, "
f"but got tau.ndim equal to {tau.ndim} and {input.ndim} is equal to {input.ndim}"
),
)
torch._check(
input.ndim == other.ndim,
lambda: (
f"torch.ormqr: Expected other to have the same number of dimensions as input, "
f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
),
)

if input.ndim > 2:
expected_batch_shape = input.shape[:-2]
actual_batch_tau_shape = tau.shape[:-1]
torch._check(
actual_batch_tau_shape == expected_batch_shape,
lambda: (
f"torch.ormqr: Expected batch dimensions of tau to be "
f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
),
)

actual_batch_other_shape = other.shape[:-2]
torch._check(
actual_batch_other_shape == expected_batch_shape,
lambda: (
f"torch.ormqr: Expected batch dimensions of other to be "
f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
),
)

torch._check(
tau.dtype == input.dtype,
lambda: (
f"torch.ormqr: Expected input and tau to have the same dtype, "
f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
),
)
torch._check(
other.dtype == input.dtype,
lambda: (
f"torch.ormqr: Expected input and other to have the same dtype, "
f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
),
)

checkSameDevice("torch.ormqr", tau, input, "tau")
checkSameDevice("torch.ormqr", other, input, "other")

return torch.empty_strided(
size=other.shape,
stride=make_contiguous_strides_for(other.shape, row_major=False),
dtype=other.dtype,
device=other.device,
)


def _padding_check_valid_input(input, padding, *, dim):
torch._check(
len(padding) == 2 * dim,
Expand Down

0 comments on commit add5157

Please sign in to comment.