<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [38]:
!pip install faker sqlparse

import random
import os
import torch
import torch.nn as nn
from faker import Faker
from torch.utils.data import Dataset, DataLoader
from collections import Counter

fake = Faker()

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [29]:
SQL_PATTERNS = (
    ["join"] * 6 +          # MOST important
    ["group_by"] * 3 +
    ["having"] * 2 +
    ["where"] * 2 +
    ["aggregation"] * 2 +
    ["multi_select"] * 1 +
    ["simple_select"] * 1
)

In [30]:
COLUMN_POOL = [
    "id","user_id","order_id","product_id",
    "name","email","age","salary",
    "department","city","country",
    "price","amount","quantity",
    "created_at","updated_at"
]

In [31]:
def generate_schema():

    num_tables = random.randint(2,6)
    schema = {}

    fake.unique.clear()

    for _ in range(num_tables):

        table = fake.unique.word()

        cols = random.sample(
            COLUMN_POOL,
            random.randint(4,10)
        )

        if "id" not in cols:
            cols.append("id")

        schema[table] = cols

    return schema

In [32]:
def schema_to_text(schema):

    parts = []

    for table, cols in schema.items():

        col_tokens = " ".join([f"[COL] {c}" for c in cols])
        parts.append(f"[TABLE] {table} {col_tokens}")

    return " ".join(parts)

In [35]:
SELECT_TEMPLATES = [
    "show {col} from {table}",
    "list {col} in {table}",
    "display {col} from {table}",
    "what are the {col} in {table}",
    "give me the {col} of {table}"
]

In [36]:
def generate_sql(schema):

    pattern = random.choice(SQL_PATTERNS)

    table = random.choice(list(schema.keys()))
    cols = schema[table]
    col1 = random.choice(cols)

    # SIMPLE
    if pattern == "simple_select":

        question = random.choice(SELECT_TEMPLATES).format(
            col=col1, table=table
        )

        sql = f"SELECT {col1} FROM {table}"

    # WHERE
    elif pattern == "where":

        question = f"{random.choice(SELECT_TEMPLATES).format(col=col1, table=table)} where {col1} is not null"

        sql = f"SELECT {col1} FROM {table} WHERE {col1} IS NOT NULL"

    # MULTI
    elif pattern == "multi_select":

        selected = random.sample(cols, min(2,len(cols)))

        question = f"show {', '.join(selected)} from {table}"

        sql = f"SELECT {', '.join(selected)} FROM {table}"

    # AGG
    elif pattern == "aggregation":

        question = f"what is the average {col1} in {table}"

        sql = f"SELECT AVG({col1}) FROM {table}"

    # GROUP
    elif pattern == "group_by":

        col2 = random.choice(cols)

        question = f"count {col1} grouped by {col2} in {table}"

        sql = f"SELECT {col2}, COUNT({col1}) FROM {table} GROUP BY {col2}"

    # HAVING
    elif pattern == "having":

        col2 = random.choice(cols)

        question = f"show {col2} having count of {col1} greater than 5"

        sql = f"""
SELECT {col2}, COUNT({col1})
FROM {table}
GROUP BY {col2}
HAVING COUNT({col1}) > 5
"""

    # JOIN (Most Important)
    elif pattern == "join":

        tables = list(schema.keys())

        if len(tables) < 2:
            return generate_sql(schema)

        t2 = random.choice([t for t in tables if t != table])

        if "id" not in schema[table] or "id" not in schema[t2]:
            return generate_sql(schema)

        question = f"join {table} and {t2} and show {table}.{col1}"

        sql = f"""
SELECT {table}.{col1}
FROM {table}
JOIN {t2}
ON {table}.id = {t2}.id
"""

    return question.strip(), sql.strip()

