In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import numpy as jnp

In [6]:
from jax import random
key = random.PRNGKey(42)

D = 2
M = 8
R = 10
W = random.normal(key, shape=(D,M,R))

In [7]:
W[0].shape

(8, 10)

In [8]:
%timeit W[0] @ W[0].T

2.3 ms ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%timeit jnp.dot(W[0], W[0].T)

2.37 ms ± 79.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
reg = jnp.ones((R,R))

W = random.normal(key, shape=(D,M,R))

reg *= jnp.dot(W[1].T, W[1])

In [24]:
reg 

DeviceArray([[10.472877  ,  2.1634684 ,  1.8259398 ,  2.998856  ,
              -4.078107  , -1.8218795 , -0.7573281 ,  0.42361373,
              -0.7675868 ,  2.695261  ],
             [ 2.1634684 ,  4.092049  ,  2.0903633 ,  0.4909489 ,
              -4.6681037 , -1.7279721 , -0.706576  , -1.2958506 ,
              -2.7669683 , -1.8912007 ],
             [ 1.8259398 ,  2.0903633 ,  7.2151275 , -2.1220875 ,
               2.914107  , -3.8662233 ,  1.044636  , -0.3260084 ,
              -5.878888  ,  1.1603925 ],
             [ 2.998856  ,  0.4909489 , -2.1220875 ,  7.8024697 ,
              -3.1010823 , -2.6626341 , -2.8245208 ,  0.47037816,
              -1.0786489 ,  2.2232282 ],
             [-4.078107  , -4.6681037 ,  2.914107  , -3.1010823 ,
              17.220627  ,  1.680552  ,  3.5457404 ,  2.077625  ,
               0.6772353 ,  4.780658  ],
             [-1.8218795 , -1.7279721 , -3.8662233 , -2.6626341 ,
               1.680552  ,  6.0490055 , -0.11766604, -0.35652307,
   

In [8]:
from tkm.utils import dotkron
from jax import jit

N,D = 5300, 2
M = 8
R = 10

Mati = random.normal(key,(N,M))
Matd = random.normal(key,(N,R))


In [27]:
%time C = dotkron(Mati,Matd)

CPU times: user 11.3 s, sys: 145 ms, total: 11.5 s
Wall time: 11.6 s


In [6]:
from jax import jit 

dotkron_compiled = jit(dotkron)

In [29]:
%time dotkron_compiled(Mati,Matd)

In [10]:
from tkm.utils import vmap_dotkron

In [16]:
%time c = dotkron(Mati,Matd)

CPU times: user 10.5 s, sys: 110 ms, total: 10.6 s
Wall time: 10.7 s


In [14]:
%timeit C = vmap_dotkron(Mati,Matd)

639 µs ± 5.18 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
C.shape

(5300, 80)

In [12]:
vmap_dotkron_compiled = jit(vmap_dotkron)


In [13]:
%timeit C_ = vmap_dotkron_compiled(Mati,Matd)

126 µs ± 5.28 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
jnp.equal(c,C)

DeviceArray([[ True,  True,  True, ...,  True,  True,  True],
             [ True,  True,  True, ...,  True,  True,  True],
             [ True,  True,  True, ...,  True,  True,  True],
             ...,
             [ True,  True,  True, ...,  True,  True,  True],
             [ True,  True,  True, ...,  True,  True,  True],
             [ True,  True,  True, ...,  True,  True,  True]], dtype=bool)

In [19]:
jnp.equal(c,C_).sum()/C_.size

DeviceArray(1., dtype=float32)

In [21]:
x = random.normal(key, (5300,2))

In [25]:
jnp.power(x[:,0, None], jnp.arange(M)).shape

(5300, 8)

In [29]:
from tkm.features import polynomial
from functools import partial
polynomial_compiled = jit(partial(polynomial, M=M))

In [65]:
%timeit polynomial(x[:,0],M)
%timeit polynomial_compiled(x[:,0])

2.34 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.59 ms ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [36]:
from tkm.features import polynomial_
polynomial_compiled_ = jit(polynomial_)

2.39 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [66]:
%timeit polynomial_(x[:,0],jnp.arange(M))
%timeit polynomial_compiled_(x[:,0],jnp.arange(M))

2.41 ms ± 95 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.23 ms ± 328 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [53]:
from tkm.features import polynomial_vmap

%timeit pv = polynomial_vmap(x[:,0],jnp.arange(M))

2.68 ms ± 93.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [56]:
pv = polynomial_vmap(x[:,0],jnp.arange(M))
pv.shape

(5300, 8)

In [58]:
polynomial_vmap_compiled = jit(partial(polynomial_vmap,rangeM=jnp.arange(M)))

In [59]:
%timeit polynomial_vmap_compiled(x[:,0])

1.59 ms ± 8.42 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [69]:
from tkm.features import compile_feature_map

poly = compile_feature_map(M=M)

In [70]:
%timeit poly(x[:,0])

1.59 ms ± 21.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# predict

In [76]:
X = x
N, D = X.shape
M = W[0].shape[0]
score = jnp.ones((N,1))

for d in range(D): #TODO JAX vmap?
    score *= jnp.dot(
        polynomial(X[:,d],M) , 
        W[d]
    )

score = jnp.sum(score, 1)
score.shape

(5300,)

In [90]:
from jax import vmap
from tkm.features import compile_feature_map

poly = compile_feature_map(M=8)

s = vmap(lambda x,y :jnp.dot(poly(x),y), (1,0),)(X, W)
s.shape

(2, 5300, 10)

In [78]:
W.shape

(2, 8, 10)

In [96]:
sc = s.prod(0)

In [97]:
sco = sc.sum(1)

In [98]:
jnp.equal(score,sco).sum() / score.size

DeviceArray(1., dtype=float32)

In [None]:
# from tkm.model import predict