<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install faker

import random
import os
import torch
import torch.nn as nn
from faker import Faker
from torch.utils.data import Dataset, DataLoader
from collections import Counter

fake = Faker()

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Collecting faker
  Downloading faker-40.4.0-py3-none-any.whl.metadata (16 kB)
Downloading faker-40.4.0-py3-none-any.whl (2.0 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m87.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faker
Successfully installed faker-40.4.0
Device: cuda


In [2]:
SQL_PATTERNS = (
    ["join"] * 5 +
    ["join_where"] * 4 +        # ⭐ NEW
    ["group_by"] * 3 +
    ["having"] * 2 +
    ["where"] * 2 +
    ["aggregation"] * 2 +
    ["order_by"] * 3 +         # ⭐ NEW
    ["limit"] * 2 +            # ⭐ NEW
    ["order_by_limit"] * 3 +   # ⭐ VERY IMPORTANT
    ["multi_select"] +
    ["simple_select"]
)


In [3]:
BASE_COLUMNS = [
    "name","email","age","salary",
    "department","city","country",
    "price","amount","quantity",
    "created_at","updated_at"
]

FK_COLUMNS = [
    "user_id",
    "order_id",
    "product_id",
    "customer_id"
]

In [4]:
COLUMN_TYPES = {
    "numeric":[
        "age","salary","price",
        "amount","quantity"
    ],
    "text":[
        "name","email",
        "department","city","country"
    ],
    "date":[
        "created_at","updated_at"
    ]
}

In [5]:
def generate_relational_schema():

    num_tables = random.randint(2,5)

    fake.unique.clear()

    tables = {}
    relationships = {}

    table_names = [fake.unique.word() for _ in range(num_tables)]

    for t in table_names:

        cols = {
            "numeric": random.sample(COLUMN_TYPES["numeric"],
                                     random.randint(1,3)),

            "text": random.sample(COLUMN_TYPES["text"],
                                  random.randint(1,3)),

            "date": random.sample(COLUMN_TYPES["date"],
                                  random.randint(0,1))
        }

        # flatten
        flat_cols = ["id"]

        for v in cols.values():
            flat_cols.extend(v)

        tables[t] = {
            "all": flat_cols,
            "numeric": cols["numeric"],
            "text": cols["text"],
            "date": cols["date"]
        }

    relationships_list = []

    for i in range(1, num_tables):

        parent = table_names[i-1]
        child = table_names[i]

        fk = f"{parent}_id"

        tables[child]["all"].append(fk)

        relationships_list.append(
            (child, fk, parent, "id")
        )

    return {
        "tables": tables,
        "relationships": relationships_list
    }

In [6]:
def schema_to_text(schema):

    parts = []

    for table, col_dict in schema["tables"].items():

        cols = col_dict["all"]

        col_tokens = " ".join([f"[COL] {c}" for c in cols])
        parts.append(f"[TABLE] {table} {col_tokens}")

    # relationships
    for child, fk, parent, pk in schema["relationships"]:
        parts.append(
            f"[REL] {child}.{fk} -> {parent}.{pk}"
        )

    return " ".join(parts)

In [7]:
SELECT_TEMPLATES = [
    "show {col} from {table}",
    "list {col} in {table}",
    "display {col} from {table}",
    "what are the {col} in {table}",
]

In [8]:
def generate_sql(schema):

    pattern = random.choice(SQL_PATTERNS)

    tables = schema["tables"]
    rels = schema["relationships"]

    table = random.choice(list(tables.keys()))

    columns = tables[table]["all"]
    numeric_cols = tables[table]["numeric"]
    text_cols = tables[table]["text"]

    # ---------- SIMPLE SELECT ----------
    if pattern == "simple_select":

        selected = random.sample(
            columns,
            random.randint(1, min(3, len(columns)))
        )

        question = f"show {', '.join(selected)} from {table}"
        sql = f"SELECT {', '.join(selected)} FROM {table}"


    # ---------- MULTI SELECT ----------
    elif pattern == "multi_select":

        selected = random.sample(
            columns,
            random.randint(2, min(4, len(columns)))
        )

        question = f"show {', '.join(selected)} from {table}"
        sql = f"SELECT {', '.join(selected)} FROM {table}"


    # ---------- WHERE ----------
    elif pattern == "where":

        # Prefer numeric filters
        if numeric_cols and random.random() < 0.7:

            col = random.choice(numeric_cols)
            operator = random.choice(["=", ">", "<"])
            value = random.randint(1, 100)

        else:

            col = random.choice(text_cols if text_cols else columns)
            operator = "="
            value = f"'{fake.word()}'"

        question = f"show {col} from {table} where {col} {operator} {value}"
        sql = f"SELECT {col} FROM {table} WHERE {col} {operator} {value}"


    # ---------- AGGREGATION ----------
    elif pattern == "aggregation":

        agg = random.choice(["COUNT","SUM","AVG","MIN","MAX"])

        if agg in ["SUM","AVG"]:

            if not numeric_cols:
                return generate_sql(schema)

            col = random.choice(numeric_cols)

        elif agg in ["MIN","MAX"]:

            candidates = numeric_cols + text_cols
            if not candidates:
                return generate_sql(schema)

            col = random.choice(candidates)

        else:  # COUNT
            col = random.choice(columns)

        question = f"what is the {agg.lower()} of {col} in {table}"
        sql = f"SELECT {agg}({col}) FROM {table}"
        # ---------- ORDER BY ----------
    elif pattern == "order_by":

        # prefer numeric columns for sorting
        if numeric_cols:
            col = random.choice(numeric_cols)
        else:
            col = random.choice(columns)

        direction = random.choice(["ASC","DESC"])

        question = f"show all records from {table} ordered by {col} {direction.lower()}"

        sql = f"""
        SELECT *
        FROM {table}
        ORDER BY {col} {direction}
        """
        # ---------- LIMIT ----------
    elif pattern == "limit":

        limit_val = random.randint(3,20)

        question = f"show first {limit_val} rows from {table}"

        sql = f"""
SELECT *
FROM {table}
LIMIT {limit_val}
"""
        # ---------- ORDER BY + LIMIT ----------
    elif pattern == "order_by_limit":

        if numeric_cols:
            col = random.choice(numeric_cols)
        else:
            col = random.choice(columns)

        direction = random.choice(["ASC","DESC"])
        limit_val = random.randint(3,15)

        question = f"show top {limit_val} records from {table} ordered by {col}"

        sql = f"""
SELECT *
FROM {table}
ORDER BY {col} {direction}
LIMIT {limit_val}
"""


    # ---------- GROUP BY ----------
    elif pattern == "group_by":

        group_candidates = text_cols if text_cols else columns
        group_col = random.choice(group_candidates)

        if numeric_cols:
            agg_col = random.choice(numeric_cols)
        else:
            return generate_sql(schema)

        agg = random.choice(["COUNT","SUM","AVG","MIN","MAX"])

        question = f"{agg.lower()} of {agg_col} grouped by {group_col}"

        sql = f"""
SELECT {group_col}, {agg}({agg_col})
FROM {table}
GROUP BY {group_col}
"""


    # ---------- HAVING ----------
    elif pattern == "having":

        group_candidates = text_cols if text_cols else columns
        group_col = random.choice(group_candidates)

        if not numeric_cols:
            return generate_sql(schema)

        agg_col = random.choice(numeric_cols)
        agg = random.choice(["COUNT","SUM","AVG","MIN","MAX"])

        operator = random.choice([">", "<", "="])
        value = random.randint(1,50)

        question = f"{agg.lower()} of {agg_col} grouped by {group_col} having value {operator} {value}"

        sql = f"""
SELECT {group_col}, {agg}({agg_col})
FROM {table}
GROUP BY {group_col}
HAVING {agg}({agg_col}) {operator} {value}
"""
    # ---------- JOIN + WHERE ----------
    elif pattern == "join_where" and rels:

        child, fk, parent, pk = random.choice(rels)

        if random.random() < 0.5:
            left, right = parent, child
            left_key, right_key = pk, fk
        else:
            left, right = child, parent
            left_key, right_key = fk, pk

        right_numeric = tables[right]["numeric"]
        right_text = tables[right]["text"]

        if right_numeric and random.random() < 0.7:

            where_col = random.choice(right_numeric)
            operator = random.choice([">","<","="])
            value = random.randint(1,100)

        else:

            candidates = right_text if right_text else right_numeric
            if not candidates:
                return generate_sql(schema)

            where_col = random.choice(candidates)
            operator = "="
            value = f"'{fake.word()}'"

        select_col = random.choice(tables[left]["all"])

        question = f"show {select_col} from {left} joined with {right} where {where_col} {operator} {value}"

        sql = f"""
SELECT {left}.{select_col}
FROM {left}
JOIN {right}
ON {left}.{left_key} = {right}.{right_key}
WHERE {right}.{where_col} {operator} {value}
"""

    # ---------- JOIN ----------
    elif pattern == "join" and rels:

        child, fk, parent, pk = random.choice(rels)

        if random.random() < 0.5:
            left, right = parent, child
            left_key, right_key = pk, fk
        else:
            left, right = child, parent
            left_key, right_key = fk, pk

        select_col = random.choice(tables[left]["all"])

        question = f"show {select_col} from {left} joined with {right}"

        sql = f"""
SELECT {left}.{select_col}
FROM {left}
JOIN {right}
ON {left}.{left_key} = {right}.{right_key}
"""

    else:
        return generate_sql(schema)

    return question.strip(), sql.strip()

In [36]:
def build_dataset_by_schema(num_schemas=8000,
                            queries_per_schema=8):

    schemas = [generate_relational_schema()
               for _ in range(num_schemas)]

    train_split = int(0.8 * num_schemas)
    val_split   = int(0.9 * num_schemas)

    train_s = schemas[:train_split]
    val_s   = schemas[train_split:val_split]
    test_s  = schemas[val_split:]

    def build_examples(schema_list):

        data = []

        for schema in schema_list:

            schema_text = schema_to_text(schema)

            for _ in range(queries_per_schema):

                q, sql = generate_sql(schema)

                model_input = f"""
Schema:
{schema_text}

Question:
{q}
"""

                data.append({
                    "input": model_input.strip(),
                    "output": sql.strip()
                })

        return data

    return build_examples(train_s), build_examples(val_s), build_examples(test_s)


train, val, test = build_dataset_by_schema()


print("Train:", len(train))
print("Val:", len(val))
print("Test:", len(test))

Train: 51200
Val: 6400
Test: 6400


In [37]:
import json

with open("train.json", "w") as f:
    json.dump(train, f)

with open("val.json", "w") as f:
    json.dump(val, f)

with open("test.json", "w") as f:
    json.dump(test, f)

print("✅ Train/Val/Test datasets saved as JSON")


✅ Train/Val/Test datasets saved as JSON


In [38]:
def build_vocab(data):

    counter = Counter()

    for row in data:
        counter.update(row["input"].split())
        counter.update(row["output"].split())

    vocab = {w:i+2 for i,(w,_) in enumerate(counter.items())}
    vocab["<pad>"] = 0
    vocab["<unk>"] = 1

    return vocab

vocab = build_vocab(train)

torch.save(vocab, "vocab.pt")

In [39]:
MAX_LEN = 160

def encode(text):

    tokens = text.split()

    ids = [vocab.get(t,1) for t in tokens][:MAX_LEN]

    ids += [0]*(MAX_LEN-len(ids))

    return torch.tensor(ids)

class NL2SQLDataset(Dataset):

    def __init__(self,data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):

        row = self.data[idx]

        return encode(row["input"]), encode(row["output"])

In [None]:
with open("train.json") as f:
    train = json.load(f)

with open("val.json") as f:
    val = json.load(f)

with open("test.json") as f:
    test = json.load(f)

print("✅ Dataset loaded!")
print("Train size:", len(train))
print("Val size:", len(val))
print("Test size:", len(test))


In [40]:
class NL2SQLModel(nn.Module):

    def __init__(self, vocab_size, d_model=256):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=8,
            num_encoder_layers=3,
            num_decoder_layers=3,
            batch_first=True
        )

        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):

        src_padding_mask = (src == 0)
        tgt_padding_mask = (tgt == 0)

        # ⭐⭐⭐ CAUSAL MASK (VERY IMPORTANT)
        tgt_seq_len = tgt.size(1)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            tgt_seq_len
        ).to(tgt.device)

        src = self.embed(src)
        tgt = self.embed(tgt)

        out = self.transformer(
            src,
            tgt,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask
        )

        return self.fc(out)


