Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pt2] add meta for ormqr #106278

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,7 +2830,6 @@ def forward(self, x):
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/de...
Expand Down
6 changes: 0 additions & 6 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,6 @@ def run_meta_crossref(
torch.masked_select : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.ormqr : {f64, c64, c128, f32},
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
torch.bincount : {i32, i64, u8, i16, i8},
torch.frexp : {f64, f16, bf16, f32},
Expand Down Expand Up @@ -715,7 +714,6 @@ def run_meta_crossref(
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
torch.kthvalue: {f16}, # aten::kthvalue.values
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out
}

meta_function_device_expected_failures_only_outplace['cuda'] = {
Expand Down Expand Up @@ -809,8 +807,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
aten.nonzero.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, c32, b8, i16, u8},
aten.ormqr.default : {c64, c128, f64, f32},
aten.ormqr.out : {c64, c128, f64, f32},
aten.tensordot.out : {c64, i8, f64, c128, i64, bf16, f32, i32, i16, u8},
aten._to_sparse.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten._to_sparse.sparse_dim : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
Expand Down Expand Up @@ -890,8 +886,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.median.default: {f16}, # aten::median
aten.median.dim: {f16}, # aten::median.dim_values
aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward
aten.ormqr.default: {f32, f64}, # aten::ormqr
aten.ormqr.out: {f32, f64}, # aten::ormqr.out
aten.rrelu_with_noise.default: {f16}, # aten::rrelu_with_noise
aten.tensordot.out: {f16}, # aten::tensordot.out
aten.unique_consecutive.default: {f16}, # aten::unique_consecutive
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,6 @@ def f(t):
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
Expand Down
92 changes: 92 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,98 @@ def _linalg_det_meta(A):
return det, LU, pivots


@register_meta(aten.ormqr)
@out_wrapper()
def ormqr(
input: Tensor,
tau: Tensor,
other: Tensor,
left: bool = True,
transpose: bool = False,
) -> Tensor:
torch._check(
input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
)
torch._check(
other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
)

left_size_condition = -2 if left else -1
torch._check(
other.shape[left_size_condition] >= tau.shape[-1],
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
)
torch._check(
other.shape[left_size_condition] == input.shape[-2],
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
)

torch._check(
tau.shape[-1] <= input.shape[-1],
lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
)

torch._check(
input.ndim - tau.ndim == 1,
lambda: (
f"torch.ormqr: Expected tau to have one dimension less than input, "
f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
),
)
torch._check(
input.ndim == other.ndim,
lambda: (
f"torch.ormqr: Expected other to have the same number of dimensions as input, "
f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
),
)

if input.ndim > 2:
expected_batch_shape = input.shape[:-2]
actual_batch_tau_shape = tau.shape[:-1]
torch._check(
actual_batch_tau_shape == expected_batch_shape,
lambda: (
f"torch.ormqr: Expected batch dimensions of tau to be "
f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
),
)

actual_batch_other_shape = other.shape[:-2]
torch._check(
actual_batch_other_shape == expected_batch_shape,
lambda: (
f"torch.ormqr: Expected batch dimensions of other to be "
f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
),
)

torch._check(
tau.dtype == input.dtype,
lambda: (
f"torch.ormqr: Expected input and tau to have the same dtype, "
f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
),
)
torch._check(
other.dtype == input.dtype,
lambda: (
f"torch.ormqr: Expected input and other to have the same dtype, "
f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
),
)

checkSameDevice("torch.ormqr", tau, input, "tau")
checkSameDevice("torch.ormqr", other, input, "other")

return torch.empty_strided(
size=other.shape,
stride=make_contiguous_strides_for(other.shape, row_major=False),
dtype=other.dtype,
device=other.device,
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very similar to linalg_householder_product meta.



def _padding_check_valid_input(input, padding, *, dim):
torch._check(
len(padding) == 2 * dim,
Expand Down