In [None]:
import asyncio
import logging
import random
from typing import List, Optional, Tuple
from uuid import UUID

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import yaml
from balrog.environments import make_env
from omegaconf import OmegaConf
from tensorzero import AsyncTensorZeroGateway
from tensorzero.util import uuid7
from tqdm import trange

sns.set_style("whitegrid")
sns.set_palette("colorblind")

logger = logging.getLogger(__name__)

ACTION_SPACE = [
    "turn left",
    "turn right",
    "go forward",
    "pick up",
    "drop",
    "toggle",
]

Load config for BALROG environments

In [None]:
with open(
    "config.yml",
    "r",
) as f:
    config_dict = yaml.safe_load(f)
config = OmegaConf.create(config_dict)

In [None]:
# Reduce this value if you're getting rate-limited by OpenAI
MAX_CONCURRENT_T0_REQUESTS = 50
semaphore = asyncio.Semaphore(MAX_CONCURRENT_T0_REQUESTS)

## Define run_episode function
The run_episode function executes a single rollout of an agent for a BabyAI task. 

In [None]:
async def run_episode(
    client: AsyncTensorZeroGateway,
    variant_name: str,
    env_name: str,
    task_name: str,
    episode_idx: int,
    config: OmegaConf,
    semaphore: asyncio.Semaphore,
    history_length: int = 2,
    seed: int = 0,
    test: bool = False,
) -> Tuple[float, float, Optional[UUID]]:
    episode_log = {
        "variant": variant_name,
        "task": task_name,
        "input_tokens": 0,
        "output_tokens": 0,
    }
    use_history = "history" in variant_name
    episode_id = uuid7()
    env = make_env(env_name, task_name, config)
    obs, _ = env.reset(seed=episode_idx + seed)
    mission = obs["mission"]
    episode_return = 0
    history = []
    for step in range(env.max_steps):
        # Generate action
        try:
            async with semaphore:
                # Generate message content
                state = obs["text"]["long_term_context"]
                # Generate action given message content
                response = await client.inference(
                    function_name="act",
                    variant_name=variant_name,
                    input={
                        "system": {
                            "mission": mission,
                        },
                        "messages": [
                            {
                                "role": "user",
                                "content": {
                                    "observation": state,
                                    "history": "\n".join(history[-history_length:]),
                                },
                            }
                        ],
                    },
                    episode_id=episode_id,
                )
                episode_log["input_tokens"] += response.usage.input_tokens
                episode_log["output_tokens"] += response.usage.output_tokens
            action = response.output.parsed["action"]
            # Check if action is valid and set to default if not
            action = env.check_action_validity(action)
        except Exception as e:
            # Handle error
            logger.error(f"Error occurred: {type(e).__name__}: {e}")
            logger.info("Choosing a random legal move as fallback.")
            action = random.choice(ACTION_SPACE)
        # update history
        if use_history:
            history.append(f"Observation:{state}\n\nYour Response:\n{action}\n")
        # Interact with environment
        obs, reward, terminated, truncated, info = env.step(action)
        # Update episode return
        episode_return += reward
        # Check if episode is done and break if so
        done = terminated or truncated
        if done:
            break
    # See if episode is successful
    progression = env.get_stats()["progression"]
    # Log feedback
    _ = await client.feedback(
        metric_name="episode_return",
        episode_id=episode_id,
        value=episode_return,
        dryrun=test,
    )
    _ = await client.feedback(
        metric_name="progression",
        episode_id=episode_id,
        value=progression,
        dryrun=test,
    )
    episode_log["episode_return"] = episode_return
    episode_log["num_steps"] = step + 1
    episode_log["failed_candidates"] = env.failed_candidates
    episode_log.update(env.get_stats())
    episode_log["seed"] = episode_idx
    episode_log["episode_id"] = episode_id
    return episode_log

We define a function to run multiple episodes of the agent for a BabyAI task in parallel. 

In [None]:
async def run_episodes(
    client: AsyncTensorZeroGateway,
    variant_name: str,
    env_name: str,
    task_name: str,
    num_episodes: int,
    config: OmegaConf,
    semaphore: asyncio.Semaphore,
    disable_progress_bar: bool = False,
    history_length: int = 2,
    seed: int = 0,
    test: bool = False,
) -> Tuple[List[float], List[float]]:
    progress_bar = trange(
        num_episodes,
        desc=f"{env_name} {task_name} {variant_name}",
        disable=disable_progress_bar,
    )

    tasks = [
        asyncio.create_task(
            run_episode(
                client=client,
                variant_name=variant_name,
                env_name=env_name,
                task_name=task_name,
                episode_idx=episode_idx,
                config=config,
                semaphore=semaphore,
                history_length=history_length,
                seed=seed,
                test=test,
            )
        )
        for episode_idx in range(num_episodes)
    ]

    num_successes = 0
    episode_logs = []
    for task in asyncio.as_completed(tasks):
        episode_log = await task
        if episode_log["progression"] == 1.0:
            num_successes += 1
        episode_logs.append(episode_log)
        current = len(episode_logs)
        progress_bar.update(1)
        progress_bar.set_postfix(
            {"Success": f"{num_successes}/{current}"},
            refresh=True,
        )
    progress_bar.close()
    return episode_logs

