## Setup experiment

### Import dependencies

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

## %pip install transformers trl wandb


In [3]:

## Python >= 3.8  


In [4]:

# !pip install transformers
## !pip install wandb
# !pip install pandas
#!pip install datasets
#!pip install accelerate
# !pip install tyro
#!pip install trl

In [5]:
import torch
from tqdm import tqdm
import pandas as pd
import wandb
import os

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
comet_ml is installed but `COMET_API_KEY` is not set.


### Configuration

In [6]:

config = PPOConfig(
    model_name    = "lvwerra/gpt2-imdb", 
    learning_rate = 1.41e-5,
    ## log_with      = "wandb",
)

sent_kwargs = {
         "return_all_scores": True, 
         "function_to_apply": "none", 
         "batch_size": 16
}


In [7]:
#enable wandb
wandb.init()




[34m[1mwandb[0m: Currently logged in as: [33mturne292[0m ([33mdanielle_turner[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Load data and models

### Load IMDB dataset
The IMDB dataset contains 50k movie review annotated with "positive"/"negative" feedback indicating the sentiment.  We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the `LengthSampler`.

In [8]:
def build_dataset(config, dataset_name="turne292/tldr_sentiment_data_small_v2", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # load imdb with datasets
    
    ds = load_dataset(dataset_name, split="train")
    ds = ds.filter(lambda x: len(x["text"]), batched=False)
    print(len(ds))

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode( sample["text"]    )[: input_size()]
        sample["query"]     = tokenizer.decode( sample["input_ids"] )
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds
    

In [9]:

dataset = build_dataset(config)
print(len(dataset))

def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
    

Found cached dataset csv (C:/Users/danda/.cache/huggingface/datasets/turne292___csv/turne292--tldr_sentiment_data_small_v2-4916c9194fdcb323/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


Filter:   0%|          | 0/47217 [00:00<?, ? examples/s]

47217


Map:   0%|          | 0/47217 [00:00<?, ? examples/s]

47217


In [10]:
dataset[55]

{'text': ' r/relationships\nTITLE: How do I (F22) tell an oblivious roommate (F22) that she is gross?\nPOST: Not sure if this is the right place to post. My roommate just moved in a few weeks ago and I am having trouble...\n\nI already told her directly "Let\'s please make sure all the food is rinsed off the dishes and utensils before we put them in the rack." and "I cleaned the counters today, so let\'s clean up after ourselves every time we cook so we can maintain it."\n\nBut every day there is meat juice, pieces of chicke',
 'label': 'negative',
 'input_ids': tensor([  374,    14, 39468,  5748]),
 'query': ' r/relationships'}

### Load pre-trained GPT2 language models

We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model.

In [11]:
model     = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token


  torch.utils._pytree._register_pytree_node(


### Initialize PPOTrainer
The `PPOTrainer` takes care of device placement and optimization later on:

In [12]:

ppo_trainer = PPOTrainer(
                   config, 
                   model, 
                   ref_model, 
                   tokenizer, 
                   dataset       = dataset, 
                   data_collator = collator
)


### Load BERT classifier
We load a BERT classifier fine-tuned on the IMDB dataset.

In [13]:

device = ppo_trainer.accelerator.device

if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
    
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model.

In [14]:

text = "this movie was really bad!!"

sentiment_pipe(text, **sent_kwargs)




[[{'label': 'NEGATIVE', 'score': 2.3350486755371094},
  {'label': 'POSITIVE', 'score': -2.726576328277588}]]

In [15]:

text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)


[[{'label': 'NEGATIVE', 'score': -2.294790029525757},
  {'label': 'POSITIVE', 'score': 2.557039737701416}]]

### Generation settings
For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length.

In [16]:

gen_kwargs = {
         "min_length":   -1, 
         "top_k":       0.0, 
         "top_p":       1.0, 
         "do_sample":  True, 
         "pad_token_id": tokenizer.eos_token_id
}


## Optimize model

### Training loop

The training loop consists of the following main steps:
1. Get the query responses from the policy network (GPT-2)
2. Get sentiments for query/responses from BERT
3. Optimize policy with PPO using the (query, response, reward) triplet

**Training time**

This step takes **~2h** on a V100 GPU with the above specified settings.

In [17]:

output_min_length     = 4
output_max_length     = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

wandb.watch(model, log_freq=100)
wandb.watch(ref_model, log_freq=100)

[]

In [18]:

generation_kwargs = {
    "min_length":     -1,
    "top_k":         0.0,
    "top_p":         1.0,
    "do_sample":    True,
    "pad_token_id": tokenizer.eos_token_id,
}


In [19]:

## ppo_trainer.config.steps = 100    ## 20,000
ppo_trainer.config.steps


20000

In [20]:
'''
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    print(batch)
    input()
'''

'\nfor epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n    print(batch)\n    input()\n'

In [24]:


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(
                     query_tensors, 
                     response_tensors, 
                     rewards
    )
    ppo_trainer.log_stats(stats, batch, rewards)
    

0it [00:00, ?it/s]

0


1it [00:20, 20.00s/it]

1


2it [00:38, 19.17s/it]

2


3it [00:58, 19.64s/it]

3


4it [01:18, 19.57s/it]

4


5it [01:37, 19.61s/it]

5


6it [01:57, 19.56s/it]

6


7it [02:16, 19.51s/it]

7


8it [02:35, 19.15s/it]

8


9it [02:54, 19.10s/it]

9


10it [03:14, 19.45s/it]

10


11it [03:35, 19.94s/it]

11


12it [03:55, 19.86s/it]

12


13it [04:14, 19.72s/it]

13


14it [04:33, 19.63s/it]

14


15it [04:53, 19.55s/it]

15


16it [05:12, 19.57s/it]

16


17it [05:32, 19.70s/it]

17


18it [05:53, 19.91s/it]

18


19it [06:12, 19.70s/it]

19


20it [06:30, 19.18s/it]

20


21it [06:48, 18.75s/it]

21


22it [07:07, 19.02s/it]

22


23it [07:26, 19.00s/it]

23


24it [07:46, 19.24s/it]

24


25it [08:05, 19.16s/it]

25


26it [08:25, 19.36s/it]

26


27it [08:44, 19.38s/it]

27


28it [09:03, 19.02s/it]

28


29it [09:22, 19.17s/it]

29


30it [09:42, 19.27s/it]

30


31it [10:01, 19.30s/it]

31


32it [10:20, 19.25s/it]

32


33it [10:40, 19.41s/it]

33


34it [11:00, 19.76s/it]

34


35it [11:18, 19.21s/it]

35


36it [11:39, 19.67s/it]

36


37it [11:59, 19.87s/it]

37


38it [12:19, 19.70s/it]

38


39it [12:38, 19.45s/it]

39


40it [12:55, 18.95s/it]

40


41it [13:13, 18.61s/it]

41


42it [13:32, 18.79s/it]

42


43it [13:49, 18.14s/it]

43


44it [14:07, 18.19s/it]

44


45it [14:25, 17.89s/it]

45


46it [14:43, 17.99s/it]

46


47it [15:03, 18.71s/it]

47


48it [15:22, 18.79s/it]

48


49it [15:40, 18.64s/it]

49


50it [15:59, 18.72s/it]

50


51it [16:17, 18.45s/it]

51


52it [16:35, 18.14s/it]

52


53it [16:53, 18.14s/it]

53


54it [17:12, 18.55s/it]

54


55it [17:31, 18.55s/it]

55


56it [17:50, 18.72s/it]

56


57it [18:08, 18.46s/it]

57


58it [18:26, 18.43s/it]

58


59it [18:44, 18.38s/it]

59


60it [19:02, 18.18s/it]

60


61it [19:18, 17.58s/it]

61


62it [19:37, 17.78s/it]

62


63it [19:55, 17.84s/it]

63


64it [20:11, 17.53s/it]

64


65it [20:29, 17.67s/it]

65


66it [20:47, 17.59s/it]

66


67it [21:05, 17.71s/it]

67


68it [21:22, 17.60s/it]

68


69it [21:41, 17.96s/it]

69


70it [21:59, 17.90s/it]

70


71it [22:17, 18.18s/it]

71


72it [22:35, 18.06s/it]

72


73it [22:53, 18.00s/it]

73


74it [23:10, 17.64s/it]

74


75it [23:27, 17.51s/it]

75


76it [23:45, 17.54s/it]

76


77it [24:03, 17.84s/it]

77


78it [24:22, 18.09s/it]

78


79it [24:39, 17.93s/it]

79


80it [24:57, 17.81s/it]

80


81it [25:14, 17.71s/it]

81


82it [25:33, 17.87s/it]

82


83it [25:50, 17.82s/it]

83


84it [26:08, 17.64s/it]

84


85it [26:26, 17.74s/it]

85


86it [26:43, 17.73s/it]

86


87it [27:01, 17.58s/it]

87


88it [27:18, 17.65s/it]

88


89it [27:36, 17.69s/it]

89


90it [27:53, 17.55s/it]

90


91it [28:11, 17.69s/it]

91


92it [28:29, 17.66s/it]

92


93it [28:46, 17.40s/it]

93


94it [29:03, 17.41s/it]

94


95it [29:22, 17.74s/it]

95


96it [29:39, 17.48s/it]

96


97it [29:57, 17.81s/it]

97


98it [30:14, 17.63s/it]

98


99it [30:32, 17.60s/it]

99


100it [30:52, 18.26s/it]

100


101it [31:11, 18.65s/it]

101


102it [31:29, 18.45s/it]

102


103it [31:47, 18.22s/it]

103


104it [32:03, 17.59s/it]

104


105it [32:21, 17.59s/it]

105


106it [32:38, 17.61s/it]

106


107it [32:56, 17.61s/it]

107


108it [33:14, 17.64s/it]

108


109it [33:32, 17.87s/it]

109


110it [33:50, 17.87s/it]

110


111it [34:08, 18.07s/it]

111


112it [34:26, 17.81s/it]

112


113it [34:43, 17.76s/it]

113


114it [35:01, 17.64s/it]

114


115it [35:19, 17.98s/it]

115


116it [35:37, 17.88s/it]

116


117it [35:55, 17.97s/it]

117


118it [36:12, 17.75s/it]

118


119it [36:30, 17.80s/it]

119


120it [36:48, 17.84s/it]

120


121it [37:06, 17.83s/it]

121


122it [37:25, 18.03s/it]

122


123it [37:42, 17.77s/it]

123


124it [38:01, 18.13s/it]

124


125it [38:18, 17.86s/it]

125


126it [38:35, 17.59s/it]

126


127it [38:52, 17.53s/it]

127


128it [39:10, 17.67s/it]

128


129it [39:27, 17.29s/it]

129


130it [39:44, 17.32s/it]

130


131it [40:02, 17.47s/it]

131


132it [40:20, 17.51s/it]

132


133it [40:37, 17.47s/it]

133


134it [40:54, 17.26s/it]

134


135it [41:12, 17.49s/it]

135


136it [41:29, 17.51s/it]

136


137it [41:46, 17.36s/it]

137


138it [42:04, 17.44s/it]

138


139it [42:21, 17.37s/it]

139


140it [42:39, 17.45s/it]

140


141it [42:56, 17.49s/it]

141


142it [43:14, 17.56s/it]

142


143it [43:32, 17.69s/it]

143


144it [43:49, 17.41s/it]

144


145it [44:07, 17.70s/it]

145


146it [44:25, 17.62s/it]

146


147it [44:42, 17.51s/it]

147


148it [45:00, 17.81s/it]

148


149it [45:18, 17.71s/it]

149


150it [45:36, 17.89s/it]

150


151it [45:54, 17.89s/it]

151


152it [46:12, 18.03s/it]

152


153it [46:29, 17.71s/it]

153


154it [46:47, 17.63s/it]

154


155it [47:03, 17.31s/it]

155


156it [47:21, 17.34s/it]

156


157it [47:39, 17.53s/it]

157


158it [47:56, 17.44s/it]

158


159it [48:13, 17.36s/it]

159


160it [48:30, 17.32s/it]

160


161it [48:48, 17.45s/it]

161


162it [49:07, 17.81s/it]

162


163it [49:24, 17.73s/it]

163


164it [49:41, 17.51s/it]

164


165it [50:00, 17.88s/it]

165


166it [50:18, 17.93s/it]

166


167it [50:36, 17.96s/it]

167


168it [50:53, 17.66s/it]

168


169it [51:11, 17.76s/it]

169


170it [51:29, 17.86s/it]

170


171it [51:47, 17.81s/it]

171


172it [52:05, 17.94s/it]

172


173it [52:23, 17.90s/it]

173


174it [52:40, 17.73s/it]

174


175it [52:59, 17.87s/it]

175


176it [53:16, 17.73s/it]

176


177it [53:33, 17.53s/it]

177


178it [53:51, 17.60s/it]

178


179it [54:08, 17.64s/it]

179


180it [54:26, 17.65s/it]

180


181it [54:44, 17.69s/it]

181


182it [55:01, 17.65s/it]

182


183it [55:20, 17.81s/it]

183


184it [55:37, 17.75s/it]

184


185it [55:55, 17.69s/it]

185


186it [56:12, 17.50s/it]

186


187it [56:30, 17.72s/it]

187


188it [56:47, 17.62s/it]

188


189it [57:05, 17.66s/it]

189


190it [57:23, 17.69s/it]

190


191it [57:40, 17.49s/it]

191


192it [57:57, 17.40s/it]

192


193it [58:15, 17.64s/it]

193


194it [58:32, 17.39s/it]

194


195it [58:50, 17.62s/it]

195


196it [59:07, 17.26s/it]

196


197it [59:25, 17.54s/it]

197


198it [59:43, 17.61s/it]

198


199it [1:00:00, 17.45s/it]

199


200it [1:00:18, 17.55s/it]

200


201it [1:00:35, 17.65s/it]

201


202it [1:00:53, 17.60s/it]

202


203it [1:01:12, 18.10s/it]

203


204it [1:01:30, 18.00s/it]

204


205it [1:01:49, 18.33s/it]

205


206it [1:02:08, 18.36s/it]

206


207it [1:02:25, 18.08s/it]

207


208it [1:02:42, 17.87s/it]

208


209it [1:03:00, 17.78s/it]

209


210it [1:03:18, 17.91s/it]

210


211it [1:03:35, 17.68s/it]

211


212it [1:03:54, 17.99s/it]

212


213it [1:04:12, 17.97s/it]

213


214it [1:04:29, 17.67s/it]

214


215it [1:04:47, 17.76s/it]

215


216it [1:05:04, 17.60s/it]

216


217it [1:05:21, 17.50s/it]

217


218it [1:05:38, 17.39s/it]

218


219it [1:05:56, 17.52s/it]

219


220it [1:06:13, 17.29s/it]

220


221it [1:06:30, 17.27s/it]

221


222it [1:06:49, 17.60s/it]

222


223it [1:07:06, 17.60s/it]

223


224it [1:07:23, 17.48s/it]

224


225it [1:07:42, 17.76s/it]

225


226it [1:07:59, 17.71s/it]

226


227it [1:08:17, 17.75s/it]

227


228it [1:08:34, 17.56s/it]

228


229it [1:08:52, 17.63s/it]

229


230it [1:09:10, 17.62s/it]

230


231it [1:09:28, 17.81s/it]

231


232it [1:09:45, 17.54s/it]

232


233it [1:10:03, 17.75s/it]

233


234it [1:10:21, 17.88s/it]

234


235it [1:10:40, 18.16s/it]

235


236it [1:10:58, 18.10s/it]

236


237it [1:11:17, 18.25s/it]

237


238it [1:11:35, 18.31s/it]

238


239it [1:11:57, 19.41s/it]

239


240it [1:12:21, 20.71s/it]

240


241it [1:12:42, 20.69s/it]

241


242it [1:13:01, 20.43s/it]

242


243it [1:13:20, 19.81s/it]

243


244it [1:13:38, 19.46s/it]

244


245it [1:13:56, 19.05s/it]

245


246it [1:14:15, 19.00s/it]

246


247it [1:14:34, 18.93s/it]

247


248it [1:14:53, 18.84s/it]

248


249it [1:15:12, 18.95s/it]

249


250it [1:15:31, 18.84s/it]

250


251it [1:15:50, 18.98s/it]

251


252it [1:16:09, 18.89s/it]

252


253it [1:16:28, 19.03s/it]

253


254it [1:16:46, 18.78s/it]

254


255it [1:17:04, 18.49s/it]

255


256it [1:17:23, 18.81s/it]

256


257it [1:17:43, 19.12s/it]

257


258it [1:18:03, 19.20s/it]

258


259it [1:18:22, 19.31s/it]

259


260it [1:18:41, 19.27s/it]

260


261it [1:19:01, 19.25s/it]

261


262it [1:19:20, 19.41s/it]

262


263it [1:19:39, 19.05s/it]

263


264it [1:19:58, 19.04s/it]

264


265it [1:20:17, 19.09s/it]

265


266it [1:20:37, 19.46s/it]

266


267it [1:20:57, 19.68s/it]

267


268it [1:21:16, 19.34s/it]

268


269it [1:21:36, 19.50s/it]

269


270it [1:21:55, 19.28s/it]

270


271it [1:22:13, 19.09s/it]

271


272it [1:22:32, 19.06s/it]

272


273it [1:22:51, 18.86s/it]

273


274it [1:23:10, 18.91s/it]

274


275it [1:23:29, 18.99s/it]

275


276it [1:23:48, 18.98s/it]

276


277it [1:24:07, 18.98s/it]

277


278it [1:24:26, 19.19s/it]

278


279it [1:24:46, 19.19s/it]

279


280it [1:25:05, 19.37s/it]

280


281it [1:25:25, 19.32s/it]

281


282it [1:25:43, 19.08s/it]

282


283it [1:26:03, 19.25s/it]

283


284it [1:26:21, 18.91s/it]

284


285it [1:26:40, 19.12s/it]

285


286it [1:27:00, 19.34s/it]

286


287it [1:27:19, 19.23s/it]

287


288it [1:27:39, 19.33s/it]

288


289it [1:27:59, 19.65s/it]

289


290it [1:28:17, 19.22s/it]

290


291it [1:28:36, 19.09s/it]

291


292it [1:28:56, 19.31s/it]

292


293it [1:29:17, 19.64s/it]

293


294it [1:29:36, 19.61s/it]

294


295it [1:29:56, 19.65s/it]

295


296it [1:30:14, 19.23s/it]

296


297it [1:30:34, 19.34s/it]

297


298it [1:30:53, 19.20s/it]

298


299it [1:31:11, 19.12s/it]

299


300it [1:31:31, 19.20s/it]

300


301it [1:31:50, 19.09s/it]

301


302it [1:32:09, 19.22s/it]

302


303it [1:32:27, 18.80s/it]

303


304it [1:32:46, 18.81s/it]

304


305it [1:33:05, 18.86s/it]

305


306it [1:33:24, 19.03s/it]

306


307it [1:33:43, 19.00s/it]

307


308it [1:34:02, 18.93s/it]

308


309it [1:34:22, 19.20s/it]

309


310it [1:34:42, 19.57s/it]

310


311it [1:35:01, 19.45s/it]

311


312it [1:35:20, 19.35s/it]

312


313it [1:35:40, 19.45s/it]

313


314it [1:35:58, 19.01s/it]

314


315it [1:36:18, 19.12s/it]

315


316it [1:36:36, 18.84s/it]

316


317it [1:36:54, 18.82s/it]

317


318it [1:37:13, 18.71s/it]

318


319it [1:37:32, 18.74s/it]

319


320it [1:37:52, 19.14s/it]

320


321it [1:38:12, 19.41s/it]

321


322it [1:38:32, 19.65s/it]

322


323it [1:38:52, 19.69s/it]

323


324it [1:39:11, 19.61s/it]

324


325it [1:39:30, 19.31s/it]

325


326it [1:39:49, 19.20s/it]

326


327it [1:40:08, 19.21s/it]

327


328it [1:40:28, 19.28s/it]

328


329it [1:40:47, 19.24s/it]

329


330it [1:41:06, 19.18s/it]

330


331it [1:41:25, 19.30s/it]

331


332it [1:41:44, 19.14s/it]

332


333it [1:42:03, 19.17s/it]

333


334it [1:42:23, 19.28s/it]

334


335it [1:42:42, 19.37s/it]

335


336it [1:43:02, 19.32s/it]

336


337it [1:43:20, 18.99s/it]

337


338it [1:43:40, 19.22s/it]

338


339it [1:43:59, 19.27s/it]

339


340it [1:44:16, 18.67s/it]

340


341it [1:44:35, 18.62s/it]

341


342it [1:44:54, 18.76s/it]

342


343it [1:45:14, 19.11s/it]

343


344it [1:45:33, 19.13s/it]

344


345it [1:45:52, 19.04s/it]

345


346it [1:46:11, 18.97s/it]

346


347it [1:46:29, 18.85s/it]

347


348it [1:46:48, 18.90s/it]

348


349it [1:47:07, 19.00s/it]

349


350it [1:47:27, 19.17s/it]

350


351it [1:47:46, 19.06s/it]

351


352it [1:48:05, 18.97s/it]

352


353it [1:48:24, 18.98s/it]

353


354it [1:48:42, 18.74s/it]

354


355it [1:49:02, 19.08s/it]

355


356it [1:49:21, 19.11s/it]

356


357it [1:49:40, 19.06s/it]

357


358it [1:50:00, 19.34s/it]

358


359it [1:50:19, 19.40s/it]

359


360it [1:50:38, 19.35s/it]

360


361it [1:50:58, 19.54s/it]

361


362it [1:51:18, 19.62s/it]

362


363it [1:51:39, 19.91s/it]

363


364it [1:51:58, 19.64s/it]

364


365it [1:52:16, 19.33s/it]

365


366it [1:52:36, 19.51s/it]

366


367it [1:52:56, 19.50s/it]

367


368it [1:53:15, 18.47s/it]


## Model inspection
Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation.

In [22]:
#### get a batch from the dataset
bs                 = 16
game_data          = dict()
dataset.set_format("pandas")
df_batch           = dataset[:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors      = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,r/legaladvice,"), this film attempts to portray how marijuana...",rails through an extremely inventive and ente...,-2.338746,2.737933
1,r/relationships,between cornerstones.<br /><br />The,": good, multiplayer, good games, fun together",0.724979,2.726515
2,r/tifu\n,HBO's Mr. Kavod,`s superb storyline and excellent music.,-0.600709,2.803211
3,r/,h number together (or so I,"g films.) Bob Dylan, and",-0.072393,0.78945
4,r/relations,of the same genius in an independent,"with the original Square inhabited, fulfilling",0.992978,0.455365
5,r/relationships\n,The movie is about hazing and a guy commits su...,"This is a great production, it is great and it...",-0.995547,2.894707
6,r/relationships\nTITLE,LE (10/2/13) * 18 pages<|endoftext|>,"3 is epic, absorbing story that will allow yo...",-0.016736,2.740645
7,r/tifu\nTITLE,"REFERENCE: Spitfires, Spec Ops","Creature/titles are amazing, amazing and",-0.667457,2.728873
8,r/tifu\nTIT,Y favorite likes fanduel 03/12,stood out jerky and amazingly well presented ...,1.412043,2.479338
9,r/relations,"hip among her boyfriend, dean, etc. also, in this",hips and put the whole movie to good use in it...,0.547193,2.114477


Looking at the reward mean/median of the generated sequences we observe a significant difference.

In [23]:
print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())

mean:


rewards (before)   -0.144184
rewards (after)     2.385100
dtype: float64


median:


rewards (before)   -0.037348
rewards (after)     2.712311
dtype: float64

## Save model
Finally, we save the model and push it to the Hugging Face for later usage.

In [None]:

## model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=True)
## tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=True)

model.save_pretrained(    "./RLHFmodel/gpt2-tldr-pos", push_to_hub=False)
tokenizer.save_pretrained("./RLHFmodel/gpt2-tldr-pos", push_to_hub=False)


In [None]:

print(tqdm)
