In [1]:
import jax.numpy as jnp
from jax import grad, jit, lax, random, vmap
import numpy as np

### key

In [2]:
# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

### matrix-matrix product using vmap

In [3]:
# vectors
u = jnp.array([1, 2, 3])
v = jnp.array([1, 0, -1])

# matrix
a = jnp.array([[0, 1, 0], [1, 1, 1]])
b = jnp.array([[0, 1], [1, 0], [0, 0]])
c = 2 * jnp.eye(2)
d = jnp.arange(1, 4+1).reshape(2, 2)
b.shape

(3, 2)

In [4]:
# vector product
jnp.vdot(u, v)

Array(-2, dtype=int32)

In [5]:
# matrix vector product
mv = vmap(jnp.vdot, in_axes=(0, None), out_axes=0)
mv(a, u)

Array([2, 6], dtype=int32)

In [6]:
# matrix vector product having more arguments
def vdot_custom(dt, u, v):
    return jnp.vdot(u, v)

mv = vmap(vdot_custom, in_axes=(None, 0, None), out_axes=0)
mv(0.01, a, u)

Array([2, 6], dtype=int32)

In [7]:
# matrix vector product
mv = vmap(jnp.vdot, in_axes=(1, None), out_axes=0)
mv(b, u)

Array([2, 1], dtype=int32)

In [8]:
# matrix matrix product
mm = vmap(mv, (None, 1), 1)
mm(c, d)

Array([[2., 4.],
       [6., 8.]], dtype=float32)

In [9]:
a = jnp.array([[1, 3], [23, -5]])
b = jnp.array([[11, 7], [19, 13]])
a, b

(Array([[ 1,  3],
        [23, -5]], dtype=int32),
 Array([[11,  7],
        [19, 13]], dtype=int32))

### batch matrix-matrix product using vmap and matmul

In [10]:
K = 10
a_batch = jnp.tile(a, (K, 1, 1))
b_batch = jnp.tile(b, (K, 1, 1))

In [11]:
# batch matrix matrix product
mm_batch = vmap(jnp.matmul, in_axes=(0, 0), out_axes=0)
mm_batch(a_batch, b_batch).shape

(10, 2, 2)

### batch norm using vmap

In [12]:
u, jnp.linalg.norm(u)

(Array([1, 2, 3], dtype=int32), Array(3.7416575, dtype=float32))

In [13]:
# batched vector
K = 10
u_batch = random.normal(key, (K, 1))

In [14]:
norm_u_batch = vmap(jnp.linalg.norm, in_axes=0, out_axes=0)(u_batch)
norm_u = jnp.linalg.norm(u_batch, axis=1)
norm_u_batch, norm_u

(Array([0.690805  , 0.48744103, 1.155789  , 0.12108463, 0.19598432,
        0.5078766 , 0.91568655, 1.70968   , 0.36749417, 0.14315689],      dtype=float32),
 Array([0.690805  , 0.48744103, 1.155789  , 0.12108463, 0.19598432,
        0.5078766 , 0.91568655, 1.70968   , 0.36749417, 0.14315689],      dtype=float32))

### ReLu layer

In [15]:
def ReLU(x):
    """ Rectified Linear Unit (ReLU) activation function
    """
    return jnp.maximum(0, x)

def relu_layer(params, x):
    """ Simple ReLu layer for single sample
    """
    return ReLU(jnp.dot(params[0], x) + params[1])

def relu_layer_batch(params, x):
    """ Error prone batch version
    """
    return ReLU(jnp.dot(x, params[0].T) + params[1])

@jit
def relu_layer_vmap(params, x):
    """ vmap version of the ReLU layer 
    """
    return vmap(relu_layer, in_axes=(None, 0), out_axes=0)(params, x)

In [16]:
# batch size
batch_dim = 32

# input dimension
feature_dim = 100

# hidden layer dimension
hidden_dim = 512

# generate a batch of vectors to process
x = random.normal(key, (batch_dim, feature_dim))

