# Train Q-learning agent on Taxi-v3

This notebook reuses `backend/taxi/rl_agent.py` to train a simple Q-learning agent, then reports basic metrics and shows a sample successful trajectory.

In [None]:
# If running on Colab, install dependencies
import sys, os
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    !pip -q install -U gymnasium==0.26.2 numpy

# Ensure backend is importable
repo_root = os.getcwd()
backend_path = os.path.join(repo_root, 'backend')
if backend_path not in sys.path:
    sys.path.append(backend_path)

In [None]:
import json
import numpy as np
from taxi.rl_agent import QLearningAgent, solve_taxi_v3_and_collect_data
import gymnasium as gym

In [None]:
# Train the agent with configurable episodes
EPISODES = 2000
ALPHA = 0.1
GAMMA = 0.99
EPSILON = 1.0
EPS_DECAY = 0.995
EPS_MIN = 0.01

env = gym.make('Taxi-v3')
agent = QLearningAgent(env, alpha=ALPHA, gamma=GAMMA, epsilon=EPSILON, epsilon_decay=EPS_DECAY, epsilon_min=EPS_MIN)
successful_episodes = agent.train(episodes=EPISODES)
env.close()
len(successful_episodes), successful_episodes[:1]

In [None]:
# Compute basic metrics over training
total_success = len(successful_episodes)
avg_success_reward = float(np.mean([ep['total_reward'] for ep in successful_episodes])) if successful_episodes else 0.0
print('Successful episodes:', total_success)
print('Average reward among successes:', avg_success_reward)
# Save to JSON for downstream analysis or explanation generation
with open('rl_successful_episodes.json', 'w') as f:
    json.dump(successful_episodes, f, indent=2)
print('Saved to rl_successful_episodes.json')

In [None]:
# Optional: generate natural-language explanations using backend.llm.explanation_generator
# Requires OPENAI_API_KEY to be set if using OpenAI.
from llm.explanation_generator import generate_explanation_for_rl_steps
try:
    explanations = generate_explanation_for_rl_steps(successful_episodes[:3])
    print('Generated', len(explanations), 'explanations. Example:')
    explanations[0] if explanations else {}
except Exception as e:
    print('Explanation generation skipped due to error:', e)