In [1]:

## Python >= 3.8  


In [2]:

## !pip install transformers
## !pip install wandb
## !pip install trl
## !pip install pandas
## !pip install datasets
## !pip install accelerate
## !pip install tyro
## !pip install nltk -U


In [3]:

import torch
from tqdm import tqdm
import pandas as pd
import wandb
import os

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler


In [4]:

config = PPOConfig(
    model_name    = "lvwerra/gpt2-imdb",
    learning_rate = 1.41e-5,
    ## log_with      = "wandb",
)

sent_kwargs = {
         "return_all_scores": True, 
         "function_to_apply": "none", 
         "batch_size": 16
}


In [5]:

## wandb.init()

wandb.init(mode="disabled") 
os.environ['WANDB_DISABLED'] = 'true'


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.



## Load IMDB dataset

The IMDB dataset contains 50k movie review annotated with "positive"/"negative" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the LengthSampler.



## Visualize details of dataset


In [6]:

dataset_name="imdb"


In [7]:

ds = load_dataset(dataset_name, split="train")


In [8]:

ds


Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [9]:

ds[15:18]


{'text': ["This film is just plain horrible. John Ritter doing pratt falls, 75% of the actors delivering their lines as if they were reading them from cue cards, poor editing, horrible sound mixing (dialogue is tough to pick up in places over the background noise), and a plot that really goes nowhere. I didn't think I'd ever say this, but Dorothy Stratten is not the worst actress in this film. There are at least 3 others that suck more. Patti Hansen delivers her lines with the passion of Ben Stein. I started to wonder if she wasn't dead inside. Even Bogdanovich's kids are awful (the oldest one is definitely reading her lines from a cue card). This movie is seriously horrible. There's a reason Bogdanovich couldn't get another project until 4 years later. Please don't watch it. If you see it in your television listings, cancel your cable. If a friend suggests it to you, reconsider your friendship. If your spouse wants to watch it, you're better off finding another soulmate. I'd rather go

In [10]:

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML


In [11]:

def show_random_elements(dataset, num_examples=20):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    
    picks = []
    
    for _ in range( num_examples ):
        
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame( dataset[picks] )        ## indexing 10 picks 
    
    print(df)
    print(dataset.features.items())
    
    for column, typ in dataset.features.items():
        print(column)
        print(typ)
        print(ClassLabel)
        ## The isinstance() function returns True if the specified object 
        ## is of the specified type, otherwise False
        if isinstance(typ, ClassLabel):
            print("Hello")
            df[column] = df[column].transform(lambda i: typ.names[i])
            ## print(typ.names[i])
            
    display(HTML(df.to_html()))


In [12]:

show_random_elements(ds)


                                                 text  label
0   "Rich in Love" is a slice-of-life film which t...      1
1   This has to be one of the most beautiful, movi...      1
2   This, along with "Hare Tonic," ranks as one of...      1
3   This is a great story and was just the beginni...      0
4   I viewed the movie together with a homophobic ...      1
5   What happened to Peter Bogdanovich? Once a bri...      0
6   This movie maybe really bad, but it is alot of...      1
7   This movie was amusing at times, hell sometime...      0
8   "Radiofreccia" is still a good surprise in Ita...      1
9   Though not seen in too many films prior, you h...      1
10  I remember my parents not understanding Saturd...      0
11  And that's how the greatest comedy of TV start...      1
12  I watched the un-aired episodes online and I w...      1
13  Although I have not seen this mini-series in o...      1
14  I don't know why, but when I am asked about ba...      0
15  Let me start by sayi

Unnamed: 0,text,label
0,"""Rich in Love"" is a slice-of-life film which takes the viewer into the goings on of a somewhat quirky Charleston, SC family. Highly romanticized, beautifully shot, well written and acted, ""RIL"" washes over you like a summer breeze as its plotless meandering breathes life into the characters such that at film's end you'll feel like an old friend of the family.<br /><br />A wonderfully crafted character-driven film from the director of ""Driving Miss Daisy"", ""RIL"" is a somewhat obscure little ""sleeper"" which will appeal most to mature audiences.",pos
1,"This has to be one of the most beautiful, moving, thought provoking films around. It's good family entertainment and at the same time makes you think very hard about the issues involved. Every time I see the ""ghost of Zac riding the bike through the puddle at the end I can't help but cry my eyes out. John Thaw's performance is so touching and it is a shame he is no longer with us. Gone but not forgotten. A outstanding film. Full marks.",pos
2,"This, along with ""Hare Tonic,"" ranks as one of the best Bugs cartoons, indeed one of the best Bugs, ever. There are some comments about how Bugs in these cartoons is ""basic,"" meaning, I guess, that he is as yet not fully developed. I actually prefer this ""basic"" version from the mid-40s (Chuck Jones' was the best version) who is actually more rabbit-sized and far more amusing than the eventual long-legged version who towered over Yosemite Sam and Daffy Duck. The latter-day Bugs came to be too suave and sophisticated for my liking. Also check out ""Hair Raising Hare"" (1946) and ""Rabbit Punch"" (1948) for great examples of classic Bugs and classic Chuck Jones.",pos
3,"This is a great story and was just the beginning of equality in the United States. (We are still working on it too.) However despite the fact this is true, it's still a movie and this is a movie site. I realize independent films have a hard time getting good actors, but wow. The only one even mediocre is the excellent Ossie Davis. But even he couldn't make up for all the actors (including the one playing him as a young man) absolutely atrocious acting. Granted the script was terribly cliché, but even then you have got to get some decent actors! I wouldn't recommend this to anybody because it is so poorly done in every category. Read some books about the true story of the U.S.S Mason, because they give these men the respect they deserve.",neg
4,"I viewed the movie together with a homophobic friend, my wife and her female friend. So I had views from all kinds of directions. Mainly, the film made me laugh, the sexual tension was not really there and the only noticeable actors were Tudor Chirila and Maria Popistasu. Yes, I do think she played her role well, even if the script was not appropriate. There were good Romanian actors around, they just didn't have complex roles. I applaud Puya's entering the movie business. I don't know why, but I think he's a good guy, I just hope he'll be a good actor.<br /><br />The wife loved the movie, though, and I think there might have been chords being played and to which I had no ear for. If the film tried to present uncommon sexual behaviors and their consequences in todays Romania, then it failed miserably. There were no consequences. Just imagine that the girls are actually a boy and a girl, and the same story becomes just a boring, uninteresting plot.<br /><br />I have no idea why it got all those BAFTA awards. In my book, it should have gotten the ""Better luck next time"" award. (bafta=good luck in Romanian).",pos
5,"What happened to Peter Bogdanovich? Once a brilliant director, a trail blazer... is now scraping the very bottom... Is this the same man who directed ""The Last Picture Show""? Here, he takes a somewhat interesting (albeit farfetched) premise, and turns it into bubble gum that loses flavor the moment you take the first bite... Dunst is not bad, but Izzard is miscast as Chaplin, and all the other actors seem to have been cast for their ""looks"", and not because they were right for the part. Too bad. I'll go rent ""Paper Moon"" again.",neg
6,"This movie maybe really bad, but it is alot of fun. The bad acting and poor direction enhance the film's hystericalness. The twins are very funny in their Conanesque roles. If you go into this film expecting the first Conan or Excalibur, than you will hate it. If you watch it while in a good mood and accept it as good, dumb fun you will have a good time. Watch for the scene where they try to hang the brothers, its funniest scene in the film. I wish Mystery Science Theatre 3000 would have done this!!",pos
7,"This movie was amusing at times, hell sometimes it was even downright funny.<br /><br />The underlying message I got from the film though, was that women are responsible for all of the troubles of man. Every time a woman is depicted in the film, she is being lazy, being slutty or lambasting some poor guy for no apparent reason. I don't think the message involved is good for women or gay men.<br /><br />But, it is a comedy, and a piece of art, so it is simply someones point of view. Even if I don't agree with it, they are still entitled to it.<br /><br />An amusing film, but some of the comments others have made are just plain stupid. Best film ever my foot.",neg
8,"""Radiofreccia"" is still a good surprise in Italian cinema. The film is based on a book of Italian songwriter Luciano Ligabue, who also directs the movie and writes the music score -of course.<br /><br />The film is a portrait of north Italian province life, in the Emilia Romagna region. We're in 1975, the time of the first free radios -one of the boys of the movie creates ""Radioraptus"". Youth wishes, friendship, love, sex, individual dramas and unemployment are among the themes, but the film speaks also about drugs -Freccia, the main character, is a victim of heroin slavery.<br /><br />Without being boring and moralist, the story flows very well; the spontaneity of actors is strong and the way of directing as well. Obviously Luciano ""Liga"" Ligabue is neither Fellini nor a movie professional, first of all he's a musician. But he succeeds in making a good product. Unfortunately he'll not repeat the success with his second movie ""Da zero a dieci"" -not good at all.<br /><br />In ""Radiofreccia"" actors are generally not very famous, the only star is Stefano Accorsi -one of the most popular young Italian actors. See in a small role another Italian songwriter -Francesco Guccini, he's the nice communist barman and football trainer!",pos
9,"Though not seen in too many films prior, you have certainly seen the basic plot themes in too many films since. <br /><br />Not one of Grant's nor Loy's best films, they make an outstanding effort together. After all, with that much talent and very good supporting cast, you know the laughs will be there.<br /><br />The film is light, has some dramatic spotting but keeps the plot moving and gets you to smile the whole way through.<br /><br />A great example of classic American film fare that has stood the test of time.<br /><br />Definite Saturday afternoon fare, heavy on the popcorn.",pos


In [13]:

ds = ds.rename_columns({"text": "review"})
ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)


In [14]:

ds


Dataset({
    features: ['review', 'label'],
    num_rows: 24895
})

In [15]:

tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token


In [16]:

def tokenize( sample ):
    sample["input_ids"] = tokenizer.encode( sample["review"]    )[: 20]
    sample["query"]     = tokenizer.decode( sample["input_ids"] )
    return sample


ds = ds.map(tokenize, batched=False)
ds


Dataset({
    features: ['review', 'label', 'input_ids', 'query'],
    num_rows: 24895
})

In [17]:

ds[15:18]


{'review': ["This film is just plain horrible. John Ritter doing pratt falls, 75% of the actors delivering their lines as if they were reading them from cue cards, poor editing, horrible sound mixing (dialogue is tough to pick up in places over the background noise), and a plot that really goes nowhere. I didn't think I'd ever say this, but Dorothy Stratten is not the worst actress in this film. There are at least 3 others that suck more. Patti Hansen delivers her lines with the passion of Ben Stein. I started to wonder if she wasn't dead inside. Even Bogdanovich's kids are awful (the oldest one is definitely reading her lines from a cue card). This movie is seriously horrible. There's a reason Bogdanovich couldn't get another project until 4 years later. Please don't watch it. If you see it in your television listings, cancel your cable. If a friend suggests it to you, reconsider your friendship. If your spouse wants to watch it, you're better off finding another soulmate. I'd rather 


## My own data


In [18]:

my_own_datasets = load_dataset("text", data_files={ "train": "/home/rcalix/Desktop/rc_train.txt", "validation": "/home/rcalix/Desktop/rc_validation.txt"} )


In [19]:

my_own_datasets


DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 4
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 4
    })
})



    
## Now this for actual RLHF  



