diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 4c71915ec9886..bbca9d4d5798e 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1440,7 +1440,6 @@ def f(a, b, c, d, e): xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition - xfail('linalg.ldl_solve', ''), # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index f768529029b50..ef2258467a6bf 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -564,6 +564,46 @@ def linalg_ldl_factor_ex_meta( return LD, pivots, info +@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out]) +@out_wrapper() +def linalg_ldl_solve_meta( + LD: Tensor, pivots: Tensor, B: Tensor, *, hermitian: bool = False +) -> Tensor: + squareCheckInputs(LD, "torch.linalg.ldl_solve") + checkFloatingOrComplex(LD, "torch.linalg.ldl_solve") + linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve") + check( + B.ndim >= 2, + lambda: ( + f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, " + f"but it has {B.ndim} dimensions instead" + ), + ) + expected_pivots_shape = LD.shape[:-1] + check( + expected_pivots_shape == pivots.shape, + lambda: ( + f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, " + f"but got pivots with shape {pivots.shape} instead" + ), + ) + check( + utils.is_integer_dtype(pivots.dtype), + lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}", + ) + check( + LD.dtype == B.dtype, + lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}", + ) + B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD) + return torch.empty_strided( + size=B_broadcast_size, + stride=make_contiguous_strides_for(B_broadcast_size, row_major=False), + dtype=B.dtype, + device=B.device, + ) + + # parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced) def _parse_qr_mode(mode: str) -> Tuple[bool, bool]: if mode == "reduced":