In [2]:
import json
import pandas as pd
import torch
from transformers import MarianTokenizer, MarianMTModel, BartTokenizer, BartForConditionalGeneration
from pathlib import Path
import logging
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random

In [3]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Suppress sacremoses warning
try:
    import sacremoses
except ImportError:
    logger.warning("sacremoses not found, but proceeding as it’s optional for MarianMT.")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Preprocessing

**Splits the dataset into training (80%) and test (20%) sets using random.shuffle for randomization.**

In [4]:
# train-test split function
def custom_train_test_split(data, test_size=0.2, random_state=42):
    random.seed(random_state)
    random.shuffle(data)
    split_idx = int(len(data) * (1 - test_size))
    train_data = data[:split_idx]
    test_data = data[split_idx:]
    return train_data, test_data

**Calculates the exact match accuracy between true and predicted SQL queries.**

In [5]:
# accuracy score function
def custom_accuracy_score(true_list, pred_list):
    matches = sum(1 for t, p in zip(true_list, pred_list) if t == p)
    return matches / len(true_list) if true_list else 0.0

# Model Loading
##### Loads pre-trained models for translation (MarianMT) and text-to-SQL generation (BART).

* Loads Helsinki-NLP/opus-mt-ar-en (MarianMT) for Arabic-to-English translation using MarianTokenizer and MarianMTModel.
* Loads facebook/bart-base (BART) for text-to-SQL with BartTokenizer and BartForConditionalGeneration, using trust_remote_code=True to enable generation capabilities

In [6]:
# Load MarianMT for Arabic-to-English translation
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translator_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translator_model = MarianMTModel.from_pretrained(translation_model_name).to(device)

# Load BART for text-to-SQL with trust_remote_code
sql_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
sql_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base", trust_remote_code=True).to(device)

# Defninning classes and methods

## Dataset class
to prepare question-SQL pairs for training the BART model.

__getitem__: Converts a question-SQL pair into tokenized input (input_ids, attention_mask) and target (labels) tensors, padded/truncated to 512 tokens.

