In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random, jit, grad
import scipy

In [3]:
import cr.sparse as crs
from cr.sparse import la
from cr.sparse import dict
from cr.sparse import pursuit
from cr.sparse import data

# Dictionary Setup

In [4]:
M = 32
N = 64
K = 3

In [5]:
key = random.PRNGKey(0)
Phi = dict.gaussian_mtx(key, M,N)



In [6]:
Phi.shape

(32, 64)

In [7]:
dict.coherence(Phi)

DeviceArray(0.5880293, dtype=float32)

# Signal Setup

In [8]:
x, omega = data.sparse_normal_representations(key, N, K, 1)
x = jnp.squeeze(x)
x

DeviceArray([ 0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              1.8160858 ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        , -0.48262328,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,  0.        ,  0.        ,  0.339889  ,
              0.        ,  0.        ,  0.        ,  0.        ,
              0.        ,

In [9]:
omega, omega.shape

(DeviceArray([ 8, 25, 55], dtype=int32), (3,))

In [10]:
y = Phi @ x
y

DeviceArray([ 0.02390813, -0.1268274 , -0.4807552 ,  0.3702036 ,
             -0.19832093, -0.32299417,  0.20099871,  0.61086226,
              0.65007794,  0.31966653,  0.12593222, -0.20471574,
             -0.5031531 ,  0.39456028, -0.37295896,  0.46000382,
              0.5659018 ,  0.24554808, -0.08826719,  0.21251862,
             -0.03524983, -0.04629171, -0.45931473,  0.15922827,
             -0.18856226,  0.01440289, -0.04701412, -0.28536397,
              0.3746668 , -0.11432485, -0.22904383,  0.2154006 ],            dtype=float32)

# Development of OMP algorithm

## First iteration

In [11]:
r = y
norm_y_sqr = r.T @ r
norm_r_sqr = norm_y_sqr
norm_r_sqr

DeviceArray(3.3290536, dtype=float32)

In [12]:
p = Phi.T @ y
p, p.shape

(DeviceArray([ 9.61995199e-02, -5.24857044e-01, -2.36030817e-02,
               1.93185583e-01,  2.01418445e-01,  5.07771730e-01,
              -6.01010323e-02, -4.06812727e-02,  1.71276557e+00,
               2.37727642e-01, -5.69290996e-01,  4.15185690e-01,
              -6.56433880e-01, -4.44442391e-01,  3.19987014e-02,
              -5.34418464e-01, -1.63692579e-01,  5.11043012e-01,
              -1.88053131e-01, -3.19571137e-01,  1.40016377e-01,
              -3.31941187e-01,  2.24794447e-01,  2.18159825e-01,
               5.78138232e-03, -1.53794974e-01, -1.93431973e-04,
              -2.35534534e-01,  8.89518335e-02, -1.14321038e-01,
              -4.75759268e-01, -2.76573420e-01, -4.11897033e-01,
              -1.75503299e-01, -2.99649894e-01, -5.52226186e-01,
               5.59622236e-02,  2.28143021e-01,  1.20019704e-01,
               7.01376438e-01, -1.53217733e-01,  5.12092233e-01,
               2.15275526e-01,  5.09653836e-02,  3.48680854e-01,
               3.18339586

In [13]:
h = p
h, h.shape

(DeviceArray([ 9.61995199e-02, -5.24857044e-01, -2.36030817e-02,
               1.93185583e-01,  2.01418445e-01,  5.07771730e-01,
              -6.01010323e-02, -4.06812727e-02,  1.71276557e+00,
               2.37727642e-01, -5.69290996e-01,  4.15185690e-01,
              -6.56433880e-01, -4.44442391e-01,  3.19987014e-02,
              -5.34418464e-01, -1.63692579e-01,  5.11043012e-01,
              -1.88053131e-01, -3.19571137e-01,  1.40016377e-01,
              -3.31941187e-01,  2.24794447e-01,  2.18159825e-01,
               5.78138232e-03, -1.53794974e-01, -1.93431973e-04,
              -2.35534534e-01,  8.89518335e-02, -1.14321038e-01,
              -4.75759268e-01, -2.76573420e-01, -4.11897033e-01,
              -1.75503299e-01, -2.99649894e-01, -5.52226186e-01,
               5.59622236e-02,  2.28143021e-01,  1.20019704e-01,
               7.01376438e-01, -1.53217733e-01,  5.12092233e-01,
               2.15275526e-01,  5.09653836e-02,  3.48680854e-01,
               3.18339586

In [14]:
i = pursuit.abs_max_idx(h)
i

DeviceArray(8, dtype=int32)

In [15]:
indices = jnp.array([i])
indices, indices.shape

(DeviceArray([8], dtype=int32), (1,))

In [16]:
atom = Phi[:, i]
atom, atom.shape

(DeviceArray([-0.0452036 , -0.06064072, -0.2012354 ,  0.23969077,
               0.02278345, -0.17967227,  0.1699537 ,  0.38612276,
               0.33129993,  0.21025042,  0.09814752, -0.12776868,
              -0.29690307,  0.18864553, -0.2137496 ,  0.0962122 ,
               0.20104681,  0.16185828, -0.04153139,  0.2210582 ,
              -0.08565383, -0.01555626, -0.24623081,  0.03879212,
              -0.07559849,  0.11092736,  0.00262051, -0.14456683,
               0.12789953, -0.07467625, -0.1764678 ,  0.21914414],            dtype=float32),
 (32,))

In [17]:
subdict = jnp.expand_dims(atom, axis=1)
subdict.shape

(32, 1)

In [18]:
L = jnp.ones((1,1))
L, L.shape

(DeviceArray([[1.]], dtype=float32), (1, 1))

In [19]:
p_I = p[indices]
p_I, p_I.shape

(DeviceArray([1.7127656], dtype=float32), (1,))

In [20]:
x_I = p_I
x_I, x_I.shape

(DeviceArray([1.7127656], dtype=float32), (1,))

In [21]:
r_new = y - subdict @ x_I
r_new, r_new.shape

(DeviceArray([ 0.10133129, -0.02296407, -0.13608614, -0.04033047,
              -0.23734362, -0.01525769, -0.09009214, -0.05047554,
               0.0826388 , -0.04044315, -0.04217148,  0.01412205,
               0.00537229,  0.0714547 , -0.00685599,  0.2952149 ,
               0.22155577, -0.0316772 , -0.01713365, -0.16610228,
               0.1114551 , -0.01964747, -0.03757909,  0.09278646,
              -0.05907978, -0.17558968, -0.05150243, -0.03775488,
               0.1556049 ,  0.01357806,  0.07320414, -0.15994191],            dtype=float32),
 (32,))

In [22]:
norm_r_new_sqr = r_new.T @ r_new
norm_r_new_sqr

DeviceArray(0.395488, dtype=float32)

## Second iteration

In [23]:
r = r_new
norm_r_sqr = norm_r_new_sqr

In [24]:
h = Phi.T @ r
h, h.shape

(DeviceArray([-1.4950220e-01, -4.0264491e-02,  2.9967882e-02,
               5.4308653e-02, -2.5237434e-02, -2.0721941e-01,
              -1.0321927e-01, -3.8601544e-02, -5.5879354e-09,
               2.8200644e-01, -1.3619012e-01,  5.9649844e-02,
              -4.2853419e-02, -6.0242899e-03,  1.8234587e-01,
              -9.8339349e-02,  7.5748578e-02,  1.9492418e-02,
              -2.1028575e-01, -1.6349390e-01,  1.1980817e-02,
              -3.8731601e-02,  1.9039874e-01, -1.7291762e-02,
              -6.0108759e-02, -5.2030534e-01,  6.0183473e-02,
              -1.5187215e-02, -4.3200962e-03, -5.9144575e-02,
              -2.8118880e-02,  7.6223135e-02,  1.1758726e-02,
               9.9556088e-02, -2.0874816e-01, -6.7137294e-02,
               2.6948545e-02, -1.1548560e-01,  6.2108980e-03,
               1.6986775e-01, -1.2599348e-01,  2.0354834e-01,
               1.1221175e-01,  4.9667105e-02,  4.2487707e-02,
              -1.5251022e-02,  5.1390164e-02, -1.5082434e-01,
        

In [25]:
i = pursuit.abs_max_idx(h)
i

DeviceArray(25, dtype=int32)

In [26]:
indices = jnp.append(indices, i)

In [27]:
indices

DeviceArray([ 8, 25], dtype=int32)

In [28]:
atom = Phi[:, i]
atom, atom.shape

(DeviceArray([-7.1421169e-02, -1.0727050e-01,  1.4660509e-01,
               2.1629900e-01,  1.9570568e-01, -6.9225013e-02,
               1.6529614e-01,  8.8993214e-02,  2.3624465e-02,
               1.7819779e-01, -1.0951436e-04, -1.3379590e-01,
              -4.4796020e-02, -1.5135485e-01, -5.5924416e-02,
              -3.1435499e-01, -3.9888123e-01,  5.5221118e-02,
              -2.5144633e-02,  1.1392352e-01, -1.6872568e-01,
               4.4980709e-02,  7.2892308e-03, -2.2424569e-03,
              -4.5592640e-02,  1.9777423e-01,  1.2812763e-01,
               1.9953127e-01, -2.8159457e-01, -6.0718294e-02,
              -2.0496592e-01,  4.8059407e-01], dtype=float32),
 (32,))

In [29]:
b = subdict.T @ atom
b

DeviceArray([0.21398745], dtype=float32)

In [30]:
L = pursuit.gram_chol_update(L, b)

In [31]:
L, L.shape

(DeviceArray([[1.        , 0.        ],
              [0.21398745, 0.97683644]], dtype=float32),
 (2, 2))

In [32]:
subdict = jnp.hstack((subdict, jnp.expand_dims(atom,1)))
subdict, subdict.shape

(DeviceArray([[-4.5203600e-02, -7.1421169e-02],
              [-6.0640719e-02, -1.0727050e-01],
              [-2.0123540e-01,  1.4660509e-01],
              [ 2.3969077e-01,  2.1629900e-01],
              [ 2.2783445e-02,  1.9570568e-01],
              [-1.7967227e-01, -6.9225013e-02],
              [ 1.6995370e-01,  1.6529614e-01],
              [ 3.8612276e-01,  8.8993214e-02],
              [ 3.3129993e-01,  2.3624465e-02],
              [ 2.1025042e-01,  1.7819779e-01],
              [ 9.8147519e-02, -1.0951436e-04],
              [-1.2776868e-01, -1.3379590e-01],
              [-2.9690307e-01, -4.4796020e-02],
              [ 1.8864553e-01, -1.5135485e-01],
              [-2.1374960e-01, -5.5924416e-02],
              [ 9.6212201e-02, -3.1435499e-01],
              [ 2.0104681e-01, -3.9888123e-01],
              [ 1.6185828e-01,  5.5221118e-02],
              [-4.1531391e-02, -2.5144633e-02],
              [ 2.2105820e-01,  1.1392352e-01],
              [-8.5653827e-02, -1.687256

In [33]:
p_I = p[indices]
p_I, p_I.shape

(DeviceArray([ 1.7127656 , -0.15379497], dtype=float32), (2,))

In [34]:
x_I = la.solve_spd_chol(L, p_I)
x_I, x_I.shape

(DeviceArray([ 1.8294473, -0.5452737], dtype=float32), (2,))

In [35]:
subdict.shape, x_I.shape

((32, 2), (2,))

In [36]:
r_new = y - subdict @ x_I
r_new, r_new.shape

(DeviceArray([ 0.06766164, -0.07438019, -0.03266576,  0.04964414,
              -0.13328886, -0.03203979, -0.01979099, -0.04700333,
               0.05686402,  0.03219104, -0.05368321, -0.04392506,
               0.01558936, -0.03308657, -0.01240945,  0.11257917,
              -0.01940215, -0.02045245, -0.02599841, -0.1297762 ,
               0.02944765,  0.00669446, -0.00487384,  0.08703738,
              -0.07511929, -0.08069178,  0.01805644,  0.08791259,
              -0.01286474, -0.01081669, -0.01796781,  0.07654329],            dtype=float32),
 (32,))

In [37]:
norm_r_new_sqr = r_new.T @ r_new
norm_r_new_sqr

DeviceArray(0.11177917, dtype=float32)

## 3rd iteration

In [38]:
r = r_new
norm_r_sqr = norm_r_new_sqr

In [39]:
h = Phi.T @ r
h, h.shape

(DeviceArray([-3.31946202e-02,  5.75344227e-02,  5.59823290e-02,
               9.25678760e-05, -9.62564629e-03, -1.04181372e-01,
              -4.98393290e-02, -8.28084722e-03,  7.17118382e-08,
               1.29959524e-01, -2.90520918e-02,  1.14879869e-02,
               4.52624299e-02, -2.05479097e-03,  1.11361682e-01,
              -2.27925349e-02, -4.38999683e-02, -9.57450196e-02,
              -7.19629154e-02, -6.24810196e-02, -6.13001734e-02,
               1.84691139e-02,  1.03331879e-02,  1.97156938e-03,
              -6.55644909e-02, -1.49011612e-08,  8.26045573e-02,
               6.91462308e-04, -2.74802223e-02, -4.61718291e-02,
              -1.07698608e-02, -3.57095189e-02, -2.38384418e-02,
               1.12424225e-01, -8.22637901e-02, -5.36019728e-02,
               5.99750578e-02, -1.18175037e-01,  1.97170563e-02,
               1.50790423e-01, -2.53476501e-02,  8.01033676e-02,
               7.57712498e-02,  1.50741547e-01, -7.60149583e-03,
              -7.99755752

In [40]:
i = pursuit.abs_max_idx(h)
i

DeviceArray(55, dtype=int32)

In [41]:
indices = jnp.append(indices, i)
indices

DeviceArray([ 8, 25, 55], dtype=int32)

In [42]:
atom = Phi[:, i]
atom, atom.shape

(DeviceArray([ 0.21045762, -0.20144783, -0.13104105,  0.11561288,
              -0.42733213, -0.08856863, -0.08201517, -0.13951491,
               0.1759708 ,  0.07012922, -0.15406491, -0.10959424,
               0.0424513 , -0.06203075, -0.03460468,  0.39294937,
               0.02434387, -0.06398986, -0.07348871, -0.3941284 ,
               0.1143724 ,  0.01079335, -0.02536266,  0.25801423,
              -0.21557917, -0.26950052,  0.02961024,  0.2161889 ,
               0.01908326, -0.02356782, -0.02202035,  0.14522955],            dtype=float32),
 (32,))

In [43]:
b = subdict.T @ atom
b

DeviceArray([-1.3205409e-04, -1.7591405e-01], dtype=float32)

In [44]:
L = pursuit.gram_chol_update(L, b)
L, L.shape

(DeviceArray([[ 1.0000000e+00,  0.0000000e+00,  0.0000000e+00],
              [ 2.1398745e-01,  9.7683644e-01,  0.0000000e+00],
              [-1.3205409e-04, -1.8005656e-01,  9.8365623e-01]],            dtype=float32),
 (3, 3))

In [45]:
subdict = jnp.hstack((subdict, jnp.expand_dims(atom,1)))
subdict, subdict.shape

(DeviceArray([[-4.5203600e-02, -7.1421169e-02,  2.1045762e-01],
              [-6.0640719e-02, -1.0727050e-01, -2.0144783e-01],
              [-2.0123540e-01,  1.4660509e-01, -1.3104105e-01],
              [ 2.3969077e-01,  2.1629900e-01,  1.1561288e-01],
              [ 2.2783445e-02,  1.9570568e-01, -4.2733213e-01],
              [-1.7967227e-01, -6.9225013e-02, -8.8568628e-02],
              [ 1.6995370e-01,  1.6529614e-01, -8.2015172e-02],
              [ 3.8612276e-01,  8.8993214e-02, -1.3951491e-01],
              [ 3.3129993e-01,  2.3624465e-02,  1.7597079e-01],
              [ 2.1025042e-01,  1.7819779e-01,  7.0129216e-02],
              [ 9.8147519e-02, -1.0951436e-04, -1.5406491e-01],
              [-1.2776868e-01, -1.3379590e-01, -1.0959424e-01],
              [-2.9690307e-01, -4.4796020e-02,  4.2451300e-02],
              [ 1.8864553e-01, -1.5135485e-01, -6.2030748e-02],
              [-2.1374960e-01, -5.5924416e-02, -3.4604684e-02],
              [ 9.6212201e-02, -3.143549

In [46]:
p_I = p[indices]
p_I, p_I.shape

(DeviceArray([ 1.7127656 , -0.15379497,  0.42454934], dtype=float32), (3,))

In [47]:
x_I = la.solve_spd_chol(L, p_I)
x_I, x_I.shape

(DeviceArray([ 1.8160858 , -0.48262328,  0.3398889 ], dtype=float32), (3,))

In [48]:
r_new = y - subdict @ x_I
r_new, r_new.shape

(DeviceArray([ 1.6763806e-08, -1.4901161e-08, -2.9802322e-08,
               0.0000000e+00, -4.4703484e-08,  0.0000000e+00,
               0.0000000e+00, -5.9604645e-08,  0.0000000e+00,
               0.0000000e+00, -1.4901161e-08, -1.4901161e-08,
               0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
               2.9802322e-08,  0.0000000e+00, -1.4901161e-08,
              -7.4505806e-09, -2.9802322e-08,  1.1175871e-08,
               3.7252903e-09,  0.0000000e+00,  1.4901161e-08,
              -1.4901161e-08, -2.5145710e-08,  3.7252903e-09,
               0.0000000e+00,  2.9802322e-08,  0.0000000e+00,
              -1.4901161e-08,  1.4901161e-08], dtype=float32),
 (32,))

In [49]:
norm_r_new_sqr = r_new.T @ r_new
norm_r_new_sqr

DeviceArray(1.2001684e-14, dtype=float32)

In [58]:
from cr.sparse.pursuit import omp

In [59]:
solution =  omp.solve(Phi, y, K)

In [60]:
solution.x_I

DeviceArray([ 1.8160858 , -0.48262328,  0.3398889 ], dtype=float32)

In [61]:
solution.I

DeviceArray([ 8, 25, 55], dtype=int32)

In [62]:
solution.r

DeviceArray([ 1.6763806e-08, -1.4901161e-08, -2.9802322e-08,
              0.0000000e+00, -4.4703484e-08,  0.0000000e+00,
              0.0000000e+00, -5.9604645e-08,  0.0000000e+00,
              0.0000000e+00, -1.4901161e-08, -1.4901161e-08,
              0.0000000e+00,  0.0000000e+00,  0.0000000e+00,
              2.9802322e-08,  0.0000000e+00, -1.4901161e-08,
             -7.4505806e-09, -2.9802322e-08,  1.1175871e-08,
              3.7252903e-09,  0.0000000e+00,  1.4901161e-08,
             -1.4901161e-08, -2.5145710e-08,  3.7252903e-09,
              0.0000000e+00,  2.9802322e-08,  0.0000000e+00,
             -1.4901161e-08,  1.4901161e-08], dtype=float32)

In [63]:
solution.r_norm_sqr

DeviceArray(1.2001684e-14, dtype=float32)

In [64]:
solution.iterations

3

In [67]:
def time_solve():
    solution = omp.solve(Phi, y, K)
    solution.x_I.block_until_ready()
    solution.r.block_until_ready()
    solution.I.block_until_ready()
    solution.r_norm_sqr.block_until_ready()

In [68]:
%timeit time_solve()

14.3 ms ± 77 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [69]:
omp_solve  = jax.jit(omp.solve, static_argnums=(2))

In [71]:
sol = omp_solve(Phi, y, K)
sol.r_norm_sqr

DeviceArray(2.4747565e-14, dtype=float32)

In [73]:
def time_solve_jit():
    solution = omp_solve(Phi, y, K)
    solution.x_I.block_until_ready()
    solution.r.block_until_ready()
    solution.I.block_until_ready()
    solution.r_norm_sqr.block_until_ready()

In [74]:
%timeit time_solve_jit()

49.3 µs ± 225 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [75]:
14.3 * 1000 / 49.3

290.0608519269777