<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_2_SQL_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install --upgrade --force-reinstall --no-cache-dir \
transformers==4.44.2 \
accelerate==0.34.2 \
sentencepiece \
sqlglot


Collecting transformers==4.44.2
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m148.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate==0.34.2
  Downloading accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting sentencepiece
  Downloading sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (10 kB)
Collecting sqlglot
  Downloading sqlglot-28.10.1-py3-none-any.whl.metadata (22 kB)
Collecting filelock (from transformers==4.44.2)
  Downloading filelock-3.20.3-py3-none-any.whl.metadata (2.1 kB)
Collecting huggingface-hub<1.0,>=0.23.2 (from transformers==4.44.2)
  Downloading huggingface_hub-0.36.2-py3-none-any.whl.metadata (15 kB)
Collecting numpy>=1.17 (from transformers==4.44.2)
  Downloading numpy-2.4.2-cp312-cp312-man

In [1]:
import transformers
import accelerate

print(transformers.__version__)
print(accelerate.__version__)


4.44.2
0.34.2


In [2]:
from transformers import Trainer, TrainingArguments
print("Trainer imported successfully ✅")


Trainer imported successfully ✅


In [3]:
import random
import json
import torch
import sqlglot

from torch.utils.data import Dataset

from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Trainer,
    TrainingArguments
)


In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [5]:
COLUMN_TYPES = ["number", "text", "date"]

def generate_schema():

    num_tables = random.randint(1, 4)

    schema = {}
    foreign_keys = []

    for t in range(num_tables):

        table = f"table_{t}"

        cols = [("id", "number")]

        for i in range(random.randint(3,6)):
            col = f"col_{random.randint(1,50)}"
            col_type = random.choice(COLUMN_TYPES)
            cols.append((col, col_type))

        schema[table] = cols

    # create foreign key if multiple tables
    if num_tables > 1:
        foreign_keys.append(
            ("table_1","id","table_0","id")
        )

    return schema, foreign_keys


In [6]:
def serialize_schema(schema, fks):

    text = "DATABASE SCHEMA:\n\n"

    for table, cols in schema.items():

        text += f"Table: {table}\nColumns:\n"

        for col, typ in cols:
            text += f"- {col} ({typ})\n"

        text += "\n"

    if fks:
        text += "Relationships:\n"
        for t1,c1,t2,c2 in fks:
            text += f"{t1}.{c1} references {t2}.{c2}\n"

    return text


In [7]:
AGGS = ["SUM","AVG","COUNT","MAX","MIN"]

def pick_number_col(cols):
    nums = [c for c in cols if c[1]=="number"]
    return random.choice(nums) if nums else random.choice(cols)

def generate_query(schema, fks):

    table = random.choice(list(schema.keys()))
    cols = schema[table]

    col, typ = random.choice(cols)

    pattern = random.choice([
        "select",
        "where",
        "multi",
        "agg",
        "group",
        "having",
        "join",
        "nested"
    ])

    # SELECT
    if pattern=="select":
        q = f"list {col} from {table}"
        sql = f"SELECT {col} FROM {table}"

    # WHERE
    elif pattern=="where":
        num_col,_ = pick_number_col(cols)
        val = random.randint(10,1000)

        q = f"show records from {table} where {num_col} is greater than {val}"
        sql = f"SELECT * FROM {table} WHERE {num_col} > {val}"

    # MULTI
    elif pattern=="multi":
        c2,_ = random.choice(cols)
        q = f"show {col} and {c2} from {table}"
        sql = f"SELECT {col}, {c2} FROM {table}"

    # AGG
    elif pattern=="agg":
        num_col,_ = pick_number_col(cols)
        agg = random.choice(AGGS)

        q = f"what is the {agg.lower()} of {num_col} in {table}"
        sql = f"SELECT {agg}({num_col}) FROM {table}"

    # GROUP
    elif pattern=="group":
        num_col,_ = pick_number_col(cols)
        agg = random.choice(AGGS)

        q = f"show {col} grouped by {col} with {agg.lower()} of {num_col}"
        sql = f"SELECT {col}, {agg}({num_col}) FROM {table} GROUP BY {col}"

    # HAVING
    elif pattern=="having":
        num_col,_ = pick_number_col(cols)
        val = random.randint(10,500)
        agg = random.choice(AGGS)

        q = f"show {col} where {agg.lower()} {num_col} is greater than {val}"
        sql = f"SELECT {col}, {agg}({num_col}) FROM {table} GROUP BY {col} HAVING {agg}({num_col}) > {val}"

    # JOIN
    elif pattern=="join" and fks:
        t1,c1,t2,c2 = fks[0]

        q = f"join {t1} with {t2}"
        sql = f"SELECT * FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"

    # NESTED
    else:
        num_col,_ = pick_number_col(cols)

        q = f"show rows from {table} where {num_col} is above average"
        sql = f"SELECT * FROM {table} WHERE {num_col} > (SELECT AVG({num_col}) FROM {table})"

    return q.lower(), sql


