# Using Wrappers

Wrappers transform environments without modifying their core logic. JaxARC provides wrappers for:

- **Action transformation** - Convert between different action formats
- **Observation augmentation** - Add channels to observations
- **Action space flattening** - Simplify complex action spaces

Wrappers follow the delegation pattern:
1. **Core environment** handles only `Action` objects (mask-based selections)
2. **Wrappers** convert user-friendly formats to/from masks
3. **Composable** - stack multiple wrappers easily

## Setup: Base Environment

Let's start with a base environment that uses mask-based actions.

In [1]:
from __future__ import annotations

import jax.random as jr

from jaxarc.configs import JaxArcConfig
from jaxarc.registration import make
from jaxarc.utils.core import get_config

# Setup environment with minimal logging
config_overrides = [
    "dataset=mini_arc",
    "action=raw",
    "wandb.enabled=false",
    "logging.log_operations=false",
    "logging.log_rewards=false",
    "visualization.enabled=false",
]

hydra_config = get_config(overrides=config_overrides)
config = JaxArcConfig.from_hydra(hydra_config)

# Create base environment
env, env_params = make("Mini-Most_Common_color_l6ab0lf3xztbyxsu3p", config=config)

# Check the action space
action_space = env.action_space(env_params)
print(f"Base action space: {action_space}")
print(f"Action keys: {list(action_space.spaces.keys())}")