In [20]:

def build_dataset(
         config, 
         dataset_name="imdb", 
         input_min_text_length=2, 
         input_max_text_length=8
):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # load imdb with datasets
    
    ds = load_dataset(dataset_name, split="train")
    
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode( sample["review"]    )[: input_size()]
        sample["query"]     = tokenizer.decode( sample["input_ids"] )
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


In [21]:

dataset = build_dataset(config)


In [22]:

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
    


## Load pre-trained GPT2 language models
We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model.


In [23]:

model     = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token


In [24]:

ppo_trainer = PPOTrainer(
                 config, 
                 model, 
                 ref_model, 
                 tokenizer, 
                 dataset=dataset, 
                 data_collator=collator
)


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.



## Load BERT classifier (Reward Function)

We load a BERT classifier fine-tuned on the IMDB dataset.


In [25]:

device = ppo_trainer.accelerator.device
device


device(type='cuda')

In [26]:

if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug

device

0

In [27]:

sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)



The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model.


In [28]:

text = "this movie was really bad!!"

sentiment_pipe(text, **sent_kwargs)




[[{'label': 'NEGATIVE', 'score': 2.3350484371185303},
  {'label': 'POSITIVE', 'score': -2.726576328277588}]]

In [29]:

text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)


[[{'label': 'NEGATIVE', 'score': -2.294790267944336},
  {'label': 'POSITIVE', 'score': 2.557040214538574}]]



