# Fine-Tuning FLAN-T5 for Ontology Explanation Generation

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from datasets import load_dataset, Dataset
import json

# Load training data
with open('../data/training_pairs.json') as f:
    data = json.load(f)
dataset = Dataset.from_dict({"input": [data["input"]], "output": [data["output"]]})

# Load model and tokenizer
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Tokenize
def preprocess(examples):
    inputs = tokenizer(examples['input'], padding='max_length', truncation=True, max_length=128)
    outputs = tokenizer(examples['output'], padding='max_length', truncation=True, max_length=128)
    inputs['labels'] = outputs['input_ids']
    return inputs
tokenized = dataset.map(preprocess)

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="../model/flan-t5-finetuned",
    per_device_train_batch_size=2,
    num_train_epochs=1,
    logging_dir="../logs",
    logging_steps=10,
    save_total_limit=1,
)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
)
trainer.train()