In [1]:
import jax 
import jax.numpy as jnp

def bad_update(arr):
    arr[0]=999
    return arr

def good_update(arr):
    return arr.at[0].set(999)

arr = jnp.zeros(5)
print("og", arr)
new_arr = good_update(arr)
print("new", new_arr)
print("og is safe", arr)

og [0. 0. 0. 0. 0.]
new [999.   0.   0.   0.   0.]
og is safe [0. 0. 0. 0. 0.]


In [2]:
import jax.numpy as jnp

def interact_with_env(state):
    return state + 1

states = jnp.array([0,10,20,30,40])

batch_interact = jax.vmap(interact_with_env)

next_states = batch_interact(states)
print(next_states)

[ 1 11 21 31 41]


In [3]:
import jax

def step_fn(carry, x):
    new_carry = carry + 1
    output_for_storage = new_carry * 2
    return new_carry, output_for_storage

start_val = 0
final_val, hisotry = jax.lax.scan(step_fn, start_val, None, length=5)

print("final counter", final_val)
print("history for computation", hisotry)

final counter 5
history for computation [ 2  4  6  8 10]


In [4]:
import jax
import jax.numpy as jnp
from typing import NamedTuple

# --- 1. The Container (The Carry) ---
# We use a NamedTuple so JAX treats it as a PyTree automatically.
class SimulationState(NamedTuple):
    position: jnp.ndarray
    velocity: jnp.ndarray
    key: jax.Array  # <--- The Critical Component

# --- 2. The Step Function (The Logic) ---
def step_fn(carry: SimulationState, x):
    # A. UNPACK
    current_pos = carry.position
    current_vel = carry.velocity
    current_key = carry.key

    # B. THE SPLIT RITUAL
    # We need 2 random events this step:
    # 1. The Agent's decision (move left or right?)
    # 2. The Environment's noise (wind gusts)
    # We usually split into (N + 1) keys: N for use, 1 to keep.
    
    # next_key: The one we SAVE for the next loop
    # agent_key: The one we USE for the agent
    # env_key:   The one we USE for the wind
    next_key, agent_key, env_key = jax.random.split(current_key, 3)

    # C. USE THE SUBKEYS
    # Event 1: Agent decides acceleration (Normal distribution)
    # Note: We use 'agent_key' here.
    acceleration = jax.random.normal(agent_key, shape=(1,))
    
    # Event 2: Wind pushes the agent (Uniform distribution)
    # Note: We use 'env_key' here. NEVER reuse agent_key!
    wind = jax.random.uniform(env_key, shape=(1,), minval=-0.5, maxval=0.5)

    # D. PHYSICS (Update state)
    new_vel = current_vel * 0.9 + acceleration + wind  # 0.9 is friction
    new_pos = current_pos + new_vel

    # E. PACK IT UP
    # Crucial: We replace 'key' with 'next_key'.
    # If we put 'current_key' back in, we would get exact repeats every step!
    new_carry = SimulationState(
        position=new_pos,
        velocity=new_vel,
        key=next_key 
    )

    # F. OUTPUT LOGS
    # We log the position to track the path
    return new_carry, new_pos

# --- 3. Run It ---
def main():
    # Initialize
    # We start with a master seed
    master_key = jax.random.PRNGKey(42)
    
    init_state = SimulationState(
        position=jnp.array([0.0]),
        velocity=jnp.array([0.0]),
        key=master_key
    )

    # Run for 10 steps
    final_state, path_history = jax.lax.scan(step_fn, init_state, None, length=10)

    # Print results
    print("Path History (The 'y' output):")
    print(path_history.squeeze()) # Squeeze to remove extra dim for clarity
    
    print("\nFinal Key (The 'new_carry'):")
    print(final_state.key)
    
    # Verification: 
    # If you run this script again, you get the EXACT same numbers.
    # If you change the seed to 43, you get totally different numbers.

if __name__ == "__main__":
    main()

Path History (The 'y' output):
[ 0.77300465  0.7917458   0.5242299   0.74356395  2.1912372   4.617793
  5.8071165   6.8608837   8.583112   10.373187  ]

Final Key (The 'new_carry'):
[1825841970 3512247751]
