In [1]:
import sys
sys.path.append("..")

import torch
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer

from core.trainers.ppo_trainer.custom_ppo import CustomPPOTrainer
from core.trainers.ppo_trainer.config import CustomPPOConfig
from core.custom_components.custom_reward_model.sentiment_reward_model import SentimentRewardModel
from trl.core import LengthSampler
from trl import AutoModelForCausalLMWithValueHead

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def build_dataset(
    config,
    dataset_name="stanfordnlp/imdb",
    input_min_text_length=2,
    input_max_text_length=8,
):
    """Build dataset for training."""
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

In [3]:
# import wandb

# wandb.init()

In [4]:

config = CustomPPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
    log_with="wandb",
    # Generation configs
    child_generation_args={
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
        "pad_token_id": None,  # Will be set in the trainer
    },
    teacher_generation_args={
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
        "pad_token_id": None,  # Will be set in the trainer
    }
)

# Initialize models and tokenizers
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load models
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

# Initialize reward model
sent_args = {"top_k": None, "function_to_apply": "none", "batch_size": 16}
reward_model = SentimentRewardModel(device=device, sent_args=sent_args)

Device set to use cuda


In [5]:
config.batch_size

128

In [6]:
 # Build dataset
dataset = build_dataset(config)

# Initialize PPO trainer
ppo_trainer = CustomPPOTrainer(
    config=config,
    child_model=model,
    teacher_model=ref_model,
    reward_model=reward_model,
    tokenizer=tokenizer,
    teacher_tokenizer=tokenizer,
    dataset=dataset,
)

[34m[1mwandb[0m: Currently logged in as: [33mhongyi-gu[0m ([33mhongyi-gu-netmind-ai[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
# Training loop
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

error_count = 0
for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #TODO: 
    try:
        # Run PPO step
        stats, scores = ppo_trainer.step(
            input_prompts=batch["query"]
        )
    except Exception as e:
        error_count += 1
        continue
    # Log stats

    ppo_trainer.log_stats(stats, batch, scores, columns_to_log=["query"])

    # # Print some example generations
    # if epoch % 10 == 0:
    #     print("\nExample generations:")
    #     for i in range(min(3, len(batch["query"]))):
    #         print(f"\nQuery: {batch['query'][i]}")
    #         print(f"Response: {tokenizer.decode(response_tensors[i])}")


  0%|          | 0/194 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_to