In [None]:
import asyncio
import logging
from rose.metrics import GREATER_THAN_THRESHOLD
from rose.rl.reinforcement_learner import SequentialReinforcementLearner

from radical.asyncflow import WorkflowEngine
from radical.asyncflow import ConcurrentExecutionBackend
from concurrent.futures import ProcessPoolExecutor
from radical.asyncflow.logging import init_default_logger

logger = logging.getLogger(__name__)

In [None]:
run_name="roserun-1"
import os
try:
    from google.colab import userdata
    os.environ["HF_TOKEN"] = userdata.get('hf_token')
    userdata.get('hf')
except:
    os.environ["HF_TOKEN"] = ""
    os.environ["HF_HOME"] = "/work/hdd/bdyk/apark4/huggingface"
import torch
import re
import random
import time
from datasets import load_dataset, Dataset
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers.trainer_utils import get_last_checkpoint
from trl import GRPOConfig, GRPOTrainer
import trackio
import gc
trackio.init(project="huggingface", space_id="iznoanygod/trackio", name=run_name, embed=False)

In [None]:
engine = await ConcurrentExecutionBackend(ProcessPoolExecutor())
asyncflow = await WorkflowEngine.create(engine)
rl = SequentialReinforcementLearner(asyncflow)

In [None]:
max_seq_length = 2048
max_prompt_length = 1024
lora_rank = 16

base_model_id="meta-llama/Llama-3.1-8B-Instruct"
lora_id = "math_lora"

