# Creating Agents for ARC Tasks

In JaxARC, an agent is simply a function that takes an observation and returns an action. The environment uses the `TimeStep` pattern where:
- `state` contains the internal environment state
- `timestep` contains observations, rewards, and termination flags
- Actions are sampled from the environment's action space

## Setup: Create an Environment

First, let's create a JaxARC environment. We'll use a simple configuration optimized for agent development.

In [1]:
from __future__ import annotations

import jax
import jax.random as jr

from jaxarc.configs import JaxArcConfig
from jaxarc.registration import available_task_ids, make
from jaxarc.utils.core import get_config

# Configure environment with visualization and logging disabled for speed
config_overrides = [
    "dataset=mini_arc",
    "action=raw",
    "wandb.enabled=false",
    "logging.log_operations=false",
    "logging.log_rewards=false",
    "visualization.enabled=false",
]

# Load configuration
hydra_config = get_config(overrides=config_overrides)
config = JaxArcConfig.from_hydra(hydra_config)

# Get a task from MiniARC
available_ids = available_task_ids("Mini", config=config, auto_download=False)
task_id = available_ids[0]

# Create environment
env, env_params = make(f"Mini-{task_id}", config=config)

print(f"Environment created for task: {task_id}")
print(f"Action space: {env.action_space(env_params)}")

[32m2025-11-03 22:12:27.955[0m | [34m[1mDEBUG   [0m | [36mjaxarc.utils.dataset_manager[0m:[36mvalidate_dataset[0m:[36m212[0m - [34m[1mDataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC[0m
[32m2025-11-03 22:12:27.955[0m | [34m[1mDEBUG   [0m | [36mjaxarc.utils.dataset_manager[0m:[36mensure_dataset_available[0m:[36m81[0m - [34m[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC[0m
[32m2025-11-03 22:12:27.958[0m | [1mINFO    [0m | [36mjaxarc.parsers.mini_arc[0m:[36m_validate_grid_constraints[0m:[36m104[0m - [1mMiniARC parser configured with optimal 5x5 grid constraints[0m
[32m2025-11-03 22:12:27.959[0m | [1mINFO    [0m | [36mjaxarc.parsers.mini_arc[0m:[36m_scan_available_tasks[0m:[36m131[0m - [1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)[0m
[32m2025-11-03 22:12:27.962[0m | [34m[1mDEBUG   [0m | [36mjaxarc.utils.dataset_manager[0m:[36mva

Environment created for task: Most_Common_color_l6ab0lf3xztbyxsu3p
Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')


## Understanding the Environment Loop

Before creating an agent, let's understand how to interact with the environment using the TimeStep API.

In [2]:
# Initialize environment
key = jr.PRNGKey(42)
state, timestep = env.reset(key, env_params)

print(f"Observation shape: {timestep.observation.shape}")
print(f"Initial reward: {timestep.reward}")
print(f"Episode terminated: {timestep.last()}")

# Get action space for sampling
action_space = env.action_space(env_params)

# Sample and take a single action
key, action_key = jr.split(key)
action = action_space.sample(action_key)

print(f"\nSampled action: {action}")

# Step the environment
state, timestep = env.step(state, action, env_params)

print("\nAfter step:")
print(f"Reward: {timestep.reward}")
print(f"Episode done: {timestep.last()}")

Observation shape: (5, 5, 1)
Initial reward: 0.0
Episode terminated: False

Sampled action: {'operation': Array(6, dtype=int32), 'selection': Array([[0, 1, 0, 1, 0],
       [0, 1, 0, 1, 0],
       [0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1]], dtype=int32)}

After step:
Reward: -0.004999999888241291
Episode done: False


## Creating a Random Agent

The simplest agent samples random actions from the action space. This serves as a baseline for comparison with more sophisticated agents.

### Single Episode

In [3]:
# Run a single episode with a random agent
def run_random_episode(env, env_params, key, max_steps=100):
    """Run one episode with random actions."""
    # Reset environment
    reset_key, loop_key = jr.split(key)
    state, timestep = env.reset(reset_key, env_params)

    action_space = env.action_space(env_params)
    episode_reward = 0.0
    step_count = 0

    # Run episode
    while not timestep.last() and step_count < max_steps:
        # Sample random action
        loop_key, action_key = jr.split(loop_key)
        action = action_space.sample(action_key)

        # Step environment
        state, timestep = env.step(state, action, env_params)

        episode_reward += float(timestep.reward)
        step_count += 1

    return episode_reward, step_count


# Test the agent
key = jr.PRNGKey(123)
reward, steps = run_random_episode(env, env_params, key, max_steps=50)

print("Episode completed!")
print(f"Total reward: {reward:.2f}")
print(f"Steps taken: {steps}")
print(f"Average reward per step: {reward / steps:.3f}")

Episode completed!
Total reward: -1.30
Steps taken: 20
Average reward per step: -0.065


## JAX-Accelerated Agent with Scan

For high-performance, we can use `jax.lax.scan` to run multiple steps efficiently. This pattern is used in PureJaxRL and similar high-throughput RL frameworks.

In [4]:
def make_jax_agent(env, env_params, num_steps):
    """Create a JIT-compiled agent using scan for efficiency."""
    action_space = env.action_space(env_params)

    def run_agent(key):
        # Reset environment
        reset_key, loop_key = jr.split(key)
        state, timestep = env.reset(reset_key, env_params)

        def step_fn(carry, _):
            """One step of the agent."""
            state, timestep, key = carry

            # Split key for action sampling and next iteration
            key, action_key, next_key = jr.split(key, 3)

            # Handle episode termination with conditional reset
            def do_reset(_):
                return env.reset(key, env_params)

            def continue_episode(_):
                return state, timestep

            state, timestep = jax.lax.cond(
                timestep.last(), do_reset, continue_episode, None
            )

            # Sample action and step
            action = action_space.sample(action_key)
            new_state, new_timestep = env.step(state, action, env_params)

            return (new_state, new_timestep, next_key), new_timestep.reward

        # Run scan over num_steps
        (final_state, final_timestep, _), rewards = jax.lax.scan(
            step_fn, (state, timestep, loop_key), None, length=num_steps
        )

        return rewards, final_timestep

    # JIT compile the entire function
    return jax.jit(run_agent)


# Create and run JIT-compiled agent
jax_agent = make_jax_agent(env, env_params, num_steps=100)

# First run includes compilation time
print("Compiling agent (first run)...")
key = jr.PRNGKey(456)
rewards, final_timestep = jax_agent(key)

print("Agent compiled and executed!")
print(f"Total reward: {float(rewards.sum()):.2f}")
print(f"Mean reward per step: {float(rewards.mean()):.3f}")
print(f"Max reward: {float(rewards.max()):.2f}")
print(f"Final episode terminated: {final_timestep.last()}")

Compiling agent (first run)...
Agent compiled and executed!
Total reward: -2.78
Mean reward per step: -0.028
Max reward: 0.63
Final episode terminated: False


## Vectorized Agent: Multiple Parallel Environments

JAX's `vmap` allows us to run multiple environments in parallel with a single function call. This dramatically increases throughput.

In [5]:
def make_vectorized_agent(env, env_params, num_envs, num_steps):
    """Create a vectorized agent that runs multiple environments in parallel."""

    # Create the single-env agent
    single_agent = make_jax_agent(env, env_params, num_steps)

    # Vectorize it across multiple environments
    vectorized_agent = jax.vmap(single_agent)

    return vectorized_agent


# Create vectorized agent
num_envs = 16
num_steps = 100

print(f"Creating vectorized agent with {num_envs} parallel environments...")
vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)

# Generate keys for each environment
key = jr.PRNGKey(789)
env_keys = jr.split(key, num_envs)

# Run all environments in parallel
print(f"Running {num_envs} environments × {num_steps} steps...")
all_rewards, all_final_timesteps = vec_agent(env_keys)

# Analyze results
print(f"\nResults across {num_envs} environments:")
print(f"Mean total reward: {float(all_rewards.sum(axis=1).mean()):.2f}")
print(f"Best environment reward: {float(all_rewards.sum(axis=1).max()):.2f}")
print(f"Worst environment reward: {float(all_rewards.sum(axis=1).min()):.2f}")
print(f"Mean reward per step: {float(all_rewards.mean()):.3f}")

# Total steps executed
total_steps = num_envs * num_steps
print(f"\nTotal steps executed: {total_steps:,}")

Creating vectorized agent with 16 parallel environments...
Running 16 environments × 100 steps...

Results across 16 environments:
Mean total reward: -4.22
Best environment reward: -1.86
Worst environment reward: -9.14
Mean reward per step: -0.042

Total steps executed: 1,600


## Performance Benchmark

Let's measure the throughput of our vectorized agent to understand the performance benefits of JAX.

In [6]:
import time

# Benchmark configuration
num_envs = 64
num_steps = 256
num_runs = 3

print("Benchmarking vectorized agent...")
print(f"Configuration: {num_envs} envs × {num_steps} steps")
print("Warmup run (includes compilation)...\n")

# Create fresh agent
vec_agent = make_vectorized_agent(env, env_params, num_envs, num_steps)
key = jr.PRNGKey(999)
env_keys = jr.split(key, num_envs)

# Warmup run (includes compilation)
start = time.time()
rewards, _ = vec_agent(env_keys)
_ = rewards.block_until_ready()  # Wait for computation
warmup_time = time.time() - start

print(f"Warmup complete: {warmup_time:.2f}s (includes JIT compilation)")

# Timed runs
print(f"\nRunning {num_runs} timed iterations...")
times = []

for i in range(num_runs):
    key, subkey = jr.split(key)
    env_keys = jr.split(subkey, num_envs)

    start = time.time()
    rewards, _ = vec_agent(env_keys)
    _ = rewards.block_until_ready()
    elapsed = time.time() - start
    times.append(elapsed)

    print(f"  Run {i + 1}: {elapsed:.3f}s")

# Calculate statistics
mean_time = sum(times) / len(times)
total_steps = num_envs * num_steps
sps = total_steps / mean_time

print("\nPerformance Results:")
print(f"Mean execution time: {mean_time:.3f}s")
print(f"Steps per second (SPS): {sps:,.0f}")
print(f"Total steps per run: {total_steps:,}")

Benchmarking vectorized agent...
Configuration: 64 envs × 256 steps
Warmup run (includes compilation)...

Warmup complete: 2.36s (includes JIT compilation)

Running 3 timed iterations...
  Run 1: 0.085s
  Run 2: 0.085s
  Run 3: 0.087s

Performance Results:
Mean execution time: 0.086s
Steps per second (SPS): 190,766
Total steps per run: 16,384


## Building Your Own Agent

To create a learning agent (not just random):

1. **Define a neural network** using Flax, Haiku, or Equinox
2. **Collect trajectories** using the scan pattern shown above
3. **Compute losses** from rewards and observations
4. **Update parameters** using Optax optimizers
5. **Repeat** the training loop