# Building Environments in Aviary

In this tutorial we'll focus on constructing a language agent environment where the agent employs a calculator **tool** to answer questions from the Grade School Mathematics 8K (GSM8K) dataset introduced in [1]. GSM8K consists of linguistically diverse grade school math word problems designed to assess multi-step mathematical reasoning. The GSM8K dataset comprises a training set of 7,473 questions and a test set of 1,319 questions. The language agent is equipped with the following tools:

1. **Calculator(Expression)**: Return the result of a numerical expression.
2. **Check(Answer)**: Check if the answer is correct.

The GSM8K environment can be represented as a Markov Decision Process $(\mathcal{V}, \mathcal{S}, \mathcal{A}, \mathcal{T}, R, \gamma)$ as


- $\mathcal{V}:$ The vocabulary - the English alphabet, together with punctuation symbols.
- $\mathcal{S}:$ The state - GSM8K question, current step in the reasoning process.
- $\mathcal{A}:$ The action - \{Calculator, Check Answer\}.
- $\mathcal{T}:$ The deterministic transition function.
- $R:$ The reward function:

    \begin{cases} 
    1 & \text{if the action submits a correct answer}, \\
    -1 & \text{if the action is an invalid tool call}, \\
    0 & \text{otherwise}.
    \end{cases}

- $\gamma:$ The discount factor

Below we define the GSM8K environment in our language agent library called **Aviary**. The framework for defining an environment losely follows the template for reinforcement learning environments from Spinning up in Deep RL [2]. Custom implementations of Aviary environments should inherit from the `Environment` class and should implement the **reset**, and **step** methods in addition to defining functions for the tools the agent has access to.

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "your_API_key"

In [None]:
import contextlib
import json
from typing import Literal

from aviary.envs.gsm8k.env import SafeMathEvaluator
from pydantic import BaseModel, ConfigDict

from aviary.core import (
    Environment,
    Message,
    TaskDataset,
    Tool,
    ToolRequestMessage,
    ToolResponseMessage,
)


class CalculatorEnvConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    correct_reward: float = 1.0
    incorrect_reward: float = 0.0
    tool_failure_reward: float = -1.0
    tool_success_reward: float = 0.0
    rel_tol: float = 1e-4

    done_on_failure: bool = True


