In [1]:
import torch
import evaluate
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
model_name = './finetuned-flan-t5'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset = load_dataset('squad_v2')


In [2]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params * 4)
for key ,value in enumerate(dataset):
    print(key, value)

990311424
0 train
1 validation


In [3]:
def data_preprocessing(data):
    inputs = [ q for q in data['question']] 
    targets = [a['text'][0] if len(a['text']) >0 else '' for a in data['answers']]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True)
        
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_dataset = dataset.map(data_preprocessing, batched=True,  batch_size=32)

In [4]:
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)



accuracy_metric = evaluate.load("accuracy")
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions

    # Check if preds are logits or ids, convert if needed
    if preds.ndim == 3:  # When the predictions are logits (batch_size, seq_len, vocab_size)
        preds = preds.argmax(-1)

    # Ensure predictions and labels are arrays
    preds = np.array(preds)
    labels = np.array(labels)

    valid_preds = []
    valid_labels = []

    for i in range(len(labels)):
        label = labels[i]
        pred = preds[i]

        # Mask out padding tokens (-100 in labels)
        valid_indices = label != -100

        # Match the lengths of predictions and labels
        min_len = min(len(pred), len(label))
        pred = pred[:min_len]  # Trim the prediction to the length of the label
        label = label[:min_len]  # Trim the label if necessary

        # Apply the mask to filter out padding tokens
        filtered_pred = pred[valid_indices[:min_len]]
        filtered_label = label[valid_indices[:min_len]]

        valid_preds.extend(filtered_pred.tolist())
        valid_labels.extend(filtered_label.tolist())

    # Compute the accuracy on valid tokens (non-padding tokens)
    accuracy = accuracy_metric.compute(predictions=valid_preds, references=valid_labels)

    return {"accuracy": accuracy['accuracy']}

model.to(device)

training_args = Seq2SeqTrainingArguments(
    output_dir = "./results",
    eval_strategy = "epoch",
    learning_rate = 1e-5,
    per_device_train_batch_size = 12,
    per_device_eval_batch_size = 12,
    weight_decay = 0.01,
    save_total_limit = 3,
    num_train_epochs =3,
    predict_with_generate =True
)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    data_collator= data_collator,
    compute_metrics= compute_metrics    
)

# trainer.train()

In [5]:

results = trainer.evaluate()
print(results)




  0%|          | 0/990 [00:00<?, ?it/s]

In [6]:
# model.save_pretrained("finetuned-flan-t5")
# tokenizer.save_pretrained("finetuned-flan-t5")