# Tutorial 01: CartPole Basic Control

This tutorial demonstrates how to train a reinforcement learning agent on the CartPole balancing task using Myriad's platform.

## Learning Objectives
1. Understand the CartPole environment structure
2. Configure training runs using Hydra configs
3. Train a DQN agent using the platform
4. Analyze training and evaluation results
5. Visualize training progress

**The CartPole task:** Balance a pole on a moving cart by applying left/right forces. Goal: Keep the pole upright for as long as possible (max 500 steps).

**Estimated runtime:** ~30-60 seconds on M1 CPU

## Setup

In [None]:
import jax
import numpy as np
import matplotlib.pyplot as plt

from myriad.configs.default import (
    AgentConfig,
    Config,
    EnvConfig,
    RunConfig,
    WandbConfig,
)
from myriad.envs.cartpole.tasks.control import make_env
from myriad.platform import train_and_evaluate

SEED = 42

---

# Section 1: Understanding the Environment

Let's explore the CartPole environment structure.

## 1.1 Environment Configuration

The environment has two main configurations:
- **Physics Config:** Physical parameters (gravity, masses, pole length, etc.)
- **Task Config:** Task parameters (max steps, termination thresholds)

In [None]:
env = make_env()

print(f"Physics: {env.config.physics}")
print(f"Task: {env.config.task}")

## 1.2 Action and Observation Spaces

**Actions:** Discrete(2)
- 0 = push left
- 1 = push right

**Observations:** [x, x_dot, theta, theta_dot]
- x: Cart position (m)
- x_dot: Cart velocity (m/s)
- theta: Pole angle from vertical (rad, 0=upright)
- theta_dot: Pole angular velocity (rad/s)

In [None]:
action_space = env.get_action_space(env.config)
obs_shape = env.get_obs_shape(env.config)

print(f"Action space: Discrete({action_space.n})")
print(f"Observation shape: {obs_shape}")

## 1.3 Initial State

The environment returns structured observations as a `PhysicsState` NamedTuple with named fields. This makes it easy to access specific components like `obs.theta` or `obs.x`.

The platform automatically converts these to arrays for efficient training.

In [None]:
key = jax.random.PRNGKey(SEED)
obs, state = env.reset(key, env.params, env.config)

print(f"Observation: {obs}")
print(f"As array: {obs.to_array()}")

## 1.4 Taking Random Actions

In [None]:
for step in range(3):
    key, action_key, step_key = jax.random.split(key, 3)
    action = action_space.sample(action_key)
    next_obs, next_state, reward, done, _ = env.step(
        step_key, state, action, env.params, env.config
    )
    
    print(f"Step {step + 1}: action={action}, theta={next_obs.theta:.4f} rad, reward={reward}, done={done}")
    
    state, obs = next_state, next_obs
    if done:
        break

---

# Section 2: Training with the Platform

Now let's train a DQN agent. The platform handles all the complexity:
- Parallel environment execution (JAX vmap)
- Replay buffer management
- Agent updates and target network sync
- Automatic environment resets
- Metrics logging and evaluation

## 2.1 Configure the Training Run

Create a configuration using Pydantic models (same structure as Hydra YAML configs):

In [None]:
config = Config(
    env=EnvConfig(name="cartpole-control"),
    agent=AgentConfig(
        name="dqn",
        learning_rate=1e-3,
        gamma=0.99,
        epsilon_start=1.0,
        epsilon_end=0.05,
        epsilon_decay_steps=10000,
        target_network_frequency=500,
        batch_size=64,
    ),
    run=RunConfig(
        seed=SEED,
        num_envs=16,
        total_timesteps=50000,
        buffer_size=10000,
        log_frequency=1000,
        eval_frequency=5000,
        eval_rollouts=10,
        eval_max_steps=500,
        scan_chunk_size=10,
    ),
    wandb=WandbConfig(enabled=False),
)

print(f"Training {config.run.total_timesteps:,} steps across {config.run.num_envs} parallel environments")

## 2.2 Run Training

**This will take ~30-60 seconds:**

In [None]:
results = train_and_evaluate(config)

## 2.3 Training Summary

In [None]:
print(f"Total timesteps: {results.training_metrics.global_steps[-1]:,}")
print(f"Evaluation checkpoints: {len(results.eval_metrics.global_steps)}")

if results.eval_metrics.mean_return:
    print(f"Final mean return: {results.eval_metrics.mean_return[-1]:.2f}")
    print(f"Best mean return: {max(results.eval_metrics.mean_return):.2f}")