In [8]:
def build_dataset(num_schemas=250):

    schemas = [generate_schema() for _ in range(num_schemas)]

    train_cut = int(0.8 * num_schemas)

    train_schemas = schemas[:train_cut]
    val_schemas = schemas[train_cut:]

    train_data = []
    val_data = []

    for schema,fks in train_schemas:
        schema_text = serialize_schema(schema,fks)

        for _ in range(20):
            q,sql = generate_query(schema,fks)

            inp = f"""
You are an expert SQL generator.

Use ONLY tables provided.

{schema_text}

Question:
{q}
"""

            train_data.append((inp,sql))

    for schema,fks in val_schemas:
        schema_text = serialize_schema(schema,fks)

        for _ in range(20):
            q,sql = generate_query(schema,fks)

            inp = f"""
You are an expert SQL generator.

Use ONLY tables provided.

{schema_text}

Question:
{q}
"""

            val_data.append((inp,sql))

    return train_data, val_data


In [9]:
train_data, val_data = build_dataset()

print(len(train_data), len(val_data))


4000 1000


In [20]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

model = T5ForConditionalGeneration.from_pretrained(
    "google/flan-t5-base"
).to(device)




In [21]:
class NL2SQLDataset(Dataset):

    def __init__(self,data):
        self.data=data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):

        inp,sql = self.data[idx]

        enc = tokenizer(
            inp,
            truncation=True,
            padding="max_length",
            max_length=512,
            return_tensors="pt"
        )

        dec = tokenizer(
            sql,
            truncation=True,
            padding="max_length",
            max_length=128,
            return_tensors="pt"
        )

        labels = dec.input_ids.squeeze()
        labels[labels==tokenizer.pad_token_id] = -100

        return {
            "input_ids":enc.input_ids.squeeze(),
            "attention_mask":enc.attention_mask.squeeze(),
            "labels":labels
        }


In [22]:
import transformers
print(transformers.__version__)


4.44.2


In [None]:
train_ds = NL2SQLDataset(train_data)
val_ds = NL2SQLDataset(val_data)

training_args = TrainingArguments(
    output_dir="nl2sql_model",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    num_train_epochs=5,
    learning_rate=2e-5,
    max_grad_norm=1.0,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=False,
    report_to="none"
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds
)

trainer.train()
torch.save(model.state_dict(), "nl2sql_model.pt")


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Epoch,Training Loss,Validation Loss
1,0.0,


In [27]:
print(model.dtype)
sample = train_ds[0]
print(sample["labels"][:50])
print(sum(sample["labels"] != -100))
for i in range(10):
    print(train_data[i][1])


torch.float32
tensor([    3, 23143, 14196,  7632,   834,  2606,     6,  4800,     4,   599,
           23,    26,    61, 21680,   953,   834,   632,   350,  4630,  6880,
          272,   476,  7632,   834,  2606,     1,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100])
tensor(26)
SELECT col_18, MAX(id) FROM table_0 GROUP BY col_18
SELECT col_18 FROM table_0
SELECT AVG(col_32) FROM table_1
SELECT * FROM table_1 JOIN table_0 ON table_1.id = table_0.id
SELECT SUM(id) FROM table_0
SELECT col_2 FROM table_0
SELECT id FROM table_2
SELECT col_9, MIN(col_29) FROM table_1 GROUP BY col_9
SELECT col_33 FROM table_2
SELECT col_47, SUM(id) FROM table_2 GROUP BY col_47 HAVING SUM(id) > 425


In [9]:
def generate_sql(schema,fks,question):

    schema_text = serialize_schema(schema,fks)

    prompt=f"""
You are an expert SQL generator.

Use ONLY tables provided.

{schema_text}

Question:
{question}
"""

    tokens=tokenizer(prompt,return_tensors="pt").to(device)

    outputs=model.generate(
        **tokens,
        max_length=150,
        num_beams=4,
        early_stopping=True
    )

    return tokenizer.decode(outputs[0],skip_special_tokens=True)


In [9]:
def validate_sql(query):

    try:
        sqlglot.parse_one(query)
        return True
    except:
        return False
