# Qwen3-0.6B batch evaluation on Taxi-v3

This notebook runs multiple Taxi-v3 episodes where the action is chosen by the local Qwen3-0.6B model via your backend (`backend/llm/client.py`, `backend/taxi/*`). It works on Colab and locally.

In [None]:
# If running on Colab, install dependencies
import sys, os
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    # Minimal set; torch and transformers are often present, but upgrade to ensure correct versions
    !pip -q install -U gymnasium==0.26.2 transformers accelerate torch safetensors sentencepiece huggingface_hub

# Ensure the backend package path is available (repo structure: backend/taxi, backend/llm)
repo_root = os.getcwd()
backend_path = os.path.join(repo_root, 'backend')
if backend_path not in sys.path:
    sys.path.append(backend_path)

# On Colab, prefer GPU if available for Qwen
if IN_COLAB and 'QWEN_DEVICE' not in os.environ:
    import torch
    os.environ['QWEN_DEVICE'] = 'cuda' if torch.cuda.is_available() else 'cpu'

# Optionally control generation behavior
os.environ.setdefault('QWEN_TEMPERATURE', '0.2')
# You can set QWEN_MAX_NEW_TOKENS to limit output; default uses a safe value from the client
# os.environ['QWEN_MAX_NEW_TOKENS'] = '256'

In [None]:
# Imports from the backend
from taxi.environment import TaxiEnvironment
from taxi.state_utils import decode_state, describe_state_for_llm, get_prompt
from llm.client import get_qwen_action, _ensure_pipeline_ready, _get_client
# Use reusable action coercion helper (no Flask dependency)
from taxi.action_utils import coerce_action

import json, time
from typing import Dict, Any, List, Tuple

In [None]:
# Warm up the Qwen pipeline (downloads model if needed)
client = _get_client()
ready = _ensure_pipeline_ready(client)
if not ready:
    print('Model is loading in the background... waiting briefly (up to ~60s)')
    # Poll for a short while
    for _ in range(60):
        time.sleep(1)
        if _ensure_pipeline_ready(client):
            break
print('Pipeline ready:', _ensure_pipeline_ready(client))

In [None]:
def qwen_policy_action(state: int) -> Tuple[int, Dict[str, Any]]:
    """Given an environment integer state, query Qwen for an action.
    Returns (action_code, full_llm_payload). If the action can't be coerced, returns (None, payload).
    """
    state_desc = describe_state_for_llm(decode_state(state))
    prompt = get_prompt(state_desc)
    result = get_qwen_action(prompt)
    action = coerce_action(result.get('action'))
    return action, result

def run_episode(max_steps: int = 200, verbose: bool = False) -> Dict[str, Any]:
    env = TaxiEnvironment()
    state = env.observation
    total_reward = 0.0
    steps = []
    success = False
    for t in range(max_steps):
        action, payload = qwen_policy_action(state)
        if action is None:
            # If the model is not ready or returns invalid output, end early
            return {
                'success': False,
                'ended_early': True,
                'reason': 'Invalid/empty action from LLM',
                'total_reward': total_reward,
                'steps': steps,
                'llm_last': payload,
            }
        next_state, reward, done = env.step(action)
        steps.append({
            'state': int(state),
            'action': int(action),
            'reward': float(reward),
        })
        total_reward += reward
        state = next_state
        if verbose:
            print(f'Step {t}: action={action}, reward={reward}, total={total_reward}')
        if done:
            success = (reward == 20)  # Taxi-v3 gives +20 on successful drop-off
            break
    env.close()
    return {
        'success': bool(success),
        'ended_early': False,
        'total_reward': float(total_reward),
        'steps': steps,
    }

def run_batch(n_episodes: int = 20, max_steps: int = 200, verbose_every: int = 0) -> Dict[str, Any]:
    results = []
    successes = 0
    total_rewards = 0.0
    for i in range(n_episodes):
        ep = run_episode(max_steps=max_steps, verbose=False)
        results.append(ep)
        successes += 1 if ep.get('success') else 0
        total_rewards += ep.get('total_reward', 0.0)
        if verbose_every and (i + 1) % verbose_every == 0:
            print(f"Episode {i+1}/{n_episodes}: success={ep.get('success')}, total_reward={ep.get('total_reward')}")
    avg_reward = total_rewards / max(1, n_episodes)
    return {
        'episodes': results,
        'success_rate': successes / max(1, n_episodes),
        'average_reward': avg_reward,
        'count': n_episodes,
    }

In [None]:
# Run the batch
metrics = run_batch(n_episodes=10, max_steps=200, verbose_every=1)
metrics

In [None]:
# Optionally save detailed results to JSON
out = 'qwen_taxi_results.json'
with open(out, 'w') as f:
    json.dump(metrics, f, indent=2)
print('Saved to', out)