In [41]:
for i in range(20):
    print(train[i]["input"])
    print(train[i]["output"])
    print("-----")

Schema:
[TABLE] environmental [COL] id [COL] quantity [COL] city [COL] department [COL] country [COL] updated_at [TABLE] voice [COL] id [COL] quantity [COL] age [COL] amount [COL] country [COL] updated_at [COL] environmental_id [TABLE] today [COL] id [COL] salary [COL] country [COL] email [COL] name [COL] voice_id [TABLE] not [COL] id [COL] price [COL] age [COL] department [COL] country [COL] updated_at [COL] today_id [TABLE] with [COL] id [COL] salary [COL] age [COL] city [COL] updated_at [COL] not_id [REL] voice.environmental_id -> environmental.id [REL] today.voice_id -> voice.id [REL] not.today_id -> today.id [REL] with.not_id -> not.id

Question:
show department from not
SELECT department FROM not
-----
Schema:
[TABLE] environmental [COL] id [COL] quantity [COL] city [COL] department [COL] country [COL] updated_at [TABLE] voice [COL] id [COL] quantity [COL] age [COL] amount [COL] country [COL] updated_at [COL] environmental_id [TABLE] today [COL] id [COL] salary [COL] country [COL

In [42]:
train_loader = DataLoader(
    NL2SQLDataset(train),
    batch_size=16,
    shuffle=True
)

val_loader = DataLoader(
    NL2SQLDataset(val),
    batch_size=16,
    shuffle=False
)

test_loader = DataLoader(
    NL2SQLDataset(test),
    batch_size=16,
    shuffle=False
)


model = NL2SQLModel(len(vocab)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)

In [43]:
CHECKPOINT_PATH = "nl2sql_checkpoint.pt"

def save_checkpoint(epoch, model, optimizer, loss):

    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": loss
    }, CHECKPOINT_PATH)