In [37]:
def build_dataset_by_schema(num_schemas=10000, queries_per_schema=6):

    schemas = [generate_schema() for _ in range(num_schemas)]

    split = int(0.9 * num_schemas)

    train_schemas = schemas[:split]
    test_schemas = schemas[split:]

    def build_examples(schema_list):

        data = []

        for schema in schema_list:

            schema_text = schema_to_text(schema)

            for _ in range(queries_per_schema):

                question, sql = generate_sql(schema)

                model_input = f"""
Schema:
{schema_text}

Question:
{question}
"""

                data.append({
                    "input": model_input.strip(),
                    "output": sql.strip()
                })

        return data

    return build_examples(train_schemas), build_examples(test_schemas)


train, test = build_dataset_by_schema()

print("Train size:", len(train))
print("Test size:", len(test))

Train size: 54000
Test size: 6000


In [39]:
def build_vocab(data):

    counter = Counter()

    for row in data:
        counter.update(row["input"].split())
        counter.update(row["output"].split())

    vocab = {w:i+2 for i,(w,_) in enumerate(counter.items())}
    vocab["<pad>"] = 0
    vocab["<unk>"] = 1

    inv_vocab = {i:w for w,i in vocab.items()}

    return vocab, inv_vocab

vocab, inv_vocab = build_vocab(train)

torch.save(vocab, "vocab.pt")

In [40]:
MAX_LEN = 160

def encode(text):
    tokens = text.split()
    ids = [vocab.get(t,1) for t in tokens][:MAX_LEN]
    ids += [0]*(MAX_LEN-len(ids))
    return torch.tensor(ids)

class NL2SQLDataset(Dataset):

    def __init__(self,data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        row = self.data[idx]
        return encode(row["input"]), encode(row["output"])

In [41]:
class NL2SQLModel(nn.Module):

    def __init__(self, vocab_size, d_model=256):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=8,
            num_encoder_layers=3,
            num_decoder_layers=3,
            batch_first=True
        )

        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):

        src_mask = (src == 0)
        tgt_mask = (tgt == 0)

        src = self.embed(src)
        tgt = self.embed(tgt)

        out = self.transformer(
            src,
            tgt,
            src_key_padding_mask=src_mask,
            tgt_key_padding_mask=tgt_mask
        )

        return self.fc(out)

In [42]:
train_loader = DataLoader(
    NL2SQLDataset(train),
    batch_size=32,
    shuffle=True
)

model = NL2SQLModel(len(vocab)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)

In [43]:
CHECKPOINT_PATH = "nl2sql_checkpoint.pt"

def save_checkpoint(epoch, model, optimizer, loss):

    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": loss
    }, CHECKPOINT_PATH)

def load_checkpoint(model, optimizer):

    if os.path.exists(CHECKPOINT_PATH):

        ckpt = torch.load(CHECKPOINT_PATH, map_location=device)

        model.load_state_dict(ckpt["model_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])

        print("Resuming from epoch:", ckpt["epoch"]+1)
        return ckpt["epoch"] + 1

    return 0

In [None]:
EPOCHS = 20
best_loss = float("inf")

start_epoch = load_checkpoint(model, optimizer)

for epoch in range(start_epoch, EPOCHS):

    model.train()
    total_loss = 0

    for x,y in train_loader:

        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        output = model(x, y[:,:-1])

        loss = loss_fn(
            output.reshape(-1, len(vocab)),
            y[:,1:].reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    print(f"Epoch {epoch} | Loss {avg_loss}")

    save_checkpoint(epoch, model, optimizer, avg_loss)

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), "best_model.pt")
        print("⭐ BEST MODEL SAVED")

Epoch 0 | Loss 1.9909309304156009
⭐ BEST MODEL SAVED
Epoch 1 | Loss 0.7649803956026008
⭐ BEST MODEL SAVED
Epoch 2 | Loss 0.6061771394711394
⭐ BEST MODEL SAVED
Epoch 3 | Loss 0.4908064574144463
⭐ BEST MODEL SAVED
Epoch 4 | Loss 0.40722163312878656
⭐ BEST MODEL SAVED
Epoch 5 | Loss 0.35079582590804
⭐ BEST MODEL SAVED
Epoch 6 | Loss 0.311383879015231
⭐ BEST MODEL SAVED
