In [1]:
from datasets import load_dataset, DatasetDict, Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration

# load baseline model and tokenizer

In [2]:
model_path = 'base/checkpoint-6111'
model = T5ForConditionalGeneration.from_pretrained(model_path)

In [3]:
tokenizer = T5Tokenizer.from_pretrained(model_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
tokenizer = T5Tokenizer.from_pretrained(model_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# load dataset

In [6]:
import json

# load Data
def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

file_path = '../MISC/Dataset/squad_v2_with_explanations.json'
data = load_data(file_path)
train_data = data['data']

In [7]:
# Prepare data to be converted into the format required by the datasets library
def extract_qas(data):
    questions = []
    contexts = []
    answers = []
    explanations = []  
    is_impossible = []

    for article in data:
        for paragraph in article['paragraphs']:
            context = paragraph['context']
            for qa in paragraph['qas']:
                # Check if an explanation exists, skip the current question if not
                explanation = qa.get('explanation')
                if explanation is None:
                    continue  

                questions.append(qa['question'])
                contexts.append(context)  
                explanations.append(explanation)  
                is_impossible.append(qa['is_impossible'])

                # answer
                if qa['is_impossible']:
                    
                    if 'plausible_answers' in qa:
                        answers.append(qa['plausible_answers'][0]['text'])
                    else:
                        answers.append("")  
                else:
                    if qa['answers']:
                        answers.append(qa['answers'][0]['text'])
                    else:
                        answers.append("")  

    return {
        'question': questions,
        'context': contexts,
        'answer': answers,
        'explanation': explanations,
        'is_impossible': is_impossible
    }

# processe data
processed_data = extract_qas(train_data)
dataset = Dataset.from_dict(processed_data)


In [8]:
dataset

Dataset({
    features: ['question', 'context', 'answer', 'explanation', 'is_impossible'],
    num_rows: 13401
})

In [9]:
dataset[2075]

{'question': 'What category of game is Legend of Zelda: Australia Twilight?',
 'context': 'The Legend of Zelda: Twilight Princess (Japanese: ゼルダの伝説 トワイライトプリンセス, Hepburn: Zeruda no Densetsu: Towairaito Purinsesu?) is an action-adventure game developed and published by Nintendo for the GameCube and Wii home video game consoles. It is the thirteenth installment in the The Legend of Zelda series. Originally planned for release on the GameCube in November 2005, Twilight Princess was delayed by Nintendo to allow its developers to refine the game, add more content, and port it to the Wii. The Wii version was released alongside the console in North America in November 2006, and in Japan, Europe, and Australia the following month. The GameCube version was released worldwide in December 2006.[b]',
 'answer': 'action-adventure',
 'explanation': "The category of game that The Legend of Zelda: Twilight Princess falls into is 'action-adventure.'",
 'is_impossible': True}

In [11]:
def preprocess_t5_explanations(examples):
    questions = examples['question']
    contexts = examples['context']
    explanations = examples['explanation']
    answers = examples['answer']

    # Predict and explain first
    source_texts = [f"question: {q} context: {c}" for q, c, a in zip(questions, contexts, answers)]
    target_texts = [f"explanations: {e}" for e, a in zip(explanations, answers)]
    # target_texts = [exp if exp else "no explanation" for exp in explanations]

    # Use tokenizer to process batch data
    model_inputs = tokenizer(source_texts, max_length=768, truncation=True, padding="max_length", return_tensors="pt")
    model_outputs = tokenizer(target_texts, max_length=128, truncation=True, padding="max_length", return_tensors="pt")

    return {
        "input_ids": model_inputs['input_ids'],
        "attention_mask": model_inputs['attention_mask'],
        "labels": model_outputs['input_ids']
    }


In [12]:
processed_dataset = dataset.map(preprocess_t5_explanations, batched=True, batch_size=64)

Map:   0%|          | 0/13401 [00:00<?, ? examples/s]

In [13]:
# processed_dataset[0]

In [16]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq, T5Tokenizer

# Define training parameters
training_args = Seq2SeqTrainingArguments(
    output_dir='./t5-two-stage-1',
    # evaluation_strategy='epoch',  # Evaluation strategy set to 'epoch'
    save_strategy = 'epoch',
    # evaluation_strategy='steps',  # Evaluation strategy set to 'epoch'
    # eval_steps=20,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    load_best_model_at_end=False,     
    logging_steps=1000,           
)

# data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

#Initialize the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer
)





In [17]:
trainer.train()

Step,Training Loss


TrainOutput(global_step=630, training_loss=0.6480337960379464, metrics={'train_runtime': 457.4106, 'train_samples_per_second': 87.893, 'train_steps_per_second': 1.377, 'total_flos': 8161719666868224.0, 'train_loss': 0.6480337960379464, 'epoch': 3.0})

In [18]:
def preprocess_t5_explanations(examples):
    questions = examples['question']
    contexts = examples['context']
    explanations = examples['explanation']
    answers = examples['answer']

    # Add explanation prediction answer
    source_texts = [f"question: {q} context: {c} explanations: {e}" for q, c, e in zip(questions, contexts, explanations)]
    target_texts = [f"{a}" for e, a in zip(explanations, answers)]
    # target_texts = [exp if exp else "no explanation" for exp in explanations]

    # Use tokenizer to process batch data
    model_inputs = tokenizer(source_texts, max_length=768, truncation=True, padding="max_length", return_tensors="pt")
    model_outputs = tokenizer(target_texts, max_length=128, truncation=True, padding="max_length", return_tensors="pt")

    return {
        "input_ids": model_inputs['input_ids'],
        "attention_mask": model_inputs['attention_mask'],
        "labels": model_outputs['input_ids']
    }


In [19]:
processed_dataset = dataset.map(preprocess_t5_explanations, batched=True, batch_size=64)

Map:   0%|          | 0/13401 [00:00<?, ? examples/s]

In [20]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq, T5Tokenizer

# Define training parameters
training_args = Seq2SeqTrainingArguments(
    output_dir='./t5-two-stage-2',
    # evaluation_strategy='epoch',  # Evaluation strategy set to 'epoch'
    save_strategy = 'epoch',
    # evaluation_strategy='steps',  # Evaluation strategy set to 'epoch'
    # eval_steps=20,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    load_best_model_at_end=False,     
    logging_steps=1000,           
)

# data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

#Initialize the trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer
)





In [21]:
trainer.train()

Step,Training Loss


TrainOutput(global_step=630, training_loss=0.015314376164996434, metrics={'train_runtime': 457.5371, 'train_samples_per_second': 87.868, 'train_steps_per_second': 1.377, 'total_flos': 8161719666868224.0, 'train_loss': 0.015314376164996434, 'epoch': 3.0})

In [22]:
import torch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define problem and context
context = "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries."

question = "What century did the Normans first gain their separate identity?"
truth = "10th"
# Encode the input text using a tokenizer and make sure the input data is also on the correct device
input_text = f"question: {question} context: {context}"
input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)

# Generate answer
outputs = model.generate(input_ids, max_length=40)

# Convert the output token IDs to text
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Generated Answer:", answer)
print("truth:", truth)

Generated Answer: 10th
truth: 10th


In [23]:
question = "Who was the Norse leader?"
truth = "Rollo"

input_text = f"question: {question} context: {context}"
input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)


outputs = model.generate(input_ids, max_length=40)


answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Generated Answer:", answer)
print("truth:", truth)

Generated Answer: Rollo
truth: Rollo