def load_checkpoint(model, optimizer):

    if os.path.exists(CHECKPOINT_PATH):

        ckpt = torch.load(CHECKPOINT_PATH,
                          map_location=device)

        model.load_state_dict(ckpt["model_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])

        print("Resuming from epoch:",
              ckpt["epoch"]+1)

        return ckpt["epoch"] + 1

    return 0

In [None]:
EPOCHS = 15
best_loss = float("inf")

start_epoch = load_checkpoint(model, optimizer)

for epoch in range(start_epoch, EPOCHS):

    ###################
    # TRAIN
    ###################
    model.train()
    train_loss = 0

    for x,y in train_loader:

        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        output = model(x, y[:,:-1])

        loss = loss_fn(
            output.reshape(-1, len(vocab)),
            y[:,1:].reshape(-1)
        )

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)


    ###################
    # VALIDATION
    ###################
    model.eval()
    val_loss = 0

    with torch.no_grad():

        for x,y in val_loader:

            x = x.to(device)
            y = y.to(device)

            output = model(x, y[:,:-1])

            loss = loss_fn(
                output.reshape(-1, len(vocab)),
                y[:,1:].reshape(-1)
            )

            val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Epoch {epoch} | Train Loss {train_loss:.4f} | Val Loss {val_loss:.4f}")

    save_checkpoint(epoch, model, optimizer, val_loss)

    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        print("⭐ BEST MODEL SAVED")


Epoch 0 | Train Loss 3.1396 | Val Loss 2.9934
⭐ BEST MODEL SAVED
Epoch 1 | Train Loss 2.6255 | Val Loss 3.1910
Epoch 2 | Train Loss 1.9118 | Val Loss 3.2314
Epoch 3 | Train Loss 1.3053 | Val Loss 3.1079
Epoch 4 | Train Loss 0.9828 | Val Loss 3.0030
Epoch 5 | Train Loss 0.7949 | Val Loss 2.9828
⭐ BEST MODEL SAVED
Epoch 6 | Train Loss 0.6839 | Val Loss 3.0611
Epoch 7 | Train Loss 0.6103 | Val Loss 3.0056
Epoch 8 | Train Loss 0.5529 | Val Loss 3.0139


In [None]:
test_loader = DataLoader(
    NL2SQLDataset(test),
    batch_size=64,
    shuffle=False
)

model.eval()
test_loss = 0

with torch.no_grad():

    for x,y in test_loader:

        x = x.to(device)
        y = y.to(device)

        output = model(x, y[:,:-1])

        loss = loss_fn(
            output.reshape(-1, len(vocab)),
            y[:,1:].reshape(-1)
        )

        test_loss += loss.item()

test_loss /= len(test_loader)

print("✅ FINAL TEST LOSS:", test_loss)


In [None]:
# ---------- LOAD VOCAB ----------
vocab = torch.load("vocab.pt")
inv_vocab = {i:w for w,i in vocab.items()}

device = "cuda" if torch.cuda.is_available() else "cpu"


# ---------- REBUILD MODEL ----------
model = NL2SQLModel(len(vocab)).to(device)

model.load_state_dict(
    torch.load("best_model.pt", map_location=device)
)

model.eval()

print("✅ Best model loaded!")


In [None]:
MAX_LEN = 220

def encode(text):

    tokens = text.split()

    ids = [vocab.get(t,1) for t in tokens][:MAX_LEN]

    ids += [0]*(MAX_LEN-len(ids))

    return torch.tensor(ids)


In [None]:
def generate_query(model, text, max_len=220):

    model.eval()

    src = encode(text).unsqueeze(0).to(device)

    tgt = torch.zeros((1,1), dtype=torch.long).to(device)

    for _ in range(max_len):

        with torch.no_grad():

            output = model(src, tgt)
            logits = output

        next_token = logits.argmax(-1)[:,-1].unsqueeze(0)

        tgt = torch.cat([tgt, next_token], dim=1)

        if next_token.item() == 0:
            break

    tokens = [
        inv_vocab.get(i,"")
        for i in tgt.squeeze().tolist()
    ]

    return " ".join(tokens)
