
## TRL - RLHF first example


In [1]:

## Python >= 3.8

## !pip install transformers
## !pip install accelerate
## !pip install twine
## !pip install datasets
## !pip install tyro


In [2]:

import torch
from transformers import GPT2Tokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer



## 1. load a pretrained model


In [4]:


model     = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")

tokenizer           = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token




## 2. initialize trainer


In [6]:


ppo_config  = {"batch_size": 1}
config      = PPOConfig(**ppo_config)

ppo_trainer = PPOTrainer(
                     config, 
                     model, 
                     model_ref, 
                     tokenizer
)




## 3. encode a query


In [7]:


query_txt    = "This morning I went to the "

query_tensor = tokenizer.encode(
                      query_txt, 
                      return_tensors="pt"
).to( model.pretrained_model.device )




## 4. generate model response


In [8]:


generation_kwargs = {
    "min_length":  -1,
    "top_k":      0.0,
    "top_p":      1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 20,
}


In [9]:

response_tensor = ppo_trainer.generate(
                     [item for item in query_tensor], 
                     return_prompt=False, 
                     **generation_kwargs
)

response_txt    = tokenizer.decode( response_tensor[0] )



In [10]:

response_txt


'\xa0The Mount administered service on the 100-meter Passositime which has been providing a hugely effect'


## 5. define a reward for response

* this could be any reward such as human feedback or output from another model)


In [11]:


reward = [ torch.tensor(1.0, device=model.pretrained_model.device) ]

reward



[tensor(1., device='mps:0')]


## 6. train model with ppo


In [12]:

train_stats = ppo_trainer.step(
                    [ query_tensor[0] ], 
                    [ response_tensor[0] ], 
                    reward
)


In [13]:

train_stats


{'objective/kl': 0.0,
 'objective/kl_dist': 0.0,
 'objective/logprobs': array([[ -8.146125  ,  -2.2490737 ,  -2.6608405 ,  -0.63042766,
          -1.6609763 ,  -9.227745  ,  -1.4445881 ,  -6.318651  ,
          -8.213805  , -14.99595   ,  -5.460775  ,  -3.3287663 ,
          -1.4201152 ,  -7.8896565 ,  -2.7351935 ,  -2.4551308 ,
          -8.23317   , -11.903272  ,  -6.710974  ,  -4.9565077 ,
          -3.4219246 ,  -1.2996787 ,  -6.0018473 ,  -2.785101  ,
          -9.123416  ,  -9.531232  ]], dtype=float32),
 'objective/ref_logprobs': array([[ -8.146125  ,  -2.2490737 ,  -2.6608405 ,  -0.63042766,
          -1.6609763 ,  -9.227745  ,  -1.4445881 ,  -6.318651  ,
          -8.213805  , -14.99595   ,  -5.460775  ,  -3.3287663 ,
          -1.4201152 ,  -7.8896565 ,  -2.7351935 ,  -2.4551308 ,
          -8.23317   , -11.903272  ,  -6.710974  ,  -4.9565077 ,
          -3.4219246 ,  -1.2996787 ,  -6.0018473 ,  -2.785101  ,
          -9.123416  ,  -9.531232  ]], dtype=float32),
 'objective/k