In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader
from torch import optim

from instruct_goose import Agent, RewardModel, RLHFTrainer, RLHFConfig, create_reference_model

### Step 1: Create RL-based Language model Agent and Reward Model

In [4]:
model_base = AutoModelForCausalLM.from_pretrained("gpt2")
reward_model = RewardModel("gpt2")

tokenizer = AutoTokenizer.from_pretrained("gpt2")
eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

In [6]:
model = Agent(model_base)
ref_model = create_reference_model(model)

### Step 1: Tokenize the dataset

In [7]:
dataset = load_dataset("imdb", split="train")
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

Found cached dataset imdb (/Users/education/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


### Step 2: Rollout

In [8]:
max_new_tokens = 20
generation_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": max_new_tokens
}
config = RLHFConfig()
trainer = RLHFTrainer(model, ref_model, config)
optimizer = optim.SGD(model.parameters(), lr=1e-3)

In [11]:
for batch in train_dataloader:
    inputs = tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")
    responses = model.generate(
        inputs["input_ids"], attention_mask=inputs["attention_mask"],
        **generation_kwargs
    )
    # extract the generated text
    responses = responses[:, -max_new_tokens:]
    
    with torch.no_grad():
        text_input_ids = torch.stack([torch.concat([q, r]) for q, r in zip(inputs["input_ids"], responses)], dim=0)
        texts = tokenizer.batch_decode(text_input_ids, skip_special_tokens=True)
        # evaluate from the reward model
        rewards = reward_model(texts)
    
    loss = trainer.compute_loss(inputs["input_ids"], responses, rewards)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"loss={loss}")

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


KeyboardInterrupt: 