## Setup experiment

### Import dependencies

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

## %pip install transformers trl wandb


In [3]:

## Python >= 3.8  


In [4]:

# !pip install transformers
## !pip install wandb
# !pip install pandas
#!pip install datasets
#!pip install accelerate
# !pip install tyro
#!pip install trl

In [5]:
import torch
from tqdm import tqdm
import pandas as pd
import wandb
import os

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
comet_ml is installed but `COMET_API_KEY` is not set.


### Configuration

In [6]:

config = PPOConfig(
    model_name    = "lvwerra/gpt2-imdb", 
    learning_rate = 1.41e-5,
    log_with      = "wandb",
)

sent_kwargs = {
         "return_all_scores": True, 
         "function_to_apply": "none", 
         "batch_size": 16
}


In [7]:
#enable wandb
wandb.init()




[34m[1mwandb[0m: Currently logged in as: [33mturne292[0m ([33mdanielle_turner[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Load data and models

### Load IMDB dataset
The IMDB dataset contains 50k movie review annotated with "positive"/"negative" feedback indicating the sentiment.  We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the `LengthSampler`.

In [8]:
def build_dataset(config, dataset_name="turne292/tldr_sentiment_data_small_v2", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # load imdb with datasets
    
    ds = load_dataset(dataset_name, split="train")
    print(len(ds))
    ds = ds.rename_columns({"text": "prompt"})
    print(len(ds))
    ds = ds.filter(lambda x: len(x["label"]), batched=False)
    print(len(ds))

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode( sample["label"]    )[: input_size()]
        sample["query"]     = tokenizer.decode( sample["input_ids"] )
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds
    

In [9]:

dataset = build_dataset(config)
print(len(dataset))

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
    

Found cached dataset csv (C:/Users/danda/.cache/huggingface/datasets/turne292___csv/turne292--tldr_sentiment_data_small_v2-4916c9194fdcb323/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
Loading cached processed dataset at C:\Users\danda\.cache\huggingface\datasets\turne292___csv\turne292--tldr_sentiment_data_small_v2-4916c9194fdcb323\0.0.0\6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1\cache-8ff1f257cc15679d.arrow
Loading cached processed dataset at C:\Users\danda\.cache\huggingface\datasets\turne292___csv\turne292--tldr_sentiment_data_small_v2-4916c9194fdcb323\0.0.0\6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1\cache-0d1bd76e00b6b4fb.arrow


47217
47217
47217
47217


In [10]:
dataset[55]

{'prompt': ' r/relationships\nTITLE: How do I (F22) tell an oblivious roommate (F22) that she is gross?\nPOST: Not sure if this is the right place to post. My roommate just moved in a few weeks ago and I am having trouble...\n\nI already told her directly "Let\'s please make sure all the food is rinsed off the dishes and utensils before we put them in the rack." and "I cleaned the counters today, so let\'s clean up after ourselves every time we cook so we can maintain it."\n\nBut every day there is meat juice, pieces of chicke',
 'label': 'negative',
 'input_ids': tensor([31591]),
 'query': 'negative'}

### Load pre-trained GPT2 language models

We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model.

In [11]:
model     = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token


  torch.utils._pytree._register_pytree_node(


### Initialize PPOTrainer
The `PPOTrainer` takes care of device placement and optimization later on:

In [12]:

ppo_trainer = PPOTrainer(
                   config, 
                   model, 
                   ref_model, 
                   tokenizer, 
                   dataset       = dataset, 
                   data_collator = collator
)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888925108, max=1.0…

### Load BERT classifier
We load a BERT classifier fine-tuned on the IMDB dataset.

In [13]:

device = ppo_trainer.accelerator.device

if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
    
sentiment_pipe = pipeline("sentiment-analysis", model="turne292/distilbert-tldr-fine-small", device=device)

Downloading config.json:   0%|          | 0.00/478 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading pytorch_model.bin:   0%|          | 0.00/266M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at turne292/distilbert-tldr-fine-small and are newly initialized: ['pre_classifier.weight', 'classifier.bias', 'classifier.weight', 'pre_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading tokenizer_config.json:   0%|          | 0.00/372 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model.

In [14]:

text = "this movie was really bad!!"

sentiment_pipe(text, **sent_kwargs)




[[{'label': 'LABEL_0', 'score': -0.05962694436311722},
  {'label': 'LABEL_1', 'score': 0.0535535104572773}]]

In [15]:

text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)


[[{'label': 'LABEL_0', 'score': -0.052025485783815384},
  {'label': 'LABEL_1', 'score': 0.04946254938840866}]]

### Generation settings
For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length.

In [16]:

gen_kwargs = {
         "min_length":   -1, 
         "top_k":       0.0, 
         "top_p":       1.0, 
         "do_sample":  True, 
         "pad_token_id": tokenizer.eos_token_id
}


## Optimize model

### Training loop

The training loop consists of the following main steps:
1. Get the query responses from the policy network (GPT-2)
2. Get sentiments for query/responses from BERT
3. Optimize policy with PPO using the (query, response, reward) triplet

**Training time**

This step takes **~2h** on a V100 GPU with the above specified settings.

In [17]:

output_min_length     = 4
output_max_length     = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

wandb.watch(model, log_freq=100)
wandb.watch(ref_model, log_freq=100)

[]

In [18]:

generation_kwargs = {
    "min_length":     -1,
    "top_k":         0.0,
    "top_p":         1.0,
    "do_sample":    True,
    "pad_token_id": tokenizer.eos_token_id,
}


In [19]:

## ppo_trainer.config.steps = 100    ## 20,000
ppo_trainer.config.steps


20000

In [20]:
'''
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    print(batch)
    input()
'''

'\nfor epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n    print(batch)\n    input()\n'

In [21]:


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(
                     query_tensors, 
                     response_tensors, 
                     rewards
    )
    ppo_trainer.log_stats(stats, batch, rewards)
    

0it [00:00, ?it/s]

0


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [00:22, 22.40s/it]

1


2it [00:44, 22.19s/it]

2


3it [01:04, 21.33s/it]

3


4it [01:27, 21.87s/it]

4


5it [01:52, 22.91s/it]

5


6it [02:14, 22.77s/it]

6


7it [02:37, 22.69s/it]

7


8it [02:58, 22.14s/it]

8


9it [03:21, 22.62s/it]

9


10it [03:42, 22.15s/it]

10


11it [04:05, 22.17s/it]

11


12it [04:25, 21.66s/it]

12


13it [04:46, 21.33s/it]

13


14it [05:07, 21.18s/it]

14


15it [05:29, 21.43s/it]

15


16it [05:49, 21.17s/it]

16


17it [06:11, 21.51s/it]

17


18it [06:33, 21.53s/it]

18


19it [06:55, 21.55s/it]

19


20it [07:15, 21.22s/it]

20


21it [07:36, 21.01s/it]

21


22it [07:58, 21.30s/it]

22


23it [08:19, 21.22s/it]

23


24it [08:39, 21.03s/it]

24


25it [09:00, 21.03s/it]

25


26it [09:22, 21.37s/it]

26


27it [09:43, 21.08s/it]

27


28it [10:05, 21.49s/it]

28


29it [10:26, 21.23s/it]

29


30it [10:46, 20.87s/it]

30


31it [11:06, 20.64s/it]

31


32it [11:27, 20.80s/it]

32


33it [11:48, 20.82s/it]

33


34it [12:09, 20.96s/it]

34


35it [12:30, 20.85s/it]

35


36it [12:50, 20.59s/it]

36


37it [13:11, 20.76s/it]

37


38it [13:31, 20.56s/it]

38


39it [13:52, 20.62s/it]

39


40it [14:13, 20.61s/it]

40


41it [14:33, 20.59s/it]

41


42it [14:53, 20.51s/it]

42


43it [15:15, 20.86s/it]

43


44it [15:35, 20.59s/it]

44


45it [15:56, 20.65s/it]

45


46it [16:16, 20.60s/it]

46


47it [16:38, 20.83s/it]

47


48it [16:59, 20.90s/it]

48


49it [17:18, 20.43s/it]

49


50it [17:38, 20.28s/it]

50


51it [17:58, 20.32s/it]

51


52it [18:22, 21.23s/it]

52


53it [18:43, 21.14s/it]

53


54it [19:05, 21.35s/it]

54


55it [19:26, 21.31s/it]

55


56it [19:47, 21.22s/it]

56


57it [20:07, 20.91s/it]

57


58it [20:29, 21.12s/it]

58


59it [20:49, 20.99s/it]

59


60it [21:10, 20.81s/it]

60


61it [21:31, 20.95s/it]

61


62it [21:52, 21.04s/it]

62


63it [22:12, 20.75s/it]

63


64it [22:34, 21.03s/it]

64


65it [22:54, 20.86s/it]

65


66it [23:15, 20.73s/it]

66


67it [23:36, 20.75s/it]

67


68it [23:56, 20.71s/it]

68


69it [24:17, 20.61s/it]

69


70it [24:38, 20.75s/it]

70


71it [24:58, 20.71s/it]

71


72it [25:20, 20.97s/it]

72


73it [25:42, 21.29s/it]

73


74it [26:03, 21.15s/it]

74


75it [26:24, 21.09s/it]

75


76it [26:44, 20.97s/it]

76


77it [27:06, 21.08s/it]

77


78it [27:26, 20.98s/it]

78


79it [27:52, 22.28s/it]

79


80it [28:12, 21.78s/it]

80


81it [28:33, 21.58s/it]

81


82it [28:54, 21.39s/it]

82


83it [29:17, 21.77s/it]

83


84it [29:38, 21.40s/it]

84


85it [29:58, 21.26s/it]

85


86it [30:21, 21.71s/it]

86


87it [30:43, 21.67s/it]

87


88it [31:04, 21.46s/it]

88


89it [31:23, 20.90s/it]

89


90it [31:45, 21.18s/it]

90


91it [32:07, 21.26s/it]

91


92it [32:28, 21.20s/it]

92


93it [32:49, 21.16s/it]

93


94it [33:10, 21.16s/it]

94


95it [33:30, 20.83s/it]

95


96it [33:52, 21.14s/it]

96


97it [34:13, 21.25s/it]

97


98it [34:34, 21.20s/it]

98


99it [34:55, 21.08s/it]

99


100it [35:17, 21.34s/it]

100


101it [35:37, 20.92s/it]

101


102it [35:58, 20.90s/it]

102


103it [36:18, 20.74s/it]

103


104it [36:39, 20.75s/it]

104


105it [37:01, 21.07s/it]

105


106it [37:23, 21.27s/it]

106


107it [37:44, 21.15s/it]

107


108it [38:05, 21.09s/it]

108


109it [38:26, 21.14s/it]

109


110it [38:48, 21.39s/it]

110


111it [39:09, 21.45s/it]

111


112it [39:29, 21.05s/it]

112


113it [39:51, 21.10s/it]

113


114it [40:11, 20.92s/it]

114


115it [40:31, 20.73s/it]

115


116it [40:51, 20.34s/it]

116


117it [41:12, 20.55s/it]

117


118it [41:32, 20.45s/it]

118


119it [41:53, 20.47s/it]

119


120it [42:14, 20.72s/it]

120


121it [42:35, 20.69s/it]

121


122it [42:54, 20.42s/it]

122


123it [43:16, 20.69s/it]

123


124it [43:36, 20.72s/it]

124


125it [43:57, 20.79s/it]

125


126it [44:19, 20.88s/it]

126


127it [44:39, 20.71s/it]

127


128it [45:00, 20.86s/it]

128


129it [45:21, 20.98s/it]

129


130it [45:42, 20.76s/it]

130


131it [46:02, 20.74s/it]

131


132it [46:24, 21.04s/it]

132


133it [46:44, 20.63s/it]

133


134it [47:05, 20.74s/it]

134


135it [47:24, 20.25s/it]

135


136it [47:46, 20.81s/it]

136


137it [48:07, 20.79s/it]

137


138it [48:24, 19.90s/it]

138


139it [48:53, 22.50s/it]

139


140it [49:13, 21.79s/it]

140


141it [49:32, 20.94s/it]

141


142it [49:53, 20.94s/it]

142


143it [50:13, 20.59s/it]

143


144it [50:32, 20.16s/it]

144


145it [50:51, 19.75s/it]

145


146it [51:10, 19.63s/it]

146


147it [51:30, 19.62s/it]

147


148it [51:49, 19.64s/it]

148


149it [52:09, 19.55s/it]

149


150it [52:28, 19.58s/it]

150


151it [52:48, 19.63s/it]

151


152it [53:07, 19.32s/it]

152


153it [53:26, 19.27s/it]

153


154it [53:46, 19.38s/it]

154


155it [54:04, 18.98s/it]

155


156it [54:23, 19.17s/it]

156


157it [54:43, 19.21s/it]

157


158it [55:02, 19.24s/it]

158


159it [55:22, 19.42s/it]

159


160it [55:42, 19.59s/it]

160


161it [56:00, 19.27s/it]

161


162it [56:22, 19.96s/it]

162


163it [56:42, 20.06s/it]

163


164it [57:03, 20.20s/it]

164


165it [57:23, 20.16s/it]

165


166it [57:42, 19.98s/it]

166


167it [58:02, 19.95s/it]

167


168it [58:22, 19.95s/it]

168


169it [58:41, 19.69s/it]

169


170it [58:59, 19.16s/it]

170


171it [59:18, 19.16s/it]

171


172it [59:39, 19.54s/it]

172


173it [59:58, 19.50s/it]

173


174it [1:00:18, 19.50s/it]

174


175it [1:00:36, 19.34s/it]

175


176it [1:00:56, 19.31s/it]

176


177it [1:01:21, 21.06s/it]

177


178it [1:01:41, 20.66s/it]

178


179it [1:02:00, 20.22s/it]

179


180it [1:02:19, 19.81s/it]

180


181it [1:02:38, 19.74s/it]

181


182it [1:02:58, 19.80s/it]

182


183it [1:03:17, 19.59s/it]

183


184it [1:03:36, 19.21s/it]

184


185it [1:03:55, 19.26s/it]

185


186it [1:04:14, 19.33s/it]

186


187it [1:04:36, 19.90s/it]

187


188it [1:04:57, 20.18s/it]

188


189it [1:05:16, 20.09s/it]

189


190it [1:05:36, 20.03s/it]

190


191it [1:05:56, 19.84s/it]

191


192it [1:06:15, 19.71s/it]

192


193it [1:06:35, 19.72s/it]

193


194it [1:06:55, 19.73s/it]

194


195it [1:07:15, 19.79s/it]

195


196it [1:07:34, 19.72s/it]

196


197it [1:07:53, 19.52s/it]

197


198it [1:08:13, 19.57s/it]

198


199it [1:08:32, 19.53s/it]

199


200it [1:08:51, 19.29s/it]

200


201it [1:09:12, 19.66s/it]

201


202it [1:09:31, 19.62s/it]

202


203it [1:09:50, 19.33s/it]

203


204it [1:10:09, 19.31s/it]

204


205it [1:10:28, 19.18s/it]

205


206it [1:10:52, 20.71s/it]

206


207it [1:11:12, 20.59s/it]

207


208it [1:11:32, 20.24s/it]

208


209it [1:11:52, 20.12s/it]

209


210it [1:12:11, 19.86s/it]

210


211it [1:12:29, 19.31s/it]

211


212it [1:12:46, 18.68s/it]

212


213it [1:13:06, 18.90s/it]

213


214it [1:13:24, 18.80s/it]

214


215it [1:13:44, 19.05s/it]

215


216it [1:14:04, 19.26s/it]

216


217it [1:14:23, 19.23s/it]

217


218it [1:14:42, 19.25s/it]

218


219it [1:15:02, 19.52s/it]

219


220it [1:15:22, 19.58s/it]

220


221it [1:15:41, 19.58s/it]

221


222it [1:16:00, 19.36s/it]

222


223it [1:16:22, 20.18s/it]

223


224it [1:16:43, 20.23s/it]

224


225it [1:17:03, 20.32s/it]

225


226it [1:17:23, 20.12s/it]

226


227it [1:17:42, 19.94s/it]

227


228it [1:18:01, 19.52s/it]

228


229it [1:18:21, 19.77s/it]

229


230it [1:18:40, 19.58s/it]

230


231it [1:19:00, 19.47s/it]

231


232it [1:19:21, 20.13s/it]

232


233it [1:19:41, 19.92s/it]

233


234it [1:20:02, 20.34s/it]

234


235it [1:20:22, 20.13s/it]

235


236it [1:20:41, 19.82s/it]

236


237it [1:21:00, 19.69s/it]

237


238it [1:21:20, 19.81s/it]

238


239it [1:21:39, 19.47s/it]

239


240it [1:21:58, 19.42s/it]

240


241it [1:22:18, 19.57s/it]

241


242it [1:22:38, 19.59s/it]

242


243it [1:22:57, 19.41s/it]

243


244it [1:23:16, 19.31s/it]

244


245it [1:23:35, 19.24s/it]

245


246it [1:23:54, 19.18s/it]

246


247it [1:24:13, 19.02s/it]

247


248it [1:24:32, 18.97s/it]

248


249it [1:24:50, 18.82s/it]

249


250it [1:25:10, 19.18s/it]

250


251it [1:25:30, 19.36s/it]

251


252it [1:25:52, 20.22s/it]

252


253it [1:26:12, 20.01s/it]

253


254it [1:26:30, 19.46s/it]

254


255it [1:26:51, 20.16s/it]

255


256it [1:27:11, 19.97s/it]

256


257it [1:27:29, 19.39s/it]

257


258it [1:27:52, 20.55s/it]

258


259it [1:28:12, 20.28s/it]

259


260it [1:28:32, 20.14s/it]

260


261it [1:28:52, 20.05s/it]

261


261it [1:29:07, 20.49s/it]


KeyboardInterrupt: 

## Model inspection
Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation.

In [22]:
#### get a batch from the dataset
bs                 = 16
game_data          = dict()
dataset.set_format("pandas")
df_batch           = dataset[:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors      = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,negative,»»<br /><br />This,"that is owned, it's predecessor in offering",0.047334,0.055765
1,negative,e is of the,hour on the set,0.019111,0.090814
2,negative,", but wouldn't waste",negative pickings resembles the,0.052664,0.03038
3,negative,", Harold Reicher",in general. The first,0.029283,0.037445
4,positive,"–toler, relatives ironic, music friendly chara...","characters, from Saddam to Henry; Henry being...",0.019365,0.009292
5,negative,Lod Plot's shuddering end.,oppressed people who trample on the,0.038886,0.065381
6,positive,I wouldn't think of this,reservation which breaks down in the,0.071503,0.059537
7,negative,manage to manage to ignore all,and a pleasure. ;-),0.052433,0.070395
8,positive,<br /><,rather than my rating,0.044724,0.040793
9,negative,", not to mention highly boring. This series ca...","and will pirate it, but at regular intervals ...",0.040651,0.055431


Looking at the reward mean/median of the generated sequences we observe a significant difference.

In [23]:
print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())

mean:


rewards (before)    0.044329
rewards (after)     0.044345
dtype: float64


median:


rewards (before)    0.042688
rewards (after)     0.045970
dtype: float64

## Save model
Finally, we save the model and push it to the Hugging Face for later usage.

In [None]:

## model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=True)
## tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=True)

model.save_pretrained(    "./RLHFmodel/gpt2-tldr-pos", push_to_hub=False)
tokenizer.save_pretrained("./RLHFmodel/gpt2-tldr-pos", push_to_hub=False)


In [None]:

print(tqdm)