## Generation settings

For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length.


In [30]:

gen_kwargs = {
         "min_length":   -1, 
         "top_k":       0.0, 
         "top_p":       1.0, 
         "do_sample":  True, 
         "pad_token_id": tokenizer.eos_token_id
}



## Optimize model

### Training loop

The training loop consists of the following main steps:

* Get the query and responses from the policy network (GPT-2)
* Get sentiments for query/responses from BERT
* Optimize policy with PPO using the (query, response, reward) triplet


In [31]:

output_min_length     = 4
output_max_length     = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)


In [32]:

generation_kwargs = {
    "min_length":     -1,
    "top_k":         0.0,
    "top_p":         1.0,
    "do_sample":    True,
    "pad_token_id": tokenizer.eos_token_id,
}


In [33]:

## ppo_trainer.config.steps = 100    ## 20,000
ppo_trainer.config.steps


20000

In [34]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    
    print(query_tensors)
    print(len(query_tensors))
    if epoch == 1:
        break


1it [00:00, 20.45it/s]

[tensor([19197,   645,  3241,   284,   262,  3651,  2157], device='cuda:0'), tensor([1212, 3807], device='cuda:0'), tensor([  34, 1501,  418, 2899], device='cuda:0'), tensor([  40, 1183,  307, 5508], device='cuda:0'), tensor([  40, 3505, 4379, 1552], device='cuda:0'), tensor([ 1212,   318,   262,   749, 14851], device='cuda:0'), tensor([6090,  691,  307], device='cuda:0'), tensor([41389,   417, 45622,   290,  1115], device='cuda:0'), tensor([  40, 2993,  340], device='cuda:0'), tensor([  40, 8288, 9827], device='cuda:0'), tensor([   39, 50107,   318,   407,   691,   262], device='cuda:0'), tensor([   40,  4398,   470,  1865,  1100, 20642], device='cuda:0'), tensor([26886, 39452,   258,   283,   357], device='cuda:0'), tensor([  40,  550,  284, 1577,  428], device='cuda:0'), tensor([ 3673,   530,   286,  3873, 13951,   338,  1266], device='cuda:0'), tensor([8241,  404,   72], device='cuda:0'), tensor([ 40, 760, 340, 338], device='cuda:0'), tensor([  818,  5751,  2750,  2185,    11, 2309




In [35]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)
    print(batch)
    print('*********************')
    print('*********************')
    print('*********************')
    print('*********************')
    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]
    print(batch)
    if epoch == 1:
        break


0it [00:00, ?it/s]

0
{'label': [tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(

1it [00:07,  7.73s/it]

{'label': [tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1,

1it [00:16, 16.19s/it]

{'label': [tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1, device='cuda:0'), tensor(0, device='cuda:0'), tensor(0, device='cuda:0'), tensor(1,




In [36]:

batch.keys()


dict_keys(['label', 'input_ids', 'query', 'response'])


#### Compute sentiment score


In [37]:

batch["query"]


['This movie is poorly conceived',
 'i liked this movie',
 'The first noticeable problem about',
 'When I spotted that Noah Wyle',
 'Modern viewers know this',
 'Evidently lots',
 'There are some redeeming qualities',
 'This is a very strange',
 'good lord!',
 'This is',
 'You know,',
 'This is part one of',
 'I mean of all the obscure',
 'Okay first of all',
 'Not that many films have truly',
 "I don't know what it",
 'Cary Grant, Douglas Fairbanks',
 'I rented this DVD having seen',
 "I'm watching this on the Star",
 'The title should have',
 "It's a",
 'OK, not possibly, honestly',
 'The folks at Disney',
 'Ugh',
 'Terrific film with',
 "I'm a pretty old",
 'I thoroughly enjoyed this',
 'A movie you start watching',
 'The Wayward Cloud is a',
 'Cavemen was by',
 'During the Civil War',
 'On the back burner',
 'THE O',
 'There are just',
 'Shazbot,',
 'I suppose JEDI',
 'Brian Yuzna',
 '1st watched 5/',
 'When I was',
 'I was having just',
 'Neatly skipping over',
 'Kudos',
 'There i

In [38]:

batch["response"]


['---Monster Hunter is supposed to',
 ', even though it was released in',
 " Tonto: he was fuelling the characters' personalities at a young age",
 ' had made an early cameo appearance at',
 " is a critic's I think",
 ' of good plot holes, with a',
 ' that make Paterno work in a',
 ' film which can be very',
 ' For the first time, even',
 ' because a compass and compass (another form of sound cancellation) is used to',
 ' it wasn\'t half bad, but it had it all."<|endoftext|>',
 ' the five sections of the film. There are several humorous remarks that border',
 ' French and Irish myths that',
 ' - I wanted to see the full length',
 ' had a "Stalker" moment," but Seidl does done',
 ' is about that feature, but I like Bens',
 ', Juan Diego, Becky Hegar, Dee Dee Richards Areha Fisher,',
 ' one episode and considered purchasing it, so I watched it one',
 " Buddy TV channel. Jones' production is absolutely fantastic. All of the actors",
 ' been "Home Alone 2"',
 ' bold move and seems to demon

In [39]:

texts = [ q + r for q, r in zip(batch["query"], batch["response"]) ]


In [40]:

texts


['This movie is poorly conceived---Monster Hunter is supposed to',
 'i liked this movie, even though it was released in',
 "The first noticeable problem about Tonto: he was fuelling the characters' personalities at a young age",
 'When I spotted that Noah Wyle had made an early cameo appearance at',
 "Modern viewers know this is a critic's I think",
 'Evidently lots of good plot holes, with a',
 'There are some redeeming qualities that make Paterno work in a',
 'This is a very strange film which can be very',
 'good lord! For the first time, even',
 'This is because a compass and compass (another form of sound cancellation) is used to',
 'You know, it wasn\'t half bad, but it had it all."<|endoftext|>',
 'This is part one of the five sections of the film. There are several humorous remarks that border',
 'I mean of all the obscure French and Irish myths that',
 'Okay first of all - I wanted to see the full length',
 'Not that many films have truly had a "Stalker" moment," but Seidl doe

In [41]:

pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
pipe_outputs


[[{'label': 'NEGATIVE', 'score': 2.2946932315826416},
  {'label': 'POSITIVE', 'score': -2.7317020893096924}],
 [{'label': 'NEGATIVE', 'score': -1.925742506980896},
  {'label': 'POSITIVE', 'score': 2.2019805908203125}],
 [{'label': 'NEGATIVE', 'score': -0.18878592550754547},
  {'label': 'POSITIVE', 'score': -0.011356303468346596}],
 [{'label': 'NEGATIVE', 'score': -0.01333677675575018},
  {'label': 'POSITIVE', 'score': -0.11825768649578094}],
 [{'label': 'NEGATIVE', 'score': -1.9098631143569946},
  {'label': 'POSITIVE', 'score': 2.153944730758667}],
 [{'label': 'NEGATIVE', 'score': -1.0707851648330688},
  {'label': 'POSITIVE', 'score': 1.1191564798355103}],
 [{'label': 'NEGATIVE', 'score': -1.4140279293060303},
  {'label': 'POSITIVE', 'score': 1.6278727054595947}],
 [{'label': 'NEGATIVE', 'score': -1.9853274822235107},
  {'label': 'POSITIVE', 'score': 2.2685697078704834}],
 [{'label': 'NEGATIVE', 'score': -1.6584359407424927},
  {'label': 'POSITIVE', 'score': 1.8835939168930054}],
 [{'l

In [42]:

rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]
rewards


[tensor(-2.7317),
 tensor(2.2020),
 tensor(-0.0114),
 tensor(-0.1183),
 tensor(2.1539),
 tensor(1.1192),
 tensor(1.6279),
 tensor(2.2686),
 tensor(1.8836),
 tensor(-0.2499),
 tensor(1.4789),
 tensor(1.3212),
 tensor(0.2753),
 tensor(1.0822),
 tensor(0.8769),
 tensor(1.6423),
 tensor(0.7947),
 tensor(1.6409),
 tensor(2.6443),
 tensor(-0.7279),
 tensor(2.3880),
 tensor(-1.0589),
 tensor(2.3231),
 tensor(-1.1874),
 tensor(2.7740),
 tensor(-0.5302),
 tensor(2.6618),
 tensor(-0.1971),
 tensor(-0.1593),
 tensor(1.9095),
 tensor(-0.0538),
 tensor(1.9543),
 tensor(-1.1134),
 tensor(0.5438),
 tensor(-1.5954),
 tensor(0.1579),
 tensor(-1.8283),
 tensor(1.7351),
 tensor(-0.1821),
 tensor(-1.6691),
 tensor(-1.7416),
 tensor(0.7545),
 tensor(-1.8997),
 tensor(1.0599),
 tensor(-2.6420),
 tensor(0.3579),
 tensor(1.1189),
 tensor(-0.8248),
 tensor(2.8067),
 tensor(-2.7147),
 tensor(-2.6384),
 tensor(-2.2516),
 tensor(0.0565),
 tensor(1.4536),
 tensor(1.7251),
 tensor(2.6113),
 tensor(-0.4088),
 tensor

In [43]:

len(rewards)


128

In [44]:

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(
                     query_tensors, 
                     response_tensors, 
                     rewards
    )
    ppo_trainer.log_stats(stats, batch, rewards)
    


0it [00:00, ?it/s]

0


1it [00:09,  9.49s/it]

1


2it [00:19,  9.86s/it]

2


3it [00:29, 10.01s/it]

3


4it [00:40, 10.37s/it]

4


5it [00:51, 10.43s/it]

5


6it [01:01, 10.32s/it]

6


7it [01:11, 10.33s/it]

7


8it [01:22, 10.52s/it]

8


9it [01:33, 10.49s/it]

9


10it [01:43, 10.38s/it]

10


11it [01:53, 10.32s/it]

11


12it [02:03, 10.34s/it]

12


13it [02:14, 10.40s/it]

13


14it [02:25, 10.53s/it]

14


15it [02:35, 10.51s/it]

15


16it [02:46, 10.56s/it]

16


17it [02:56, 10.41s/it]

17


18it [03:06, 10.38s/it]

18


19it [03:17, 10.52s/it]

19


20it [03:28, 10.63s/it]

20


21it [03:39, 10.64s/it]

21


22it [03:49, 10.61s/it]

22


23it [03:59, 10.49s/it]

23


24it [04:10, 10.52s/it]

24


25it [04:20, 10.36s/it]

25


26it [04:31, 10.53s/it]

26


27it [04:41, 10.46s/it]

27


28it [04:51, 10.36s/it]

28


29it [05:01, 10.25s/it]

29


30it [05:12, 10.46s/it]

30


31it [05:23, 10.48s/it]

31


32it [05:33, 10.50s/it]

32


33it [05:43, 10.35s/it]

33


34it [05:54, 10.33s/it]

34


35it [06:04, 10.41s/it]

35


36it [06:14, 10.36s/it]

36


37it [06:25, 10.34s/it]

37


38it [06:35, 10.25s/it]

38


39it [06:45, 10.32s/it]

39


40it [06:55, 10.30s/it]

40


41it [07:06, 10.38s/it]

41


42it [07:16, 10.33s/it]

42


43it [07:27, 10.42s/it]

43


44it [07:37, 10.37s/it]

44


45it [07:47, 10.36s/it]

45


46it [07:58, 10.29s/it]

46


47it [08:07, 10.17s/it]

47


48it [08:18, 10.26s/it]

48


49it [08:29, 10.45s/it]

49


50it [08:39, 10.42s/it]

50


51it [08:49, 10.32s/it]

51


52it [09:00, 10.43s/it]

52


53it [09:10, 10.22s/it]

53


54it [09:20, 10.26s/it]

54


55it [09:30, 10.32s/it]

55


56it [09:41, 10.39s/it]

56


57it [09:52, 10.45s/it]

57


58it [10:01, 10.25s/it]

58


59it [10:12, 10.28s/it]

59


60it [10:21, 10.11s/it]

60


61it [10:32, 10.28s/it]

61


62it [10:43, 10.46s/it]

62


63it [10:53, 10.45s/it]

63


64it [11:04, 10.42s/it]

64


65it [11:14, 10.49s/it]

65


66it [11:25, 10.49s/it]

66


67it [11:35, 10.40s/it]

67


68it [11:45, 10.32s/it]

68


69it [11:55, 10.27s/it]

69


70it [12:06, 10.24s/it]

70


71it [12:16, 10.29s/it]

71


72it [12:26, 10.36s/it]

72


73it [12:37, 10.46s/it]

73


74it [12:47, 10.34s/it]

74


75it [12:58, 10.37s/it]

75


76it [13:08, 10.35s/it]

76


77it [13:18, 10.37s/it]

77


78it [13:29, 10.42s/it]

78


79it [13:39, 10.45s/it]

79


80it [13:50, 10.44s/it]

80


81it [14:01, 10.56s/it]

81


82it [14:11, 10.51s/it]

82


83it [14:21, 10.33s/it]

83


84it [14:31, 10.35s/it]

84


85it [14:42, 10.36s/it]

85


86it [14:53, 10.54s/it]

86


87it [15:03, 10.41s/it]

87


88it [15:13, 10.36s/it]

88


89it [15:24, 10.45s/it]

89


90it [15:35, 10.59s/it]

90


91it [15:45, 10.48s/it]

91


92it [15:56, 10.68s/it]

92


93it [16:07, 10.73s/it]

93


94it [16:17, 10.59s/it]

94


95it [16:28, 10.53s/it]

95


96it [16:38, 10.44s/it]

96


97it [16:48, 10.38s/it]

97


98it [16:58, 10.33s/it]

98


99it [17:09, 10.50s/it]

99


100it [17:20, 10.54s/it]

100


101it [17:30, 10.54s/it]

101


102it [17:41, 10.63s/it]

102


103it [17:52, 10.65s/it]

103


104it [18:02, 10.56s/it]

104


105it [18:13, 10.56s/it]

105


106it [18:23, 10.51s/it]

106


107it [18:34, 10.49s/it]

107


108it [18:44, 10.59s/it]

108


109it [18:55, 10.46s/it]

109


110it [19:05, 10.45s/it]

110


111it [19:16, 10.66s/it]

111


112it [19:26, 10.56s/it]

112


113it [19:37, 10.62s/it]

113


114it [19:48, 10.57s/it]

114


115it [19:58, 10.54s/it]

115


116it [20:08, 10.43s/it]

116


117it [20:19, 10.42s/it]

117


118it [20:29, 10.26s/it]

118


119it [20:39, 10.20s/it]

119


120it [20:50, 10.41s/it]

120


121it [21:00, 10.41s/it]

121


122it [21:10, 10.26s/it]

122


123it [21:20, 10.28s/it]

123


124it [21:30, 10.26s/it]

124


125it [21:41, 10.25s/it]

125


126it [21:51, 10.34s/it]

126


127it [22:02, 10.36s/it]

127


128it [22:13, 10.51s/it]

128


129it [22:23, 10.53s/it]

129


130it [22:33, 10.37s/it]

130


131it [22:43, 10.19s/it]

131


132it [22:53, 10.31s/it]

132


133it [23:04, 10.31s/it]

133


134it [23:14, 10.42s/it]

134


135it [23:25, 10.35s/it]

135


136it [23:35, 10.47s/it]

136


137it [23:46, 10.45s/it]

137


138it [23:56, 10.30s/it]

138


139it [24:06, 10.36s/it]

139


140it [24:17, 10.42s/it]

140


141it [24:27, 10.47s/it]

141


142it [24:38, 10.42s/it]

142


143it [24:48, 10.50s/it]

143


144it [24:59, 10.59s/it]

144


145it [25:09, 10.34s/it]

145


146it [25:20, 10.48s/it]

146


147it [25:30, 10.56s/it]

147


148it [25:41, 10.45s/it]

148


149it [25:51, 10.47s/it]

149


150it [26:01, 10.38s/it]

150


151it [26:12, 10.36s/it]

151


152it [26:23, 10.51s/it]

152


153it [26:32, 10.27s/it]

153


154it [26:43, 10.32s/it]

154


155it [26:53, 10.30s/it]

155


156it [27:03, 10.37s/it]

156


157it [27:14, 10.55s/it]

157


158it [27:25, 10.54s/it]

158


159it [27:35, 10.51s/it]

159


160it [27:46, 10.55s/it]

160


161it [27:56, 10.48s/it]

161


162it [28:07, 10.43s/it]

162


163it [28:17, 10.45s/it]

163


164it [28:28, 10.62s/it]

164


165it [28:39, 10.64s/it]

165


166it [28:49, 10.59s/it]

166


167it [29:00, 10.56s/it]

167


168it [29:10, 10.48s/it]

168


169it [29:21, 10.51s/it]

169


170it [29:31, 10.42s/it]

170


171it [29:41, 10.30s/it]

171


172it [29:52, 10.48s/it]

172


173it [30:03, 10.55s/it]

173


174it [30:13, 10.38s/it]

174


175it [30:23, 10.41s/it]

175


176it [30:33, 10.37s/it]

176


177it [30:44, 10.47s/it]

177


178it [30:54, 10.34s/it]

178


179it [31:05, 10.47s/it]

179


180it [31:15, 10.37s/it]

180


181it [31:26, 10.59s/it]

181


182it [31:37, 10.59s/it]

182


183it [31:47, 10.54s/it]

183


184it [31:57, 10.36s/it]

184


185it [32:07, 10.36s/it]

185


186it [32:18, 10.31s/it]

186


187it [32:28, 10.39s/it]

187


188it [32:39, 10.50s/it]

188


189it [32:49, 10.40s/it]

189


190it [32:59, 10.38s/it]

190


191it [33:10, 10.39s/it]

191


192it [33:20, 10.29s/it]

192


193it [33:30, 10.32s/it]

193


194it [33:40, 10.42s/it]


In [45]:

torch.cuda.get_device_name(0)


'NVIDIA A30'


One can observe how the model starts to generate more positive outputs after a few optimisation steps.

Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient.



Let's inspect some examples from the IMDB dataset. We can use model_ref to compare the tuned model model against the model before optimisation.


In [46]:

#### get a batch from the dataset
bs                 = 16
game_data          = dict()


In [47]:

game_data 


{}

In [48]:

dataset.set_format("pandas")


In [49]:

df_batch           = dataset[:].sample(bs)
df_batch 


Unnamed: 0,review,label,input_ids,query
9325,Well I guess I know the answer to that questio...,0,"[5779, 314]",Well I
23943,"This is an excellent, fast paced thriller by W...",1,"[1212, 318, 281, 6275, 11, 3049]","This is an excellent, fast"
10309,"Now, I flicked onto this just out of curiosity...",0,"[3844, 11, 314, 781]","Now, I fl"
14702,We tend to forget that the master/slave contex...,1,"[1135, 4327, 284, 6044, 326, 262]",We tend to forget that the
4466,"The proverb ""Never judge a book by it's cover""...",0,"[464, 36950]",The proverb
8982,I've never understood the appeal of Garbo. She...,0,"[40, 1053, 1239, 7247]",I've never understood
14943,"Hugh (Ed Harris) is a hotshot, bachelor senato...",1,"[39, 6724, 357, 7407, 10026]",Hugh (Ed Harris
16515,This particular Joe McDoakes short subject was...,1,"[1212, 1948, 5689, 1982]",This particular Joe Mc
13573,Sisters In Law is made by the same directors o...,1,"[50, 6223, 554, 3854, 318, 925, 416]",Sisters In Law is made by
16473,I was very fond of this film. It kept me guess...,1,"[40, 373, 845, 16245, 286]",I was very fond of


In [50]:

game_data["query"] = df_batch["query"].tolist()
query_tensors      = df_batch["input_ids"].tolist()


In [51]:

response_tensors_ref, response_tensors = [], []


In [52]:

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    
    
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)


In [53]:

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]


In [54]:

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]




In [55]:

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]


In [56]:

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results


Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,Well I,don't know why they,"love THIS, wonderful film",-1.327288,2.887235
1,"This is an excellent, fast","-paced ride, with big action scenes",moving story with a nice idea of contemporary,2.836497,2.924836
2,"Now, I fl",inched. It was so wrong,"ocked to this hilarious show, and",-1.44626,2.764766
3,We tend to forget that the,picture itself was wish,"movie is thrilling,",-0.736429,2.090948
4,The proverb,on the Désir de Rathmeister's installation wa...,keeps a very pleasant note...well written.(19...,0.27414,2.469272
5,I've never understood,more by this point in your,this wonderful script & still encourage,1.372031,2.755003
6,Hugh (Ed Harris,", ""The Hurt Locker"") helps",) is marvelous and fictional. It's,0.467258,2.578309
7,This particular Joe Mc,Avoy came into her film as if,"Govern shone, with his Voyager light and",1.32904,2.701365
8,Sisters In Law is made by,Donna Hawley and Lauren Ridge. The two stars ...,a group of filmmakers who gave Lawrence Sim's...,1.828336,2.65118
9,I was very fond of,Korea. I love all of their eccentricities. My,"her throughout you really, now very fond of. She",2.317214,2.652236


In [57]:

print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())



mean:


rewards (before)    0.187238
rewards (after)     2.513685
dtype: float64


median:


rewards (before)    0.283274
rewards (after)     2.651708
dtype: float64

In [58]:

## model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=True)
## tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=True)

model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=False)
tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=False)



('gpt2-imdb-pos-v2/tokenizer_config.json',
 'gpt2-imdb-pos-v2/special_tokens_map.json',
 'gpt2-imdb-pos-v2/vocab.json',
 'gpt2-imdb-pos-v2/merges.txt',
 'gpt2-imdb-pos-v2/added_tokens.json',
 'gpt2-imdb-pos-v2/tokenizer.json')