In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random


In [43]:
import numpy as np

In [27]:
import time

In [14]:
key1 = random.PRNGKey(0)
x1 = random.normal(key1, (10,))
print(x1)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [15]:
key2 = random.PRNGKey(3)
x2 = random.normal(key2, (10,))
print(x2)

[ 0.18600447 -0.1762959   0.4396897  -1.3058784   1.7010686  -1.8713968
 -0.19887435  1.2654579  -1.0456703  -1.4045582 ]


In [17]:
key3 = random.PRNGKey(0)
x3 = random.normal(key3, (10,))
print(x3)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [18]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


In [19]:
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()


4.24 ms ± 379 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()


527 µs ± 44.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [21]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))


[0.25       0.19661197 0.10499357]


In [22]:
weight=jnp.array([0,1])
batches=jnp.array([[1,1],[1,2],[0,1]])

In [49]:
weight=jnp.array(np.random.randn(1000,2))
batches=jnp.array(np.random.randn(2555,15,2))

In [50]:
print(weight.shape,batches.shape)

(1000, 2) (2555, 15, 2)


In [51]:
def single_element(weight,input):
    return jnp.dot(weight,input)

In [64]:
@jit
def matmul_1(weight,input):
    print("matmul")
    return jnp.dot(weight,input.T)

In [66]:
print(matmul_1(weight,batches[0]).shape)

(1000, 15)


In [35]:
single_element(weight,batches[0])

Array(1, dtype=int32)

In [36]:

final_ans=[]
t1=time.time()
for arr in batches:
    final_ans.append(single_element(weight,arr))
t2=time.time()
final_ans=jnp.array(final_ans)
print(final_ans)

[1 2 1]


In [40]:
vmap_batch_mul=vmap(single_element,in_axes=(0,None))
vmap_batch_mul2=vmap(single_element,in_axes=(None,0))

In [41]:
print(vmap_batch_mul(batches,weight))
print(vmap_batch_mul2(weight,batches))

[1 2 1]
[1 2 1]


In [67]:
vmap_matmul=vmap(matmul_1,in_axes=(0,None),out_axes=0)
print(vmap_matmul(batches,weight).shape)

matmul
(2555, 15, 1000)
