# Llama Agent Inference (GPU)

A notebook demonstrating how to use an agent powered by a local model (Llama 3.1 8B) to solve problems in the GSM8k test set using a calculator. **The tutorial assumes access to a GPU machine**.

Note that the model struggles to output valid tool call syntax consistently, causing many failures.

TODO: add EI notebook demonstrating how to address this

NB: To run this notebook you need to install ldp with the `nn` dependency as well as the aviary dependencies

```bash
pip install "ldp[nn]" "fhaviary[gsm8k]"
```

In [None]:
from aviary.envs.gsm8k import GSM8kDataset

from ldp.alg import Evaluator, EvaluatorConfig
from ldp.alg.callbacks import Callback
from ldp.data_structures import Trajectory
from ldp.nn import AgentLMConfig, SimpleLocalLLMAgent, TorchDType

In [None]:
class AccuracyCallback(Callback):
    """Simple callback that logs accuracy of each batch."""

    async def after_eval_step(self, trajectories: list[Trajectory]):
        # CalculatorEnvironment returns a terminal reward of 1 if
        # the agent solved the problem correctly.
        pass_rate = sum(t.steps[-1].reward == 1 for t in trajectories) / len(
            trajectories
        )
        print(f"Pass rate: {100 * pass_rate:.2f}%")

In [None]:
lm_config = AgentLMConfig(
    model="meta-llama/Llama-3.1-8B-Instruct",
    dtype=TorchDType.bf16,
    chat_template="llama3.1_chat_template_thought.jinja",
    max_new_tokens=100,
    # Parameters for async inference
    batch_size=8,  # fits onto a single 4090 with these params
    max_wait_interval=10.0,
)
agent = SimpleLocalLLMAgent(lm_config)

In [None]:
dataset = GSM8kDataset(split="test")
callback = AccuracyCallback()
evaluator = Evaluator(
    config=EvaluatorConfig(
        batch_size=64,
        num_eval_iterations=1,  # Only run one batch, then exit
        max_rollout_steps=10,
    ),
    agent=agent,
    dataset=dataset,
    callbacks=[callback],
)

# Note that Llama 3.1 8B does not always reliably follow the tool-calling
# syntax, so we will see several (caught) errors. The pass rate will be <10%.
await evaluator.evaluate()