# Pre-requisites

Upload the train, dev and test files as generated by the DPR script:
- `parrot-qa-ctx-train.json`
- `parrot-qa-ctx-dev.json`
- `parrot-qa-ctx-test.json`



In [1]:
!pip install transformers datasets sentencepiece rouge_score

Collecting rouge_score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Installing collected packages: rouge-score
Successfully installed rouge-score-0.0.4


# Step 2: UnifiedQA Fine-tuning

In [2]:
TOKENIZER_BATCH_SIZE = 16
TRAIN_BATCH_SIZE = 1
EVAL_BATCH_SIZE = 1

MODEL_NAME = 'allenai/unifiedqa-t5-small'
DEVICE = 'cuda'

### Reformat dataset

In [3]:
import json
from datasets import Dataset


def create_dataset(file_path):
    with open(file_path) as fp:
        dataset = json.load(fp)
    
    q, a, c = [], [], []
    for item in dataset:
        q.append(item['question'])
        a.append(item['answer'])
        c.append(' '.join(item['contexts']))
    
    data = {'question': q, 'answer': a, 'context': c}
    return Dataset.from_dict(data)


train = create_dataset('parrot-qa-ctx-train.json')
dev = create_dataset('parrot-qa-ctx-dev.json')
test = create_dataset('parrot-qa-ctx-test.json')

len(train), len(dev), len(test)

(1811, 226, 227)

### Load and perform tokenization

In [4]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)


In [5]:
def tokenize_all(samples):
    q, c, a = samples['question'], samples['context'], samples['answer']
    qc = [f'{qval} \\n {cval}' for (qval, cval) in zip(q, c)]
    inp = tokenizer(qc, padding=True, truncation=True, max_length=1024)
    outp = tokenizer(a, padding=True, truncation=True, max_length=1024)
    return {
        'input_ids': inp.input_ids,
        'attention_mask': inp.attention_mask,
        'labels': outp.input_ids,
        'decoder_attention_mask': outp.attention_mask
    }


train = train.map(tokenize_all, batched=True, batch_size=TOKENIZER_BATCH_SIZE)
dev = dev.map(tokenize_all, batched=True, batch_size=TOKENIZER_BATCH_SIZE)
test = test.map(tokenize_all, batched=True, batch_size=TOKENIZER_BATCH_SIZE)

  0%|          | 0/114 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

### Train model

In [6]:
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)


In [7]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments


args = Seq2SeqTrainingArguments(
    'output',
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    load_best_model_at_end=True,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    predict_with_generate=True,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train,
    eval_dataset=dev,
    tokenizer=tokenizer,
)

trainer.train()

The following columns in the training set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: question, context, answer. If question, context, answer are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1811
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 5433


Epoch,Training Loss,Validation Loss
1,1.2326,1.199466
2,1.1148,1.180304
3,1.219,1.175479


The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: question, context, answer. If question, context, answer are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 226
  Batch size = 1
Saving model checkpoint to output/checkpoint-1811
Configuration saved in output/checkpoint-1811/config.json
Model weights saved in output/checkpoint-1811/pytorch_model.bin
tokenizer config file saved in output/checkpoint-1811/tokenizer_config.json
Special tokens file saved in output/checkpoint-1811/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: question, context, answer. If question, context, answer are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Eva

TrainOutput(global_step=5433, training_loss=1.2526715378855202, metrics={'train_runtime': 573.5401, 'train_samples_per_second': 9.473, 'train_steps_per_second': 9.473, 'total_flos': 1447323549106176.0, 'train_loss': 1.2526715378855202, 'epoch': 3.0})

### Perform inference

In [8]:
pred = trainer.predict(dev)
pred_answers = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)

The following columns in the test set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: question, context, answer. If question, context, answer are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 226
  Batch size = 1


In [9]:
from datasets import load_metric
rouge = load_metric("rouge")

rouge.compute(predictions=pred_answers, references=dev['answer'])

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

{'rouge1': AggregateScore(low=Score(precision=0.25129996218426737, recall=0.08322148467305589, fmeasure=0.1101340502110918), mid=Score(precision=0.2793211495254704, recall=0.09588119246654433, fmeasure=0.12296142355828275), high=Score(precision=0.307842769036158, recall=0.11100298714239018, fmeasure=0.13711446692162765)),
 'rouge2': AggregateScore(low=Score(precision=0.03609320252019367, recall=0.010145783703056162, fmeasure=0.014480302002023224), mid=Score(precision=0.05012776063661017, recall=0.016336081255035908, fmeasure=0.021676614530712103), high=Score(precision=0.06826763391475779, recall=0.025170557845276127, fmeasure=0.03138202646301603)),
 'rougeL': AggregateScore(low=Score(precision=0.211136240984058, recall=0.07288789093890122, fmeasure=0.09555707183678025), mid=Score(precision=0.23505256745105793, recall=0.08521608477200736, fmeasure=0.10734579785885978), high=Score(precision=0.26083928889051555, recall=0.10041269427656588, fmeasure=0.12037235637932528)),
 'rougeLsum': Agg