## Tutorial: Getting Started

### 0️⃣ Set up environment and install dependencies

Please make sure you have set up the environment and installed required libraries by following the steps in the rl-swarm README.md and running the `run_rl_swarm.sh` script

### 1️⃣ Import dependencies

In this step, we import all the necessary libraries and set up logging

In [1]:
import os
import sys

# Get the path to src/ relative to current notebook
src_path = os.path.join(os.getcwd(), 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

In [2]:
import torch
from typing import List

# Import huggingface transformers for loading pre-trained language models
from transformers import AutoModelForCausalLM

# Import genrl_swarm modules for data, game management, rewards, and training
from genrl_swarm.communication.distributed.null_comm import NullCommunicationBackend
from genrl_swarm.data.text_data_managers import SimpleTextDataManager
from genrl_swarm.game import BaseGameManager
from genrl_swarm.state import GameState
from genrl_swarm.rewards import text_games_reward_utils
from genrl_swarm.trainer.GRPOTrainer import GRPOTrainerModule

import logging

# Set up root logger to display INFO level logs
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

# This ensures that log messages from imported modules propagate to the root logger
logging.getLogger().setLevel(logging.INFO)

  from .autonotebook import tqdm as notebook_tqdm


### 2️⃣ Define Constants and Utility Functions

In [3]:
# We define maximum number of rounds of RL training. Each round can consist of multiple stages but we limit this example to a single stage.
MAX_ROUNDS = 10
MAX_STAGES = 1

In [4]:
# The GSM8k dataset stores answers like "#### 42". We define a function to extract that answer.
# You can modify this function if you're working with a different dataset format.
def extract_hash_answer(text: str) -> str | None: 
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

### 3️⃣ Prepare the Dataset with SimpleTextDataManager

In [5]:
# The system prompt guides the model to produce answers in a specific format.
# You can modify this to control how the model should think and answer.
# This system prompt will be prepended to each question for both training and evaluation.
SYSTEM_PROMPT = """
You are given a math problem, and you want to come up with the best possible answer. 
Think through the solution of the problem step by step and then state your final answer.
An ideal solution will satisfy three important criteria:
  1) Correct step-by-step reasoning.
  2) Clear and concise explanation.
  3) Final answer in the form: Answer: $Answer (without quotes)
Remember to put your answer on its own line after \"Answer:\".
"""

In [6]:
# SimpleTextDataManager handles dataset loading, preprocessing and feeding into the RL game.
# You can modify num_train_samples or num_evaluation_samples for larger or smaller training sets.
data_manager = SimpleTextDataManager(
    train_dataset="openai/gsm8k",
    evaluation_dataset="openai/gsm8k",
    data_subset="main",
    num_train_samples=2,
    column_name_map={'question': 'question', 'answer': 'answer'},
    column_preprocessing_map={'answer': extract_hash_answer},
    system_prompt=SYSTEM_PROMPT
)

### 4️⃣ Define Reward Conditions

In [7]:
# Here we define two types of reward conditions:
# - format_reward_condition checks if the model produces output in correct format
# - correctness_reward_condition checks if the actual answer matches expected answer
# You can adjust weights to control the importance of format vs correctness.
reward_conditions = [
    text_games_reward_utils.format_reward_condition(pattern=r"\nAnswer: \d+", weight=0.5),
    text_games_reward_utils.correctness_reward_condition(
        pattern=r'Answer: .*?([\d,]+(?:\.\d+)?)', weight=2.0)
]

### 5️⃣ Test Reward Function (Sanity Check)

In [8]:
# Before running full training, it's useful to manually verify reward calculation.
# We create some example completions and check how rewards are assigned.
completions = [
    "Question: 2+2\nAnswer: 4", 
    "Question: 2+2\nAnswer: 5",
    "Question: 2+2\nAnswer: 4.000",
    "Question: 2+1\nAnswer 3"
]

correct_answers = [4, 4, 4, 3]

# Calculate rewards for these samples
rewards = text_games_reward_utils.calculate_reward(
    completions=completions,
    correct_answers=correct_answers,
    reward_conditions=reward_conditions
)

print(rewards)

[2.5, 0.5, 2.5, 0.0]


In [9]:
# RewardManager manages how rewards are computed during training rounds. We initialize reward manager here with the reward conditions we defined in the previous step.
reward_manager = text_games_reward_utils.get_default_reward_manager(
    reward_conditions=reward_conditions, 
    max_rounds=MAX_ROUNDS
)

### 6️⃣ Load Model and Trainer

In [10]:
# We load a pretrained language model from HuggingFace.
# You can swap this model to experiment with different LLM backbones.
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
models = [AutoModelForCausalLM.from_pretrained(model_name)]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [11]:
# GRPOTrainerModule handles reinforcement learning updates.
# This is where the RL optimization happens.
trainer = GRPOTrainerModule(models)

2025-06-23 15:02:42,811 - genrl_swarm.logging_utils.global_defs - INFO - Invalid log type: None. Default to terminal logging


### 7️⃣ Initialize the Game Manager

In [12]:
# GameState keeps track of the current round and stage.
game_state = GameState(round=0, stage=0)

In [None]:
# BaseGameManager orchestrates the full RL game loop.
# You can adjust max_stage and max_round for more complex multi-stage setups.
game_manager = BaseGameManager(
    max_stage=MAX_STAGES,
    max_round=MAX_ROUNDS,
    game_state=game_state,
    reward_manager=reward_manager,
    trainer=trainer,
    data_manager=data_manager,
    run_mode="train",
    communication=NullCommunicationBackend()
)

In [14]:
# Defining an evaluation function to calculate rewards for the evaluation data, game_state, rewards and trainer defined by game_manager.
@torch.no_grad()
def evaluate(game_manager: BaseGameManager) -> List[float]:
    completions = []
    correct_answers = []

    eval_data = game_manager.data_manager.get_eval_data(split='test')[:10]
    for idx, world_state in eval_data:
        prompt = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": world_state.environment_states['question']}
            ]
        input_ids = game_manager.trainer.processing_class.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True, return_tensors="pt")
        input_ids = input_ids.to(game_manager.trainer.model.device)
        outputs = game_manager.trainer.model(input_ids)
        outputs = game_manager.trainer.model.generate(input_ids, attention_mask = torch.ones_like(input_ids), generation_config=game_manager.trainer.generation_config)

        answer = game_manager.trainer.processing_class.decode(outputs[0], skip_special_tokens=True)
        completions.append(answer)
        correct_answers.append(answer)
    
    rewards = text_games_reward_utils.calculate_reward(
        completions=completions,
        correct_answers=correct_answers,
        reward_conditions=reward_conditions
    )

    return rewards

