<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install faker

import random
import json
import os
import torch
import torch.nn as nn
from faker import Faker
from torch.utils.data import Dataset, DataLoader
from collections import Counter
fake = Faker()

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Collecting faker
  Downloading faker-40.4.0-py3-none-any.whl.metadata (16 kB)
Downloading faker-40.4.0-py3-none-any.whl (2.0 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faker
Successfully installed faker-40.4.0
Device: cuda


In [None]:
SQL_PATTERNS = (
    ["join"] * 5 +
    ["join_where"] * 4 +
    ["group_by"] * 3 +
    ["having"] * 2 +
    ["where"] * 2 +
    ["aggregation"] * 2 +
    ["order_by"] * 3 +
    ["limit"] * 2 +
    ["order_by_limit"] * 3 +
    ["multi_select"] +
    ["simple_select"]
)


In [None]:
BASE_COLUMNS = [
    "name","email","age","salary",
    "department","city","country",
    "price","amount","quantity",
    "created_at","updated_at"
]

FK_COLUMNS = [
    "user_id",
    "order_id",
    "product_id",
    "customer_id"
]

In [None]:
COLUMN_TYPES = {
    "numeric":["age","salary","price","amount","quantity"],
    "text":["name","email","department","city","country"],
    "date":["created_at","updated_at"]
}


In [None]:
def generate_relational_schema():

    num_tables = random.randint(2,5)
    fake.unique.clear()

    tables = {}
    table_names = [fake.unique.word() for _ in range(num_tables)]

    for t in table_names:

        cols = {
            "numeric": random.sample(COLUMN_TYPES["numeric"], random.randint(1,3)),
            "text": random.sample(COLUMN_TYPES["text"], random.randint(1,3)),
            "date": random.sample(COLUMN_TYPES["date"], random.randint(0,1))
        }

        flat_cols = ["id"]
        for v in cols.values():
            flat_cols.extend(v)

        tables[t] = {
            "all": flat_cols,
            "numeric": cols["numeric"],
            "text": cols["text"],
            "date": cols["date"]
        }

    relationships = []

    for i in range(1, num_tables):
        parent = table_names[i-1]
        child = table_names[i]
        fk = f"{parent}_id"
        tables[child]["all"].append(fk)
        relationships.append((child, fk, parent, "id"))

    return {"tables": tables, "relationships": relationships}


In [None]:
def schema_to_text(schema):

    parts = []

    for table, col_dict in schema["tables"].items():
        col_tokens = " ".join([f"[COL] {c}" for c in col_dict["all"]])
        parts.append(f"[TABLE] {table} {col_tokens}")

    for child, fk, parent, pk in schema["relationships"]:
        parts.append(f"[REL] {child}.{fk} -> {parent}.{pk}")

    return " ".join(parts)


In [None]:
SELECT_TEMPLATES = [
    "show {col} from {table}",
    "list {col} in {table}",
    "display {col} from {table}",
    "what are the {col} in {table}",
]

In [None]:
def generate_sql(schema):

    pattern = random.choice(SQL_PATTERNS)
    tables = schema["tables"]
    rels = schema["relationships"]

    table = random.choice(list(tables.keys()))
    columns = tables[table]["all"]
    numeric_cols = tables[table]["numeric"]
    text_cols = tables[table]["text"]

    # SIMPLE SELECT
    if pattern == "simple_select":

        selected = random.sample(columns, random.randint(1, min(3,len(columns))))
        question = f"show {', '.join(selected)} from {table}"
        sql = f"SELECT {', '.join(selected)} FROM {table}"

    # WHERE
    elif pattern == "where":

        col = random.choice(numeric_cols if numeric_cols else columns)
        operator = random.choice(["=",">","<"])
        value = random.randint(1,100)

        question = f"show {col} from {table} where {col} {operator} {value}"
        sql = f"SELECT {col} FROM {table} WHERE {col} {operator} {value}"

    # ORDER BY
    elif pattern == "order_by":

        col = random.choice(numeric_cols if numeric_cols else columns)
        direction = random.choice(["ASC","DESC"])

        question = f"show all records from {table} ordered by {col}"
        sql = f"SELECT * FROM {table} ORDER BY {col} {direction}"

    # LIMIT
    elif pattern == "limit":

        limit_val = random.randint(3,20)
        question = f"show first {limit_val} rows from {table}"
        sql = f"SELECT * FROM {table} LIMIT {limit_val}"

    # ORDER BY + LIMIT
    elif pattern == "order_by_limit":

        col = random.choice(numeric_cols if numeric_cols else columns)
        direction = random.choice(["ASC","DESC"])
        limit_val = random.randint(3,15)

        question = f"show top {limit_val} records from {table} ordered by {col}"
        sql = f"SELECT * FROM {table} ORDER BY {col} {direction} LIMIT {limit_val}"

    # JOIN
    elif pattern == "join" and rels:

        child, fk, parent, pk = random.choice(rels)
        select_col = random.choice(tables[parent]["all"])

        question = f"show {select_col} from {parent} joined with {child}"
        sql = f"SELECT {parent}.{select_col} FROM {parent} JOIN {child} ON {parent}.{pk} = {child}.{fk}"

    else:
        return generate_sql(schema)

    return question.strip(), sql.strip()


In [None]:
def build_dataset(num_schemas=8000, queries_per_schema=8):

    data = []

    for _ in range(num_schemas):

        schema = generate_relational_schema()
        schema_text = schema_to_text(schema)

        for _ in range(queries_per_schema):

            q, sql = generate_sql(schema)

            model_input = f"""
Schema:
{schema_text}

Question:
{q}
"""

            data.append({
                "input": model_input.strip(),
                "output": sql.strip()
            })

    return data


dataset = build_dataset()

print("Total examples:", len(dataset))

with open("dataset.json", "w") as f:
    json.dump(dataset, f)

print("✅ dataset.json saved")


Train: 51200
Val: 6400
Test: 6400


In [None]:
def build_vocab(data):

    counter = Counter()

    for row in data:
        counter.update(row["input"].split())
        counter.update(row["output"].split())

    vocab = {w:i+4 for i,(w,_) in enumerate(counter.items())}

    vocab["<pad>"] = 0
    vocab["<unk>"] = 1
    vocab["<sos>"] = 2
    vocab["<eos>"] = 3

    return vocab


In [None]:
with open("dataset.json") as f:
    dataset = json.load(f)

random.shuffle(dataset)

split = int(0.9 * len(dataset))

train = dataset[:split]
val   = dataset[split:]
vocab = build_vocab(train)
torch.save(vocab, "vocab.pt")


print("Train:", len(train))
print("Val:", len(val))


✅ Train/Val/Test datasets saved as JSON


In [None]:
MAX_LEN = 160

def encode(text, add_special_tokens=False):

    tokens = text.split()

    ids = [vocab.get(t,1) for t in tokens]

    if add_special_tokens:
        ids = [vocab["<sos>"]] + ids + [vocab["<eos>"]]

    ids = ids[:MAX_LEN]

    ids += [0]*(MAX_LEN-len(ids))

    return torch.tensor(ids)


class NL2SQLDataset(Dataset):
    def __init__(self,data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):

        row = self.data[idx]

        src = encode(row["input"], add_special_tokens=False)
        tgt = encode(row["output"], add_special_tokens=True)

        return src, tgt



In [None]:
class NL2SQLModel(nn.Module):

    def __init__(self, vocab_size, d_model=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(
          d_model=d_model,
          nhead=4,
          num_encoder_layers=2,
          num_decoder_layers=2,
          batch_first=True
        )

        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):

        src_mask = (src==0)
        tgt_mask = (tgt==0)

        seq_len = tgt.size(1)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(tgt.device)

        src = self.embed(src)
        tgt = self.embed(tgt)

        out = self.transformer(
            src,
            tgt,
            tgt_mask=causal_mask,
            src_key_padding_mask=src_mask,
            tgt_key_padding_mask=tgt_mask
        )

        return self.fc(out)


In [2]:
train_loader = DataLoader(NL2SQLDataset(train), batch_size=16, shuffle=True)
val_loader   = DataLoader(NL2SQLDataset(val), batch_size=16)

model = NL2SQLModel(len(vocab)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)

CHECKPOINT_PATH = "nl2sql_checkpoint.pt"

def save_checkpoint(epoch, model, optimizer, val_loss):
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "val_loss": val_loss
    }, CHECKPOINT_PATH)