class CalculatorEnv(Environment[None]):
    def __init__(
        self,
        problem_id: str,
        problem: str,
        answer: float,
        config: CalculatorEnvConfig | None = None,
    ):
        """An environment for solving simple math problems using a calculator tool and a check answer tool.

        Arguments:
            problem_id: The unique identifier for the problem (index in the GSM8K dataset)
            problem: The calculation problem in string format.
            answer: The ground truth (correct) answer to the calculation problem.
            config: Configuration for environment behavior and rewards.
            calc_tool: The calculator tool used for performing calculations.
            check_tool: The tool for checking the agent's answer against the correct answer.
            tools: List of available tools for the environment.
        """
        self.problem_id = problem_id
        self.problem = problem
        self.answer = float(answer)
        self.config = config or CalculatorEnvConfig()

        self.calc_tool = Tool.from_function(self.calculator)
        self.check_tool = Tool.from_function(self.check_answer)
        self.tools = [self.calc_tool, self.check_tool]

    async def reset(self) -> tuple[list[Message], list[Tool]]:
        """Resets the environment, returning the initial problem statement and available tools.

        Returns:
            A tuple containing:
                - A list with the problem message for the agent to view.
                - A list of tools available for the agent to interact with.
        """
        self.state = None  # this environment is effectively stateless
        return [Message(content=self.problem)], self.tools

    async def step(
        self, action: ToolRequestMessage
    ) -> tuple[list[Message], float, bool, bool]:
        """Processes an action (tool request) and returns the result.

        Args:
            action: The action message containing tool requests.

        Returns:
            A tuple containing:
                - A list of response messages from the tools e.g. the output of the calculator
                - The total reward accumulated from the tool calls.
                - A flag indicating if the episode is done (e.g., answer is correct or a tool failure).
                - Always False, as this environment does not use truncation.
        """
        # We must use a tool at each step
        if not action.tool_calls:
            return (
                [
                    Message(
                        content="Must call one of the provided tools (calculator or check_answer)."
                    )
                ],
                self.config.tool_failure_reward,
                self.config.done_on_failure,
                False,
            )

        # Detect invalid tool calls
        valid_action, invalid_action = self.filter_invalid_tool_calls(action)

        invalid_response_msgs = [
            ToolResponseMessage.from_call(tool_call, content="")
            for tool_call in invalid_action.tool_calls
        ]

        # Execute tool calls
        if valid_action.tool_calls:
            results = await self.exec_tool_calls(valid_action)
            response_msgs = []
            total_reward = 0.0
            any_done = False

            for tool_call, result in zip(valid_action.tool_calls, results, strict=True):
                response, reward, done = json.loads(result.content)

                response_msgs.append(
                    ToolResponseMessage.from_call(tool_call, content=str(response))
                )

                total_reward += reward
                any_done |= done

            return response_msgs + invalid_response_msgs, total_reward, any_done, False

        return (
            invalid_response_msgs,
            self.config.tool_failure_reward * len(invalid_response_msgs),
            self.config.done_on_failure,
            False,
        )

    # We define our tools, check_answer and calculator, below

    def check_answer(self, answer: str) -> tuple[bool, float, Literal[True]]:
        """Check if the proposed answer is correct.

        Args:
            answer: Proposed answer.

        Returns:
            Three-tuple of if correct, associated reward (correct_reward if correct,
                tool_failure_reward if tool failure, otherwise incorrect_reward), and
                True indicating done.
        """
        try:
            correct: bool = (
                abs(float(answer) - self.answer)
                / (abs(self.answer) + self.config.rel_tol)
                < self.config.rel_tol
            )
            reward = (
                self.config.correct_reward if correct else self.config.incorrect_reward
            )
        except ValueError:
            return False, self.config.tool_failure_reward, True
        else:
            return correct, reward, True

    def calculator(self, expr: str) -> tuple[float | str, float, bool]:
        """Calculate a mathematical expression using a secure evaluator.

        Args:
            expr: A valid mathematical expression.

        Returns:
            A three-tuple where the first element is the float evaluation if successful,
                or a string containing the failure cause if unsuccessful, the second
                element is the reward associated with success or failure, and the third
                element is a boolean indicating if this action is terminal.
        """
        try:
            expr = expr.strip()
            result = SafeMathEvaluator.evaluate(expr)
            with contextlib.suppress(ValueError):  # If possible, downcast float to int
                if int(result) == result:
                    result = int(result)
        except Exception as exc:
            return (
                f"Error using calculator: {exc!r}.",
                self.config.tool_failure_reward,
                self.config.done_on_failure,
            )
        return result, self.config.tool_success_reward, False

We next define an instance of `TaskDataset` to load the GSM8K dataset from Hugging Face and provide convenience methods for accessing specific questions, namely `get_new_env_by_idx`.

In [None]:
from enum import StrEnum

import datasets
import pandas as pd

# SEE: https://huggingface.co/datasets/openai/gsm8k
GSM8K_PUBLIC_SOURCE = "openai/gsm8k"


