In [1]:
import sys
sys.path.append("..")

import torch
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer

from core.trainers.ppo_trainer.custom_ppo import CustomPPOTrainer
from core.trainers.ppo_trainer.config import CustomPPOConfig
from core.custom_components.custom_reward_model.sentiment_reward_model import SentimentRewardModel
from trl.core import LengthSampler
from trl import AutoModelForCausalLMWithValueHead

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def build_dataset(
    config,
    dataset_name="stanfordnlp/imdb",
    input_min_text_length=2,
    input_max_text_length=8,
):
    """Build dataset for training."""
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

In [3]:
# import wandb

# wandb.init()

In [4]:
model_name = "meta-llama/Llama-3.2-1B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizers
child_tokenizer = AutoTokenizer.from_pretrained(model_name)
teacher_tokenizer = AutoTokenizer.from_pretrained(model_name)
child_tokenizer.pad_token = child_tokenizer.eos_token
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

# Load models with value heads for PPO
child_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)
teacher_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

# Move models to device
child_model.to(device)
teacher_model.to(device)

student_generation_args = {
    "max_new_tokens": 100,
    "temperature": 0.7,
    "top_p": 0.9,
    "do_sample": True,
    "pad_token_id": child_tokenizer.eos_token_id
}

teacher_generation_args = {
    "max_new_tokens": 150,
    "do_sample": False,
    "pad_token_id": teacher_tokenizer.eos_token_id
}




In [5]:
config = CustomPPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
    log_with="wandb",
    
    # Generation configs
    child_generation_args=student_generation_args,
    teacher_generation_args=teacher_generation_args,
    
    # custom interaction model
    custom_interaction_model="test",
)

# Initialize reward model
sent_args = {"top_k": None, "function_to_apply": "none", "batch_size": 16}
reward_model = SentimentRewardModel(device=device, sent_args=sent_args)

Device set to use cuda


In [6]:
 # Build dataset
dataset = build_dataset(config)

# Initialize PPO trainer
ppo_trainer = CustomPPOTrainer(
    config=config,
    child_model=child_model,
    teacher_model=teacher_model,
    ref_model=teacher_model,
    reward_model=reward_model,
    tokenizer=child_tokenizer,
    teacher_tokenizer=teacher_tokenizer,
    dataset=dataset,
)

[34m[1mwandb[0m: Currently logged in as: [33mhongyi-gu[0m ([33mhongyi-gu-netmind-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# Training loop
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

error_count = 0
for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #TODO: 
    try:
        # Run PPO step
        stats, scores = ppo_trainer.step(
            input_prompts=batch["query"]
        )
    except Exception as e:
        error_count += 1
        continue
    # Log stats

    ppo_trainer.log_stats(stats, batch, scores, columns_to_log=["query"])

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}


for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):
    pass

  0%|          | 0/194 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
