In [4]:
from time import time 

from jax import numpy as jnp, lax, random
from jax import vmap, jit

In [5]:
# set a random key
key = random.PRNGKey(seed=1)

In [6]:
# parameters
d = 1
dt = 0.01
K = int(1e3)
N = int(1e4)
x_init = -1 * jnp.ones((K, d))
subkeys = random.split(key, K)

1. sample brownian motion with a finite-time horizon vectorized also in the time steps

In [4]:
def brownian_finite(x_init, key):  

    # brownian increments                                                                
    db = jnp.sqrt(dt) * random.normal(key, (N,) + x_init.shape)

    # cumulative sum                                                       
    x = jnp.cumsum(db, axis=0)                                             

    # add B(0) = 0                                                       
    x = jnp.insert(x, 0, 0, axis=0)

    # set origin
    x += x_init

    return x

brownian_finite = jit(brownian_finite)

In [5]:
%time x = brownian_finite(x_init, key)

2024-01-22 16:00:52.871682: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %pad.167 = f32[10001,1000,3]{2,1,0} pad(f32[1,1000,3]{2,1,0} %broadcast.234, f32[] %constant.22), padding=0_10000x0_0x0_0, metadata={op_name="jit(brownian_finite)/jit(main)/scatter[update_jaxpr=None update_consts=() dimension_numbers=ScatterDimensionNumbers(update_window_dims=(1, 2), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) indices_are_sorted=False unique_indices=False mode=GatherScatterMode.FILL_OR_DROP]" source_file="/tmp/ipykernel_111961/2002951467.py" source_line=10}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/

CPU times: user 28 s, sys: 15.8 s, total: 43.9 s
Wall time: 11 s


In [6]:
x.shape

(10001, 1000, 3)

2. sample brownian motion with a finite-time horizon with a loop in the time-steps

In [7]:
def brownian_finite(x_init, key):

    # dimension
    d = x_init.shape[0]

    # initialize trajectory
    xt = x_init

    # trajectory list
    x = [xt]
    for i in range(N):

        # brownian increments
        key, subkey = random.split(key)    
        dbt = jnp.sqrt(dt) * random.normal(key, (d,))

        # update
        xt += dbt

        # save position
        x += [xt]
        
    return jnp.stack(x)

brownian_finite = vmap(brownian_finite, in_axes=(0, 0), out_axes=1)
#brownian_finite = jit(brownian_finite)#, static_argnums=(1))

In [8]:
x = brownian_finite(x_init, subkeys)

In [9]:
def apply_fn(state):
    
    # unpack values
    xt, key = state
    d = xt.shape[0]

    # brownian motion
    key, subkey = random.split(key)
    eta = random.normal(key, (d,))     

    # sde update
    xt = xt + jnp.sqrt(dt) * eta
    
    return xt, key

3. sample brownian motion with a finite-time horizon with foriloop

In [10]:
def body_fn_loop(i, state):
    return apply_fn(state)
    
def brownian_finite(x_init, key):
    return lax.fori_loop(0, N, body_fn_loop, (x_init, key))

brownian_finite = jit(vmap(brownian_finite))

In [11]:
%time x_T, _ = brownian_finite(x_init, subkeys)

CPU times: user 1.83 s, sys: 396 ms, total: 2.23 s
Wall time: 1.12 s


4. sample brownian motion with a finite-time horizon with scan

In [12]:
def body_fn_scan(state, x):
    state = apply_fn(state)
    return state, state[0]
    #return state, None

def simulate_brownian_finite(x_init, key):
    return lax.scan(body_fn_scan, init=(x_init, key), xs=None, length=N)

simulate_brownian_finite = jit(vmap(simulate_brownian_finite))

In [13]:
%time (x_T, key), x = simulate_brownian_finite(x_init, subkeys)

CPU times: user 2 s, sys: 382 ms, total: 2.38 s
Wall time: 1.19 s


In [13]:
#x_T