# Example: LLMs Learn to Navigate Mazes from Experience (BabyAI Benchmark)


## Setup

In [None]:
import asyncio
import random
from typing import List, Optional, Tuple
from uuid import UUID

import altair as alt
import pandas as pd
import yaml
from balrog.environments import make_env
from omegaconf import OmegaConf
from tensorzero import AsyncTensorZeroGateway
from tensorzero.util import uuid7
from tqdm import trange

Load config for BALROG environments

In [None]:
with open("config.yml") as f:
    config_dict = yaml.safe_load(f)
config = OmegaConf.create(config_dict)

In [None]:
# Reduce this value if you're getting rate-limited by OpenAI
MAX_CONCURRENT_T0_REQUESTS = 50
semaphore = asyncio.Semaphore(MAX_CONCURRENT_T0_REQUESTS)

## Helper Functions

The `run_episode` function executes a single episode of the agent for a BabyAI task (maze navigation game).

In [None]:
async def run_episode(
    client: AsyncTensorZeroGateway,
    variant_name: str,
    env_name: str,
    task_name: str,
    episode_idx: int,
    config: OmegaConf,
    semaphore: asyncio.Semaphore,
    history_length: int = 2,
    seed: int = 0,
    test: bool = False,
) -> Tuple[float, float, Optional[UUID]]:
    episode_log = {
        "variant": variant_name,
        "task": task_name,
        "input_tokens": 0,
        "output_tokens": 0,
    }
    use_history = "history" in variant_name
    episode_id = uuid7()
    env = make_env(env_name, task_name, config)
    obs, _ = env.reset(seed=episode_idx + seed)
    mission = obs["mission"]
    episode_return = 0
    history = []
    for step in range(env.max_steps):
        # Generate action
        try:
            async with semaphore:
                # Generate message content
                state = obs["text"]["long_term_context"]
                # Generate action given message content
                response = await client.inference(
                    function_name="act",
                    variant_name=variant_name,
                    input={
                        "system": {
                            "mission": mission,
                        },
                        "messages": [
                            {
                                "role": "user",
                                "content": [
                                    {
                                        "type": "text",
                                        "arguments": {
                                            "observation": state,
                                            "history": "\n".join(
                                                history[-history_length:]
                                            ),
                                        },
                                    }
                                ],
                            }
                        ],
                    },
                    episode_id=episode_id,
                    cache_options={"enabled": "on"},
                )
                episode_log["input_tokens"] += response.usage.input_tokens
                episode_log["output_tokens"] += response.usage.output_tokens
            action = response.output.parsed["action"]
            # Check if action is valid and set to default if not
            action = env.check_action_validity(action)
        except Exception as e:
            # Handle error
            print(f"Error occurred: {type(e).__name__}: {e}")
            print("Choosing a random legal move as fallback.")
            action = random.choice(
                [
                    "turn left",
                    "turn right",
                    "go forward",
                    "pick up",
                    "drop",
                    "toggle",
                ]
            )
        # Update history
        if use_history:
            history.append(f"Observation:{state}\n\nYour Response:\n{action}\n")
        # Interact with environment
        obs, reward, terminated, truncated, info = env.step(action)
        # Update episode return
        episode_return += reward
        # Check if episode is done and break if so
        done = terminated or truncated
        if done:
            break
    # See if episode is successful
    progression = env.get_stats()["progression"]
    # Log feedback
    await client.feedback(
        metric_name="episode_return",
        episode_id=episode_id,
        value=episode_return,
        dryrun=test,
    )
    await client.feedback(
        metric_name="progression",
        episode_id=episode_id,
        value=progression,
        dryrun=test,
    )
    episode_log["episode_return"] = episode_return
    episode_log["num_steps"] = step + 1
    episode_log["failed_candidates"] = env.failed_candidates
    episode_log.update(env.get_stats())
    episode_log["seed"] = episode_idx
    episode_log["episode_id"] = episode_id
    return episode_log

We define a function to run multiple episodes of the agent for a BabyAI task in parallel. 

In [None]:
async def run_episodes(
    client: AsyncTensorZeroGateway,
    variant_name: str,
    env_name: str,
    task_name: str,
    num_episodes: int,
    config: OmegaConf,
    semaphore: asyncio.Semaphore,
    disable_progress_bar: bool = False,
    history_length: int = 2,
    seed: int = 0,
    test: bool = False,
) -> Tuple[List[float], List[float]]:
    progress_bar = trange(
        num_episodes,
        desc=f"{env_name} {task_name} {variant_name}",
        disable=disable_progress_bar,
    )

    tasks = [
        asyncio.create_task(
            run_episode(
                client=client,
                variant_name=variant_name,
                env_name=env_name,
                task_name=task_name,
                episode_idx=episode_idx,
                config=config,
                semaphore=semaphore,
                history_length=history_length,
                seed=seed,
                test=test,
            )
        )
        for episode_idx in range(num_episodes)
    ]

    num_successes = 0
    episode_logs = []
    for task in asyncio.as_completed(tasks):
        episode_log = await task
        if episode_log["progression"] == 1.0:
            num_successes += 1
        episode_logs.append(episode_log)
        current = len(episode_logs)
        progress_bar.update(1)
        progress_bar.set_postfix(
            {"Success": f"{num_successes}/{current}"},
            refresh=True,
        )
    progress_bar.close()
    return episode_logs

