In [None]:
pip install torch transformers tqdm scikit-learn nltk rouge-score hf_xet SPARQLWrapper matplotlib -q

In [None]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    AutoConfig
)
from tqdm import tqdm
import nltk
import numpy as np
import os
import re
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer

nltk.download('punkt', quiet=True)

MODEL_NAME = "google/flan-t5-base"
MAX_SOURCE_LENGTH = 128
MAX_TARGET_LENGTH = 256
BATCH_SIZE = 16
ACCUMULATION_STEPS = 1
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 100
EARLY_STOP_PATIENCE = 10
LENGTH_PENALTY = 1.1
NUM_BEAMS = 6
NO_REPEAT_NGRAM_SIZE = 3
DATASET_PATH = "/kaggle/input/lcquad2/train.json"
CHECKPOINT_DIR = "/kaggle/working/checkpoints"
PERSISTENT_LOAD_CHECKPOINT = "/kaggle/input/my_checkpoints/latest.pt"
WORKING_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "latest.pt")
AUGMENT_PARAPHRASE = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    tqdm.write(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    tqdm.write("Using CPU")

def normalize_sparql(query):
    return re.sub(r'\s+', ' ', query.strip().lower())

class KBQADataset(Dataset):
    def __init__(self, data, tokenizer, max_source_length, max_target_length, augment_paraphrase=False):
        self.samples = []
        for item in data:
            question = item.get('question') or ""
            sparql = item.get('sparql_wikidata') or ""
            self.samples.append((question, sparql))
            if augment_paraphrase:
                paraphrased = item.get('paraphrased_question')
                if paraphrased and paraphrased.strip() and paraphrased.strip() != question.strip():
                    self.samples.append((paraphrased, sparql))
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        question, sparql = self.samples[idx]
        src_enc = self.tokenizer(
            f"generate query: {question}",
            max_length=self.max_source_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        tgt_enc = self.tokenizer(
            sparql,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        labels = tgt_enc['input_ids'].squeeze().clone()
        labels[labels == self.tokenizer.pad_token_id] = -100
        return {
            'input_ids': src_enc['input_ids'].squeeze(),
            'attention_mask': src_enc['attention_mask'].squeeze(),
            'labels': labels,
            'target_text': sparql
        }

def load_dataset():
    if not os.path.exists(DATASET_PATH):
        raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}.")
    with open(DATASET_PATH, 'r') as f:
        return json.load(f)

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0
    optimizer.zero_grad()
    for i, batch in enumerate(tqdm(dataloader, desc="Training", leave=True)):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss.mean() / ACCUMULATION_STEPS
        loss.backward()

        if (i + 1) % ACCUMULATION_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.25)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item() * ACCUMULATION_STEPS
    return total_loss / len(dataloader)

def main():
    tqdm.write(f"Using device: {DEVICE}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model_config = AutoConfig.from_pretrained(MODEL_NAME)
    model_config.dropout_rate = 0.1
    model_config.attention_dropout_rate = 0.1
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, config=model_config)
    model.gradient_checkpointing_enable()
    model.to(DEVICE)

    data = load_dataset()
    train_ds = KBQADataset(data, tokenizer, MAX_SOURCE_LENGTH, MAX_TARGET_LENGTH, augment_paraphrase=AUGMENT_PARAPHRASE)
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_dl) * NUM_EPOCHS
    warmup_steps = int(0.03 * total_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    start_epoch = 1
    best_bleu = -1
    stop_counter = 0

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    checkpoint_loaded = False
    if os.path.exists(PERSISTENT_LOAD_CHECKPOINT):
        tqdm.write(f"Resuming from persistent checkpoint {PERSISTENT_LOAD_CHECKPOINT}")
        checkpoint = torch.load(PERSISTENT_LOAD_CHECKPOINT, map_location=DEVICE)
        checkpoint_loaded = True
    elif os.path.exists(WORKING_CHECKPOINT):
        tqdm.write(f"Resuming from working checkpoint {WORKING_CHECKPOINT}")
        checkpoint = torch.load(WORKING_CHECKPOINT, map_location=DEVICE)
        checkpoint_loaded = True

    if checkpoint_loaded:
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])
        scheduler.load_state_dict(checkpoint['sched_state'])
        start_epoch = checkpoint['epoch'] + 1
        best_bleu = checkpoint.get('best_bleu', -1)
        stop_counter = checkpoint.get('stop_counter', 0)
        tqdm.write(f"Resumed at epoch {start_epoch}")
    else:
        tqdm.write("No checkpoint found, starting from scratch.")

    for epoch in range(start_epoch, NUM_EPOCHS + 1):
        tqdm.write(f"Starting Epoch {epoch}/{NUM_EPOCHS}")
        tr_loss = train_epoch(model, train_dl, optimizer, scheduler, DEVICE)
        tqdm.write(f"Epoch {epoch}: Train Loss: {tr_loss:.4f}")

        if epoch % 2 == 0 or epoch == NUM_EPOCHS:
            torch.save({
                'model_state': model.state_dict(),
                'optim_state': optimizer.state_dict(),
                'sched_state': scheduler.state_dict(),
                'epoch': epoch,
                'best_bleu': best_bleu,
                'stop_counter': stop_counter
            }, WORKING_CHECKPOINT)
            tqdm.write(f"Checkpoint saved at epoch {epoch}")

    output_dir = "./t5_model"
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

if __name__ == "__main__":
    main()