class GSM8kDataset(TaskDataset):
    """A dataset class for GSM8K.

    Attributes:
        config: Configuration for the Calculator environment.
        src_df: DataFrame containing dataset problems, processed to include
                numerical answers and unique problem IDs.

    """

    class Split(StrEnum):
        train_full = "train_full"  # full training set from OpenAI
        train = "train"  # 80% of train_full
        val = "val"  # 20% of train_full
        test = "test"

    def __init__(
        self,
        split: Split | str,
        config: CalculatorEnvConfig | None = None,
        hf_source: str = GSM8K_PUBLIC_SOURCE,
    ):
        self.config = config

        if isinstance(split, str):
            split = self.Split(split)

        src_df = self._get_df_from_hf(hf_source, split)

        # Assign problem ID for the env
        src_df["problem_id"] = split.value + "_" + src_df.index.astype(str)

        # attempt to extract a numerical answer
        try:
            src_df["answer_num"] = src_df["answer"].apply(
                # answer is formatted as: <some text>\n#### <answer_num>
                lambda a: float(a.split("#### ")[1].replace(",", ""))
            )
        except Exception as e:
            raise RuntimeError(
                "Failed to extract numerical answer from 'answer' column"
            ) from e

        self.src_df = src_df

    def _get_df_from_hf(self, hf_source: str, split: Split) -> "pd.DataFrame":
        """Loads the GSM8K dataset from the Hugging Face hub and processes it based on the specified split.

        Args:
            hf_source: The Hugging Face source identifier for the GSM8K dataset.
            split: The specified dataset split.

        Returns:
            DataFrame containing the GSM8K problems, filtered and split according to the specification.
        """
        # All non-test splits are derived from train
        hf_split = "test" if split == self.Split.test else "train"

        kw = {}
        if hf_source == GSM8K_PUBLIC_SOURCE:
            kw["name"] = "main"  # as opposed to "socratic"

        src_df = (
            datasets.load_dataset(hf_source, split=hf_split, **kw)
            .to_pandas()
            .reset_index(drop=True)
        )
        if split == self.Split.train:
            src_df = src_df[src_df.index % 5 != 0]
        elif split == self.Split.val:
            src_df = src_df[src_df.index % 5 == 0]
        return src_df

    def get_new_env_by_idx(self, idx: int) -> CalculatorEnv:
        """Creates a new Calculator environment instance for a specific problem in the dataset.

        Args:
            idx: Index of the problem in the dataset.

        Returns:
            A new Calculator environment initialized with the problem ID,
            question, answer, and configuration.
        """
        row = self.src_df.iloc[idx]
        return CalculatorEnv(
            problem_id=row["problem_id"],
            problem=row["question"],
            answer=row["answer_num"],
            config=self.config,
        )

    def __len__(self) -> int:
        """Returns the number of problems in the dataset."""
        return len(self.src_df)

In the next cell, we perform a **rollout** using the `ToolSelector` agent using GPT-4o as the base LLM. 

In [None]:
# ToolSelector is a simple language agent that directly prompts a language model to call a tool.
from aviary.core import ToolSelector

dataset = GSM8kDataset(split="train")
total_reward = 0
verbose = True  # Whether to explicitly print the environment obserations and actions.
n_questions = 3
max_steps = 5  # Maximum number of actions the agent can take.

for i in range(n_questions):
    print(f"Evaluating problem number {i + 1}:\n")
    env = dataset.get_new_env_by_idx(i)
    obs, tools = await env.reset()
    tool_selector = ToolSelector(model_name="gpt-4o", accum_messages=True)
    for _ in range(max_steps):
        action = await tool_selector(obs, tools)

        if verbose:
            print(f"Observation is {obs}\n")
            print(f"Action is {action}\n")

        obs, reward, done, _ = await env.step(action)
        total_reward += reward
        print()
        if done:
            break

print(f"Accuracy is {total_reward / n_questions * 100}%")

Evaluating problem number 1:

Observation is [Message(role='user', content='Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?')]

Action is Tool request message '' for tool calls: calculator(expr='(50/60) * 12') [id=call_AmfsBdqTgU3jV9Ni3dWWDi7t]


Observation is [ToolResponseMessage(role='tool', content='10', name='calculator', tool_call_id='call_AmfsBdqTgU3jV9Ni3dWWDi7t')]

Action is Tool request message '' for tool calls: check_answer(answer='10') [id=call_1HKFkh8BhOs7BBPFAojuoAyE]


Evaluating problem number 2:

Observation is [Message(role='user', content='Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?')]

Action is Tool request message '' for tool calls: calculator(expr='100 / 2') [id=call_Ts9iE2ZZSjBVOTEFl

## References

[1] Cobbe, K., Kosaraju, V., Bavarian, M., Chen, M., Jun, H., Kaiser, L., Plappert, M., Tworek, J., Hilton, J., Nakano, R. and Hesse, C., 2021. [Training verifiers to solve math word problems](https://arxiv.org/abs/2110.14168). arXiv preprint arXiv:2110.14168.

[2] Achiam, J., [Spinning up in Deep Reinforcement Learning](https://spinningup.openai.com/en/latest/). 2018.