# generate Gaussian weights and biases
params = [
    random.normal(key, (hidden_dim, feature_dim)),
    random.normal(key, (hidden_dim, )),
]

out1 = jnp.stack([relu_layer(params, x[i, :]) for i in range(x.shape[0])])
out2 = relu_layer_batch(params, x)
out3 = relu_layer_vmap(params, x)

In [17]:
out3

Array([[ 0.4935343,  0.       ,  0.       , ...,  0.       ,  0.       ,
         0.7005646],
       [ 0.       ,  0.       ,  0.       , ...,  0.       ,  0.       ,
         0.8145983],
       [ 0.       ,  0.       ,  0.       , ...,  0.       ,  0.       ,
         7.39748  ],
       ...,
       [ 6.1218243,  0.       ,  0.       , ...,  6.7441397,  0.       ,
        11.019762 ],
       [ 0.       ,  0.       , 21.238564 , ...,  0.       ,  0.       ,
         0.       ],
       [ 0.       ,  6.7150664,  0.       , ...,  0.       ,  0.       ,
        10.949038 ]], dtype=float32)

### double well

In [18]:
def doublewell(x):
    alpha = jnp.ones(10)
    return jnp.sum(alpha * (x**2 - 1) ** 2)

def doublewell_batch(x):
    alpha = jnp.ones(10)
    return jnp.sum(alpha * (x ** 2 -1) **2, axis=1)

@jit
def doublewell_vmap(x):
    return vmap(doublewell, in_axes=0, out_axes=0)(x)
    #return jit(vmap(doublewell)(x))

In [19]:
d = 10
K = 100
x = random.normal(key, (K, d))

out1 = jnp.stack([doublewell(x[i, :]) for i in range(x.shape[0])])
out2 = doublewell_batch(x)
out3 = doublewell_vmap(x)

In [20]:
(out2 == out3)

Array([ True,  True,  True,  True,  True,  True,  True, False,  True,
        True, False,  True,  True,  True,  True,  True,  True,  True,
       False,  True,  True,  True,  True,  True, False,  True,  True,
       False,  True, False,  True,  True,  True, False,  True,  True,
        True,  True, False,  True,  True, False,  True, False,  True,
       False, False, False, False, False,  True,  True,  True,  True,
       False, False,  True,  True,  True,  True, False, False,  True,
        True,  True, False,  True,  True,  True,  True, False,  True,
        True, False, False,  True, False,  True,  True,  True, False,
        True, False,  True,  True,  True, False, False,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True], dtype=bool)

### TracerBoolConversionError

#### if-else statements

In [6]:
def f(x):
    if x > 0:
        return True
    else:
        return False
    
def g(x):
    return jnp.where(x > 0, True, False)

In [7]:
# batch size
K = 10

x = random.normal(key, (K,))

In [13]:
#jit(f)(x)
jit(vmap(g))(x)

Array([ True, False, False,  True, False, False,  True,  True, False,
        True], dtype=bool)

#### While loops

In [21]:
cond_fn = lambda x: x <= 0
body_fn = lambda x: x + 1

def f(x):
    while cond_fn(x):
        x = body_fn(x)
    return x
    
def g(x):
    return lax.while_loop(cond_fn, body_fn, x)

In [24]:
#f(0), f(jnp.zeros(1))
#jit(f)(jnp.zeros(1))
#vmap(f)(np.arange(3.))
jit(vmap(g))(np.arange(3.))

Array([False,  True,  True], dtype=bool)

#### Where

In [36]:
N = 10

def f(x):
    return jnp.where(x <= 5)[0]

def g(x):
    return jnp.argwhere(x <= 5, size=N, fill_value=jnp.nan).squeeze()

x = jnp.arange(0, N+2, 2)
f(x), g(x)
#jit(f)(x)
jit(g)(x)

Array([ 0.,  1.,  2., nan, nan, nan, nan, nan, nan, nan],      dtype=float32, weak_type=True)