In [1]:
from functools import partial
import jax.scipy as jsp
from jax import lax
import jax.numpy as np
import numpy as onp
import jax


def linear_solve(a, b):
  a_factors = jsp.linalg.lu_factor(a)
  def solve(matvec, x):
    return jsp.linalg.lu_solve(a_factors, x)
  def transpose_solve(vecmat, x):
    return jsp.linalg.lu_solve(a_factors, x, trans=1)
  matvec = partial(np.dot, a)
  return lax.custom_linear_solve(matvec, b, solve, transpose_solve)

def loss(solve):
  def f(a, b):
    return solve(a, b).sum()
  return f



In [3]:
rs = onp.random.RandomState(0)
a = rs.randn(500, 500)
a = jax.device_put(a.T @ a + 0.1 * np.eye(500))
b = jax.device_put(rs.randn(500))

# general purpose solve
# current
grad = jax.jit(jax.grad(loss(np.linalg.solve)))
%timeit jax.device_get(grad(a, b))
# 33.8 ms per loop
# new
grad = jax.jit(jax.grad(loss(linear_solve)))
%timeit jax.device_get(grad(a, b))
# 10.1 ms per loop

The slowest run took 41.51 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 8.42 ms per loop
The slowest run took 45.25 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 9.18 ms per loop
