In [1]:
# install the TRL library 
!pip install trl==0.0.3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting trl==0.0.3
  Downloading trl-0.0.3-py3-none-any.whl (15 kB)
Collecting transformers==4.3.2
  Downloading transformers-4.3.2-py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 4.7 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 65.4 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 84.0 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=895260 sha256=78f5f26d853df52133d8759c2e38940b8840fe6f02a73928360dc8c5aa4700c3
  Stored in directory: /root/.cache/pip/

In [2]:
!pip install datasets transformers wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)
[K     |████████████████████████████████| 451 kB 4.6 MB/s 
Collecting wandb
  Downloading wandb-0.13.7-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 92.3 MB/s 
Collecting xxhash
  Downloading xxhash-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 81.7 MB/s 
[?25hCollecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 89.6 MB/s 
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 86.5 MB/s 
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1
  Downlo

In [3]:
# auto-reload external modules that might've changed 
%load_ext autoreload
%autoreload 2

In [4]:
import torch
import wandb
import time
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
from random import choices
import matplotlib.pyplot as plt
tqdm.pandas()

from datasets import load_dataset

from transformers import GPT2Tokenizer
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from trl.core import build_bert_batch_from_txt

In [5]:
config = {
    "lm_name": "lvwerra/gpt2-imdb",       # LLM
    "ref_lm_name": "lvwerra/gpt2-imdb",   # reference LLM (same as above)
    "cls_model_name": "lvwerra/distilbert-imdb",  # BERT classification model 
    "tk_name": "gpt2",  # tokenizer name
    "steps": 51200,
    "batch_size": 256,
    "forward_batch_size": 16,
    "ppo_epochs": 4,   
    "txt_in_len": 5,
    "txt_out_len": 20,
    "lr": 1.41e-5,
    "init_kl_coef": 0.2,
    "target": 6,
    "horizon": 10000,
    "gamma": 1,
    "lam": 0.95,
    "cliprange": .2,
    "cliprange_value": .2,
    "vf_coef": .1, 
    "seed": 1,
}

In [6]:
np.random.seed(config['seed'])

The `gpt2_imdb` model was fine-tuned on the IMDB dataset for 1 epoch with the huggingface script (no special settings). The other parameters are mostly taken from the original paper ["Fine-Tuning Language Models from Human Preferences"](https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models).

In [7]:
# log all metrics during training 
wandb.init(name='long-response', project='gpt2-ctrl', config=config)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [8]:
# load imdb with datasets
dataset = load_dataset('imdb', split='train')
dataset = dataset.rename_columns({'text': 'review', 'label': 'sentiment'})
dataset.set_format('pandas')
df = dataset[:]

# make sure the reviews are long enough (>500)
df = df.loc[df['review'].str.len() > 500]

# but still less than 1000
df['review'] = df['review'].apply(lambda x: x[:1000])

df.tail()

Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1...


Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1. Subsequent calls will reuse this data.


Unnamed: 0,review,sentiment
24991,I'd always wanted David Duchovney to go into t...,1
24995,A hit at the time but now better categorised a...,1
24996,I love this movie like no other. Another time ...,1
24997,This film and it's sequel Barry Mckenzie holds...,1
24998,'The Adventures Of Barry McKenzie' started lif...,1


In [9]:
# load in the BERT sentiment classifier & tokenizer
sentiment_model = AutoModelForSequenceClassification.from_pretrained(config['cls_model_name'])
sentiment_tokenizer = AutoTokenizer.from_pretrained(config['cls_model_name'])

Downloading:   0%|          | 0.00/735 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/333 [00:00<?, ?B/s]

The model outputs are the logits for the negative and positive class. We will use the logits for the positive class as the reward signal for fine-tuning the GPT2 LLM.

In [10]:
text = 'this movie gave me crippling depression and pushed me ever so slightly to the precipice.'
output = sentiment_model.forward(sentiment_tokenizer.encode(text, return_tensors="pt")) # pt returns pytorch tensors
output

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.5268, -0.8633]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [11]:
text = 'movies like this should not be made. what a colossal waste of time!'
output = sentiment_model.forward(sentiment_tokenizer.encode(text, return_tensors="pt"))
output

SequenceClassifierOutput(loss=None, logits=tensor([[ 2.5744, -2.9026]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [12]:
text = "watching morbius has been the single greatest event in my life so far. I was in a state of pure bliss from start to finish."
output = sentiment_model.forward(sentiment_tokenizer.encode(text, return_tensors="pt"))
output

SequenceClassifierOutput(loss=None, logits=tensor([[-2.1735,  2.4532]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [13]:
output[0]

tensor([[-2.1735,  2.4532]], grad_fn=<AddmmBackward0>)

In [14]:
# the positive logit will be the reward signal
output[0][0, 1]

tensor(2.4532, grad_fn=<SelectBackward0>)

Load the pre-trained GPT2 models, one for fine-tuning and another for reference. This is so that we can calculate KL-divergence between the models as we fine-tune. This score will be incorporated into the reward signal in the PPO training so that the LLM being fine-tuned doesn't deviate too much from the referene LLM. 

In [15]:
gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['lm_name'])
gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(config['ref_lm_name'])
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(config['tk_name'])

Downloading:   0%|          | 0.00/577 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at lvwerra/gpt2-imdb and are newly initialized: ['transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.3.attn.masked_bias', 'transformer.h.4.attn.masked_bias', 'transformer.h.5.attn.masked_bias', 'transformer.h.6.attn.masked_bias', 'transformer.h.7.attn.masked_bias', 'transformer.h.8.attn.masked_bias', 'transformer.h.9.attn.masked_bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.masked_bias', 'v_head.summary.weight', 'v_head.summary.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at lvwerra/gpt2-imdb and are newly initialized: ['transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.3.attn.masked_

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [17]:
# move models to device 
_ = gpt2_model.to(device)
_ = gpt2_model_ref.to(device)
_ = sentiment_model.to(device)

In [18]:
# log the gradients and weights of the model during training
wandb.watch(gpt2_model, log='all')

[]

In [19]:
# tokenize reviews and clilp to max text len
df['tokens'] = df['review'].progress_apply(lambda x: gpt2_tokenizer.encode(' '+x, return_tensors="pt").to(device)[0, :config['txt_in_len']])




  0%|          | 0/22578 [00:00<?, ?it/s][A[A[A


  0%|          | 41/22578 [00:00<00:55, 402.74it/s][A[A[A


  0%|          | 83/22578 [00:00<00:54, 411.38it/s][A[A[A


  1%|          | 133/22578 [00:00<00:49, 451.58it/s][A[A[A


  1%|          | 184/22578 [00:00<00:47, 473.04it/s][A[A[A


  1%|          | 242/22578 [00:00<00:43, 510.25it/s][A[A[A


  1%|▏         | 296/22578 [00:00<00:42, 518.87it/s][A[A[A


  2%|▏         | 351/22578 [00:00<00:42, 528.73it/s][A[A[A


  2%|▏         | 410/22578 [00:00<00:40, 546.43it/s][A[A[A


  2%|▏         | 466/22578 [00:00<00:40, 549.95it/s][A[A[A


  2%|▏         | 529/22578 [00:01<00:38, 572.04it/s][A[A[A


  3%|▎         | 592/22578 [00:01<00:37, 586.53it/s][A[A[A


  3%|▎         | 656/22578 [00:01<00:36, 601.47it/s][A[A[A


  3%|▎         | 718/22578 [00:01<00:36, 604.96it/s][A[A[A


  3%|▎         | 782/22578 [00:01<00:35, 614.58it/s][A[A[A


  4%|▍         | 847/22578 [00:01<00:34, 624.09it/

In [20]:
# as well as detokenize into queries for display
df['query'] = df['tokens'].progress_apply(lambda x: gpt2_tokenizer.decode(x))

100%|██████████| 22578/22578 [00:04<00:00, 4565.70it/s] 


In [21]:
df['query'].head()

0                     I rented I AM C
1                      "I Am Curious:
2             If only to avoid making
3     This film was probably inspired
4                 Oh, brother...after
Name: query, dtype: object

Each query needs to be appended with the control token to signal to the model what target sentiment we aim to generate. 

In [22]:
ctrl_str = ['[negative]', '[neutral]', '[positive]']
ctrl_tokens = dict((s, gpt2_tokenizer.encode(s, return_tensors="pt").squeeze().to(device)) for s in ctrl_str)

In [23]:
ctrl_tokens

{'[negative]': tensor([   58, 31591,    60], device='cuda:0'),
 '[neutral]': tensor([   58, 29797,    60], device='cuda:0'),
 '[positive]': tensor([   58, 24561,    60], device='cuda:0')}

We define a reward function that takes in logits and scales it according to

In [29]:
def positive_logit_to_reward(logit, task):
  """
  Take the positive sentiment logit and scale it for the task.
    task [negative]: reward = -logit
    task [neutral]: reward = -2*abs(logit)+4
    task [positive]: reward = logit
  """
  for i in range(len(logit)):
    if task[i] == '[negative]':
      logit[i] = -logit[i]
    elif task[i] == '[neutral]':
      logit[i] = -2 * torch.abs(logit[i]) + 4
    elif task[i] == '[positive]':
      pass 
    
    else:
      raise ValueError('task has to be in [0, 1, 2]')
    
    return logit

In [30]:
# the positive logit is passed in 3 times, for each sentiment 
positive_logit_to_reward(torch.Tensor([4,4,4]), ctrl_str)

tensor([-4.,  4.,  4.])

The training loop consists of the following steps: 
*   Get a batch of queries and create random controls 
*   Get query responses from the LLM 
*   Join query and responses and tokenize for BERT input 
*   Get sentiments for query / response pairs from BERT
*   Optimize policy with PPO (query, response, reward) triplet
*   Log all training statistics 


In [None]:
# setup PPO trainer 
ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, **config)
fbs = config['forward_batch_size']

for epoch in tqdm(range(int(np.ceil(config['steps']/config['batch_size'])))):
  # empty cache for every run 
  torch.cuda.empty_cache()

  logs = dict()
  game_data = dict()
  timing = dict()
  t0 = time.time()

  # get a batch from the dataset and annotate tasks 
  df_batch = df.sample(config['batch_size'])
  task_list = choices(ctrl_str, k=config['batch_size']) # pick list of batch_size ctr_str indices at random
  task_tensors = torch.stack([ctrl_tokens[t] for t in task_list])
  query_list = df_batch['query'].tolist()
  game_data['query'] = [t + q for t, q in zip(task_list, query_list)]

  query_tensors = torch.stack(df_batch['tokens'].tolist())  # tokenized queries 
  query_tensors = torch.cat((task_tensors, query_tensors), axis=1)  # concatenate tokenized control string with tokenized queries

  # get response from GPT2
  t = time.time()
  response_tensors = []
  # feed in queries with lower batch size (fbs), so as to avoid out of memory error 
  for i in range(int(config['batch_size']/fbs)):
    response = respond_to_batch(gpt2_model, query_tensors[i*fbs:(i+1)*fbs], txt_len=config['txt_out_len'])
    response_tensors.append(response)
  response_tensors = torch.cat(response_tensors)  # list is concatenated to tensors 
  
  game_data['response'] = [gpt2_tokenizer.decode(response_tensors[i, :]) for i in range(config['batch_size'])]
  timing['time/gpt2_response'] = time.time() - t

  # tokenize text for sentiment analysis 
  t = time.time()
  texts = [q + r for q,r in zip(query_list, game_data['response'])] # query + response for BERT input 
  sentiment_inputs, attention_masks = build_bert_batch_from_txt(texts, sentiment_tokenizer, device)
  timing['time/build_bert_input_sentiment'] = time.time() - t

  # get sentiment score 
  t = time.time()
  positive_logits = []
  for i in range(int(config['batch_size']/fbs)):
    res = sentiment_model.forward(sentiment_inputs[i*fbs:(i+1)*fbs], 
                                  attention_masks[i*fbs:(i+1)*fbs])[0][:, 1].detach()
    positive_logits.append(res)

  rewards = positive_logit_to_reward(torch.cat(positive_logits), task_list)
  timing['time/get_sentiment_preds'] = time.time() - t

  # run PPO training 
  t = time.time()
  stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
  timing['time/ppo_optimization'] = time.time() - t

  # log everything 
  timing['time/epoch'] = time.time()-t0
  table_rows = [list(r) for r in zip(game_data['query'], game_data['response'], rewards.cpu().tolist())]
  logs.update({'game_log':wandb.Table(
      columns=['query', 'response', 'reward'],
      rows=table_rows)})
  
  logs.update(timing)
  logs.update(stats)

  logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy()
  logs['env/reward_std'] = torch.std(rewards).cpu().numpy()
  logs['env/reward_dist'] = rewards.cpu().numpy()

  for ctrl_s in ctrl_str:
      key = 'env/reward_'+ctrl_s.strip('[]')
      logs[key] = np.mean([r for r, t in zip(logs['env/reward_dist'], task_list) if t==ctrl_s])

  wandb.log(logs)

  0%|          | 0/200 [00:00<?, ?it/s]