In [None]:
seed = 200
num_episodes = 20
task_names = config.tasks.babyai_tasks

## Baseline
### Generating an action

#### System prompt
We use the BALROG system prompt:
```.minijinja
You are an agent playing a simple navigation game.
Your goal is to {{ mission }}.
The following are the possible actions you can take in the game, followed by a short description of each action:

turn left: turn to the left
turn right: turn to the right
go forward: take one step forward
pick up: pick up the object below you
drop: drop the object that you are holding
toggle: manipulate the object in front of you

In a moment I will present you an observation.

Tips:

- Once the desired object you want to interact or pickup in front of you, you can use the 'toggle' action to interact with it.
- It doesn't make sense to repeat the same action over and over if the observation doesn't change.

PLAY!
```
in `config/functions/act/baseline/system.minijinja`.

#### User prompt
The user prompt is
```.minijinja
Current Observation:
{{ observation }}

Only respond with a JSON object with the following schema:

{
  "thinking": "..."
  "action": "..."
}

The "action" field is required and must always contain one of the above actions at a time and no other text.

Example:

User: a wall 4 steps forward
a wall 3 steps left
Agent: {"thinking": "Since there is a wall 4 steps forward and 3 steps to the left, the best action to take would be to turn right, allowing you to navigate without hitting the walls.", "action": "turn right"}
```
in `config/functions/act/baseline/user.minijinja`

In [None]:
results_baseline = []
for task_name in task_names:
    async with AsyncTensorZeroGateway("http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="baseline",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            seed=seed,
            test=True,
        )
        results_baseline.extend(results_task)

## Reasoning
Following the BALROG chain of thought example, we add the instruction 
```.minijinja
Current Observation:
{{ observation }}

Only respond with a JSON object with the following schema:

{
  "thinking": "..."
  "action": "..."
}

The "thinking" field should contain your thought process about what's the best course of action step by step.
The "action" field is required and must always contain one of the above actions at a time and no other text.

Example:

User: a wall 4 steps forward
a wall 3 steps left
Agent: {"thinking": "Since there is a wall 4 steps forward and 3 steps to the left, the best action to take would be to turn right, allowing you to navigate without hitting the walls.", "action": "turn right"}
```
to the user prompt in `config/functions/act/reasoning/user.minijinja`.

In [None]:
results_reasoning = []
for task_name in task_names:
    async with AsyncTensorZeroGateway("http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            seed=seed,
            test=True,
        )
        results_reasoning.extend(results_task)

## History

We see if adding previous observations and actions helps the performance of the baseline.
The user prompt is changed to

```.minijinja
History:
{{ history }}

Current Observation:
{{ observation }}

Only respond with a JSON object with the following schema:

{
  "thinking": "..."
  "action": "..."
}

The "action" field is required and must always contain one of the above actions at a time and no other text.

Example:

User: a wall 4 steps forward
a wall 3 steps left
Agent: {"thinking": "Since there is a wall 4 steps forward and 3 steps to the left, the best action to take would be to turn right, allowing you to navigate without hitting the walls.", "action": "turn right"}
```
in `config/functions/act/history/user.minijinja`.
We add the previous two observations and actions to the field `history` in the examples below.

In [None]:
history_length = 8

