<a href="https://colab.research.google.com/github/romanodev/jax-pv/blob/master/sparse_preconditioner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import time
import sys
import jax.numpy as jnp
from jax import random
import jax
from jax import lax
from jax import grad, jit, vmap


In [None]:
key = random.PRNGKey(1)

n = 1000

#Jax---------------------------------------
A_jax = random.uniform(key,(n,n))*jnp.tri(n).T
b_jax = random.uniform(key,(n,1))
x_jax = jnp.linalg.solve(A_jax,b_jax)


#Numpy-------------------------------------
A = np.random.randn(n, n) * np.tri(n).T
b = np.random.randn(n)
x = np.linalg.solve(A,b)

In [None]:
def one_shot_numpy(A,b):

 n = len(b)
 xcomp_numpy = np.zeros(n)
 xcomp_numpy[-1] = b[-1]/A[-1,-1]
 
 for i in range(n - 2, -1, -1):

    xcomp_numpy[i] = (b[i] - np.dot(A[i, i + 1:], xcomp_numpy[i + 1:])) / A[i, i]

 return xcomp_numpy


In [None]:
%%timeit
xcomp_numpy = one_shot_numpy(A,b)

100 loops, best of 3: 2.56 ms per loop


We first try scan - however, we won't need to differentiate through the solve, hence this is only for testing

In [None]:
@jit
def fsub_scan(A_jax,b_jax):

  n = len(b_jax)
  xcomp_jax = jax.ops.index_update(np.zeros(n), jax.ops.index[-1],b_jax[-1,0]/A_jax[-1,-1])

  @jit
  def iteration(carry,i):

    (xcomp_jax,A_jax,b_jax) = carry
    a1 = jnp.where(jnp.arange(n)>i,xcomp_jax,0)
    a2 = jnp.where(jnp.arange(n)>i,A_jax[i],0)
    tmp = (b_jax[i] - jnp.dot(a1,a2)) / A_jax[i, i]
    xcomp_jax = jax.ops.index_update(xcomp_jax, jax.ops.index[i],tmp[0])

    return (xcomp_jax,A_jax,b_jax),None

  return lax.scan(iteration, (xcomp_jax,A_jax,b_jax), jnp.arange(n-2,-1,-1))[0][0]


In [None]:
%%timeit
xcomp_jax  = fsub_scan(A_jax,b_jax).block_until_ready()

The slowest run took 9.70 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 19.9 ms per loop


In [None]:
@jit
def fsub_fori_loop(A_jax,b_jax):

  n = len(b_jax)
  xcomp_jax = jax.ops.index_update(np.zeros(n), jax.ops.index[-1],b_jax[-1,0]/A_jax[-1,-1])

  @jit
  def run(k,carry):

    (xcomp_jax,A_jax,b_jax) = carry
    n = len(xcomp_jax)
    i = n-2-k
    a1 = jnp.where(jnp.arange(n)>i,xcomp_jax,0)
    a2 = jnp.where(jnp.arange(n)>i,A_jax[i],0)
    tmp = (b_jax[i] - jnp.dot(a1,a2)) / A_jax[i, i]
    xcomp_jax = jax.ops.index_update(xcomp_jax, jax.ops.index[i],tmp[0])

    return (xcomp_jax,A_jax,b_jax)


  return lax.fori_loop(0,n-1,run,(xcomp_jax,A_jax,b_jax))[0]



In [None]:
%%timeit
xcomp_jax  = fsub_fori_loop(A_jax,b_jax).block_until_ready()

The slowest run took 5.30 times longer than the fastest. This could mean that an intermediate result is being cached.
10 loops, best of 3: 20 ms per loop


In [None]:
print(np.allclose(xcomp_jax,x_jax.T[0]))

True


TODO: rewrite backsubstitution using this approach https://github.com/google/jax/issues/2491