In [1]:

# Python >= 3.8


In [2]:

!pip install transformers
!pip install wandb
!pip install trl
!pip install pandas
!pip install datasets
!pip install accelerate
!pip install tyro
!pip install nltk -U




In [3]:

import torch
from tqdm import tqdm
import pandas as pd
import wandb
import os

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler


In [4]:

config = PPOConfig(
    model_name    = "gpt2",
    learning_rate = 1.41e-5,
    ## log_with      = "wandb",
)

sent_kwargs = {
         "return_all_scores": True,
         "function_to_apply": "none",
         "batch_size": 16
}


In [5]:

## wandb.init()

wandb.init(mode="disabled")
os.environ['WANDB_DISABLED'] = 'true'



## Load IMDB dataset

The IMDB dataset contains 50k movie review annotated with "positive"/"negative" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the LengthSampler.



## Visualize details of dataset


In [6]:

dataset_name="go_emotions"


In [7]:

ds = load_dataset(dataset_name, split="train")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [8]:

ds


Dataset({
    features: ['text', 'labels', 'id'],
    num_rows: 43410
})

In [9]:

ds[15:18]


{'text': ['Shit, I guess I accidentally bought a Pay-Per-View boxing match',
  'Thank you friend',
  'Fucking coward.'],
 'labels': [[3, 12], [15], [2]],
 'id': ['edivtm3', 'eeqd04y', 'edk0z9k']}

In [10]:

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML


In [11]:

def show_random_elements(dataset, num_examples=20):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."

    picks = []

    for _ in range( num_examples ):

        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame( dataset[picks] )        ## indexing 10 picks

    print(df)
    print(dataset.features.items())

    for column, typ in dataset.features.items():
        print(column)
        print(typ)
        print(ClassLabel)
        ## The isinstance() function returns True if the specified object
        ## is of the specified type, otherwise False
        if isinstance(typ, ClassLabel):
            print("Hello")
            df[column] = df[column].transform(lambda i: typ.names[i])
            ## print(typ.names[i])

    display(HTML(df.to_html()))


In [12]:

show_random_elements(ds)


                                                 text    labels       id
0   Does it contain like 5 members like the real o...       [1]  ednshdu
1   You are alone and touching her...that is a gla...      [27]  eera0wb
2   except for call your mother. The undisputed be...      [27]  ee5fldn
3   I don’t think I’ve ever seen that many downvotes.       [9]  ee7i3ux
4                 What if your only job was recovery?       [7]  ed0swqf
5                           Hijab, keeps my ears warm      [27]  eeostkh
6   Big words coming from a guy who only talks abo...      [27]  eewicfv
7               How cute, the JV players are fighting       [0]  ee2hi7v
8   I swear of all the people that dislike artifac...      [11]  eer9s5f
9   True, but it all started under [NAME] and will...      [11]  eeng9p9
10                               hey uhhhh I love you      [18]  ednzjb6
11                 Haha "like a snake on a toddler".        [1]  edtsjzy
12                        This literally made my da

Unnamed: 0,text,labels,id
0,Does it contain like 5 members like the real one? lol.,[1],ednshdu
1,You are alone and touching her...that is a glaring sign!,[27],eera0wb
2,except for call your mother. The undisputed best bagels in dc.,[27],ee5fldn
3,I don’t think I’ve ever seen that many downvotes.,[9],ee7i3ux
4,What if your only job was recovery?,[7],ed0swqf
5,"Hijab, keeps my ears warm",[27],eeostkh
6,Big words coming from a guy who only talks about [NAME],[27],eewicfv
7,"How cute, the JV players are fighting",[0],ee2hi7v
8,"I swear of all the people that dislike artifact, the gwent crew is by far the most obnoxious.",[11],eer9s5f
9,"True, but it all started under [NAME] and will get even worse if Labour manages to get into government.",[11],eeng9p9


In [13]:

ds = ds.rename_columns({"text": "review"})



In [14]:

ds


Dataset({
    features: ['review', 'labels', 'id'],
    num_rows: 43410
})

In [15]:
ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

In [16]:
ds

Dataset({
    features: ['review', 'labels', 'id'],
    num_rows: 5
})

In [17]:

tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token


In [18]:

def tokenize( sample ):
    sample["input_ids"] = tokenizer.encode( sample["review"]    )[: 20]
    sample["query"]     = tokenizer.decode( sample["input_ids"] )
    return sample


ds = ds.map(tokenize, batched=False)
ds


Dataset({
    features: ['review', 'labels', 'id', 'input_ids', 'query'],
    num_rows: 5
})

In [19]:

ds[15:18]


{'review': [], 'labels': [], 'id': [], 'input_ids': [], 'query': []}



    
## Now this for actual RLHF  



In [20]:

def build_dataset(
         config,
         dataset_name="go_emotions",
         input_min_text_length=2,
         input_max_text_length=8
):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # load imdb with datasets

    ds = load_dataset(dataset_name, split="train")

    ds = ds.rename_columns({"text": "review"})
    # ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode( sample["review"]    )[: input_size()]
        sample["query"]     = tokenizer.decode( sample["input_ids"] )
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


In [21]:

dataset = build_dataset(config)


In [22]:

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])



## Load pre-trained GPT2 language models
We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model.


In [23]:

model     = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token


In [24]:

ppo_trainer = PPOTrainer(
                 config,
                 model,
                 ref_model,
                 tokenizer,
                 dataset=dataset,
                 data_collator=collator
)



## Load BERT classifier (Reward Function)

We load a BERT classifier fine-tuned on the IMDB dataset.


In [25]:

device = ppo_trainer.accelerator.device
device


device(type='cuda')

In [26]:

if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug

device

0

In [27]:

sentiment_pipe = pipeline("sentiment-analysis", model="SamLowe/roberta-base-go_emotions", device=device)



The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model.


In [28]:

text = "this movie was really bad!!"

sentiment_pipe(text, **sent_kwargs)




[[{'label': 'admiration', 'score': -4.414551258087158},
  {'label': 'amusement', 'score': -5.570497512817383},
  {'label': 'anger', 'score': -2.9189398288726807},
  {'label': 'annoyance', 'score': -1.3585652112960815},
  {'label': 'approval', 'score': -5.215704917907715},
  {'label': 'caring', 'score': -6.811775207519531},
  {'label': 'confusion', 'score': -4.697816848754883},
  {'label': 'curiosity', 'score': -5.354123592376709},
  {'label': 'desire', 'score': -5.924103260040283},
  {'label': 'disappointment', 'score': -0.5507287383079529},
  {'label': 'disapproval', 'score': -2.1392874717712402},
  {'label': 'disgust', 'score': -1.1301006078720093},
  {'label': 'embarrassment', 'score': -4.194786071777344},
  {'label': 'excitement', 'score': -6.156727313995361},
  {'label': 'fear', 'score': -4.558877944946289},
  {'label': 'gratitude', 'score': -6.890646457672119},
  {'label': 'grief', 'score': -6.52256965637207},
  {'label': 'joy', 'score': -6.479082107543945},
  {'label': 'love', '

In [29]:

text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)


[[{'label': 'admiration', 'score': 3.072791576385498},
  {'label': 'amusement', 'score': -5.656264305114746},
  {'label': 'anger', 'score': -6.556077480316162},
  {'label': 'annoyance', 'score': -5.903409957885742},
  {'label': 'approval', 'score': -3.417325258255005},
  {'label': 'caring', 'score': -6.516039848327637},
  {'label': 'confusion', 'score': -6.159371852874756},
  {'label': 'curiosity', 'score': -5.4741902351379395},
  {'label': 'desire', 'score': -6.721272945404053},
  {'label': 'disappointment', 'score': -6.090070724487305},
  {'label': 'disapproval', 'score': -5.804598331451416},
  {'label': 'disgust', 'score': -7.164876461029053},
  {'label': 'embarrassment', 'score': -8.019838333129883},
  {'label': 'excitement', 'score': -4.54598331451416},
  {'label': 'fear', 'score': -7.5312018394470215},
  {'label': 'gratitude', 'score': -5.093955993652344},
  {'label': 'grief', 'score': -7.936448574066162},
  {'label': 'joy', 'score': -4.590585708618164},
  {'label': 'love', 'scor



## Generation settings

For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length.


In [30]:

gen_kwargs = {
         "min_length":   -1,
         "top_k":       0.0,
         "top_p":       1.0,
         "do_sample":  True,
         "pad_token_id": tokenizer.eos_token_id
}



## Optimize model

### Training loop

The training loop consists of the following main steps:

* Get the query and responses from the policy network (GPT-2)
* Get sentiments for query/responses from BERT
* Optimize policy with PPO using the (query, response, reward) triplet


In [31]:

output_min_length     = 4
output_max_length     = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)


In [32]:

generation_kwargs = {
    "min_length":     -1,
    "top_k":         0.0,
    "top_p":         1.0,
    "do_sample":    True,
    "pad_token_id": tokenizer.eos_token_id,
}


In [33]:

## ppo_trainer.config.steps = 100    ## 20,000
ppo_trainer.config.steps


20000

In [34]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    print(query_tensors)
    print(len(query_tensors))
    if epoch == 1:
        break


1it [00:00,  7.89it/s]

[tensor([2396, 4451,  340, 3073,  588,  339, 4423], device='cuda:0'), tensor([  45,  993,  345, 7684,  423], device='cuda:0'), tensor([  40,  561, 1826,  510], device='cuda:0'), tensor([   40,   466,   326,   351,   616, 13850], device='cuda:0'), tensor([   1, 3347], device='cuda:0'), tensor([   40,  1101,   523,  9675,   685, 20608], device='cuda:0'), tensor([ 464, 1109,  428, 2125,  470,  257,  685], device='cuda:0'), tensor([2504,  338,  281], device='cuda:0'), tensor([2504,  338,  644,  314, 2982, 2406], device='cuda:0'), tensor([361, 340, 338], device='cuda:0'), tensor([9275,  340, 1541,  475, 5875,  345], device='cuda:0'), tensor([22017,   326,   447,   247,    82,   257], device='cuda:0'), tensor([5779,  379, 1551], device='cuda:0'), tensor([42322,    11,   314,  7048], device='cuda:0'), tensor([ 1639,   389, 26329, 12214,    13,  6363], device='cuda:0'), tensor([28211,   508], device='cuda:0'), tensor([21129,  3505,  1612,   685], device='cuda:0'), tensor([2514, 1577,  683,  25




In [35]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)
    print(batch)
    print('*********************')
    print('*********************')
    print('*********************')
    print('*********************')
    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]
    print(batch)
    if epoch == 1:
        break


0it [00:00, ?it/s]

0
{'input_ids': [tensor([5195,  340, 1392, 4615,   13], device='cuda:0'), tensor([25082,   329,   683,    11,  1049], device='cuda:0'), tensor([10995,   314,  1101,  9675], device='cuda:0'), tensor([1219,   11,  339,  373,  287], device='cuda:0'), tensor([  40, 1654, 2911,  428], device='cuda:0'), tensor([   58, 20608,    60,   561], device='cuda:0'), tensor([11028, 12248,  1320,   447], device='cuda:0'), tensor([ 2396,    11,   345,   651, 27406,  1394], device='cuda:0'), tensor([ 1212,   318, 12876,   355,  3131], device='cuda:0'), tensor([33336, 22574], device='cuda:0'), tensor([10995,  1377,   257], device='cuda:0'), tensor([ 40, 836], device='cuda:0'), tensor([34784,    13,  5070,  8592,   389,  1541,  7818], device='cuda:0'), tensor([  40, 1101, 1297,  314,  804], device='cuda:0'), tensor([  40, 4236,  351, 2279,  345,  531,   13], device='cuda:0'), tensor([2061,  338,  510,  351,  326,  289], device='cuda:0'), tensor([1639,  804,  588,  685], device='cuda:0'), tensor([  47, 2433

1it [00:27, 27.39s/it]

1
{'input_ids': [tensor([ 817, 1381,  257,  922,  835,  284], device='cuda:0'), tensor([10262,   631,  3228], device='cuda:0'), tensor([ 43, 349,  13], device='cuda:0'), tensor([ 3987,   470,  6044,   546,   262, 17448], device='cuda:0'), tensor([  40,  991, 5465, 5181], device='cuda:0'), tensor([  32, 5007,  318, 5626], device='cuda:0'), tensor([9690,  582], device='cuda:0'), tensor([ 40, 765, 284], device='cuda:0'), tensor([1026,  468,  284, 2291,  257], device='cuda:0'), tensor([23307,  1310,  1165], device='cuda:0'), tensor([1212,  318,  523,   11], device='cuda:0'), tensor([   40,  1816, 14380,   329], device='cuda:0'), tensor([5297,  475,  644], device='cuda:0'), tensor([18243,   278], device='cuda:0'), tensor([5308, 1165,   13], device='cuda:0'), tensor([ 40, 481], device='cuda:0'), tensor([2990,  711, 7812], device='cuda:0'), tensor([10919,   257, 10195], device='cuda:0'), tensor([1026,  857], device='cuda:0'), tensor([1544, 3751,  703, 2562,  340,  318], device='cuda:0'), tens

1it [00:42, 42.96s/it]

{'input_ids': [tensor([ 817, 1381,  257,  922,  835,  284], device='cuda:0'), tensor([10262,   631,  3228], device='cuda:0'), tensor([ 43, 349,  13], device='cuda:0'), tensor([ 3987,   470,  6044,   546,   262, 17448], device='cuda:0'), tensor([  40,  991, 5465, 5181], device='cuda:0'), tensor([  32, 5007,  318, 5626], device='cuda:0'), tensor([9690,  582], device='cuda:0'), tensor([ 40, 765, 284], device='cuda:0'), tensor([1026,  468,  284, 2291,  257], device='cuda:0'), tensor([23307,  1310,  1165], device='cuda:0'), tensor([1212,  318,  523,   11], device='cuda:0'), tensor([   40,  1816, 14380,   329], device='cuda:0'), tensor([5297,  475,  644], device='cuda:0'), tensor([18243,   278], device='cuda:0'), tensor([5308, 1165,   13], device='cuda:0'), tensor([ 40, 481], device='cuda:0'), tensor([2990,  711, 7812], device='cuda:0'), tensor([10919,   257, 10195], device='cuda:0'), tensor([1026,  857], device='cuda:0'), tensor([1544, 3751,  703, 2562,  340,  318], device='cuda:0'), tensor




In [36]:

batch.keys()


dict_keys(['input_ids', 'query', 'response'])


#### Compute sentiment score


In [37]:

batch["query"]


['Thats a good way to',
 'Agree!!',
 'Lol.',
 "Don't forget about the flooding",
 'I still hate Cat',
 'A Wall is NOT',
 'Thanks man',
 'I want to',
 'It has to include a',
 'Too little too',
 'This is so,',
 'I went nuts for',
 'Yes but what',
 'Searching',
 'Me too.',
 'I will',
 'They play loud',
 'what a shame',
 'It does',
 'He showed how easy it is',
 'Ok now I’',
 'Hey, did you know',
 'Idk',
 'Sounds nicer',
 'Ugh, he',
 '[NAME] bless you,',
 'Thank you.',
 'Because cognitive dissonance.',
 'Man it was that long',
 'What the',
 'Crush.',
 "Don't worry... I'm",
 'This is why',
 'Its the same AS to',
 'If its your first',
 "Nope, you weren't",
 'Keep on going',
 'I totally',
 '[NAME] got',
 "That's a",
 'i can also post pics of',
 'Oh man, I didn',
 'If that’',
 'That poor snek',
 'Silence and For Greater Glory',
 'First and foremost you',
 'I want a Dorit',
 'Is this the place that had',
 'If by ‘needs',
 'Wow... Hope you',
 'So basically the greatest person on',
 'After awhile'

In [38]:

batch["response"]


[' send me cloned mef',
 " If I'm not mistaken, some",
 ' A coffee bean is a food.)\nCompleted Mackay County 2010\nInd',
 " and you'll be staying in a",
 ': "This is for all',
 ' the end of the world if a',
 ', thank you (another little one sound',
 ' agree that democracy is a',
 ' cited list of "promotional',
 ' much is at stake when it comes to the economy.\n\nAnd Alberta',
 ' and I have "rebasa" such that it\'s dead',
 ' this complete set the full time, I was expecting this party night,"',
 ' do Scots believe are the',
 ' Memories of Human Reality\n\nThe founding',
 " I don't have many needs. This was my fit for my military",
 ' not accept your idea...see how it is."',
 " music and usually don't sound too humbled. Jones' takes on the",
 ') or blame for your own selection of hacker tools. We',
 ' keep satellites and air traffic neatly separated: once in Moscow, at Thess',
 " when you're trying to learn",
 'will heroes remember, and mortality will',
 ' that Minnie Tak',
 'bn:1759',
 "

In [39]:

texts = [ q + r for q, r in zip(batch["query"], batch["response"]) ]


In [40]:

texts


['Thats a good way to send me cloned mef',
 "Agree!! If I'm not mistaken, some",
 'Lol. A coffee bean is a food.)\nCompleted Mackay County 2010\nInd',
 "Don't forget about the flooding and you'll be staying in a",
 'I still hate Cat: "This is for all',
 'A Wall is NOT the end of the world if a',
 'Thanks man, thank you (another little one sound',
 'I want to agree that democracy is a',
 'It has to include a cited list of "promotional',
 'Too little too much is at stake when it comes to the economy.\n\nAnd Alberta',
 'This is so, and I have "rebasa" such that it\'s dead',
 'I went nuts for this complete set the full time, I was expecting this party night,"',
 'Yes but what do Scots believe are the',
 'Searching Memories of Human Reality\n\nThe founding',
 "Me too. I don't have many needs. This was my fit for my military",
 'I will not accept your idea...see how it is."',
 "They play loud music and usually don't sound too humbled. Jones' takes on the",
 'what a shame) or blame for your o

In [41]:

pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
pipe_outputs


[[{'label': 'admiration', 'score': -0.7318404912948608},
  {'label': 'amusement', 'score': -7.137917518615723},
  {'label': 'anger', 'score': -7.918138027191162},
  {'label': 'annoyance', 'score': -6.302253723144531},
  {'label': 'approval', 'score': -0.0831017717719078},
  {'label': 'caring', 'score': -5.575819492340088},
  {'label': 'confusion', 'score': -6.07122802734375},
  {'label': 'curiosity', 'score': -6.546262741088867},
  {'label': 'desire', 'score': -6.391024112701416},
  {'label': 'disappointment', 'score': -6.888919830322266},
  {'label': 'disapproval', 'score': -5.453632354736328},
  {'label': 'disgust', 'score': -7.6543288230896},
  {'label': 'embarrassment', 'score': -9.055230140686035},
  {'label': 'excitement', 'score': -5.794053554534912},
  {'label': 'fear', 'score': -8.355690956115723},
  {'label': 'gratitude', 'score': -5.052929878234863},
  {'label': 'grief', 'score': -8.862860679626465},
  {'label': 'joy', 'score': -5.421574592590332},
  {'label': 'love', 'score

In [42]:

rewards = [ torch.tensor(output[5]["score"]) for output in pipe_outputs]
rewards


[tensor(-5.5758),
 tensor(-5.0259),
 tensor(-6.8954),
 tensor(-2.8499),
 tensor(-5.6117),
 tensor(-5.1342),
 tensor(-6.4502),
 tensor(-5.1224),
 tensor(-6.5572),
 tensor(-6.2088),
 tensor(-6.9306),
 tensor(-6.2309),
 tensor(-7.0525),
 tensor(-7.3272),
 tensor(-6.6360),
 tensor(-5.0267),
 tensor(-6.1178),
 tensor(-6.4169),
 tensor(-6.3429),
 tensor(-4.0027),
 tensor(-5.8673),
 tensor(-6.2277),
 tensor(-7.0327),
 tensor(-6.9987),
 tensor(-6.8226),
 tensor(1.7083),
 tensor(-5.9201),
 tensor(-7.4698),
 tensor(-7.4069),
 tensor(-6.9978),
 tensor(-6.9641),
 tensor(0.3223),
 tensor(-7.3059),
 tensor(-7.0415),
 tensor(-6.0250),
 tensor(-6.4353),
 tensor(-3.0460),
 tensor(-6.4678),
 tensor(-6.5772),
 tensor(-7.7240),
 tensor(-6.5498),
 tensor(-7.1231),
 tensor(-6.7379),
 tensor(-5.0862),
 tensor(-6.5630),
 tensor(-0.6877),
 tensor(-5.7862),
 tensor(-6.5284),
 tensor(-6.1076),
 tensor(-1.6212),
 tensor(-6.3775),
 tensor(-6.1322),
 tensor(-5.7536),
 tensor(-5.9134),
 tensor(-4.4899),
 tensor(-6.0

In [43]:

len(rewards)


128

In [44]:
count = 0
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    count += 1

count

339it [00:02, 159.70it/s]


339

In [45]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [ torch.tensor(output[5]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(
                     query_tensors,
                     response_tensors,
                     rewards
    )
    ppo_trainer.log_stats(stats, batch, rewards)



0it [00:00, ?it/s]

0


1it [00:18, 18.03s/it]

1


2it [00:37, 19.11s/it]

2


3it [00:57, 19.12s/it]

3


4it [01:19, 20.30s/it]

4


5it [01:39, 20.30s/it]

5


6it [01:58, 19.91s/it]

6


7it [02:18, 19.75s/it]

7


8it [02:39, 20.32s/it]

8


9it [02:59, 20.34s/it]

9


10it [03:19, 20.02s/it]

10


11it [03:39, 19.97s/it]

11


12it [03:58, 19.84s/it]

12


13it [04:18, 19.86s/it]

13


14it [04:39, 20.07s/it]

14


15it [04:59, 20.19s/it]

15


16it [05:19, 20.26s/it]

16


17it [05:38, 19.83s/it]

17


18it [05:58, 19.90s/it]

18


19it [06:19, 19.99s/it]

19


20it [06:39, 20.09s/it]

20


21it [07:01, 20.61s/it]

21


22it [07:20, 20.31s/it]

22


23it [07:40, 20.10s/it]

23


24it [08:01, 20.26s/it]

24


25it [08:20, 19.98s/it]

25


26it [08:41, 20.25s/it]

26


27it [09:01, 20.19s/it]

27


28it [09:20, 19.95s/it]

28


29it [09:39, 19.74s/it]

29


30it [10:00, 20.09s/it]

30


31it [10:21, 20.29s/it]

31


32it [10:41, 20.28s/it]

32


33it [11:01, 20.04s/it]

33


34it [11:20, 19.91s/it]

34


35it [11:41, 20.03s/it]

35


36it [12:01, 19.97s/it]

36


37it [12:21, 19.97s/it]

37


38it [12:40, 19.68s/it]

38


39it [13:00, 19.84s/it]

39


40it [13:19, 19.70s/it]

40


41it [13:39, 19.75s/it]

41


42it [13:59, 19.71s/it]

42


43it [14:19, 19.85s/it]

43


44it [14:38, 19.79s/it]

44


45it [14:58, 19.84s/it]

45


46it [15:17, 19.60s/it]

46


47it [15:36, 19.39s/it]

47


48it [15:57, 19.85s/it]

48


49it [16:18, 20.04s/it]

49


50it [16:37, 19.87s/it]

50


51it [16:57, 19.81s/it]

51


52it [17:17, 19.79s/it]

52


53it [17:35, 19.46s/it]

53


54it [17:55, 19.65s/it]

54


55it [18:15, 19.68s/it]

55


56it [18:35, 19.75s/it]

56


57it [18:56, 20.10s/it]

57


58it [19:15, 19.70s/it]

58


59it [19:34, 19.67s/it]

59


60it [19:53, 19.39s/it]

60


61it [20:14, 19.76s/it]

61


62it [20:35, 20.07s/it]

62


63it [20:55, 20.23s/it]

63


64it [21:15, 20.08s/it]

64


65it [21:35, 20.10s/it]

65


66it [21:55, 20.15s/it]

66


67it [22:15, 19.88s/it]

67


68it [22:34, 19.70s/it]

68


69it [22:54, 19.75s/it]

69


70it [23:13, 19.69s/it]

70


71it [23:33, 19.77s/it]

71


72it [23:54, 19.96s/it]

72


73it [24:14, 20.11s/it]

73


74it [24:33, 19.80s/it]

74


75it [24:53, 19.96s/it]

75


76it [25:13, 19.80s/it]

76


77it [25:32, 19.72s/it]

77


78it [25:53, 19.89s/it]

78


79it [26:12, 19.85s/it]

79


80it [26:32, 19.78s/it]

80


81it [26:53, 20.08s/it]

81


82it [27:13, 19.99s/it]

82


83it [27:32, 19.73s/it]

83


84it [27:52, 19.81s/it]

84


85it [28:11, 19.73s/it]

85


86it [28:32, 19.91s/it]

86


87it [28:51, 19.76s/it]

87


88it [29:10, 19.58s/it]

88


89it [29:31, 20.04s/it]

89


90it [29:53, 20.61s/it]

90


91it [30:14, 20.55s/it]

91


92it [30:35, 20.75s/it]

92


93it [30:56, 20.79s/it]

93


94it [31:16, 20.63s/it]

94


95it [31:36, 20.44s/it]

95


96it [31:57, 20.53s/it]

96


97it [32:17, 20.47s/it]

97


98it [32:37, 20.38s/it]

98


99it [32:58, 20.63s/it]

99


100it [33:20, 20.95s/it]

100


101it [33:41, 21.02s/it]

101


102it [34:03, 21.07s/it]

102


103it [34:23, 20.89s/it]

103


104it [34:43, 20.78s/it]

104


105it [35:04, 20.59s/it]

105


106it [35:23, 20.36s/it]

106


107it [35:44, 20.31s/it]

107


108it [36:05, 20.62s/it]

108


109it [36:24, 20.28s/it]

109


110it [36:44, 20.12s/it]

110


111it [37:05, 20.40s/it]

111


112it [37:25, 20.26s/it]

112


113it [37:45, 20.20s/it]

113


114it [38:05, 20.18s/it]

114


115it [38:25, 19.96s/it]

115


116it [38:44, 19.71s/it]

116


117it [39:04, 19.75s/it]

117


118it [39:22, 19.36s/it]

118


119it [39:42, 19.40s/it]

119


120it [40:03, 19.85s/it]

120


121it [40:22, 19.80s/it]

121


122it [40:41, 19.51s/it]

122


123it [41:00, 19.42s/it]

123


124it [41:20, 19.35s/it]

124


125it [41:39, 19.50s/it]

125


126it [41:59, 19.50s/it]

126


127it [42:18, 19.41s/it]

127


128it [42:39, 19.84s/it]

128


129it [42:59, 19.75s/it]

129


130it [43:17, 19.43s/it]

130


131it [43:35, 19.05s/it]

131


132it [43:55, 19.24s/it]

132


133it [44:15, 19.44s/it]

133


134it [44:35, 19.59s/it]

134


135it [44:54, 19.47s/it]

135


136it [45:14, 19.74s/it]

136


137it [45:34, 19.66s/it]

137


138it [45:53, 19.55s/it]

138


139it [46:13, 19.74s/it]

139


140it [46:33, 19.68s/it]

140


141it [46:53, 19.71s/it]

141


142it [47:12, 19.72s/it]

142


143it [47:32, 19.79s/it]

143


144it [47:53, 20.01s/it]

144


145it [48:12, 19.59s/it]

145


146it [48:32, 19.72s/it]

146


147it [48:52, 19.90s/it]

147


148it [49:13, 20.37s/it]

148


149it [49:34, 20.36s/it]

149


150it [49:53, 19.92s/it]

150


151it [50:13, 19.98s/it]

151


152it [50:33, 20.16s/it]

152


153it [50:52, 19.67s/it]

153


154it [51:12, 19.73s/it]

154


155it [51:31, 19.57s/it]

155


156it [51:51, 19.67s/it]

156


157it [52:12, 20.02s/it]

157


158it [52:31, 19.87s/it]

158


159it [52:51, 19.74s/it]

159


160it [53:11, 20.01s/it]

160


161it [53:31, 19.90s/it]

161


162it [53:50, 19.71s/it]

162


163it [54:11, 20.05s/it]

163


164it [54:31, 20.18s/it]

164


165it [54:51, 20.12s/it]

165


166it [55:12, 20.11s/it]

166


167it [55:31, 19.96s/it]

167


168it [55:50, 19.74s/it]

168


169it [56:11, 19.94s/it]

169


170it [56:30, 19.68s/it]

170


171it [56:49, 19.39s/it]

171


172it [57:09, 19.80s/it]

172


173it [57:29, 19.86s/it]

173


174it [57:48, 19.47s/it]

174


175it [58:09, 19.83s/it]

175


176it [58:27, 19.51s/it]

176


177it [58:47, 19.69s/it]

177


178it [59:06, 19.48s/it]

178


179it [59:26, 19.48s/it]

179


180it [59:46, 19.52s/it]

180


181it [1:00:07, 19.98s/it]

181


182it [1:00:26, 19.88s/it]

182


183it [1:00:46, 19.92s/it]

183


184it [1:01:05, 19.60s/it]

184


185it [1:01:24, 19.48s/it]

185


186it [1:01:44, 19.47s/it]

186


187it [1:02:04, 19.63s/it]

187


188it [1:02:24, 19.87s/it]

188


189it [1:02:44, 19.85s/it]

189


190it [1:03:04, 19.83s/it]

190


191it [1:03:24, 19.84s/it]

191


192it [1:03:44, 19.93s/it]

192


193it [1:04:04, 19.91s/it]

193


194it [1:04:23, 19.76s/it]

194


195it [1:04:43, 19.96s/it]

195


196it [1:05:03, 19.82s/it]

196


197it [1:05:22, 19.65s/it]

197


198it [1:05:42, 19.80s/it]

198


199it [1:06:02, 19.78s/it]

199


200it [1:06:22, 19.77s/it]

200


201it [1:06:42, 20.00s/it]

201


202it [1:07:02, 19.85s/it]

202


203it [1:07:21, 19.55s/it]

203


204it [1:07:41, 19.70s/it]

204


205it [1:08:01, 19.84s/it]

205


206it [1:08:20, 19.54s/it]

206


207it [1:08:40, 19.84s/it]

207


208it [1:09:00, 19.70s/it]

208


209it [1:09:19, 19.56s/it]

209


210it [1:09:39, 19.68s/it]

210


211it [1:09:58, 19.46s/it]

211


212it [1:10:17, 19.41s/it]

212


213it [1:10:37, 19.40s/it]

213


214it [1:10:55, 19.10s/it]

214


215it [1:11:15, 19.39s/it]

215


216it [1:11:35, 19.48s/it]

216


217it [1:11:54, 19.49s/it]

217


218it [1:12:14, 19.64s/it]

218


219it [1:12:34, 19.84s/it]

219


220it [1:12:54, 19.87s/it]

220


221it [1:13:15, 20.13s/it]

221


222it [1:13:34, 19.89s/it]

222


223it [1:13:54, 19.86s/it]

223


224it [1:14:15, 20.11s/it]

224


225it [1:14:36, 20.25s/it]

225


226it [1:14:55, 20.15s/it]

226


227it [1:15:16, 20.28s/it]

227


228it [1:15:35, 19.90s/it]

228


229it [1:15:55, 19.85s/it]

229


230it [1:16:14, 19.81s/it]

230


231it [1:16:34, 19.75s/it]

231


232it [1:16:54, 19.90s/it]

232


233it [1:17:14, 19.83s/it]

233


234it [1:17:35, 20.11s/it]

234


235it [1:17:54, 19.75s/it]

235


236it [1:18:13, 19.77s/it]

236


237it [1:18:33, 19.61s/it]

237


238it [1:18:52, 19.64s/it]

238


239it [1:19:11, 19.40s/it]

239


240it [1:19:31, 19.42s/it]

240


241it [1:19:51, 19.65s/it]

241


242it [1:20:11, 19.68s/it]

242


243it [1:20:30, 19.55s/it]

243


244it [1:20:49, 19.37s/it]

244


245it [1:21:09, 19.55s/it]

245


246it [1:21:28, 19.42s/it]

246


247it [1:21:47, 19.38s/it]

247


248it [1:22:07, 19.56s/it]

248


249it [1:22:27, 19.49s/it]

249


250it [1:22:46, 19.60s/it]

250


251it [1:23:06, 19.66s/it]

251


252it [1:23:26, 19.61s/it]

252


253it [1:23:46, 19.81s/it]

253


254it [1:24:05, 19.44s/it]

254


255it [1:24:25, 19.71s/it]

255


256it [1:24:44, 19.65s/it]

256


257it [1:25:04, 19.53s/it]

257


258it [1:25:24, 19.75s/it]

258


259it [1:25:43, 19.61s/it]

259


260it [1:26:04, 19.87s/it]

260


261it [1:26:24, 19.97s/it]

261


262it [1:26:45, 20.23s/it]

262


263it [1:27:04, 19.92s/it]

263


264it [1:27:23, 19.81s/it]

264


265it [1:27:43, 19.68s/it]

265


266it [1:28:02, 19.65s/it]

266


267it [1:28:23, 19.89s/it]

267


268it [1:28:43, 19.81s/it]

268


269it [1:29:02, 19.62s/it]

269


270it [1:29:22, 19.70s/it]

270


271it [1:29:43, 20.15s/it]

271


272it [1:30:04, 20.35s/it]

272


273it [1:30:23, 20.19s/it]

273


274it [1:30:43, 19.94s/it]

274


275it [1:31:03, 20.02s/it]

275


276it [1:31:23, 20.00s/it]

276


277it [1:31:43, 20.15s/it]

277


278it [1:32:02, 19.79s/it]

278


279it [1:32:22, 19.85s/it]

279


280it [1:32:43, 19.97s/it]

280


281it [1:33:02, 19.75s/it]

281


282it [1:33:22, 19.86s/it]

282


283it [1:33:43, 20.17s/it]

283


284it [1:34:02, 19.93s/it]

284


285it [1:34:22, 19.93s/it]

285


286it [1:34:42, 19.95s/it]

286


287it [1:35:03, 20.09s/it]

287


288it [1:35:22, 20.00s/it]

288


289it [1:35:43, 20.28s/it]

289


290it [1:36:03, 20.17s/it]

290


291it [1:36:24, 20.28s/it]

291


292it [1:36:45, 20.44s/it]

292


293it [1:37:06, 20.67s/it]

293


294it [1:37:26, 20.68s/it]

294


295it [1:37:46, 20.47s/it]

295


296it [1:38:07, 20.48s/it]

296


297it [1:38:28, 20.55s/it]

297


298it [1:38:48, 20.40s/it]

298


299it [1:39:09, 20.59s/it]

299


300it [1:39:29, 20.55s/it]

300


301it [1:39:49, 20.31s/it]

301


302it [1:40:10, 20.42s/it]

302


303it [1:40:31, 20.57s/it]

303


304it [1:40:50, 20.13s/it]

304


305it [1:41:12, 20.68s/it]

305


306it [1:41:30, 20.15s/it]

306


307it [1:41:52, 20.45s/it]

307


308it [1:42:12, 20.54s/it]

308


309it [1:42:33, 20.47s/it]

309


310it [1:42:54, 20.61s/it]

310


311it [1:43:13, 20.36s/it]

311


312it [1:43:34, 20.39s/it]

312


313it [1:43:55, 20.51s/it]

313


314it [1:44:13, 19.95s/it]

314


315it [1:44:34, 20.22s/it]

315


316it [1:44:56, 20.61s/it]

316


317it [1:45:15, 20.24s/it]

317


318it [1:45:34, 19.98s/it]

318


319it [1:45:54, 19.82s/it]

319


320it [1:46:14, 19.84s/it]

320


321it [1:46:33, 19.61s/it]

321


322it [1:46:54, 19.96s/it]

322


323it [1:47:13, 19.91s/it]

323


324it [1:47:33, 19.80s/it]

324


325it [1:47:53, 19.74s/it]

325


326it [1:48:12, 19.65s/it]

326


327it [1:48:32, 19.79s/it]

327


328it [1:48:51, 19.61s/it]

328


329it [1:49:11, 19.66s/it]

329


330it [1:49:30, 19.53s/it]

330


331it [1:49:49, 19.38s/it]

331


332it [1:50:10, 19.75s/it]

332


333it [1:50:30, 19.85s/it]

333


334it [1:50:49, 19.62s/it]

334


335it [1:51:10, 19.88s/it]

335


336it [1:51:30, 19.90s/it]

336


337it [1:51:49, 19.74s/it]

337


338it [1:52:08, 19.60s/it]

338


339it [1:52:28, 19.91s/it]


In [46]:

torch.cuda.get_device_name(0)


'Tesla T4'


One can observe how the model starts to generate more positive outputs after a few optimisation steps.

Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient.



Let's inspect some examples from the IMDB dataset. We can use model_ref to compare the tuned model model against the model before optimisation.


In [47]:

#### get a batch from the dataset
bs                 = 16
game_data          = dict()


In [48]:

game_data


{}

In [49]:

dataset.set_format("pandas")


In [50]:

df_batch           = dataset[:].sample(bs)
df_batch


Unnamed: 0,review,labels,id,input_ids,query
12330,I dated a girl from Broomfield Colorado. She'd...,[4],ed5d4hm,"[40, 14567, 257, 2576, 422]",I dated a girl from
9215,You really need to stop hanging out with those...,"[10, 27]",eeqm77h,"[1639, 1107, 761, 284, 2245]",You really need to stop
778,Looks deliriously boring.,[3],eeyg7qy,"[41102, 1619, 343]",Looks delir
3843,TYT is a gossip shitpost panel for race baitin...,"[3, 10, 11]",edg81eo,"[9936, 51, 318, 257, 30914, 7510, 7353]",TYT is a gossip shitpost
5245,“You bring home gifts in bag.”,[27],edxv5xg,"[447, 250, 1639, 2222, 1363, 13201]",“You bring home gifts
6191,"So..the head [NAME], is afraid of being rammed?",[6],eesiyw1,"[2396, 492, 1169, 1182, 685, 20608]",So..the head [NAME
32019,"I’ll be honest, I didn’t expect that to be real",[4],eerzcr2,"[40, 447, 247, 297, 307, 5508, 11]","I’ll be honest,"
36276,Congratulatioins!,[15],edm66kk,"[18649, 10366, 377]",Congratul
26729,Yes it does. I kept doing this for 2 mins. But...,[4],eevt50a,"[5297, 340, 857, 13]",Yes it does.
6199,Yea ur right on that but that doesn't really n...,[4],eeb4yu8,"[56, 18213, 2956, 826, 319]",Yea ur right on


In [51]:

game_data["query"] = df_batch["query"].tolist()
query_tensors      = df_batch["input_ids"].tolist()


In [52]:

response_tensors_ref, response_tensors = [], []


In [53]:

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()

    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)


    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)


In [54]:

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]


In [55]:

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[5]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]




In [56]:

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[5]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]


In [57]:

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results


Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,I dated a girl from,my past BUHHHHHHHHHH!! THEY LIVE IN NEW YORK ...,"now, take care in last game thing, put care o...",-6.185028,1.899247
1,You really need to stop,-off before you meet your Brazilian fiancé wit...,"and be careful of vulnerable or destitute, sh...",-4.064278,2.243831
2,Looks delir,ious in its admission the next time you retake...,". Be careful and ""\n\n\nHope\n\nBut",-4.949306,2.261854
3,TYT is a gossip shitpost,. A strange sight is happening. It looks,. Be careful. This is otherwise. Start,-7.584946,2.122829
4,“You bring home gifts,", might as well make a",. We'll take care of,-6.229325,2.059892
5,So..the head [NAME,]. begins to turn with her head,"],.. to take care what you",-6.758536,1.449007
6,"I’ll be honest,",I'm pretty sure that the biggest because what...,"be careful, as if of all those who are above are",-5.05409,1.521199
7,Congratul,"aced, thus making him a law unto them, became ...",ently. Be safe and you won't be\n\n\nused.,-5.657579,1.000404
8,Yes it does.,It's a good idea to visit the,Be careful.\n\n\nEVA plastic\n,-5.60134,1.835512
9,Yea ur right on,i8 vin deed blit x,it. DON. Be careful. You,-6.647352,1.638231


In [58]:

print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())



mean:


rewards (before)   -6.115143
rewards (after)     0.936721
dtype: float64


median:


rewards (before)   -6.207177
rewards (after)     1.867379
dtype: float64

In [59]:

## model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=True)
## tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=True)

model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=False)
tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=False)



('gpt2-imdb-pos-v2/tokenizer_config.json',
 'gpt2-imdb-pos-v2/special_tokens_map.json',
 'gpt2-imdb-pos-v2/vocab.json',
 'gpt2-imdb-pos-v2/merges.txt',
 'gpt2-imdb-pos-v2/added_tokens.json',
 'gpt2-imdb-pos-v2/tokenizer.json')