[32m2025-11-18 22:47:09.240[0m | [34m[1mDEBUG   [0m | [36mjaxarc.utils.dataset_manager[0m:[36mvalidate_dataset[0m:[36m212[0m - [34m[1mDataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC[0m
[32m2025-11-18 22:47:09.240[0m | [34m[1mDEBUG   [0m | [36mjaxarc.utils.dataset_manager[0m:[36mensure_dataset_available[0m:[36m81[0m - [34m[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC[0m
[32m2025-11-18 22:47:09.243[0m | [1mINFO    [0m | [36mjaxarc.parsers.mini_arc[0m:[36m_validate_grid_constraints[0m:[36m104[0m - [1mMiniARC parser configured with optimal 5x5 grid constraints[0m
[32m2025-11-18 22:47:09.245[0m | [1mINFO    [0m | [36mjaxarc.parsers.mini_arc[0m:[36m_scan_available_tasks[0m:[36m131[0m - [1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)[0m
[32m2025-11-18 22:47:09.246[0m | [34m[1mDEBUG   [0m | [36mjaxarc.parsers.mini_arc[0m:[36m_load_t

Base action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), selection=MultiDiscreteSpace(num_values=[Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32), Array([2, 2, 2, 2, 2], dtype=int32)], dtype=int32, name='selection_mask')}, name='arc_action')
Action keys: ['operation', 'selection']


## Action Wrappers

Action wrappers convert user-friendly action formats into the mask-based `Action` objects that the core environment expects.

### 1. PointActionWrapper

Converts point-based actions `{"operation": op, "row": r, "col": c}` to mask selections.

In [2]:
from jaxarc.wrappers import PointActionWrapper

# Wrap environment
point_env = PointActionWrapper(env)

# Check new action space
point_action_space = point_env.action_space(env_params)
print(f"Point action space: {point_action_space}")
print(f"Action keys: {list(point_action_space.spaces.keys())}")

# Reset and take a point action
key = jr.PRNGKey(42)
state, timestep = point_env.reset(key, env_params)

print(f"\nInitial observation shape: {timestep.observation.shape}")

# Take a point action
action = {"operation": 2, "row": 2, "col": 3}
state, timestep = point_env.step(state, action, env_params)

print(f"Point action executed: {action}")
print(f"Reward: {float(timestep.reward):.3f}")

Point action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')
Action keys: ['operation', 'row', 'col']

Initial observation shape: (5, 5, 1)

Initial observation shape: (5, 5, 1)
Point action executed: {'operation': 2, 'row': 2, 'col': 3}
Reward: -0.005
Point action executed: {'operation': 2, 'row': 2, 'col': 3}
Reward: -0.005


### BboxActionWrapper

For operations that require a rectangular region (selection, copy, cut), use `BboxActionWrapper`:

In [3]:
from jaxarc.wrappers import BboxActionWrapper

# Wrap environment
bbox_env = BboxActionWrapper(env)

# Check action space
bbox_action_space = bbox_env.action_space(env_params)
print(f"Bbox action space: {bbox_action_space}")
print(f"Action keys: {list(bbox_action_space.spaces.keys())}")

# Reset and take a bbox action
key = jr.PRNGKey(43)
state, timestep = bbox_env.reset(key, env_params)

# Select a 2x3 region
action = {"operation": 0, "r1": 1, "c1": 1, "r2": 2, "c2": 3}
state, timestep = bbox_env.step(state, action, env_params)

print(f"\nBbox action executed: {action}")
print(f"Reward: {float(timestep.reward):.3f}")

Bbox action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), r1=DiscreteSpace(num_values=5, dtype=int32, name=''), c1=DiscreteSpace(num_values=5, dtype=int32, name=''), r2=DiscreteSpace(num_values=5, dtype=int32, name=''), c2=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='bbox_action')
Action keys: ['operation', 'r1', 'c1', 'r2', 'c2']

Bbox action executed: {'operation': 0, 'r1': 1, 'c1': 1, 'r2': 2, 'c2': 3}
Reward: -0.005


### FlattenActionWrapper

For RL algorithms that work with single discrete action spaces, `FlattenActionWrapper` flattens the composite action space:

In [4]:
from jaxarc.wrappers import FlattenActionWrapper

# Wrap environment
# Using PointActionWrapper here to reduce the action space size for demonstration
flat_env = FlattenActionWrapper(point_env)

# Check action space
flat_action_space = flat_env.action_space(env_params)
print(f"Flattened action space: {flat_action_space}")

# Reset and take a flattened action
key = jr.PRNGKey(44)
state, timestep = flat_env.reset(key, env_params)

# Sample a random action
action = flat_action_space.sample(key)
state, timestep = flat_env.step(state, action, env_params)

print(f"\nFlattened action: {action}")
print(f"Reward: {float(timestep.reward):.3f}")

Flattened action space: DiscreteSpace(num_values=875, dtype=int32, name='')

Flattened action: 752
Reward: -0.005


## Observation Wrappers

Observation wrappers add channels to the observation tensor, providing the agent with additional context.

### Basic Observation Wrappers

These wrappers add single-channel context:

In [5]:
from jaxarc.wrappers import (
    AnswerObservationWrapper,
    ClipboardObservationWrapper,
    InputGridObservationWrapper,
)

# Start fresh
key = jr.PRNGKey(45)
state, timestep = env.reset(key, env_params)
print(f"Base observation shape: {timestep.observation.shape}")

# Add input grid channel
env_with_input = InputGridObservationWrapper(env)
state, timestep = env_with_input.reset(key, env_params)
print(f"+ InputGridObservationWrapper: {timestep.observation.shape}")

# Add answer grid channel
env_with_answer = AnswerObservationWrapper(env_with_input)
state, timestep = env_with_answer.reset(key, env_params)
print(f"+ AnswerObservationWrapper: {timestep.observation.shape}")

# Add clipboard channel
env_with_clipboard = ClipboardObservationWrapper(env_with_answer)
state, timestep = env_with_clipboard.reset(key, env_params)
print(f"+ ClipboardObservationWrapper: {timestep.observation.shape}")

print(f"\nTotal channels so far: {timestep.observation.shape[-1]}")

Base observation shape: (5, 5, 1)
+ InputGridObservationWrapper: (5, 5, 2)
+ AnswerObservationWrapper: (5, 5, 3)
+ ClipboardObservationWrapper: (5, 5, 4)

Total channels so far: 4
+ AnswerObservationWrapper: (5, 5, 3)
+ ClipboardObservationWrapper: (5, 5, 4)

Total channels so far: 4


### ContextualObservationWrapper

The `ContextualObservationWrapper` adds **demonstration pairs** from the task to the observation. This gives the agent access to other input/output examples that illustrate the task's transformation pattern.

Key features:
- Adds `2 * num_context_pairs` channels (input + output for each pair)
- During **training**: excludes the current pair being solved
- During **testing**: includes all demonstration pairs (since we're solving a test pair)
- Pads with zeros if fewer demonstration pairs are available than requested

In [6]:
from jaxarc.wrappers import ContextualObservationWrapper

# Add 3 demonstration pairs as context
env_with_context = ContextualObservationWrapper(env_with_clipboard, num_context_pairs=3)

key = jr.PRNGKey(45)
state, timestep = env_with_context.reset(key, env_params)

print("With ContextualObservationWrapper (3 pairs):")
print(f"  Observation shape: {timestep.observation.shape}")
print(f"  Added channels: {3 * 2} (3 pairs × 2 channels per pair)")

print(f"\nTotal channels: {timestep.observation.shape[-1]}")

With ContextualObservationWrapper (3 pairs):
  Observation shape: (5, 5, 10)
  Added channels: 6 (3 pairs × 2 channels per pair)

Total channels: 10


## Combining Action and Observation Wrappers

You can chain both types of wrappers together:

In [7]:
# Create a fully wrapped environment
wrapped_env = PointActionWrapper(env)
wrapped_env = InputGridObservationWrapper(wrapped_env)
wrapped_env = AnswerObservationWrapper(wrapped_env)

# Reset and inspect
key = jr.PRNGKey(46)
state, timestep = wrapped_env.reset(key, env_params)

print("Wrapped environment:")
print(f"  Action space: {wrapped_env.action_space(env_params)}")
print(f"  Observation shape: {timestep.observation.shape}")

# Take a point action
action = {"operation": 1, "row": 1, "col": 1}
state, timestep = wrapped_env.step(state, action, env_params)

print("\nAction executed successfully with enhanced observations")

Wrapped environment:
  Action space: DictSpace({operation=DiscreteSpace(num_values=35, dtype=int32, name='operation'), row=DiscreteSpace(num_values=5, dtype=int32, name=''), col=DiscreteSpace(num_values=5, dtype=int32, name='')}, name='point_action')
  Observation shape: (5, 5, 3)

Action executed successfully with enhanced observations


## Summary


| Wrapper Type | Purpose | Example Use Case |
|-------------|---------|------------------|
| **Action Wrappers** | | |
| `PointActionWrapper` | Dict actions with single points | Agents that select one cell at a time |
| `BboxActionWrapper` | Dict actions with bounding boxes | Agents that work with regions |
| `FlattenActionWrapper` | Single discrete action space | Standard RL algorithms (DQN, PPO) |
| **Observation Wrappers** | | |
| `InputGridObservationWrapper` | Add input grid channel | Always visible reference |
| `AnswerObservationWrapper` | Add answer grid channel | Training with supervision |
| `ClipboardObservationWrapper` | Add clipboard channel | Copy-paste operations |
| `ContextualObservationWrapper` | Add demonstration pairs | Few-shot learning, pattern recognition |
| **Visualization Wrappers** | | |
| `StepVisualizationWrapper` | Enable detailed SVG rendering | Debugging agent actions and transitions |

Wrappers enhance environment usability without altering core logic. They enable flexible action formats, richer observations, and better visualization, facilitating effective agent training and evaluation.