Skip to content

Commit

Permalink
[pt2] add meta for linalg_ldl_solve
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
nkaretnikov committed May 14, 2023
1 parent 874f010 commit 777e564
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,6 @@ def f(a, b, c, d, e):
xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
xfail('linalg.ldl_solve', ''), # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
Expand Down
40 changes: 40 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,46 @@ def linalg_ldl_factor_ex_meta(
return LD, pivots, info


@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
@out_wrapper()
def linalg_ldl_solve_meta(
LD: Tensor, pivots: Tensor, B: Tensor, *, hermitian: bool = False
) -> Tensor:
squareCheckInputs(LD, "torch.linalg.ldl_solve")
checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
check(
B.ndim >= 2,
lambda: (
f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
f"but it has {B.ndim} dimensions instead"
),
)
expected_pivots_shape = LD.shape[:-1]
check(
expected_pivots_shape == pivots.shape,
lambda: (
f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
f"but got pivots with shape {pivots.shape} instead"
),
)
check(
utils.is_integer_dtype(pivots.dtype),
lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
)
check(
LD.dtype == B.dtype,
lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
)
B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
return torch.empty_strided(
size=B_broadcast_size,
stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
dtype=B.dtype,
device=B.device,
)


# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
if mode == "reduced":
Expand Down

0 comments on commit 777e564

Please sign in to comment.