In [1]:
import os
import yaml
import trlx
import torch

from datasets import load_dataset
from transformers import pipeline
from typing import List
from trlx.data.configs import TRLConfig
# from yaml import save_load

default_config = yaml.safe_load(open('configs/ppo_config.yml'))

NOTE: Redirects are currently not supported in Windows or MacOs.


Next, let's create a reward function. It takes the outputs of the LM and gives a score - this is what the RL algorithm optimizes. In this case, we want positive movie reviews so we use a sentiment classifier trained on human annotations to serve as our reward model.

In [2]:
# !pip install ipywidgets
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb",
                        top_k=2, truncation=True, batch_size=256, device=device)

def get_positive_score(scores):
    "Extract value associated witha positive sentiment from pipeline's output"
    return dict(map(lambda x: tuple(x.values()), scores))['POSITIVE']


def reward_fn(samples: List[str]) -> List[float]:
    "Reward function that takes a list of samples and returns a list of rewards"
    sentiments = list(map(get_positive_score, sentiment_fn(samples)))
    return sentiments


Next let's load the IMDB dataset and get some text for the language model to complete. For each of the reviews, just the first four words are given, and the LM must learn complete the review.

In [4]:
# Take first few words off of the movie reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

Found cached dataset imdb (/Users/jeffcoggshall/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


With a config, reward model and the dataset ready, you are now set to train a language model with RL!

Some eval prompts are also passed in and the reward model score of the generations by your trained LM will be reported. 

Hopefully, at the end of RL training, it's high!

In [7]:
hparams={}

config = TRLConfig.update(default_config, hparams)

In [11]:
config

TRLConfig(method=PPOConfig(name='ppoconfig', ppo_epochs=4, num_rollouts=128, chunk_size=128, init_kl_coef=0.05, target=6, horizon=10000, gamma=1, lam=0.95, cliprange=0.2, cliprange_value=0.2, vf_coef=1, scale_reward=False, ref_mean=None, ref_std=None, cliprange_reward=10, gen_kwargs={'max_new_tokens': 40, 'top_k': 0, 'top_p': 1.0, 'do_sample': True}), model=ModelConfig(model_type='AcceleratePPOModel', model_path='lvwerra/gpt2-imdb', tokenizer_path='gpt2', num_layers_unfrozen=2), optimizer=OptimizerConfig(name='adamw', kwargs={'lr': 0.0001, 'betas': [0.9, 0.95], 'eps': 1e-08, 'weight_decay': 1e-06}), scheduler=SchedulerConfig(name='cosine_annealing', kwargs={'T_max': 10000, 'eta_min': 0.0001}), train=TrainConfig(total_steps=10000, seq_length=1024, epochs=100, batch_size=128, checkpoint_interval=10000, eval_interval=100, pipeline='PromptPipeline', orchestrator='PPOOrchestrator', project_name='trlx', entity_name=None, checkpoint_dir='ckpts', rollout_logging_dir=None, seed=42))

In [13]:
config.train.seed

42

In [14]:


model = trlx.train(
    reward_fn=reward_fn,
    prompts=prompts,
    eval_prompts=["I don't know much about Hungarian underground"]  * 64,
    config=config,
    
)

Downloading:   0%|          | 0.00/577 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: