In [1]:
import torch
from prepare_data import create_or_load
from collator import T2TDataCollator
from transformers import AdamW, get_scheduler, Trainer, TrainingArguments, Adafactor
from transformers import T5Tokenizer
from model import T5PromptTuningLM
from transformers.data.processors.squad import SquadV2Processor, squad_convert_examples_to_features

In [2]:
model_name = 't5-small'
n_tokens = 10
batch_size = 28

tokenizer = T5Tokenizer.from_pretrained(model_name)
train_dataset, valid_dataset = create_or_load(tokenizer)

# Run the below cells if you want to train

In [9]:
# if you want to train
class Config:
    # Prompt-tuning
    n_prompt_tokens = 10
    init_from_vocab = True
    # random_range = 0.5
args = Config()

model = T5PromptTuningLM.from_pretrained(
    model_name,
    n_tokens=args.n_prompt_tokens,
    initialize_from_vocab=args.init_from_vocab)

Initializing soft prompt...


In [10]:
# Set up training arguments, optimizers, etc
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if n == "soft_prompt.weight"],
        "lr": 1e-2,
        "scale_parameter": False,
        "relative_step": False,
    }
]
optimizer = Adafactor(optimizer_grouped_parameters)
lr_scheduler = get_scheduler(
    name='cosine',
    num_warmup_steps=0,
    optimizer=optimizer,
    num_training_steps=3,
)

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    per_device_train_batch_size=batch_size,  # batch size per device during training
    per_device_eval_batch_size=batch_size*2,   # batch size for evaluation
    logging_dir='./logs',            # directory for storing logs
    logging_steps=100,
    save_steps=3000,
    report_to='tensorboard',
    prediction_loss_only=True,
    num_train_epochs=3,
)

In [11]:
# Initialize trainer
trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        data_collator=T2TDataCollator(),
        optimizers=(optimizer, lr_scheduler),
    )

In [47]:
# start training
trainer.train()

***** Running training *****
  Num examples = 87599
  Num Epochs = 3
  Instantaneous batch size per device = 28
  Total train batch size (w. parallel, distributed & accumulation) = 28
  Gradient Accumulation steps = 1
  Total optimization steps = 9387


Step,Training Loss


KeyboardInterrupt: 

In [None]:
# start evaluate, save prompt
trainer.evaluate()
model.save_soft_prompt('soft_prompt', filename=f'soft_prompt_{model_name}_{n_tokens}.model')

***** Running Evaluation *****
  Num examples = 10570
  Batch size = 56


# Run the below cells if you want to experiment

In [3]:
# load a new base model with trained soft prompt
model = T5PromptTuningLM.from_pretrained(model_name, 
                                          return_dict=False, 
                                          soft_prompt_path=f'soft_prompt/soft_prompt_t5-small_10.model')

Set soft prompt! (n_tokens: 10)


In [None]:
predictions = []
ans = []
for i in range(1000):
    print('------------------------------------')
    print(i)
    question, context = valid_dataset['question'][i], valid_dataset['context'][i]
    input_ids = tokenizer.encode('question: %s  context: %s' % (question, context), 
                             return_tensors='pt').to(model.device)
    answers = valid_dataset['answers'][i]['text']
    for i in range(len(answers)):
        answers[i] = answers[i].lower().strip()
    ans.append(answers)
    print(f'context: {context}')
    print()
    print(f'question: {question}')
    print()
    print(f'answers: {answers}')
    decoder_input_ids = torch.tensor([[tokenizer.encode(tokenizer.pad_token)[0]]]).to(input_ids.device)
    for i in range(10):
        idx = model(input_ids, decoder_input_ids=decoder_input_ids, return_dict=True).logits.argmax(-1)[0][-1]
        decoder_input_ids=torch.cat((decoder_input_ids,torch.tensor([[idx]]).to(decoder_input_ids.device)), dim=1)
    pred = ' '.join([tokenizer.decode(decoder_input_ids[0], skip_special_tokens=False)])
    pred = pred.replace('</s>','').replace('<pad>','')
#     print(f'indices: {indices}')
    predictions.append(pred.lower().strip())
    
    print(f'model prediction: {pred.lower().strip()}')

In [9]:
import collections
import re
import string
import numpy as np

from absl import logging

def _metric_max_over_ground_truths(metric_fn, ground_truths, prediction):
    """Computes the maximum of the metric over all ground truths."""
    return max(metric_fn(ground_truth, prediction) for ground_truth in ground_truths)


def _exact_match_score(target, prediction):
    return target == prediction

def _f1_score(target, prediction):
    """Computes token f1 score for a single target and prediction."""
    prediction_tokens = prediction.split()
    target_tokens = target.split()
    common = (collections.Counter(prediction_tokens) & collections.Counter(target_tokens))
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(target_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def qa_metrics(targets, predictions):
    """Computes exact match and f1 QA scores, expecting pre-normalized text."""
    if len(targets) != len(predictions):
        raise ValueError("Number of targets and predictions must match.")
    em = np.mean([
        _metric_max_over_ground_truths(_exact_match_score, t, p)
        for p, t in zip(predictions, targets)
    ])
    f1 = np.mean([
        _metric_max_over_ground_truths(_f1_score, t, p)
        for p, t in zip(predictions, targets)
    ])
    em *= 100
    f1 *= 100
    logging.info("EM = %.2f, F1 = %.2f", em, f1)
    return {"em": em, "f1": f1}

In [10]:
qa_metrics(ans, predictions)

INFO:absl:EM = 63.60, F1 = 78.93


{'em': 63.6, 'f1': 78.92817904317904}