In [1]:
import jax.numpy as jnp
from jax import lax


def cumsum(arr):

  def scan_fn(carry, x):
    return carry + x, carry + x

  _, result = lax.scan(scan_fn, 0, arr)
  return result


arr = jnp.array([1.0, 2.0, 3.0, 4.0])
cumulative_sum = cumsum(arr)  # Output: [1, 3, 6, 10]
print(cumulative_sum)


[ 1.  3.  6. 10.]


In [2]:
from jax import grad
import jax.numpy as jnp

def loss(arr):
    return jnp.sum(cumsum(arr))

grad(loss)(arr)

Array([4., 3., 2., 1.], dtype=float32)

In [3]:
from jax import make_jaxpr

make_jaxpr(grad(loss))(arr)

{ lambda ; a:f32[4]. let
    _:f32[] b:f32[4] = scan[
      _split_transpose=False
      jaxpr={ lambda ; c:f32[] d:f32[]. let
          e:f32[] = add c d
          f:f32[] = add c d
        in (e, f) }
      length=4
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] 0.0 a
    _:f32[] = reduce_sum[axes=(0,)] b
    g:f32[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] 1.0
    _:f32[] h:f32[4] = scan[
      _split_transpose=False
      jaxpr={ lambda ; i:f32[] j:f32[]. let
          k:f32[] = add_any j i
          l:f32[] = add_any j i
        in (k, l) }
      length=4
      linear=(True, True)
      num_carry=1
      num_consts=0
      reverse=True
      unroll=1
    ] 0.0 g
  in (h,) }

In [4]:
def pick_0(arr):
    return arr[0]

make_jaxpr(pick_0)(jnp.array([
    [1, 2,],
    [3, 4,],
]))

{ lambda ; a:i32[2,2]. let
    b:i32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=None] a
    c:i32[2] = squeeze[dimensions=(0,)] b
  in (c,) }

In [5]:

def cumsum(arr):

  def scan_fn(carry, x):
    return carry + x, carry + x

  _, result = lax.scan(scan_fn, jnp.array([0.0] * 3), arr)
  return result


arr = jnp.stack(
    [jnp.array([1.0, 1.0, 1.0]) * i for i in range(4)])
# Eval `arr`
print(arr)

cumulative_sum = cumsum(arr)

print(cumulative_sum)


[[0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]]
[[0. 0. 0.]
 [1. 1. 1.]
 [3. 3. 3.]
 [6. 6. 6.]]