In [15]:
# Let's evaluate the model before we start training
untrained_model_rewards = evaluate(game_manager)
print(untrained_model_rewards)
print("Average reward for model before GRPO training:", sum(untrained_model_rewards) / len(untrained_model_rewards))

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Average reward for model before GRPO training: 0.0


### 8️⃣ Run the Game Loop

In [16]:
# This kicks off the reinforcement learning game!
# It will repeatedly generate completions, and begin training the model.
game_manager.run_game()

2025-06-23 15:03:50,651 - genrl_swarm.logging_utils.global_defs - INFO - Starting round: 1/10.
Map: 100%|██████████| 2/2 [00:00<00:00, 595.78 examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 1146.61 examples/s]
2025-06-23 15:04:17,614 - genrl_swarm.logging_utils.global_defs - INFO - {'train/loss': 0.13290411233901978, 'train/rewards': 0.5}
2025-06-23 15:04:19,293 - genrl_swarm.logging_utils.global_defs - INFO - Starting round: 2/10.
Map: 100%|██████████| 2/2 [00:00<00:00, 728.11 examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 1209.95 examples/s]
2025-06-23 15:04:27,014 - genrl_swarm.logging_utils.global_defs - INFO - {'train/loss': 0.0, 'train/rewards': 0.0}
2025-06-23 15:04:29,247 - genrl_swarm.logging_utils.global_defs - INFO - Starting round: 3/10.
Map: 100%|██████████| 2/2 [00:00<00:00, 486.04 examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 1450.06 examples/s]
2025-06-23 15:04:44,409 - genrl_swarm.logging_utils.global_defs - INFO - {'train/loss': 0.14718927443027496, 'tr

In [None]:
new_rewards = evaluate(game_manager)
print(new_rewards)
print("Average reward for model after GRPO training:", sum(new_rewards) / len(new_rewards))

[0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0]
Average reward for model after GRPO training: 0.2


### ✅ Summary
- Loaded GSM8k dataset with preprocessing
- Defined reward functions
- Set up the genrl_swarm framework for RL
- Ran single-agent single-stage RL training using GRPO
The full genrl_swarm package allows much more complex multi-agent setups.

### 👉 Next Steps

Looking to build your own swarms? Start by looking at examples in the `genrl_swarm/examples` folder. They can be run by executing the corresponding recipe configuration in the `genrl_swarm/recipes` folder by changing the final line in the launch script `run_rl_swarm.sh` to point to the new configuration.

For example, to run the multistage version of gsm8k, you can run the following command while in the root `rl-swarm-private` repository. Make sure to set the relevant environment variables like done in the run_rl_swarm script.

```bash  
python "$ROOT/genrl-swarm-zh1p4ng/src/genrl_swarm/runner/swarm_launcher.py" \
    --config-path "$ROOT/genrl-swarm-zh1p4ng/recipes/multistage_math" \
    --config-name "msm_gsm8k_grpo.yaml"
```