In [1]:

import datasets
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

from ray.rllib.examples.rlhf.ppo_ft.rlhf_env import RLHFEnv, generate_response

  from .autonotebook import tqdm as notebook_tqdm


Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


In [2]:
import time
from pprint import pprint

import torch
import numpy as np

In [3]:
model = GPT2LMHeadModel.from_pretrained("gpt2")

In [4]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [5]:
ds = datasets.load_dataset("yizhongw/self_instruct", split="train")
ds

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 82612
})

In [6]:
ds[0]

{'prompt': 'Make a list of 10 ways to help students improve their study skills.\n\nOutput:',
 'completion': " 1. Make a schedule for studying and stick to it.\n2. Study in the same place every time.\n3. Set goals for yourself.\n4. Take breaks when you need them.\n5. Don't cram before an exam.\n6. Get enough sleep.\n7. Eat healthy food.\n8. Exercise regularly.\n9. Find a study partner.\n10. Reward yourself after completing a task."}

In [7]:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

In [8]:
prompt_tokens = tokenizer(ds[1]["prompt"], return_tensors="np")
prompt_tokens

{'input_ids': array([[25714,    25,  9938,   503,   644,   389,   262,  1994, 10233,
          287,   262,  3188,    30,  5072,   366, 26652,   352,  1600,
          366, 26652,   362,  1600,  2644,   837,   366, 26652,   299,
         1911,   198,   198,   464,  1578,  1829,   468, 23019,   422,
          262,  6342, 13963, 12729,    13,   628]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [9]:
env = RLHFEnv({
    "tokenizer_path": "gpt2",
    "sft_model_path": "gpt2",
    "prompt_dataset_path": "yizhongw/self_instruct",
    "prompt_dataset_split": "train",
    "kl_coeff": 0.2,
    "max_generation_length": 50,
})

In [10]:
obs, _ = env.reset(seed=52)
print(obs)

{'input_ids': array([[15056,   281,  7177,   286, 37014,    11,  3551,   257,  2163,
          326,  5860,  2081,   611,   612,  7160,   281,  6376,  1312,
          884,   326,   790,  5002,   287,   262,  7177,   468,   663,
         1988,  3220,   416,   530,   618,   356,  1445,   340,  2651,
         1312,  4113,    13,   628]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [11]:
input_ids = torch.tensor(obs['input_ids'])
attention_mask = torch.tensor(obs['attention_mask'])
input_ids

tensor([[15056,   281,  7177,   286, 37014,    11,  3551,   257,  2163,   326,
          5860,  2081,   611,   612,  7160,   281,  6376,  1312,   884,   326,
           790,  5002,   287,   262,  7177,   468,   663,  1988,  3220,   416,
           530,   618,   356,  1445,   340,  2651,  1312,  4113,    13,   628]])

In [12]:
s = time.time()
# reset
obs, _ = env.reset(seed=62)
# generate response from the actor LLM
out = generate_response(
    model, 
    input_ids=torch.tensor(obs['input_ids']),
    max_length=env.max_generation_length, 
    eos_token_id=tokenizer.eos_token_id
)

# construct the action
n_generated_tokens = out["n_generated_tokens"]
n_input_tokens = out["n_input_tokens"]
generated_tokens = out["sequence"][-n_generated_tokens:]
action = {
    "sequence": generated_tokens.numpy(),
    "response_mask": np.array([[0]*n_input_tokens + [1]*n_generated_tokens]),
    "probs": out["probs"].numpy(),
    "attention_mask": np.array([[1]*(n_input_tokens + n_generated_tokens)]),
}

# pass the action in
next_obs, reward, terminated, truncated, info = env.step(action)
print(f"1 step took {time.time() - s} seconds")
print(f"Reward: {reward}")
print(f"Info: {info}")

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]
1 step took 7.08188009262085 seconds
Reward: -2.531736660003662
Info: {'r_align': -1.0, 'r_kl': 7.6586833000183105, 'n_response_tokens': 50}


In [13]:
action = env.action_space.sample()
print(action["attention_mask"])
print(action["response_mask"])
print(action["probs"])
print(action["sequence"])

[1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 