In [None]:
seed = 200
num_episodes = 20
task_names = config.tasks.babyai_tasks

## Baseline

The `baseline` variant uses a simple system prompt that guides the LLM to navigate the maze.

You can find the prompts in `config/functions/act/baseline`.

In [None]:
results_baseline = []

for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(
        gateway_url="http://localhost:3000", timeout=180.0
    ) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="baseline",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            seed=seed,
            test=True,
        )
        results_baseline.extend(results_task)

## Reasoning

The `reasoning` variant uses a system prompt that guides the LLM to reason about the best course of action.

You can find the prompts in `config/functions/act/reasoning`.

In [None]:
results_reasoning = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(
        gateway_url="http://localhost:3000", timeout=180.0
    ) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            seed=seed,
            test=True,
        )
        results_reasoning.extend(results_task)

## History

The `history` variant uses the previous observations and actions to guide the LLM to navigate the maze.
We add the previous two observations and actions to the field `history` in the examples below.

You can find the prompts in `config/functions/act/history`.

In [None]:
history_length = 8

results_history = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(
        gateway_url="http://localhost:3000", timeout=180.0
    ) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history.extend(results_task)

## History and Reasoning

The `history_and_reasoning` variant combines the reasoning variant and the history variant.

You can find the prompts in `config/functions/act/history_and_reasoning`.

In [None]:
results_history_and_reasoning = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(
        gateway_url="http://localhost:3000", timeout=180.0
    ) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history_and_reasoning.extend(results_task)

## Results

In [None]:
df = pd.DataFrame(
    results_baseline
    + results_reasoning
    + results_history
    + results_history_and_reasoning
)

### Success Rate

In [None]:
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Task Success Rate")

chart

### Episode Return

In [None]:
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Return")

chart

### Episode Length

In [None]:
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Length")

chart

### Episode Generated Token Count

In [None]:
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Generated Token Count")

chart

### Episode Input Token Count

In [None]:
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Input Token Count")

chart

## Improving Performance with Supervised Fine-tuning (SFT)

The results above show that the `history_and_reasoning` variant yields the best success rate.
Here we describe how to improve the performance of the `history_and_reasoning` variant by fine-tuning it on a separate set of random episodes.

First we run a large set of episodes for each task using the `history_and_reasoning` variant to generate data for fine-tuning.

In [None]:
num_episodes_ft = 200
seed_ft = 0

for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(
        gateway_url="http://localhost:3000", timeout=180.0
    ) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes_ft,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed_ft,
            test=False,
        )


We provide two option for fine-tuning a model: using a notebook or using the TensorZero UI.
You can fine-tune on episodes that successfully completed the task, or episodes that achieved a sufficiently high return (e.g. 0.7).

See the `README.md` file for more details.


### Evaluating the Fine-tuned Variant

After fine-tuning, create a `history_and_reasoning_sft` variant and run the following code to evaluate it.

In [None]:
results_history_and_reasoning_ft = []
for task_name in task_names:
    async with await AsyncTensorZeroGateway.build_http(
        gateway_url="http://localhost:3000", timeout=180.0
    ) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning_sft",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history_and_reasoning_ft.extend(results_task)

## Results

We combine the results of the fine-tuned model with the results of the `history_and_reasoning` variant.

In [None]:
df_ft = pd.DataFrame(results_history_and_reasoning_ft)

df = pd.concat([df, df_ft])

We see below that the fine-tuned model performs better than the `history_and_reasoning` variant!

### Success Rate

In [None]:
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Task Success Rate")

chart

### Episode Return

In [None]:
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Return")

chart

### Episode Length

In [None]:
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Length")

chart

### Episode Generated Token Count

In [None]:
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Generated Token Count")

chart

### Episode Input Token Count

In [None]:
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()

# Create a base chart
bars = (
    alt.Chart(summary)
    .encode(
        y=alt.Y("variant:N", title="Variant"),
        x=alt.X("mean:Q", title="Value ± 1 SEM", scale=alt.Scale(zero=False)),
    )
    .mark_bar(color="#1f77b4")
)

# Create error bars
error_bars = (
    alt.Chart(summary)
    .mark_errorbar(color="black")
    .encode(y="variant:N", x=alt.X("low:Q", title="Value ± 1 SEM"), x2="high:Q")
    .transform_calculate(low="datum.mean - datum.sem", high="datum.mean + datum.sem")
)

# Combine the layers
chart = (bars + error_bars).properties(title="Episode Input Token Count")

chart