In [1]:
import ray
from ray import air, tune
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.rlhf.ppo_ft.rlhf_env import RLHFEnv
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec

from ray.rllib.examples.rlhf.ppo_ft.rlhf_ppo_module import RLHFPPOTorchRLModule
from ray.rllib.examples.rlhf.ppo_ft.ppo_rlhf import PPORLHF
from ray.rllib.policy.sample_batch import SampleBatch


  from .autonotebook import tqdm as notebook_tqdm


Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


In [2]:
env_config = {
    "tokenizer_path": "gpt2",
    "sft_model_path": "gpt2",
    "prompt_dataset_path": "yizhongw/self_instruct",
    "prompt_dataset_split": "train",
    "kl_coeff": 0.2,
    "max_generation_length": 50,
}

In [3]:

env_creator = lambda config: RLHFEnv(config)
tune.register_env("RLHFEnv", env_creator)
config = (
    PPOConfig(algo_class=PPORLHF)
    .framework("torch")
    .environment(
        "RLHFEnv", 
        env_config=env_config,
        # observation_space=env.observation_space,
        # action_space=env.action_space,
        disable_env_checking=True,
    )
    .rl_module(
        _enable_rl_module_api=True,
        rl_module_spec=SingleAgentRLModuleSpec(RLHFPPOTorchRLModule),
    )
    .training(
        num_sgd_iter=1,
        sgd_minibatch_size=1,
        train_batch_size=1,
        _enable_learner_api=True
    )
    .rollouts(
        num_rollout_workers=0
    )
    .experimental(
        _disable_preprocessor_api=True,
        _disable_initialize_loss_from_dummy_batch=True,
    )
)



In [4]:
algo = config.build()


No config specified, defaulting to: self_instruct/self_instruct
Found cached dataset self_instruct (/Users/kourosh/.cache/huggingface/datasets/yizhongw___self_instruct/self_instruct/1.0.0/11093735ceb03802310b2f412253585f7bd1cc0435787541e37d4d2b4cca4148)
[2023-04-10 14:08:06] [Ray Tune] INFO ray.tune.trainable.trainable::Trainable.setup took 21.056 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [5]:
module = algo.rlhf_module
module

RLHFPPOTorchRLModule(
  (actor): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
 

In [6]:
batch = algo.sampler.sample(batch_size=2)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


In [7]:
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
batch.set_get_interceptor(convert_to_torch_tensor)

In [8]:
from ray.rllib.policy.sample_batch import SampleBatch

In [9]:
batch[SampleBatch.ACTIONS].keys()

dict_keys(['sequence', 'response_mask', 'probs', 'attention_mask'])

In [10]:
batch[SampleBatch.ACTIONS]["response_mask"].size()

torch.Size([2, 68])

In [11]:
foo = module.critic.base(
    input_ids=batch[SampleBatch.ACTIONS]["sequence"],
    attention_mask=batch[SampleBatch.ACTIONS]["attention_mask"],
)

In [12]:
foo["last_hidden_state"].size()


torch.Size([2, 68, 768])

In [13]:
values = module.critic.trunk(foo["last_hidden_state"][:, -1]).squeeze(-1)
values.shape

torch.Size([2])

In [15]:

vf_out = module.critic(
    input_ids=batch[SampleBatch.ACTIONS]["sequence"],
    attention_mask=batch[SampleBatch.ACTIONS]["attention_mask"],
)

In [16]:
vf_out

tensor([-1.2531, -5.9997], grad_fn=<SqueezeBackward1>)

In [39]:
input_ids = batch[SampleBatch.ACTIONS]["sequence"].clone()
actor_out = module.actor(
    input_ids=input_ids,
    attention_mask=batch[SampleBatch.ACTIONS]["attention_mask"],
)
actor_out.logits[0, -2, 0]

tensor(30.6448, grad_fn=<SelectBackward0>)

In [40]:
from ray.rllib.models.torch.torch_distributions import TorchCategorical


In [41]:
actor_logits = actor_out.logits # (batch_size, seq_len, vocab_size)
dist1 = TorchCategorical.from_logits(actor_logits)
dist1.sample().shape

torch.Size([2, 68])

In [42]:
input_ids = batch[SampleBatch.ACTIONS]["sequence"].clone()
input_ids[0, -1] = 1 # this should change the logit[0, -1, 0] value
actor_out = module.actor(
    input_ids=input_ids,
    attention_mask=batch[SampleBatch.ACTIONS]["attention_mask"],
)
actor_out.logits[0, -1, 0]

tensor(46.9457, grad_fn=<SelectBackward0>)

In [44]:
actor_logits = actor_out.logits # (batch_size, seq_len, vocab_size)
dist2 = TorchCategorical.from_logits(actor_logits)

foo = dist1.kl(dist2)
foo.shape

torch.Size([2, 68])

In [58]:
from transformers import GPT2Tokenizer, AutoModelForCausalLM

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Let's assume your two sequences are:
sequence_1 = "This is a longer sequence."
sequence_2 = "Short sequence."

tokenizer.pad_token = tokenizer.eos_token
# Tokenize and pad the sequences
tokens = tokenizer([sequence_1, sequence_2], padding=True, padding_side="left", return_tensors="pt")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]

# Forward pass
output = model(input_ids, attention_mask=attention_mask)

# Get the logits
logits = output.logits

# Check the logits for both sequences
logits_sequence_2 = logits[1, :sum(attention_mask[1]), :]


Keyword arguments {'padding_side': 'left'} not recognized.
Keyword arguments {'padding_side': 'left'} not recognized.


In [59]:
input_ids

tensor([[ 1212,   318,   257,  2392,  8379,    13],
        [16438,  8379,    13, 50256, 50256, 50256]])

In [60]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0, 0]])

In [56]:
logits_sequence_2

tensor([[-29.3949, -28.4463, -31.5550,  ..., -36.6213, -36.1419, -29.0593],
        [-83.4044, -80.8802, -86.6023,  ..., -92.9075, -92.1946, -83.7211],
        [-90.2337, -89.1662, -88.8223,  ..., -96.3760, -95.5851, -83.0652]],
       grad_fn=<SliceBackward0>)

In [57]:
# compare to when we pass the shorter sequence
tokens = tokenizer([sequence_2], return_tensors="pt")

input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]

print("input_ids: ", input_ids)
print("attention_mask: ", attention_mask)

# Forward pass
output = model(input_ids, attention_mask=attention_mask)

# Get the logits
logits = output.logits

# Check the logits for both sequences
logits_sequence_2_2 = logits[0, :len(sequence_2), :]
print(logits_sequence_2_2)

input_ids:  tensor([[16438,  8379,    13]])
attention_mask:  tensor([[1, 1, 1]])
tensor([[-29.3949, -28.4463, -31.5550,  ..., -36.6213, -36.1419, -29.0593],
        [-83.4044, -80.8802, -86.6023,  ..., -92.9075, -92.1946, -83.7211],
        [-90.2338, -89.1662, -88.8224,  ..., -96.3761, -95.5852, -83.0652]],
       grad_fn=<SliceBackward0>)
