In [15]:
from time import time 

import jax.numpy as jnp
from jax import vmap, jit, random, lax
import matplotlib.pyplot as plt

In [48]:
# set a random key
key = random.PRNGKey(seed=1)
key, key.shape

(Array([0, 1], dtype=uint32), (2,))

In [49]:
# parameters
d = 2
dt = 0.01
K = 100
N = 1000
x_init = 1 * jnp.ones((K, d))
#x_init

1. sample brownian motion with a finite-time horizon vectorized also in the time steps

In [50]:
def brownian_finite(x_init, key):  

    # brownian increments                                                                
    db = jnp.sqrt(dt) * random.normal(key, (N,) + x_init.shape)

    # cumulative sum                                                       
    x = jnp.cumsum(db, axis=0)                                             

    # add B(0) = 0                                                       
    x = jnp.insert(x, 0, 0, axis=0)

    # set origin
    x += x_init

    return x

brownian_finite = jit(brownian_finite)

In [54]:
x = brownian_finite(x_init, key)

In [55]:
x.shape

(1001, 100, 2)

2. sample brownian motion with a finite-time horizon with a loop in the time-steps

In [56]:
def brownian_finite(x_init, key):

    # dimension
    d = x_init.shape[0]

    # initialize trajectory
    xt = x_init

    # trajectory list
    x = [xt]
    for i in range(N):

        # brownian increments
        key, subkey = random.split(key)    
        dbt = jnp.sqrt(dt) * random.normal(key, (d,))

        # update
        xt += dbt

        # save position
        x += [xt]
        
    return jnp.stack(x)

brownian_finite = vmap(brownian_finite, in_axes=(0, 0), out_axes=1)
#brownian_finite = jit(brownian_finite)#, static_argnums=(1))

In [57]:
subkeys = random.split(key, K)
x = brownian_finite(x_init, subkeys)

In [61]:
#subkeys.shape

3. sample brownian motion with a finite-time horizon with foriloop

In [71]:
def body_fn(i, val):
    xt, key = val
    d = xt.shape[0]
    key, subkey = random.split(key)
    dbt = jnp.sqrt(dt) * random.normal(key, (d,))                                                  
    return xt + dbt, key

def brownian_finite(x_init, key):
    return lax.fori_loop(0, N, body_fn, (x_init, key))

brownian_finite = vmap(brownian_finite, in_axes=(0, 0), out_axes=0)
brownian_finite = jit(brownian_finite)

In [75]:
x_T, _ = brownian_finite(x_init, subkeys)