## Setup experiment

### Import dependencies

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

## %pip install transformers trl wandb


In [3]:

## Python >= 3.8  


In [4]:

# !pip install transformers
## !pip install wandb
# !pip install pandas
#!pip install datasets
#!pip install accelerate
# !pip install tyro
#!pip install trl

In [5]:
import torch
from tqdm import tqdm
import pandas as pd
import wandb
import os

tqdm.pandas()

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


### Configuration

In [6]:

config = PPOConfig(
    model_name    = "lvwerra/gpt2-imdb", 
    learning_rate = 1.41e-5,
    ## log_with      = "wandb",
)

sent_kwargs = {
         "return_all_scores": True, 
         "function_to_apply": "none", 
         "batch_size": 16
}


In [7]:

## wandb.init()

wandb.init(mode="disabled") 
os.environ['WANDB_DISABLED'] = 'true'



## Load data and models

### Load IMDB dataset
The IMDB dataset contains 50k movie review annotated with "positive"/"negative" feedback indicating the sentiment.  We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the `LengthSampler`.

In [10]:
def build_dataset(config, dataset_name="turne292/tldr-sentiment", input_min_text_length=2, input_max_text_length=8):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    tokenizer           = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # load imdb with datasets
    
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"prompt": "label"})
    ds = ds.filter(lambda x: len(x["label"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode( sample["label"]    )[: input_size()]
        sample["query"]     = tokenizer.decode( sample["input_ids"] )
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds
    

In [11]:

dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
    

Found cached dataset csv (C:/Users/danda/.cache/huggingface/datasets/turne292___csv/turne292--tldr-sentiment-ab0e8d3f664f5065/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


Filter:   0%|          | 0/116722 [00:00<?, ? examples/s]

Map:   0%|          | 0/116528 [00:00<?, ? examples/s]

### Load pre-trained GPT2 language models

We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model.

In [12]:
model     = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token


  torch.utils._pytree._register_pytree_node(


Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

### Initialize PPOTrainer
The `PPOTrainer` takes care of device placement and optimization later on:

In [13]:

ppo_trainer = PPOTrainer(
                   config, 
                   model, 
                   ref_model, 
                   tokenizer, 
                   dataset       = dataset, 
                   data_collator = collator
)


### Load BERT classifier
We load a BERT classifier fine-tuned on the IMDB dataset.

In [14]:

device = ppo_trainer.accelerator.device

if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
    
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

Downloading config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model.

In [15]:

text = "this movie was really bad!!"

sentiment_pipe(text, **sent_kwargs)




[[{'label': 'NEGATIVE', 'score': 2.3350486755371094},
  {'label': 'POSITIVE', 'score': -2.726576328277588}]]

In [16]:

text = "this movie was really good!!"
sentiment_pipe(text, **sent_kwargs)


[[{'label': 'NEGATIVE', 'score': -2.294790029525757},
  {'label': 'POSITIVE', 'score': 2.557039737701416}]]

### Generation settings
For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length.

In [17]:

gen_kwargs = {
         "min_length":   -1, 
         "top_k":       0.0, 
         "top_p":       1.0, 
         "do_sample":  True, 
         "pad_token_id": tokenizer.eos_token_id
}


## Optimize model

### Training loop

The training loop consists of the following main steps:
1. Get the query responses from the policy network (GPT-2)
2. Get sentiments for query/responses from BERT
3. Optimize policy with PPO using the (query, response, reward) triplet

**Training time**

This step takes **~2h** on a V100 GPU with the above specified settings.

In [18]:

output_min_length     = 4
output_max_length     = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)



In [19]:

generation_kwargs = {
    "min_length":     -1,
    "top_k":         0.0,
    "top_p":         1.0,
    "do_sample":    True,
    "pad_token_id": tokenizer.eos_token_id,
}


In [20]:

## ppo_trainer.config.steps = 100    ## 20,000
ppo_trainer.config.steps


20000

In [21]:
'''
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    print(batch)
    input()
'''

'\nfor epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n    print(batch)\n    input()\n'

In [23]:


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    print(epoch)

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len                             = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response                            = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append( response.squeeze()[-gen_len:] )
    batch["response"] = [ tokenizer.decode(r.squeeze()) for r in response_tensors ]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [ torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(
                     query_tensors, 
                     response_tensors, 
                     rewards
    )
    ppo_trainer.log_stats(stats, batch, rewards)
    

0it [00:00, ?it/s]

0


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
1it [00:19, 19.46s/it]

1


2it [00:36, 18.09s/it]

2


3it [00:53, 17.74s/it]

3


4it [01:11, 17.84s/it]

4


5it [01:29, 17.89s/it]

5


6it [01:48, 18.07s/it]

6


7it [02:07, 18.39s/it]

7


8it [02:25, 18.41s/it]

8


9it [02:44, 18.47s/it]

9


10it [03:02, 18.28s/it]

10


11it [03:20, 18.28s/it]

11


12it [03:39, 18.34s/it]

12


13it [03:57, 18.42s/it]

13


14it [04:16, 18.57s/it]

14


15it [04:35, 18.61s/it]

15


16it [04:54, 18.84s/it]

16


17it [05:15, 19.30s/it]

17


18it [05:36, 19.95s/it]

18


19it [05:57, 20.19s/it]

19


20it [06:17, 20.08s/it]

20


21it [06:35, 19.70s/it]

21


22it [06:55, 19.71s/it]

22


23it [07:15, 19.75s/it]

23


24it [07:34, 19.69s/it]

24


25it [07:54, 19.62s/it]

25


26it [08:14, 19.90s/it]

26


27it [08:33, 19.35s/it]

27


28it [08:52, 19.49s/it]

28


29it [09:11, 19.32s/it]

29


30it [09:31, 19.32s/it]

30


31it [09:50, 19.24s/it]

31


32it [10:10, 19.59s/it]

32


33it [10:29, 19.54s/it]

33


34it [10:51, 20.04s/it]

34


35it [11:17, 21.87s/it]

35


36it [11:37, 21.23s/it]

36


37it [11:56, 20.80s/it]

37


38it [12:15, 20.21s/it]

38


39it [12:35, 20.02s/it]

39


40it [12:54, 19.88s/it]

40


41it [13:14, 19.77s/it]

41


42it [13:33, 19.46s/it]

42


43it [13:52, 19.57s/it]

43


44it [14:10, 19.09s/it]

44


45it [14:30, 19.24s/it]

45


46it [14:49, 19.18s/it]

46


47it [15:09, 19.39s/it]

47


48it [15:29, 19.66s/it]

48


49it [15:47, 19.20s/it]

49


50it [16:05, 18.89s/it]

50


51it [16:25, 18.98s/it]

51


52it [16:44, 19.18s/it]

52


53it [17:03, 19.13s/it]

53


54it [17:23, 19.28s/it]

54


55it [17:43, 19.58s/it]

55


56it [18:03, 19.73s/it]

56


57it [18:22, 19.56s/it]

57


58it [18:43, 19.71s/it]

58


59it [19:01, 19.47s/it]

59


60it [19:20, 19.22s/it]

60


61it [19:39, 19.28s/it]

61


62it [19:59, 19.30s/it]

62


63it [20:18, 19.15s/it]

63


64it [20:37, 19.34s/it]

64


65it [20:57, 19.40s/it]

65


66it [21:16, 19.31s/it]

66


67it [21:36, 19.37s/it]

67


68it [21:55, 19.40s/it]

68


69it [22:14, 19.34s/it]

69


70it [22:34, 19.36s/it]

70


71it [22:53, 19.49s/it]

71


72it [23:13, 19.63s/it]

72


73it [23:34, 19.93s/it]

73


74it [23:54, 19.96s/it]

74


75it [24:13, 19.74s/it]

75


76it [24:32, 19.55s/it]

76


77it [24:53, 19.80s/it]

77


78it [25:13, 19.94s/it]

78


79it [25:33, 19.81s/it]

79


80it [25:51, 19.51s/it]

80


81it [26:11, 19.55s/it]

81


82it [26:30, 19.39s/it]

82


83it [26:51, 19.97s/it]

83


84it [27:11, 19.80s/it]

84


85it [27:30, 19.76s/it]

85


86it [27:51, 19.88s/it]

86


87it [28:10, 19.84s/it]

87


88it [28:30, 19.83s/it]

88


89it [28:49, 19.52s/it]

89


90it [29:09, 19.72s/it]

90


91it [29:29, 19.86s/it]

91


92it [29:49, 19.79s/it]

92


93it [30:09, 19.85s/it]

93


94it [30:29, 19.87s/it]

94


95it [30:48, 19.71s/it]

95


96it [31:07, 19.59s/it]

96


97it [31:27, 19.64s/it]

97


98it [31:47, 19.63s/it]

98


99it [32:06, 19.59s/it]

99


100it [32:27, 19.92s/it]

100


101it [32:46, 19.64s/it]

101


102it [33:06, 19.69s/it]

102


103it [33:23, 19.07s/it]

103


104it [33:42, 18.82s/it]

104


105it [34:01, 18.98s/it]

105


106it [34:20, 18.91s/it]

106


107it [34:39, 19.00s/it]

107


108it [34:58, 19.01s/it]

108


109it [35:17, 19.11s/it]

109


110it [35:37, 19.30s/it]

110


111it [35:58, 19.78s/it]

111


112it [36:18, 19.84s/it]

112


113it [36:38, 19.88s/it]

113


114it [36:58, 19.91s/it]

114


115it [37:18, 20.05s/it]

115


116it [37:37, 19.64s/it]

116


117it [37:57, 19.68s/it]

117


118it [38:16, 19.47s/it]

118


119it [38:35, 19.39s/it]

119


120it [38:55, 19.63s/it]

120


121it [39:14, 19.49s/it]

121


122it [39:33, 19.14s/it]

122


123it [39:53, 19.41s/it]

123


124it [40:12, 19.34s/it]

124


125it [40:31, 19.42s/it]

125


126it [40:52, 19.77s/it]

126


127it [41:12, 19.86s/it]

127


128it [41:32, 19.90s/it]

128


129it [41:51, 19.52s/it]

129


130it [42:11, 19.60s/it]

130


131it [42:31, 19.89s/it]

131


132it [42:52, 20.08s/it]

132


133it [43:11, 19.75s/it]

133


134it [43:30, 19.63s/it]

134


135it [43:49, 19.42s/it]

135


136it [44:09, 19.59s/it]

136


137it [44:28, 19.51s/it]

137


138it [44:46, 18.93s/it]

138


139it [45:06, 19.21s/it]

139


140it [45:25, 19.13s/it]

140


141it [45:44, 19.10s/it]

141


142it [46:04, 19.59s/it]

142


143it [46:24, 19.62s/it]

143


144it [46:44, 19.66s/it]

144


145it [47:03, 19.65s/it]

145


146it [47:23, 19.59s/it]

146


147it [47:43, 19.68s/it]

147


148it [48:03, 19.78s/it]

148


149it [48:22, 19.58s/it]

149


150it [48:41, 19.49s/it]

150


151it [49:00, 19.41s/it]

151


152it [49:19, 19.30s/it]

152


153it [49:39, 19.31s/it]

153


154it [49:58, 19.44s/it]

154


155it [50:17, 19.15s/it]

155


156it [50:36, 19.20s/it]

156


157it [50:56, 19.38s/it]

157


158it [51:16, 19.40s/it]

158


159it [51:35, 19.40s/it]

159


160it [51:55, 19.70s/it]

160


161it [52:14, 19.36s/it]

161


162it [52:34, 19.56s/it]

162


163it [52:53, 19.57s/it]

163


164it [53:14, 19.82s/it]

164


165it [53:33, 19.73s/it]

165


166it [53:52, 19.48s/it]

166


167it [54:11, 19.36s/it]

167


168it [54:31, 19.44s/it]

168


169it [54:51, 19.47s/it]

169


170it [55:09, 19.27s/it]

170


171it [55:28, 19.03s/it]

171


172it [55:49, 19.62s/it]

172


173it [56:08, 19.58s/it]

173


174it [56:27, 19.32s/it]

174


175it [56:46, 19.20s/it]

175


176it [57:05, 19.12s/it]

176


177it [57:25, 19.53s/it]

177


178it [57:44, 19.24s/it]

178


179it [58:04, 19.41s/it]

179


180it [58:22, 18.98s/it]

180


181it [58:43, 19.55s/it]

181


182it [59:02, 19.54s/it]

182


183it [59:21, 19.34s/it]

183


184it [59:40, 19.17s/it]

184


185it [59:58, 18.99s/it]

185


186it [1:00:17, 18.86s/it]

186


187it [1:00:37, 19.16s/it]

187


188it [1:00:57, 19.52s/it]

188


189it [1:01:16, 19.31s/it]

189


190it [1:01:35, 19.17s/it]

190


191it [1:01:54, 19.29s/it]

191


192it [1:02:13, 19.19s/it]

192


193it [1:02:33, 19.25s/it]

193


194it [1:02:52, 19.18s/it]

194


195it [1:03:11, 19.34s/it]

195


196it [1:03:31, 19.44s/it]

196


197it [1:03:50, 19.24s/it]

197


198it [1:04:09, 19.27s/it]

198


199it [1:04:29, 19.38s/it]

199


200it [1:04:48, 19.27s/it]

200


201it [1:05:08, 19.48s/it]

201


202it [1:05:28, 19.57s/it]

202


203it [1:05:47, 19.46s/it]

203


204it [1:06:06, 19.43s/it]

204


205it [1:06:26, 19.54s/it]

205


206it [1:06:45, 19.39s/it]

206


207it [1:07:05, 19.69s/it]

207


208it [1:07:24, 19.49s/it]

208


209it [1:07:43, 19.29s/it]

209


210it [1:08:03, 19.55s/it]

210


211it [1:08:22, 19.25s/it]

211


212it [1:08:40, 18.93s/it]

212


213it [1:08:59, 18.97s/it]

213


214it [1:09:17, 18.53s/it]

214


215it [1:09:36, 18.70s/it]

215


216it [1:09:55, 18.84s/it]

216


217it [1:10:14, 18.89s/it]

217


218it [1:10:33, 19.04s/it]

218


219it [1:10:54, 19.59s/it]

219


220it [1:11:14, 19.70s/it]

220


221it [1:11:35, 19.98s/it]

221


222it [1:11:54, 19.68s/it]

222


223it [1:12:14, 19.69s/it]

223


224it [1:12:32, 19.44s/it]

224


225it [1:12:53, 19.77s/it]

225


226it [1:13:13, 19.74s/it]

226


227it [1:13:33, 19.97s/it]

227


228it [1:13:52, 19.77s/it]

228


229it [1:14:12, 19.84s/it]

229


230it [1:14:31, 19.61s/it]

230


231it [1:14:51, 19.48s/it]

231


232it [1:15:11, 19.71s/it]

232


233it [1:15:29, 19.36s/it]

233


234it [1:15:50, 19.80s/it]

234


235it [1:16:09, 19.51s/it]

235


236it [1:16:28, 19.35s/it]

236


237it [1:16:47, 19.36s/it]

237


238it [1:17:07, 19.43s/it]

238


239it [1:17:25, 19.06s/it]

239


240it [1:17:45, 19.14s/it]

240


241it [1:18:04, 19.35s/it]

241


242it [1:18:24, 19.48s/it]

242


243it [1:18:43, 19.41s/it]

243


244it [1:19:02, 19.15s/it]

244


245it [1:19:22, 19.40s/it]

245


246it [1:19:41, 19.23s/it]

246


247it [1:19:59, 18.92s/it]

247


248it [1:20:18, 18.85s/it]

248


249it [1:20:36, 18.79s/it]

249


250it [1:20:56, 18.95s/it]

250


251it [1:21:16, 19.32s/it]

251


252it [1:21:35, 19.30s/it]

252


253it [1:21:56, 19.70s/it]

253


254it [1:22:14, 19.26s/it]

254


255it [1:22:35, 19.71s/it]

255


256it [1:22:55, 19.80s/it]

256


257it [1:23:16, 20.20s/it]

257


258it [1:23:38, 20.65s/it]

258


259it [1:23:57, 20.31s/it]

259


260it [1:24:18, 20.43s/it]

260


261it [1:24:38, 20.50s/it]

261


262it [1:24:59, 20.60s/it]

262


263it [1:25:19, 20.24s/it]

263


264it [1:25:38, 20.02s/it]

264


265it [1:25:57, 19.72s/it]

265


266it [1:26:18, 19.95s/it]

266


267it [1:26:38, 20.18s/it]

267


268it [1:26:58, 20.02s/it]

268


269it [1:27:18, 19.95s/it]

269


270it [1:27:37, 19.72s/it]

270


271it [1:27:57, 19.91s/it]

271


272it [1:28:18, 20.21s/it]

272


273it [1:28:38, 20.02s/it]

273


274it [1:28:57, 19.67s/it]

274


275it [1:29:17, 19.86s/it]

275


276it [1:29:37, 19.90s/it]

276


277it [1:29:57, 19.86s/it]

277


278it [1:30:16, 19.54s/it]

278


279it [1:30:36, 19.69s/it]

279


280it [1:30:55, 19.69s/it]

280


281it [1:31:15, 19.56s/it]

281


282it [1:31:35, 19.92s/it]

282


283it [1:31:56, 20.01s/it]

283


284it [1:32:16, 20.08s/it]

284


285it [1:32:36, 20.11s/it]

285


286it [1:32:55, 19.83s/it]

286


287it [1:33:14, 19.45s/it]

287


288it [1:33:34, 19.56s/it]

288


289it [1:33:53, 19.52s/it]

289


290it [1:34:13, 19.60s/it]

290


291it [1:34:33, 19.83s/it]

291


292it [1:34:53, 19.79s/it]

292


293it [1:35:13, 19.93s/it]

293


294it [1:35:33, 19.91s/it]

294


295it [1:35:52, 19.71s/it]

295


296it [1:36:12, 19.60s/it]

296


297it [1:36:31, 19.65s/it]

297


298it [1:36:50, 19.45s/it]

298


299it [1:37:10, 19.63s/it]

299


300it [1:37:31, 19.82s/it]

300


301it [1:37:49, 19.46s/it]

301


302it [1:38:08, 19.35s/it]

302


303it [1:38:28, 19.56s/it]

303


304it [1:38:47, 19.24s/it]

304


305it [1:39:06, 19.28s/it]

305


306it [1:39:24, 18.93s/it]

306


307it [1:39:45, 19.45s/it]

307


308it [1:40:05, 19.50s/it]

308


309it [1:40:24, 19.34s/it]

309


310it [1:40:44, 19.53s/it]

310


311it [1:41:03, 19.61s/it]

311


312it [1:41:24, 19.87s/it]

312


313it [1:41:45, 20.39s/it]

313


314it [1:42:04, 19.96s/it]

314


315it [1:42:24, 20.00s/it]

315


316it [1:42:44, 19.86s/it]

316


317it [1:43:03, 19.74s/it]

317


318it [1:43:23, 19.79s/it]

318


319it [1:43:43, 19.76s/it]

319


320it [1:44:03, 19.78s/it]

320


321it [1:44:22, 19.54s/it]

321


322it [1:44:43, 20.01s/it]

322


323it [1:45:03, 19.98s/it]

323


324it [1:45:22, 19.58s/it]

324


325it [1:45:41, 19.61s/it]

325


326it [1:46:00, 19.40s/it]

326


327it [1:46:19, 19.21s/it]

327


328it [1:46:38, 19.08s/it]

328


329it [1:46:56, 18.95s/it]

329


330it [1:47:15, 18.73s/it]

330


331it [1:47:34, 18.80s/it]

331


332it [1:47:54, 19.32s/it]

332


333it [1:48:13, 19.29s/it]

333


334it [1:48:32, 19.21s/it]

334


335it [1:48:53, 19.67s/it]

335


336it [1:49:12, 19.46s/it]

336


337it [1:49:32, 19.56s/it]

337


338it [1:49:50, 19.21s/it]

338


339it [1:50:10, 19.29s/it]

339


340it [1:50:29, 19.31s/it]

340


341it [1:50:49, 19.53s/it]

341


342it [1:51:08, 19.29s/it]

342


343it [1:51:27, 19.28s/it]

343


344it [1:51:47, 19.36s/it]

344


345it [1:52:07, 19.63s/it]

345


346it [1:52:26, 19.59s/it]

346


347it [1:52:45, 19.44s/it]

347


348it [1:53:05, 19.43s/it]

348


349it [1:53:25, 19.69s/it]

349


350it [1:53:46, 19.91s/it]

350


351it [1:54:05, 19.82s/it]

351


352it [1:54:25, 19.81s/it]

352


353it [1:54:44, 19.52s/it]

353


354it [1:55:03, 19.47s/it]

354


355it [1:55:22, 19.20s/it]

355


356it [1:55:41, 19.22s/it]

356


357it [1:56:00, 19.19s/it]

357


358it [1:56:19, 19.23s/it]

358


359it [1:56:39, 19.43s/it]

359


360it [1:56:59, 19.41s/it]

360


361it [1:57:18, 19.43s/it]

361


362it [1:57:36, 19.03s/it]

362


363it [1:57:55, 18.83s/it]

363


364it [1:58:12, 18.34s/it]

364


365it [1:58:28, 17.83s/it]

365


366it [1:58:47, 17.94s/it]

366


367it [1:59:05, 18.13s/it]

367


368it [1:59:23, 18.14s/it]

368


369it [1:59:41, 18.07s/it]

369


370it [1:59:59, 17.92s/it]

370


371it [2:00:16, 17.82s/it]

371


372it [2:00:35, 18.00s/it]

372


373it [2:00:52, 17.62s/it]

373


374it [2:01:09, 17.61s/it]

374


375it [2:01:27, 17.67s/it]

375


376it [2:01:46, 17.95s/it]

376


377it [2:02:06, 18.63s/it]

377


378it [2:02:26, 19.03s/it]

378


379it [2:02:44, 18.85s/it]

379


380it [2:03:01, 18.27s/it]

380


381it [2:03:19, 18.01s/it]

381


382it [2:03:36, 17.90s/it]

382


383it [2:03:53, 17.69s/it]

383


384it [2:04:11, 17.72s/it]

384


385it [2:04:29, 17.79s/it]

385


386it [2:04:46, 17.61s/it]

386


387it [2:05:04, 17.69s/it]

387


388it [2:05:23, 17.94s/it]

388


389it [2:05:40, 17.86s/it]

389


390it [2:05:58, 17.90s/it]

390


391it [2:06:16, 17.80s/it]

391


392it [2:06:33, 17.57s/it]

392


393it [2:06:52, 17.87s/it]

393


394it [2:07:08, 17.60s/it]

394


395it [2:07:27, 17.78s/it]

395


396it [2:07:45, 17.96s/it]

396


397it [2:08:02, 17.78s/it]

397


398it [2:08:20, 17.75s/it]

398


399it [2:08:38, 17.93s/it]

399


400it [2:08:55, 17.58s/it]

400


401it [2:09:13, 17.58s/it]

401


402it [2:09:30, 17.60s/it]

402


403it [2:09:48, 17.72s/it]

403


404it [2:10:07, 18.10s/it]

404


405it [2:10:25, 17.95s/it]

405


406it [2:10:43, 17.88s/it]

406


407it [2:11:02, 18.16s/it]

407


408it [2:11:19, 18.08s/it]

408


409it [2:11:38, 18.36s/it]

409


410it [2:11:57, 18.43s/it]

410


411it [2:12:16, 18.58s/it]

411


412it [2:12:36, 18.94s/it]

412


413it [2:12:54, 18.85s/it]

413


414it [2:13:14, 19.13s/it]

414


415it [2:13:33, 19.04s/it]

415


416it [2:13:51, 18.62s/it]

416


417it [2:14:08, 18.30s/it]

417


418it [2:14:26, 18.19s/it]

418


419it [2:14:44, 18.08s/it]

419


420it [2:15:03, 18.35s/it]

420


421it [2:15:21, 18.36s/it]

421


422it [2:15:41, 18.61s/it]

422


423it [2:16:00, 18.80s/it]

423


424it [2:16:18, 18.60s/it]

424


425it [2:16:36, 18.50s/it]

425


426it [2:16:56, 18.89s/it]

426


427it [2:17:15, 18.92s/it]

427


428it [2:17:33, 18.64s/it]

428


429it [2:17:51, 18.32s/it]

429


430it [2:18:10, 18.54s/it]

430


431it [2:18:29, 18.77s/it]

431


432it [2:18:49, 19.04s/it]

432


433it [2:19:08, 19.12s/it]

433


434it [2:19:26, 18.92s/it]

434


435it [2:19:44, 18.66s/it]

435


436it [2:20:03, 18.61s/it]

436


437it [2:20:21, 18.56s/it]

437


438it [2:20:40, 18.67s/it]

438


439it [2:20:58, 18.49s/it]

439


440it [2:21:16, 18.40s/it]

440


441it [2:21:35, 18.44s/it]

441


442it [2:21:54, 18.51s/it]

442


443it [2:22:12, 18.53s/it]

443


444it [2:22:32, 18.78s/it]

444


445it [2:22:50, 18.78s/it]

445


446it [2:23:10, 19.08s/it]

446


447it [2:23:28, 18.82s/it]

447


448it [2:23:48, 19.06s/it]

448


449it [2:24:07, 19.11s/it]

449


450it [2:24:24, 18.52s/it]

450


451it [2:24:43, 18.49s/it]

451


452it [2:25:02, 18.59s/it]

452


453it [2:25:21, 18.89s/it]

453


454it [2:25:40, 18.98s/it]

454


455it [2:25:59, 18.82s/it]

455


456it [2:26:18, 18.79s/it]

456


457it [2:26:36, 18.73s/it]

457


458it [2:26:55, 18.75s/it]

458


459it [2:27:14, 18.95s/it]

459


460it [2:27:34, 19.02s/it]

460


461it [2:27:52, 18.78s/it]

461


462it [2:28:10, 18.58s/it]

462


463it [2:28:29, 18.60s/it]

463


464it [2:28:47, 18.49s/it]

464


465it [2:29:06, 18.68s/it]

465


466it [2:29:25, 18.76s/it]

466


467it [2:29:44, 18.73s/it]

467


468it [2:30:03, 19.00s/it]

468


469it [2:30:23, 19.11s/it]

469


470it [2:30:42, 19.17s/it]

470


471it [2:31:01, 19.26s/it]

471


472it [2:31:20, 19.10s/it]

472


473it [2:31:40, 19.27s/it]

473


474it [2:31:59, 19.18s/it]

474


475it [2:32:17, 18.84s/it]

475


476it [2:32:36, 19.10s/it]

476


477it [2:32:56, 19.20s/it]

477


478it [2:33:14, 18.94s/it]

478


479it [2:33:32, 18.62s/it]

479


480it [2:33:51, 18.74s/it]

480


481it [2:34:10, 18.83s/it]

481


482it [2:34:30, 19.02s/it]

482


483it [2:34:47, 18.46s/it]

483


484it [2:35:06, 18.76s/it]

484


485it [2:35:24, 18.40s/it]

485


486it [2:35:45, 19.11s/it]

486


487it [2:36:04, 19.21s/it]

487


488it [2:36:23, 19.14s/it]

488


489it [2:36:40, 18.62s/it]

489


490it [2:36:59, 18.62s/it]

490


491it [2:37:17, 18.55s/it]

491


492it [2:37:34, 18.09s/it]

492


493it [2:37:54, 18.46s/it]

493


494it [2:38:12, 18.33s/it]

494


495it [2:38:29, 18.05s/it]

495


496it [2:38:47, 18.11s/it]

496


497it [2:39:06, 18.18s/it]

497


498it [2:39:25, 18.54s/it]

498


499it [2:39:44, 18.73s/it]

499


500it [2:40:03, 18.59s/it]

500


501it [2:40:21, 18.61s/it]

501


502it [2:40:40, 18.69s/it]

502


503it [2:40:59, 18.64s/it]

503


504it [2:41:16, 18.39s/it]

504


505it [2:41:35, 18.49s/it]

505


506it [2:41:53, 18.28s/it]

506


507it [2:42:11, 18.30s/it]

507


508it [2:42:30, 18.32s/it]

508


509it [2:42:47, 18.18s/it]

509


510it [2:43:06, 18.30s/it]

510


511it [2:43:25, 18.49s/it]

511


512it [2:43:43, 18.23s/it]

512


513it [2:44:01, 18.16s/it]

513


514it [2:44:19, 18.35s/it]

514


515it [2:44:37, 18.02s/it]

515


516it [2:44:54, 17.89s/it]

516


517it [2:45:14, 18.33s/it]

517


518it [2:45:32, 18.45s/it]

518


519it [2:45:51, 18.51s/it]

519


520it [2:46:09, 18.38s/it]

520


521it [2:46:26, 18.03s/it]

521


522it [2:46:45, 18.17s/it]

522


523it [2:47:03, 18.13s/it]

523


523it [2:47:09, 19.18s/it]


KeyboardInterrupt: 

## Model inspection
Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation.

In [24]:
#### get a batch from the dataset
bs                 = 16
game_data          = dict()
dataset.set_format("pandas")
df_batch           = dataset[:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors      = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    output  = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"]  = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,SUB,STITION. Girlsactive,) This is the best,-1.175855,2.260252
1,SUB,SPONDS! YES,is a very well made,0.226873,2.672822
2,SUBREDDIT,""". I was wrong.<br /><br",) was an influenceively amazing piece for me,-1.669313,2.633916
3,SUBRED,"TRISSPOT, movie,"" in which Neil Armstrong lea...",""". A wonderful episode of The Twilight Zone wi...",-0.631368,2.870436
4,SUBREDDIT: r,/newscientific_,/music and inspired by,-0.633002,1.876992
5,SUBREDDIT: r,/adiwub comments user Score: 63,"/ So witty, and totally hilarious by the",-0.560263,2.652614
6,SUBRED,"LETTER...... ADD TO YOUR JOINES BELOW,",", SO RAINDS is a MUST SEE DVD!!!!!!!!",0.03861,2.480814
7,SUBREDDIT,ANNOUNCES THIS FOOTBALL COVER UP EVEN THEN! T...,Maintenance. I really appreciate this movie. ...,-1.145346,2.700906
8,SUBREDDIT:,TWILIGHT is definitely the *worst,"I still remember the 1,200 hits",-2.69496,2.132594
9,SUBRED,""",""4R"" - 300MM)",""" is a great movie with excellent actors",-0.32457,2.826058


Looking at the reward mean/median of the generated sequences we observe a significant difference.

In [25]:
print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())

mean:


rewards (before)   -0.769566
rewards (after)     2.333861
dtype: float64


median:


rewards (before)   -0.632185
rewards (after)     2.497479
dtype: float64

## Save model
Finally, we save the model and push it to the Hugging Face for later usage.

In [26]:

## model.save_pretrained(    "gpt2-imdb-pos-v2", push_to_hub=True)
## tokenizer.save_pretrained("gpt2-imdb-pos-v2", push_to_hub=True)

model.save_pretrained(    "./RLHFmodel/gpt2-tldr-pos", push_to_hub=False)
tokenizer.save_pretrained("./RLHFmodel/gpt2-tldr-pos", push_to_hub=False)


('./RLHFmodel/gpt2-tldr-pos\\tokenizer_config.json',
 './RLHFmodel/gpt2-tldr-pos\\special_tokens_map.json',
 './RLHFmodel/gpt2-tldr-pos\\vocab.json',
 './RLHFmodel/gpt2-tldr-pos\\merges.txt',
 './RLHFmodel/gpt2-tldr-pos\\added_tokens.json',
 './RLHFmodel/gpt2-tldr-pos\\tokenizer.json')

In [None]:

print(tqdm)