results_history = []
for task_name in task_names:
    async with AsyncTensorZeroGateway("http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history.extend(results_task)

## History and reasoning

This variant combines the reasoning variant and the history variant.
The user prompt is changed to
```.minijinja
History:
{{ history }}

Current Observation:
{{ observation }}

Only respond with a JSON object with the following schema:

{
  "thinking": "..."
  "action": "..."
}

The "thinking" field should contain your thought process about what's the best course of action step by step.
The "action" field is required and must always contain one of the above actions at a time and no other text.

Example:

User: a wall 4 steps forward
a wall 3 steps left
Agent: {"thinking": "Since there is a wall 4 steps forward and 3 steps to the left, the best action to take would be to turn right, allowing you to navigate without hitting the walls.", "action": "turn right"}
```
in `config/functions/act/history_and_reasoning/user.minijinja`.

In [None]:
results_history_and_reasoning = []
for task_name in task_names:
    async with AsyncTensorZeroGateway("http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history_and_reasoning.extend(results_task)

## Results

In [None]:
df = pd.DataFrame(
    results_baseline
    + results_reasoning
    + results_history
    + results_history_and_reasoning
)

### Success rate

In [None]:
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()

# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Task success rate")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode return

In [None]:
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode return")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode length

In [None]:
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode length")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode generated token count

In [None]:
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode generated token count")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode input token count

In [None]:
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode input token count")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

## Improving performance with fine tuning

The results above show that the history_and_reasoning variant performs yields the best success rate.
Here we describe how to improve the performance of the history_and_reasoning variant by fine-tuning it on a separate set of random episodes.

First we run a set of episodes for each taskusing the history_and_reasoning variant.

In [None]:
num_episodes_ft = 200
seed_ft = 0

for task_name in task_names:
    async with AsyncTensorZeroGateway("http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes_ft,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed_ft,
            test=False,
        )


We provide two option for fine-tuning a model: using a notebook or using the tensorzero web interface.

### Fine-tuning using a notebook

To fine-tune a model, you can use the notebook at `recipes/supervised-fine-tuning/metrics/openai/`. 
You will need to change the values of the following variables to:
```.python
CONFIG_PATH="../../../../examples/babyai/config/tensorzero.toml
FUNCTION_NAME = "act"
METRIC_NAME = "episode_return"
TEMPLATE_VARIANT_NAME = "history_and_reasoning"
FLOAT_METRIC_THRESHOLD = 0.7
```
After running the notebook, you will need to add the fine-tuned model to your `tensorzero.toml` with the `model_id` given in the notebook. And add a 'history_and_reasoning_ft' variant to the function `act` that uses the fine-tuned model.

### Fine-tuning using the tensorzero web interface

To fine-tune a model using the tensorzero web interface, you can go to http://localhost:4000, and click on "Supervised Fine-tuning"

![alt text](img/homepage.png "Homepage")

Then, just fill in the form and click on "Start Fine-tuning Job".

![alt text](img/sft_form.png "SFT form")

When the job is finished, you will need to copy and past the fine-tuned model configto your `tensorzero.toml` file and add a `history_and_reasoning_ft` variant to the function `act`.

### Evaluating the fine-tuned 

After fine-tuning, you can run the following code to evaluate the fine-tuned model.

In [None]:
results_history_and_reasoning_ft = []
for task_name in task_names:
    async with AsyncTensorZeroGateway("http://localhost:3000", timeout=180.0) as client:
        results_task = await run_episodes(
            client=client,
            variant_name="history_and_reasoning_ft",
            env_name="babyai",
            task_name=task_name,
            num_episodes=num_episodes,
            config=config,
            semaphore=semaphore,
            disable_progress_bar=False,
            history_length=history_length,
            seed=seed,
            test=True,
        )
        results_history_and_reasoning_ft.extend(results_task)

## Results

We combine the results of the fine-tuned model with the results of the history_and_reasoning variant.

In [None]:
df_ft = pd.DataFrame(results_history_and_reasoning_ft)

df = pd.concat([df, df_ft])

We see below that the fine-tuned model performs better than the history_and_reasoning variant!

### Success rate

In [None]:
summary = df.groupby("variant")["progression"].agg(["mean", "sem"]).reset_index()

# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Task success rate")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode return

In [None]:
summary = df.groupby("variant")["episode_return"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode return")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode length

In [None]:
summary = df.groupby("variant")["num_steps"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode length")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode generated token count

In [None]:
summary = df.groupby("variant")["output_tokens"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode generated token count")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()

### Episode input token count

In [None]:
summary = df.groupby("variant")["input_tokens"].agg(["mean", "sem"]).reset_index()


# Plot the data
plt.figure(figsize=(8, 6))
ax = sns.pointplot(
    data=summary,
    x="variant",
    y="mean",
    linestyle="none",
    capsize=0.1,
    err_kws={"linewidth": 1},
    color="C0",
    markers="o",
)

# Add error bars for ±1 SEM
ax.errorbar(
    summary["variant"],
    summary["mean"],
    yerr=summary["sem"],
    fmt="o",
    color="C0",
    capsize=5,
)

# Customize the plot
ax.set_title("Episode input token count")
ax.set_ylabel("Value ± 1 SEM")
ax.set_xlabel("Variant")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()

plt.show()