In [39]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
from jax import numpy as jnp

In [41]:
from jax import random
key = random.PRNGKey(42)

D = 2
M = 8
R = 10
W = random.normal(key, shape=(D,M,R))

In [42]:
W[0].shape

(8, 10)

In [43]:
%timeit W[0] @ W[0].T

419 µs ± 2.48 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [44]:
%timeit jnp.dot(W[0], W[0].T)

552 µs ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [45]:
reg = jnp.ones((R,R))

W = random.normal(key, shape=(D,M,R))

reg *= jnp.dot(W[1].T, W[1])

In [46]:
reg 

Array([[10.47288   ,  2.1634696 ,  1.8259413 ,  2.9988558 , -4.0781074 ,
        -1.8218789 , -0.7573269 ,  0.42361394, -0.7675856 ,  2.695262  ],
       [ 2.1634696 ,  4.092049  ,  2.090363  ,  0.49094906, -4.6681066 ,
        -1.7279721 , -0.7065751 , -1.2958506 , -2.7669685 , -1.8912005 ],
       [ 1.8259413 ,  2.090363  ,  7.2151246 , -2.1220863 ,  2.9141078 ,
        -3.8662224 ,  1.044639  , -0.3260081 , -5.878887  ,  1.1603943 ],
       [ 2.9988558 ,  0.49094906, -2.1220863 ,  7.8024697 , -3.1010823 ,
        -2.6626344 , -2.8245208 ,  0.47037813, -1.0786488 ,  2.2232292 ],
       [-4.0781074 , -4.6681066 ,  2.9141078 , -3.1010823 , 17.220644  ,
         1.6805526 ,  3.5457447 ,  2.077627  ,  0.67723554,  4.780664  ],
       [-1.8218789 , -1.7279721 , -3.8662224 , -2.6626344 ,  1.6805526 ,
         6.0490055 , -0.11766616, -0.35652265,  5.9354253 , -2.1349566 ],
       [-0.7573269 , -0.7065751 ,  1.044639  , -2.8245208 ,  3.5457447 ,
        -0.11766616, 12.093737  ,  1.1507288 

In [47]:
from tkm.kron import dotkron
from jax import jit

N,D = 5300, 2
M = 8
R = 10

Mati = random.normal(key,(N,M))
Matd = random.normal(key,(N,R))


In [48]:
# %time C = dotkron(Mati,Matd)

In [49]:
from jax import jit 

dotkron_compiled = jit(dotkron)

In [50]:
# %time dotkron_compiled(Mati,Matd)

In [51]:
from tkm.kron import vmap_dotkron

In [52]:
%time c = dotkron(Mati,Matd)

CPU times: total: 8.5 s
Wall time: 8.62 s


In [53]:
%timeit C = vmap_dotkron(Mati,Matd)

549 µs ± 26.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [54]:
C.shape

(5300, 80)

In [55]:
vmap_dotkron_compiled = jit(vmap_dotkron)


In [56]:
%timeit C_ = vmap_dotkron_compiled(Mati,Matd)

545 µs ± 26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [57]:
jnp.equal(c,C)

Array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]], dtype=bool)

In [58]:
jnp.equal(c,C_).sum()/C_.size

Array(1., dtype=float32)

In [59]:
x = random.normal(key, (5300,2))

In [60]:
jnp.power(x[:,0, None], jnp.arange(M)).shape

(5300, 8)

In [61]:
from tkm.features import polynomial
from functools import partial
polynomial_compiled = jit(partial(polynomial, M=M))

In [62]:
%timeit polynomial(x[:,0],M)
%timeit polynomial_compiled(x[:,0])

1.11 ms ± 100 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
684 µs ± 39 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [63]:
from tkm.features import polynomial_
polynomial_compiled_ = jit(polynomial_)

In [64]:
%timeit polynomial_(x[:,0],jnp.arange(M))
%timeit polynomial_compiled_(x[:,0],jnp.arange(M))

1.02 ms ± 8.64 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
789 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [65]:
from tkm.features import polynomial_vmap

%timeit pv = polynomial_vmap(x[:,0],jnp.arange(M))

2.16 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [66]:
pv = polynomial_vmap(x[:,0],jnp.arange(M))
pv.shape

(5300, 8)

In [67]:
polynomial_vmap_compiled = jit(partial(polynomial_vmap,rangeM=jnp.arange(M)))

In [68]:
%timeit polynomial_vmap_compiled(x[:,0])

643 µs ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [79]:
from tkm.features import compile_feature_map

poly = compile_feature_map(feature_map=polynomial,M=M)

In [80]:
%timeit poly(x[:,0])

831 µs ± 26.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# predict

In [74]:
X = x
N, D = X.shape
M = W[0].shape[0]
score = jnp.ones((N,1))

for d in range(D): #TODO JAX vmap?
    score *= jnp.dot(
        polynomial(X[:,d],M) , 
        W[d]
    )

score = jnp.sum(score, 1)
score.shape

(5300,)

In [76]:
from jax import vmap
from tkm.features import compile_feature_map

poly = compile_feature_map(feature_map=polynomial, M=8)

s = vmap(lambda x,y :jnp.dot(poly(x),y), (1,0),)(X, W)
s.shape

(2, 5300, 10)

In [None]:
W.shape

(2, 8, 10)

In [81]:
sc = s.prod(0)

In [82]:
sco = sc.sum(1)

In [83]:
jnp.equal(score,sco).sum() / score.size

Array(1., dtype=float32)

In [None]:
# from tkm.model import predict