# Tune T5 Paraphrase model to generate better Jeopardy question prompts

## Setup experiment

### Import dependencies

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import wandb
import time
import os
from tqdm import tqdm
import numpy as np

import re
import hashlib
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from datasets import load_dataset, load_metric, concatenate_datasets
from trl.t5 import T5HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from trl.core import pad_to_size

### Configuration

In [None]:
config = {
    "lm_name": "Vamsi/T5_Paraphrase_Paws",
    "ref_lm_name": "Vamsi/T5_Paraphrase_Paws",
    "cls_model_name": "vblagoje/bert-base-searchqa",
    "tk_name": "t5-base",
    "steps": 25600,
    "batch_size": 4,
    "forward_batch_size": 4,
    "ppo_epochs": 4,   
    "txt_in_len": 5,
    "txt_out_len": 128,
    "lr": 1.41e-5,
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
}

You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper ["Fine-Tuning Language Models from Human Preferences"](
https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models.

### Initialize W&B logger
We use `wandb`to log all the metrics during training.

In [None]:
wandb.init(name='run-42', project='t5-aqa', config=config)

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Offline run mode, not syncing to the cloud.
[34m[1mwandb[0m: W&B syncing is set to `offline` in this directory.  Run `wandb online` to enable cloud syncing.


## Load data and models

### Load SearchQA and convert it to SQuAD format
Description here

In [None]:
def strip_html(txt):
  cleanr = re.compile("<.*?>")
  cleantext = re.sub(cleanr, " ", txt)
  return cleantext

def remove_special_chars(txt):
  pat = r'[^a-zA-z0-9.,!?/:;\"\'\s]'
  result = re.sub(pat, "", txt)
  return re.sub("\n", "", result)


def clean(txt):
  return strip_html(remove_special_chars(txt))

def convertSearchQAExampleToSquadExample(example):
  snippets = example["search_results"]["snippets"][:10]
  snippets = [x for x in snippets if x != None]
  context = "".join(snippets)
  answers = {}
  answer_for_match = ' ' + re.escape(example["answer"]) + ' '
  id = hashlib.new("sha1", example["question"].encode())
  if re.search(answer_for_match, context):
    matches = re.finditer(answer_for_match, context)
    answer_start = [pos.start() + 1 for pos in matches]
    answers = {
          'answer_start': answer_start,
          'text': [example["answer"]] * len(answer_start)
      }

  return {"id": id.hexdigest(),
          "title": example["question"],
          "question": example["question"],
          "answers": answers,
          "context": clean(context)}

## Load SearchQA and convert it to SQuAD

In [None]:
# Load SearchQA
search_qa = load_dataset("search_qa", "train_test_val")

#clean up and convert to SQuAD format
squad_qa = search_qa.map(convertSearchQAExampleToSquadExample, remove_columns=search_qa["train"].column_names)

#filter no answer questions
squad_qa = squad_qa.filter(lambda example: example["answers"]["answer_start"] is not None)

Reusing dataset search_qa (/Users/vblagoje/.cache/huggingface/datasets/search_qa/train_test_val/1.0.0/a2a9f2281af3826aaca532a2214573f11c1979499ac14b5639c7f02ac3ff0c63)
Loading cached processed dataset at /Users/vblagoje/.cache/huggingface/datasets/search_qa/train_test_val/1.0.0/a2a9f2281af3826aaca532a2214573f11c1979499ac14b5639c7f02ac3ff0c63/cache-1324614813ad4b0d.arrow
Loading cached processed dataset at /Users/vblagoje/.cache/huggingface/datasets/search_qa/train_test_val/1.0.0/a2a9f2281af3826aaca532a2214573f11c1979499ac14b5639c7f02ac3ff0c63/cache-73b41b39d9de9e01.arrow
Loading cached processed dataset at /Users/vblagoje/.cache/huggingface/datasets/search_qa/train_test_val/1.0.0/a2a9f2281af3826aaca532a2214573f11c1979499ac14b5639c7f02ac3ff0c63/cache-bce669f6e324a00c.arrow
Loading cached processed dataset at /Users/vblagoje/.cache/huggingface/datasets/search_qa/train_test_val/1.0.0/a2a9f2281af3826aaca532a2214573f11c1979499ac14b5639c7f02ac3ff0c63/cache-84d3f8dffcef1123.arrow
Loading cach

### Show an example entry item in the dataset

In [None]:
squad_qa["train"][6]

{'answers': {'answer_start': [18], 'text': ['Jesse James']},
 'context': "In the Wild West, Jesse James was legendary  a Robin Hoodlike figure who the public loved and lawmakers hated. The outlaw's notorious bank robbing spree...Jesse James, one of America's most notorious outlaws, is shot to death by Robert Ford, a member of his gang who hoped to collect the bounty on Jesse's head.",
 'id': 'b64e19f081a8b60c3a2b6742bd66a7275d46b5a3',
 'question': 'Outlaw: "Murdered by a traitor and a coward whose name is not worthy to appear here"',
 'title': 'Outlaw: "Murdered by a traitor and a coward whose name is not worthy to appear here"'}

### Load QA pipeline


In [None]:
qa_model = AutoModelForQuestionAnswering.from_pretrained(config["cls_model_name"])
q_tokenizer = AutoTokenizer.from_pretrained(config["cls_model_name"])
#qa = pipeline("question-answering", model=qa_model, tokenizer=q_tokenizer)
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=q_tokenizer, device=-1)

The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model.

In [None]:
training_example_id = 6
context = squad_qa["train"][training_example_id]["context"]
q = squad_qa["train"][training_example_id]["question"]
pipeline_answer = qa_pipeline(question=q, context=context)
pipeline_answer


{'score': 0.8468492031097412, 'start': 18, 'end': 29, 'answer': 'Jesse James'}

The resulting reward signal:

In [None]:
squad = load_metric("squad")

def f1_score(prediction, reference):

  predictions = [{'prediction_text': prediction, 'id': '1'}]
  references = [{'answers':  reference, 'id': '1'}]
  r = squad.compute(predictions=predictions, references=references)
  return r.get("f1", 0.0)

In [None]:
f1_score(pipeline_answer["answer"], squad_qa["train"][training_example_id]["answers"])

100.0

### Load pre-trained language models

We load the model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model.

In [None]:
paraphrase_model = T5HeadWithValueModel.from_pretrained(config['lm_name'])
paraphrase_model_ref = T5HeadWithValueModel.from_pretrained(config['ref_lm_name'])
tokenizer = AutoTokenizer.from_pretrained(config['tk_name'])

Some weights of T5HeadWithValueModel were not initialized from the model checkpoint at Vamsi/T5_Paraphrase_Paws and are newly initialized: ['v_head.state_representation.weight', 'v_head.state_representation.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5HeadWithValueModel were not initialized from the model checkpoint at Vamsi/T5_Paraphrase_Paws and are newly initialized: ['v_head.state_representation.weight', 'v_head.state_representation.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Watch model with wandb
This wandb magic logs the gradients and weights of the model during training.

In [None]:
wandb.watch(paraphrase_model, log='all')

[<wandb.wandb_torch.TorchGraph at 0x7fe7d4c4e4f0>]

### Move models to GPU

If `cuda` is available move the computations to the GPU.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
_ = paraphrase_model.to(device)
_ = paraphrase_model_ref.to(device)

### Tokenize questions

We tokenize all IMDB in advance to avoid tokenizing twice. In the first step we encode the queries and slice the first `txt_in_len` tokens. In a second step we decode these tokens back to text for later display.

In [None]:
squad_qa_encoded = squad_qa["train"].map(lambda x: tokenizer("paraphrase: " + x['question'] + " </s>", padding="max_length", truncation=True), batched=False)
squad_qa_encoded.set_format(type='torch', columns=['input_ids', 'attention_mask'], device=device)

Loading cached processed dataset at /Users/vblagoje/.cache/huggingface/datasets/search_qa/train_test_val/1.0.0/a2a9f2281af3826aaca532a2214573f11c1979499ac14b5639c7f02ac3ff0c63/cache-487c7b47ce1067e4.arrow


## Optimize model

**Steps**

The training loop consists of the following steps:
1. Get a batch of queries
2. Get the query responses from the policy
3. Join query and responses and tokenize for BERT analysis
4. Get sentiments for query/responses from BERT
5. Optimize policy with PPO using the (query, response, reward) triplet
6. Log all the training statistics

**Forward batching**

Since the models can be fairly big and we want to rollout large PPO batches this can lead to out-of-memory errors when doing the forward passes for text generation and sentiment analysis. We introduce the parameter `forward_batch_size` to split the forward passes into smaller batches. Although this hurts performance a little this is neglectible compared to the computations of the backward passes when optimizing the model. The same parameter is used in the `PPOTrainer` when doing forward passes. The `batch_size` should multiple of `forward_batch_size`.

**Training time**

This step takes **~2h** on a P6000 GPU with the above specified settings.

In [None]:
fbs = config['forward_batch_size']
train_ds = squad_qa_encoded
pop_size = range(len(train_ds))

#for epoch in tqdm(range(int(np.ceil(config["steps"]/config['batch_size'])))):
torch.cuda.empty_cache()
logs = dict()
game_data = dict()
timing = dict()
t0 = time.time()

#### get a batch of questions from the dataset
indices = np.random.choice(pop_size, size=config['batch_size'], replace=False)
sample = train_ds.select(indices=indices)

game_data['question'] = sample['question']
question_tensors = sample['input_ids'] ##stacked automatically
question_masks = sample['attention_mask'] ##stacked automatically

#### reformulate questions
t = time.time()
max_response_length = 0
total_length = config['txt_in_len']+config['txt_out_len']
response_tensors = []
game_data['question_reformulated'] = []
for i in range(int(config['batch_size']/fbs)):
    qt = question_tensors[i*fbs:(i+1)*fbs]
    qm = question_masks[i*fbs:(i+1)*fbs]
    response  = respond_to_batch(paraphrase_model, qt, txt_len=config['txt_out_len'], **{"attention_mask":qm})
    response_tensors.append(response)
    if response.shape[1] > max_response_length:
        max_response_length = response.shape[1]

pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
response_tensors = [pad_to_size(t,
                                size=max_response_length,
                                padding=pad_id) for t in response_tensors]

response_tensors = torch.cat(response_tensors)
for i in range(response_tensors.shape[0]):
    game_data['question_reformulated'].append(tokenizer.decode(response_tensors[i].squeeze(), skip_special_tokens=True,clean_up_tokenization_spaces=True))

timing['time/get_response'] = time.time()-t

#### send reformulated questions to QA system
t = time.time()
qa_results = qa_pipeline(question=game_data["question_reformulated"], context=sample["context"])
timing['time/build_input_sentiment'] = time.time()-t

#### get f1 score for reformulated questions
t = time.time()
rewards = []
for i in range(len(qa_results)):
     score = f1_score(qa_results[i]["answer"], sample["answers"][i])
     rewards.append(torch.tensor(score, dtype=torch.float32, device=device).unsqueeze(-1))
rewards = torch.cat(rewards)
timing['time/get_sentiment_preds'] = time.time()-t

In [None]:
#### Run PPO training
t = time.time()
ppo_trainer = PPOTrainer(paraphrase_model, paraphrase_model_ref, **config)
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
timing['time/optimization'] = time.time()-t
