<div class="align-center">
<a href="https://oumi.ai/"><img src="https://oumi.ai/docs/en/latest/_static/logo/header_logo.png" height="200"></a>

[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)
[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)
[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)
</div>

👋 Welcome to Open Universal Machine Intelligence (Oumi)!

🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.

🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.

⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi).

# Train a Letter Counting Model using GRPO

This notebook focuses on the popular LLM prompt: "How Many R’s Are in the Word Strawberry?". First, we will use a custom evaluation function to evaluate many popular models on the task of counting letters in words. Then, we will use GRPO to train a model to improve its performance on this task.

## Prerequisites

### Machine Requirements

❗**NOTICE:** This notebook doesn't run on Colab due to memory requirements.

It is recommended to run this notebook on a machine which has a GPU with at least TODO GB VRAM. If your local machine cannot run this notebook, you can instead run this notebook on a cloud platform. The following demonstrates how to open a VSCode instance backed by a GCP node with 4 A100 GPUs, from which the notebook can be run.

```bash
# Run on your local machine
gcloud auth application-default login  # Authenticate with GCP
make gcpcode ARGS="--resources.accelerators A100:4"
```

### Oumi Installation

First, let's install Oumi and vLLM. You can find more detailed instructions about Oumi installation [here](https://oumi.ai/docs/en/latest/get_started/installation.html). Here, we include Oumi's GPU dependencies.

In [None]:
%pip install oumi[gpu]

### API Keys

As part of this notebook, you can evaluate frontier models from Open AI, Google, and Anthropic on the letter counting task. If you want to evaluate any of these models, set the corresponding fields below.

In [1]:
import os

os.environ["OPENAI_API_KEY"] = ""  # Set your OpenAI API key here.
os.environ["GEMINI_API_KEY"] = ""  # Set your Gemini API key here.
os.environ["ANTHROPIC_API_KEY"] = ""  # Set your  Anthropic API key here.

# Set your GCP project id and region, to be able to query Llama 3.1 405B in Vertex.
REGION = ""  # Set your GCP region here.
PROJECT_ID = ""  # Set your GCP project id here.

### Tutorial Directory

Next, we'll set up a directory to use for this tutorial.

In [2]:
from pathlib import Path

tutorial_dir = "letter_counting_tutorial"

Path(tutorial_dir).mkdir(parents=True, exist_ok=True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Disable warnings from HF.

## Dataset

In [3]:
from pprint import pprint

from oumi.datasets.grpo.letter_count import LetterCountGrpoDataset

dataset = LetterCountGrpoDataset(split="validation")
print("-" * 80)
# print("Raw example:")
# pprint(dataset.raw(0).to_dict())
print("Evaluation example:")
pprint(dataset.conversation(0).to_dict())

[2025-04-08 21:32:11,092][oumi][rank0][pid:10161][MainThread][INFO]][base_map_dataset.py:91] Creating map dataset (type: LetterCountGrpoDataset)... dataset_name: 'oumi-ai/oumi-letter-count'
[2025-04-08 21:32:14,480][oumi][rank0][pid:10161][MainThread][INFO]][base_map_dataset.py:487] Dataset Info:
	Split: validation
	Version: 0.0.0
	Dataset size: 22894322
	Download size: 5697295
	Size: 28591617 bytes
	Rows: 10000
	Columns: ['conversation_id', 'messages', 'metadata']