In [None]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
You must use LaTeX to format mathematical expressions, and you must use \\boxed{...} to indicate the final answer.
"""

In [None]:
def get_answer(expr: str):
    match = re.search(r"\\boxed\{(.+?)\}", expr)
    if match:
        return match.group(1).strip()
    return None

def correctness_reward_func(prompts, completions, ground_truth, **kwargs):
    rewards = []
    for prompt, completion, ground in zip(prompts, completions, ground_truth):
        c = get_answer(completion[0]["content"])
        g = get_answer(ground)
        reward = 2.0 if g == c else 0
        rewards.append(reward)
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

In [None]:
def to_prompt_completion(example):
    return {
        "prompt": [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': example['problem']}
            
        ],
        "ground_truth": str(example["solution"]).strip(),
    }

In [None]:
def load_llama_or_latest_checkpoint(
    base_model_id: str,
    lora_id: str,
    dtype=torch.bfloat16,
    device_map="auto",
):
    """
    If `output_dir` contains checkpoints for this run, load the latest one.
    Otherwise, load the base model from Hugging Face Hub.
    """

    last_checkpoint = None

    if os.path.isdir(lora_id):
        last_checkpoint = get_last_checkpoint(lora_id)
        
    tokenizer = AutoTokenizer.from_pretrained(base_model_id, padding_side="left", use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token
    iteration = 0
    if last_checkpoint is not None:
        print(f"Found LoRA checkpoint at: {last_checkpoint}")
        print(f"Loading base model: {base_model_id}")
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            dtype=dtype,
            device_map=device_map,
        )
        m = re.search(r"checkpoint-(\d+)", last_checkpoint)
        if m:
            iteration = int(m.group(1))
        else:
            print(f"Warning: could not parse iteration from {basename}, leaving iteration=0")
        # Attach LoRA adapter weights
        print("Applying LoRA adapter from checkpoint...")
        model = PeftModel.from_pretrained(base_model, last_checkpoint, is_trainable=True)
        loaded_from = last_checkpoint
    else:
        print(f"No checkpoint found, loading base model: {base_model_id}")
        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            dtype=dtype,
            device_map=device_map,
        )
        model = get_peft_model(model, lora_config)
        loaded_from = base_model_id

    return model, tokenizer, loaded_from, iteration

In [None]:
gpu_count = torch.cuda.device_count()
def memory_stats():
    print("memory allocated: ", [torch.cuda.memory_allocated(i)/1024**2 for i in range(gpu_count)])
    print("memory reserved: ", [torch.cuda.memory_reserved(i)/1024**2 for i in range(gpu_count)])
    for i in range(gpu_count):
        print(torch.cuda.memory_summary(i))

In [None]:
async def run(rl, **kwargs):
    @rl.update_task(as_executable=False)
    async def update(*args, **kwargs) -> dict:
        model, tokenizer, loaded_from, iteration = load_llama_or_latest_checkpoint(
            base_model_id=base_model_id,
            lora_id=lora_id,
            dtype=torch.bfloat16,
        )
        model.print_trainable_parameters()
        logger.info("Model Loaded...")
        training_args = GRPOConfig(
            learning_rate = 5e-6,
            adam_beta1 = 0.9,
            adam_beta2 = 0.99,
            weight_decay = 0.1,
            warmup_ratio = 0.1,
            lr_scheduler_type = "cosine",
            optim = "paged_adamw_8bit",
            logging_steps = 1,
            generation_batch_size = 8,
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 1,
            bf16=True,
            gradient_checkpointing=False,
            num_generations = 8,
            max_prompt_length = max_prompt_length,
            max_completion_length = max_seq_length,
            num_train_epochs = 1,
            save_steps = 10,
            max_steps = iteration+50,
            max_grad_norm = 0.1,
            report_to = "trackio",
            run_name="roserun-1",
            output_dir = lora_id,
        )
        dataset = load_dataset("qwedsacf/competition_math", split="train")
        mapped = dataset.map(to_prompt_completion, remove_columns=dataset.column_names).shuffle()
        logger.info("Configured...")
        trainer = GRPOTrainer(
            model = model,
            processing_class = tokenizer,
            reward_funcs = [
                strict_format_reward_func,
                correctness_reward_func,
            ],
            args = training_args,
            train_dataset = mapped,
        )
        logger.info("Starting Training...")
        memory_stats()
        if loaded_from == base_model_id:
            trainer.train(resume_from_checkpoint=False)
        else:
            trainer.train(resume_from_checkpoint=loaded_from)
        logger.info("Finished Training...")
        memory_stats()
        del model
        del tokenizer
        del trainer
        del dataset
        del mapped
        with torch.no_grad():
            torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        gc.collect()
        logger.info("Cleaned up memory...")
        memory_stats()
    # ========================================================================
    # 3. STOP CRITERION TASK
    # ========================================================================
    @rl.as_stop_criterion(metric_name='MODEL_REWARD', threshold=.5, operator=GREATER_THAN_THRESHOLD, as_executable=False)
    async def check_reward(*args, **kwargs):
        def rewards_func(prompts, completions, ground_truth, **kwargs):
            rewards = []
            for prompt, completion, ground in zip(prompts, completions, ground_truth):
                c = get_answer(completion)
                g = get_answer(ground)
                reward = 1.0 if g == c else 0
                rewards.append(reward)
            return rewards
        batch_size=8
        iteration=1
        model, tokenizer, loaded_from, _ = load_llama_or_latest_checkpoint(
            base_model_id=base_model_id,
            lora_id=lora_id,
            dtype=torch.bfloat16,
        )
        dataset = load_dataset("qwedsacf/competition_math", split="train")
        mapped = dataset.map(to_prompt_completion, remove_columns=dataset.column_names)
        total_correct=0.0
        logger.info("Starting grading...")
        memory_stats()
        for step in range(iteration):
            shuffled_dataset = mapped.shuffle()
            batch = shuffled_dataset.select(range(batch_size))
            messages = [
                [{"role": "user", "content": ex["prompt"]}]
                for ex in batch
            ]
            inputs = tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                padding=True,
                return_tensors="pt",
            ).to(model.device)
            with torch.no_grad():
                outputs = model.generate(
                    inputs,
                    max_new_tokens=2048,
                    do_sample=True,      # sampling
                    top_p=0.9,
                    temperature=0.7,
                    pad_token_id = tokenizer.eos_token_id
                    # don't usually mix beam search + sampling;
                    # if you want beam search, drop top_p/temperature and set num_beams>1
                )
                
            texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            rewards = rewards_func([prompt["prompt"] for prompt in batch], texts, [prompt["ground_truth"] for prompt in batch])
            del shuffled_dataset
            del batch
            del inputs
            del outputs
            del texts
            total_correct = total_correct + sum(rewards)/2
        logger.info("Finished grading...")
        memory_stats()
        del model
        del tokenizer
        del dataset
        del mapped
        with torch.no_grad():
            torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        gc.collect()
        logger.info("Cleaned up memory...")
        memory_stats()
        return total_correct / (iteration*batch_size)

    # Run
    logger.info("Starting Reinforcement Learning with ROSE...")
    await rl.learn(skip_simulation_step=True,**kwargs)
    logger.info("Reinforcement Learning completed!")

try:
    engine = await ConcurrentExecutionBackend(ProcessPoolExecutor())
    asyncflow = await WorkflowEngine.create(engine)
    rl = SequentialReinforcementLearner(asyncflow)

    init_default_logger(logging.INFO)
    await run(rl, max_iter=1)
except Exception as e:
    print(f'Learner Failed with: {e}')
finally:
    await rl.shutdown()
    logging.getLogger().handlers.clear()