In [7]:
# Dataset for text-to-SQL
class SQLDataset(Dataset):
    def __init__(self, questions, sql_queries, tokenizer, max_length=512):
        self.questions = questions
        self.sql_queries = sql_queries
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = str(self.questions[idx])
        sql = str(self.sql_queries[idx])
        input_encoding = self.tokenizer(
            question,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        target_encoding = self.tokenizer(
            sql,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        return {
            "input_ids": input_encoding["input_ids"].squeeze(),
            "attention_mask": input_encoding["attention_mask"].squeeze(),
            "labels": target_encoding["input_ids"].squeeze()
        }

## normalization 
Normalizes SQL strings by removing extra whitespace and standardizing commas (e.g., SELECT name , age to SELECT name,age).

In [None]:
# Function to normalize SQL strings
def normalize_sql(sql):
    import re
    sql = re.sub(r'\s+', ' ', sql.strip())
    sql = re.sub(r'\s*,\s*', ',', sql)
    return sql

## Translates Arabic text to English 
using MarianMT, logging the input and output.

Returns the translated text.

In [9]:
# Function to translate Arabic to English
def translate_arabic_to_english(text):
    inputs = translator_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    inputs = {k: v for k, v in inputs.items() if k != "token_type_ids"}
    translated = translator_model.generate(**inputs, max_length=512)
    translated_text = translator_tokenizer.decode(translated[0], skip_special_tokens=True)
    logger.info(f"Translated '{text}' to '{translated_text}'")
    return translated_text

## Generates an SQL query from English text 

using the BART model’s generate method.

Returns the decoded SQL string, logging the process.

In [10]:
# Function to convert English text to SQL
def text_to_sql(english_text, model, tokenizer):
    inputs = tokenizer(english_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512)
    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    logger.info(f"Generated SQL for '{english_text}': {sql_query}")
    return sql_query

## Combines translation and SQL generation for an Arabic query.
Returns the English translation and generated SQL.

In [11]:
# Function to process a single query
def process_query(arabic_question, model, tokenizer):
    english_question = translate_arabic_to_english(arabic_question)
    sql_query = text_to_sql(english_question, model, tokenizer)
    return english_question, sql_query


# Data Loading and Splitting

In [12]:
# Load dataset
dataset_path = "/kaggle/input/txttosql-nlp/AR_spider.jsonl"
data = []
if not Path(dataset_path).exists():
    raise FileNotFoundError(f"Dataset not found at {dataset_path}. Please ensure the file exists.")
with open(dataset_path, 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line.strip()))

# Convert to DataFrame
df = pd.DataFrame(data)
logger.info(f"Loaded {len(df)} queries from dataset.")


In [13]:
# Split dataset into train and test
train_data, test_data = custom_train_test_split(data, test_size=0.2, random_state=42)
train_df = pd.DataFrame(train_data)
test_df = pd.DataFrame(test_data)
logger.info(f"Training set: {len(train_df)} samples, Test set: {len(test_df)} samples")

# Prepare training data
train_questions = train_df['question'].tolist()
train_sqls = train_df['query'].tolist()
train_dataset = SQLDataset(train_questions, train_sqls, sql_tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


# Model Training
Fine-tunes the BART model on the training data for 10 epochs.

* Uses AdamW optimizer with a learning rate of 2e-5.
* Trains for 10 epochs (changed from 3), processing batches and computing loss.
* Logs the average loss per epoch and switches the model to evaluation mode after training.

In [14]:
# Fine-tune BART model
optimizer = torch.optim.AdamW(sql_model.parameters(), lr=2e-5)
sql_model.train()
logger.info("Starting fine-tuning BART for text-to-SQL...")
for epoch in range(10):  # Changed from 3 to 10 epochs
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = sql_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    logger.info(f"Epoch {epoch+1} Loss: {total_loss / len(train_loader):.4f}")

sql_model.eval()
logger.info("Fine-tuning completed.")

Epoch 1: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 2: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 3: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 4: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 5: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 6: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 7: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 8: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 9: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 10: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]


# Model Evaluation

Evaluates the fine-tuned model on the test set.

* Defines evaluate_system(test_data, model, tokenizer, num_samples=None) to process up to 100 test samples.
* Generates SQL for each Arabic query, normalizes true and predicted SQL, and computes accuracy.
* Logs details and prints mismatches (up to 10).

In [15]:
# Evaluate system accuracy
def evaluate_system(test_data, model, tokenizer, num_samples=None):
    true_sqls = []
    pred_sqls = []
    
    samples = test_data[:num_samples] if num_samples else test_data
    
    for i, item in enumerate(tqdm(samples, desc="Evaluating")):
        arabic_q = item['arabic']
        true_sql = item['query']
        
        english_q, pred_sql = process_query(arabic_q, model, tokenizer)
        
        true_sql_norm = normalize_sql(true_sql)
        pred_sql_norm = normalize_sql(pred_sql)
        
        true_sqls.append(true_sql_norm)
        pred_sqls.append(pred_sql_norm)
        
        logger.info(f"Sample {i+1}: Arabic: {arabic_q}, English: {english_q}, True SQL: {true_sql_norm}, Predicted SQL: {pred_sql_norm}")
    
    accuracy = custom_accuracy_score(true_sqls, pred_sqls)
    return accuracy, true_sqls, pred_sqls

In [1]:
# Run evaluation on test set
try:
    accuracy, true_sqls, pred_sqls = evaluate_system(test_data, sql_model, sql_tokenizer, num_samples=min(len(test_data), 100))
    print(f"System Accuracy (on {min(len(test_data), 100)} test samples): {accuracy:.2f}")
    mismatches = [(i, t, p) for i, (t, p) in enumerate(zip(true_sqls, pred_sqls)) if t != p][:10]
    for i, true, pred in mismatches:
        print(f"Mismatch in sample {i+1}:")
        print(f"  True SQL: {true}")
        print(f"  Predicted SQL: {pred}")
except Exception as e:
    print(f"Error during evaluation: {e}")

# Test with example queries
test_queries = [
    "كم عدد رؤساء الأقسام الذين تزيد أعمارهم عن 56 سنة؟",
    "ما هي أسماء الأقسام التي لديها أكثر من 10 موظفين؟",
    "كم عدد الموظفين في قسم المبيعات؟"
]

print("\nTesting multiple queries:")
for q in test_queries:
    eng, sql = process_query(q, sql_model, sql_tokenizer)
    print(f"\nArabic: {q}")
    print(f"English: {eng}")
    print(f"SQL: {sql}")

print("\nNote: The text-to-SQL model is fine-tuned on the dataset using BART for 10 epochs. "
      "Accuracy depends on training data size and query diversity. Ensure scikit-learn>=1.5.2 is installed if sklearn is needed elsewhere.")

2025-05-21 04:27:38.329234: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747801658.521150      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747801658.576284      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/917k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.13M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Epoch 1: 100%|██████████| 640/640 [08:59<00:00,  1.19it/s]
Epoch 2: 100%|██████████| 640/640 [08:59<00:00,  1.19it/s]
Epoch 3: 100%|██████████| 640/640 [08:59<00:00,  1.19it/s]
Epoch 4: 100%|██████████| 640/640 [08:59<00:00,  1.19it/s]
Epoch 5: 100%|██████████| 640/640 [08:59<00:00,  1.19it/s]
Epoch 6: 100%|██████████| 640/640 [09:00<00:00,  1.19it/s]
Epoch 7: 100%|██████████| 640/640 [08:59<00:00,  1.19it/s]
Epoch 8: 100%|██████████| 640/640 [09:00<00:00,  1.19it/s]
Epoch 9: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Epoch 10: 100%|██████████| 640/640 [09:00<00:00,  1.18it/s]
Evaluating: 100%|██████████| 100/100 [00:48<00:00,  2.05it/s]


System Accuracy (on 100 test samples): 0.03
Mismatch in sample 2:
  True SQL: SELECT Planned_Delivery_Date,Actual_Delivery_Date FROM BOOKINGS
  Predicted SQL: SELECT T1.prereq_date,T2.date_of_delivery FROM Reservations AS T1 JOIN Delivery_Roles AS T2 ON T3.Reservation_ID = T4.Resident_ID
Mismatch in sample 3:
  True SQL: SELECT T2.lot_details FROM INVESTORS AS T1 JOIN LOTS AS T2 ON T1.investor_id = T2.investor_id WHERE T1.Investor_details = "l"
  Predicted SQL: SELECT T1.purchase_details FROM purchase_transactions AS T1 JOIN TRANSACTIONS AS T2 ON T2.transaction_id = T3.transactions_id WHERE T3,amount_purchased > 10000
Mismatch in sample 4:
  True SQL: SELECT T2.balance FROM accounts AS T1 JOIN checking AS T2 ON T1.custid = T2.custid WHERE T1.name IN (SELECT T1.name FROM accounts AS T1 JOIN savings AS T2 ON T1.custid = T2.custid WHERE T2.balance > (SELECT avg(balance) FROM savings))
  Predicted SQL: SELECT T2.balance FROM accounts AS T1 JOIN savings AS T2 ON T1.council_id = T2;
Mismatch