@@ -376,3 +376,25 @@ def lu_solve(b, LU_data, LU_pivots, **kwargs):
376376 _pivots = LU_pivots - 1
377377 x = jax .scipy .linalg .lu_solve ((LU_data , _pivots ), b )
378378 return x
379+
380+ @register_function (torch .linalg .tensorsolve )
381+ def linalg_tensorsolve (A , b , dims = None ):
382+ # examples:
383+ # A = torch.randn(2, 3, 6), b = torch.randn(3, 2)
384+ # A = torch.randn(2, 3, 6), b = torch.randn(2, 3) -> torch.Size([3, 6])
385+ # A = torch.randn(9, 2, 6, 3) b = torch.randn(6, 3) -> torch.Size([6, 3])
386+ # A = torch.randn(9, 2, 3, 6) b = torch.randn(6, 3) -> torch.Size([3, 6])
387+ # A = torch.randn(18, 6, 3) b = torch.randn(18) -> torch.Size([6, 3])
388+ # A = torch.randn(3, 8, 4, 6) b = torch.randn(4, 6) -> torch.Size([4,6])
389+ # A = torch.randn(3, 8, 1, 2, 2, 6) b = torch.randn(3, 4, 2) -> torch.Size([2, 2, 6])
390+
391+ # torch allows b to be shaped differently.
392+ # especially when axes are moved using dims.
393+ # ValueError: After moving axes to end, leading shape of a must match shape of b. got a.shape=(3, 2, 6), b.shape=(2, 3)
394+ # So we are handling the moveaxis and forcing b's shape to match what jax expects
395+ if dims is not None :
396+ A = jnp .moveaxis (A , dims , len (dims ) * (A .ndim - 1 ,))
397+ dims = None
398+ if A .shape [:b .ndim ] != b .shape :
399+ b = jnp .reshape (b , A .shape [:b .ndim ])
400+ return jnp .linalg .tensorsolve (A , b , axes = dims )
0 commit comments