In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset

dataset_name = 'aqua_rat'
model_name = 'google-t5/t5-small'

dataset = load_dataset(dataset_name, split='train')
tokenizer = T5Tokenizer.from_pretrained(model_name)

def preprocess_math_questions(examples):
    # Concatenate question and options into a single string.
    # Format could be: "question: <question> options: A: <optionA> B: <optionB> C: <optionC> D: <optionD> E: <optionE>"
    questions_and_options = [
        f"question: {q} options: A: {opts[0]} B: {opts[1]} C: {opts[2]} D: {opts[3]} E: {opts[4]}"
        for q, opts in zip(examples["question"], examples["options"])
    ]
    
    # The correct answer is mapped to its full text option.
    correct_answers = [opts[ord(examples["correct"][i]) - ord('A')] for i, opts in enumerate(examples["options"])]
    
    # Tokenize inputs and labels
    input_encodings = tokenizer(questions_and_options, padding="max_length", truncation=True, max_length=512)
    target_encodings = tokenizer(correct_answers, padding="max_length", truncation=True, max_length=128)
    
    return {
        "input_ids": input_encodings.input_ids,
        "attention_mask": input_encodings.attention_mask,
        "labels": target_encodings.input_ids
    }

# Apply preprocessing
processed_dataset = dataset.map(preprocess_math_questions, batched=True)

model = T5ForConditionalGeneration.from_pretrained(model_name)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_strategy="no",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    # eval_dataset=processed_eval_dataset, # If you have an evaluation dataset
)

trainer.train()



  0%|          | 0/36552 [00:00<?, ?it/s]

KeyboardInterrupt: 