---

# Section 3: Analyzing Results

The platform returns complete `TrainingResults` with all metrics.

## 3.1 Training Metrics

In [None]:
if results.training_metrics.loss:
    print(f"Initial loss: {results.training_metrics.loss[0]:.4f}")
    print(f"Final loss: {results.training_metrics.loss[-1]:.4f}")

if "td_error" in results.training_metrics.agent_metrics:
    td_errors = results.training_metrics.agent_metrics["td_error"]
    print(f"Initial TD error: {td_errors[0]:.4f}")
    print(f"Final TD error: {td_errors[-1]:.4f}")

if "q_value" in results.training_metrics.agent_metrics:
    q_values = results.training_metrics.agent_metrics["q_value"]
    print(f"Initial Q-value: {q_values[0]:.4f}")
    print(f"Final Q-value: {q_values[-1]:.4f}")

## 3.2 Evaluation Metrics

In [None]:
if results.eval_metrics.mean_return:
    for i, (step, mean_ret, std_ret, mean_len) in enumerate(
        zip(
            results.eval_metrics.global_steps,
            results.eval_metrics.mean_return,
            results.eval_metrics.std_return,
            results.eval_metrics.mean_length,
        )
    ):
        print(f"Checkpoint {i + 1} @ {step:,} steps: {mean_ret:.1f} ± {std_ret:.1f} (length: {mean_len:.0f})")

## 3.3 Performance Summary

In [None]:
results.summary()

---

# Section 4: Visualizing Progress

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("CartPole DQN Training Results", fontsize=16, fontweight="bold")

# Training Loss
if results.training_metrics.loss:
    ax = axes[0, 0]
    ax.plot(results.training_metrics.global_steps, results.training_metrics.loss, linewidth=2, color="#2E86AB")
    ax.set_xlabel("Environment Steps")
    ax.set_ylabel("Loss")
    ax.set_title("Training Loss", fontweight="bold")
    ax.grid(alpha=0.3)

# TD Error
if "td_error" in results.training_metrics.agent_metrics:
    ax = axes[0, 1]
    ax.plot(results.training_metrics.global_steps, results.training_metrics.agent_metrics["td_error"], linewidth=2, color="#A23B72")
    ax.set_xlabel("Environment Steps")
    ax.set_ylabel("TD Error")
    ax.set_title("Temporal Difference Error", fontweight="bold")
    ax.grid(alpha=0.3)

# Evaluation Returns
if results.eval_metrics.mean_return:
    ax = axes[1, 0]
    steps = results.eval_metrics.global_steps
    mean_ret = results.eval_metrics.mean_return
    std_ret = results.eval_metrics.std_return
    
    ax.plot(steps, mean_ret, linewidth=2, color="#F18F01", label="Mean Return")
    ax.fill_between(
        steps,
        np.array(mean_ret) - np.array(std_ret),
        np.array(mean_ret) + np.array(std_ret),
        alpha=0.3,
        color="#F18F01",
        label="±1 Std Dev",
    )
    ax.axhline(y=500, color="green", linestyle="--", alpha=0.7, label="Max (500)")
    ax.set_xlabel("Environment Steps")
    ax.set_ylabel("Episode Return")
    ax.set_title("Evaluation Performance", fontweight="bold")
    ax.legend()
    ax.grid(alpha=0.3)

# Episode Lengths
if results.eval_metrics.mean_length:
    ax = axes[1, 1]
    ax.plot(results.eval_metrics.global_steps, results.eval_metrics.mean_length, linewidth=2, color="#6A994E")
    ax.axhline(y=500, color="green", linestyle="--", alpha=0.7, label="Max (500)")
    ax.set_xlabel("Environment Steps")
    ax.set_ylabel("Episode Length")
    ax.set_title("Episode Length", fontweight="bold")
    ax.legend()
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

---

# Tutorial Complete!

## What You Learned
1. How to explore CartPole environment structure
2. How to configure training runs using `Config` objects
3. How to train agents using `train_and_evaluate()`
4. How to access and analyze `TrainingResults`
5. How to visualize training progress

## Next Steps
Try experimenting with:
- Different hyperparameters (`learning_rate`, `epsilon_decay`, etc.)
- More parallel environments (`num_envs`)
- Longer training (`total_timesteps`)
- W&B logging (`wandb.enabled=True`)
- Other agents (PPO, SAC) when available