diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 51200dc6b406..d9f7e8018264 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4638,7 +4638,7 @@ def merge_dicts(*dicts): add_docstr(torch.lu_solve, r""" -lu_solve(input, LU_data, LU_pivots, *, out=None) -> Tensor +lu_solve(b, LU_data, LU_pivots, *, out=None) -> Tensor Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted LU factorization of A from :meth:`torch.lu`. diff --git a/torch/overrides.py b/torch/overrides.py index e8a3933a1954..2af6e36ea914 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -505,7 +505,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.lt: lambda input, other, out=None: -1, torch.less: lambda input, other, out=None: -1, torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, - torch.lu_solve: lambda input, LU_data, LU_pivots, out=None: -1, + torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, torch.masked_fill: lambda input, mask, value: -1, torch.masked_scatter: lambda input, mask, source: -1,