In [85]:
from time import time 

import jax.numpy as jnp
from jax import vmap, jit, random, lax
import matplotlib.pyplot as plt

In [118]:
# set a random key
key = random.PRNGKey(seed=1)

In [119]:
# parameters
d = 3
dt = 0.01
K = 1000
x_init = 0 * jnp.ones((K, d))
r_lim = 1
#x_init

In [111]:
# stop condition
cond_fn = lambda x, r_lim: jnp.where(jnp.linalg.norm(x) >= r_lim, True, False)

1. sample brownian motion fht with a while loop in the time-steps

In [70]:
def brownian_fht(dt, x_init, key):

    # dimension
    d = x_init.shape[0]
    
    # initialize trajectory
    xt = x_init

    # trajectory list
    x = [xt]
    while not cond_fn(xt, r_lim):

        # brownian increments
        key, subkey = random.split(key)
        dbt = jnp.sqrt(dt) * random.normal(key, (d,))                                                  

        # update
        xt += dbt

        # save position
        x += [xt]
        
    return jnp.stack(x)

#brownian_fht = vmap(brownian_fht, in_axes=(None, 0, 0), out_axes=1)
#sample_finite = jit(sample_finite, static_argnums=(0,))

In [71]:
x = brownian_fht(dt, x_init[0], key)#; x

2. sample brownian motion fht using lax

In [138]:
def cond_fn(val):
    _, xt, _ = val
    return jnp.where(jnp.linalg.norm(xt) <= 1, True, False)

def body_fn(val):
    step, xt, key = val
    d = xt.shape[0]
    key, subkey = random.split(key)
    dbt = jnp.sqrt(dt) * random.normal(key, (d,))
    step = step + 1
    return step, xt + dbt, key

def brownian_fht(x_init, key):
    return lax.while_loop(cond_fn, body_fn, (0, x_init, key))

brownian_fht = vmap(brownian_fht, in_axes=(0, 0), out_axes=0)
brownian_fht = jit(brownian_fht)

In [139]:
key = random.PRNGKey(seed=1)
subkeys = random.split(key, K)
#steps, x_fht, _ = brownian_fht(x_init[0], key)
steps, x_fht, _ = brownian_fht(x_init, subkeys)