In [1]:
import jax.numpy as jnp
from jax import vmap, jit, random, lax
from jax import make_jaxpr

# Scan
1. Cumulative sum

In [2]:
@jit
def body_fn(carry, x):
    carry = carry + x
    return carry, carry

def cumsum(xs):
    return lax.scan(body_fn, init=0, xs=xs)

In [3]:
x = jnp.arange(10)
%time cumsum(x)
#%time vmap(cumsum)(jnp.stack((x, x)))

CPU times: user 21.1 ms, sys: 1.52 ms, total: 22.7 ms
Wall time: 22.3 ms


(Array(45, dtype=int32),
 Array([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32))

2. Accumulate

In [4]:
@jit
def body_fn(carry, x):
    carry = carry + 1
    return carry, carry

def accumulate(length):
    return lax.scan(body_fn, init=0, xs=None, length=length)

In [5]:
accumulate(10)
#vmap(accumulate)(jnp.array([10, 5]))

(Array(10, dtype=int32, weak_type=True),
 Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int32, weak_type=True))

3. Simulate

In [6]:
@jit
def body_fn(carry, x):
    carry = carry + 1
    return carry, carry

def simulate(init):
    return lax.scan(body_fn, init=init, xs=None, length=10)

In [7]:
#simulate(0)
vmap(simulate)(jnp.ones(2))

(Array([11., 11.], dtype=float32),
 Array([[ 2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.],
        [ 2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.]], dtype=float32))

4. Simulate with stop condition

In [8]:
def is_done(x):
    return x >= 10#jnp.where

In [10]:
is_done(10)

True

In [26]:
def apply_fn(carry):
    return carry + 1

def simulate_fht(carry, x):
    carry = lax.cond(carry < 10, apply_fn, lambda x: x, carry)
    #carry = jnp.where(is_done(carry), carry, carry + 1)
    return carry, None

In [27]:
%timeit lax.scan(simulate_fht, init=0, xs=None, length=int(1e6))

1.04 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


5. Other examples

In [9]:
@jit
def scan_fun(x, y):
    # A loop using scan, note how I access y within the body
    n = x.shape[0]

    def body(carry, x):
        curr, i = carry
        return (curr + x * y[i], i+1), None

    (res, _), _ = lax.scan(body, (0., 0), x)
    return res / n


In [10]:
n = 50
m = 150
x = jnp.arange(n, dtype=jnp.float32)
y = jnp.arange(m, step=2., dtype=jnp.float32)

In [11]:
x

Array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
       39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49.],      dtype=float32)

In [12]:
y

Array([  0.,   2.,   4.,   6.,   8.,  10.,  12.,  14.,  16.,  18.,  20.,
        22.,  24.,  26.,  28.,  30.,  32.,  34.,  36.,  38.,  40.,  42.,
        44.,  46.,  48.,  50.,  52.,  54.,  56.,  58.,  60.,  62.,  64.,
        66.,  68.,  70.,  72.,  74.,  76.,  78.,  80.,  82.,  84.,  86.,
        88.,  90.,  92.,  94.,  96.,  98., 100., 102., 104., 106., 108.,
       110., 112., 114., 116., 118., 120., 122., 124., 126., 128., 130.,
       132., 134., 136., 138., 140., 142., 144., 146., 148.],      dtype=float32)

In [13]:
scan_fun(x, y)

Array(1617., dtype=float32)

In [2]:
# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
#make_jaxpr(func11)(jnp.arange(16), 5.)
func11(jnp.arange(3), 1.)

(Array(6., dtype=float32), Array([0., 1., 3.], dtype=float32, weak_type=True))