def load_checkpoint(model, optimizer):
    if os.path.exists(CHECKPOINT_PATH):
        ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
        model.load_state_dict(ckpt["model_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])
        print(f"✅ Resuming from epoch {ckpt['epoch']+1}")
        return ckpt["epoch"] + 1
    return 0


EPOCHS = 15
best_loss = float("inf")

start_epoch = load_checkpoint(model, optimizer)

for epoch in range(start_epoch, EPOCHS):

    ###################
    # TRAIN
    ###################
    model.train()
    train_loss = 0

    for x,y in train_loader:

        x,y = x.to(device), y.to(device)

        optimizer.zero_grad()

        output = model(x, y[:,:-1])

        loss = loss_fn(
            output.reshape(-1,len(vocab)),
            y[:,1:].reshape(-1)
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)


    ###################
    # VALIDATION
    ###################
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)

            output = model(x, y[:,:-1])

            loss = loss_fn(
                output.reshape(-1,len(vocab)),
                y[:,1:].reshape(-1)
            )

            val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Epoch {epoch} | Train {train_loss:.4f} | Val {val_loss:.4f}")

    # save every epoch
    save_checkpoint(epoch, model, optimizer, val_loss)

    # save best model separately
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        print("⭐ BEST MODEL SAVED")


In [None]:
inv_vocab = {i:w for w,i in vocab.items()}

def generate_query(model, text):

    model.eval()

    src = encode(text, add_special_tokens=False).unsqueeze(0).to(device)

    tgt = torch.tensor([[vocab["<sos>"]]], dtype=torch.long).to(device)

    for _ in range(MAX_LEN):

        with torch.no_grad():
            output = model(src, tgt)

        next_token = output.argmax(-1)[:,-1]

        if next_token.item() == vocab["<eos>"]:
            break

        tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)

    tokens = [
        inv_vocab.get(i,"")
        for i in tgt.squeeze().tolist()
        if i not in [vocab["<sos>"], vocab["<pad>"]]
    ]

    return " ".join(tokens)


In [12]:
manual_input = """
Schema:
[TABLE] employees [COL] id [COL] name [COL] salary [COL] department_id
[TABLE] departments [COL] id [COL] dept_name
[REL] employees.department_id -> departments.id

Question:
show salary from employees where salary > 100
"""

print(generate_query(model, manual_input))


<pad> FROM salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary salary FROM professor WHERE salary > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 > 58 58 58 58 58 salary > 58 58 58 58 58 salary > 58 58 58 58 58 salary > 58 58 58 58 58 58 salary > 58 58 58 58 58 58 58 salary > 58 58 58 58 58 58 58 salary > 58 58 58 58 58 58 58 58 salary > 58 58 58 58 58 58 58 58 58 salary > 58 58 58 58 58 58 58 58 58 58 58 salary