[2025-04-08 21:32:14,598][oumi][rank0][pid:10161][MainThread][INFO]][base_map_dataset.py:426] Loaded DataFrame with shape: (10000, 3). Columns:
conversation_id    object
messages           object
metadata           object
dtype: object
--------------------------------------------------------------------------------
Evaluation example:
{'conversation_id': 'oumi_letter_count_0',
 'messages': [{'content': "Could you determine the count of 'l's in "
                          "'substantial'?",
               'role': 'user'},
    

## Evaluation

See custom evaluation function at `src/oumi/evaluation/registry/count_letters_task.py`.
TODO: Can we print its contents?

In [None]:
NUM_SAMPLES = 5
# NUM_SAMPLES = 100

model_names = [
    "llama_3b",
    # Uncomment any models you wish to evaluate - you can evaluate multiple at once.
    # "gpt_4o",
    # "o1_preview",
    # "gemini_pro",
    # "llama_405b",
    # "claude_sonnet",
]

In [None]:
configs = {
    "llama_3b": """
      model:
        model_name: "HuggingFaceTB/SmolLM2-135M-Instruct"
        # model_name: "meta-llama/Llama-3.2-3B-Instruct"
        model_max_length: 131072
        torch_dtype_str: "bfloat16"
        attn_implementation: "sdpa"
        trust_remote_code: True

      generation:
        max_new_tokens: 2048

      tasks:
        - evaluation_backend: custom
          task_name: count_letters

      # inference_engine: VLLM
      output_dir: "letter_counting_tutorial/evaluation"
      """,
    "gpt_4o": """
      model:
        model_name: "gpt-4o"

      inference_engine: OPENAI

      inference_remote_params:
        api_key_env_varname: "OPENAI_API_KEY"
        max_retries: 3
        num_workers: 100
        politeness_policy: 60
        connection_timeout: 300

      generation:
        max_new_tokens: 8192
        temperature: 0.0

      tasks:
        - evaluation_backend: custom
          task_name: hallucination_classification
      """,
    "o1_preview": """
      model:
        model_name: "o1-preview"

      inference_engine: OPENAI

      inference_remote_params:
        api_key_env_varname: "OPENAI_API_KEY"
        max_retries: 3
        num_workers: 100
        politeness_policy: 60
        connection_timeout: 300

      generation:
        max_new_tokens: 8192
        temperature: 1.0

      tasks:
        - evaluation_backend: custom
          task_name: hallucination_classification
      """,
    "gemini_pro": """
      model:
        model_name: "gemini-2.5-pro-preview-03-25"

      inference_engine: GOOGLE_GEMINI

      inference_remote_params:
        api_key_env_varname: "GEMINI_API_KEY"
        max_retries: 3
        num_workers: 2
        politeness_policy: 60
        connection_timeout: 300

      generation:
        max_new_tokens: 8192
        temperature: 0.0

      tasks:
        - evaluation_backend: custom
          task_name: hallucination_classification
      """,
    "llama_405b": f"""
      model:
        model_name: "meta/llama-3.1-405b-instruct-maas"

      inference_engine: GOOGLE_VERTEX

      inference_remote_params:
        api_url: "https://{REGION}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{REGION}/endpoints/openapi/chat/completions"
        max_retries: 3
        num_workers: 10
        politeness_policy: 60
        connection_timeout: 300

      generation:
        max_new_tokens: 8192
        temperature: 0.0

      tasks:
        - evaluation_backend: custom
          task_name: hallucination_classification
      """,
    "claude_sonnet": """
      model:
        model_name: "claude-3-7-sonnet-latest"

      inference_engine: ANTHROPIC

      inference_remote_params:
        api_key_env_varname: "ANTHROPIC_API_KEY"
        max_retries: 3
        num_workers: 5
        politeness_policy: 65
        connection_timeout: 300

      generation:
        max_new_tokens: 8192
        temperature: 0.0

      tasks:
        - evaluation_backend: custom
          task_name: hallucination_classification
      """,
}

In [None]:
from oumi.core.configs import EvaluationConfig
from oumi.core.evaluation import Evaluator

results = {}

for model_name in model_names:
    # Create the evaluation config from the YAML string.
    config_yaml: str = configs[model_name]
    config = EvaluationConfig.from_str(config_yaml)
    config.tasks[0].num_samples = NUM_SAMPLES

    # Run the evaluation.
    evaluator = Evaluator()
    evaluator_out = evaluator.evaluate(config)

    # # Record the results.
    results[model_name] = evaluator_out[0].get_results()

In [None]:
print(f"Total samples: {NUM_SAMPLES}")
for model_name, result in results.items():
    print("-" * 80)
    print(f"Model: {model_name}")
    print(f"Accuracy: {result['accuracy']}")
    correct = result["num_correct_answers"]
    incorrect = result["num_incorrect_answers"]
    invalid = result["num_invalid_answers"]
    print(f"Num correct, incorrect, invalid: {correct}, {incorrect}, {invalid}")

## GRPO

Set `training.enable_wandb` to True if you want to log your training run to Weights and Biases. In addition, you must also log into WandB, ex. by running `wandb login`.

In [None]:
%%writefile $tutorial_dir/train.yaml

model:
  model_name: "meta-llama/Llama-3.2-3B-Instruct"
  model_max_length: 8192
  torch_dtype_str: "bfloat16"
  attn_implementation: "sdpa"

data:
  train:
    datasets:
      - dataset_name: "oumi-ai/oumi-letter-count"
        split: "train"

training:
  trainer_type: "TRL_GRPO"
  save_steps: 500
  max_steps: 500
  per_device_train_batch_size: 2
  gradient_accumulation_steps: 1
  learning_rate: 5e-5

  reward_functions: ["count_letters"]

  ddp_find_unused_parameters: False
  optimizer: "adafactor"
  compile: True

  grpo:
    num_generations: 4

  dataloader_num_workers: "auto"
  dataloader_prefetch_factor: 32

  logging_steps: 10
  output_dir: "letter_counting_tutorial/llama_3b_grpo"
  # Set this to True if you want to log to Weights and Biases.
  enable_wandb: False


## Evaluating our Trained Model

Let's now evaluate our trained model.

In [None]:
# Create the evaluation config from the YAML string.
config_yaml: str = configs["llama_3b"]
config = EvaluationConfig.from_str(config_yaml)
config.tasks[0].num_samples = NUM_SAMPLES
config.model.model_name = "letter_counting_tutorial/llama_3b_grpo"

# Run the evaluation.
evaluator = Evaluator()
evaluator_out = evaluator.evaluate(config)

# # Record the results.
trained_model_results = evaluator_out[0].get_results()

print(f"Accuracy: {trained_model_results['accuracy']}")
correct = trained_model_results["num_correct_answers"]
incorrect = trained_model_results["num_incorrect_answers"]
invalid = trained_model_results["num_invalid_answers"]
print(f"Num correct, incorrect, invalid: {correct}, {incorrect}, {invalid}")