<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
import random
import json

# ---------------- SCHEMAS ----------------
SCHEMAS = [
    {
        "tables": {
            "employees": ["id", "name", "salary", "dept_id"],
            "departments": ["id", "name"]
        },
        "numeric": {
            "employees": ["id", "salary", "dept_id"]
        },
        "text": {
            "employees": ["name"]
        },
        "join": ("employees", "departments", "dept_id", "id")
    },
    {
        "tables": {
            "students": ["id", "name", "marks", "class_id"],
            "classes": ["id", "name"]
        },
        "numeric": {
            "students": ["id", "marks", "class_id"]
        },
        "text": {
            "students": ["name"]
        },
        "join": ("students", "classes", "class_id", "id")
    }
]

AGGS = ["COUNT", "SUM", "AVG", "MAX", "MIN"]

# ---------------- GENERATOR ----------------
def generate_example():
    db = random.choice(SCHEMAS)
    schema = db["tables"]

    main_table = list(schema.keys())[0]
    cols = schema[main_table]

    intent = random.choices(
        ["SELECT", "AGG", "WHERE", "GROUP", "JOIN"],
        weights=[0.30, 0.25, 0.20, 0.15, 0.10]
    )[0]

    # ---------- SELECT ----------
    if intent == "SELECT":
        col = random.choice(cols)
        question = random.choice([
            f"show {col} of {main_table}",
            f"get {col} from {main_table}",
            f"list {col} in {main_table}",
            f"display {col} for {main_table}"
        ])
        sql = f"SELECT {col} FROM {main_table}"

    # ---------- AGGREGATION ----------
    elif intent == "AGG":
        agg = random.choice(AGGS)

        if agg == "COUNT":
            col = random.choice(cols)
        else:
            col = random.choice(db["numeric"][main_table])

        question = random.choice([
            f"show {agg.lower()} of {col} of {main_table}",
            f"what is the {agg.lower()} {col} in {main_table}",
            f"give me the {agg.lower()} {col} from {main_table}"
        ])
        sql = f"SELECT {agg}({col}) FROM {main_table}"

    # ---------- WHERE ----------
    elif intent == "WHERE":
        col = random.choice(db["numeric"][main_table])
        val = random.choice([10, 20, 50, 100])
        question = random.choice([
            f"show {col} of {main_table} where {col} > {val}",
            f"list {main_table} with {col} greater than {val}",
            f"get {col} from {main_table} having {col} above {val}"
        ])
        sql = f"SELECT {col} FROM {main_table} WHERE {col} > {val}"

    # ---------- GROUP BY ----------
    elif intent == "GROUP":
        agg = random.choice(["COUNT", "AVG"])
        group_col = random.choice(db["text"][main_table])
        num_col = random.choice(db["numeric"][main_table])

        question = random.choice([
            f"show {agg.lower()} of {num_col} per {group_col}",
            f"get {agg.lower()} {num_col} grouped by {group_col}",
            f"list {group_col} wise {agg.lower()} {num_col}"
        ])
        sql = (
            f"SELECT {group_col}, {agg}({num_col}) "
            f"FROM {main_table} GROUP BY {group_col}"
        )

    # ---------- JOIN ----------
    else:
        t1, t2, c1, c2 = db["join"]
        question = random.choice([
            f"show {t1} name and {t2} name",
            f"get {t1} names with their {t2}",
            f"list {t1} and corresponding {t2}"
        ])
        sql = (
            f"SELECT {t1}.name, {t2}.name "
            f"FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"
        )

    return {
        "question": question,
        "schema": schema,
        "sql": sql
    }

# ---------------- DATASET CREATION ----------------
def generate_dataset(n=80_000, outfile="nl2sql_varied.json"):
    data = []
    for i in range(n):
        data.append(generate_example())
        if (i + 1) % 20_000 == 0:
            print(f"Generated {i+1} examples")

    with open(outfile, "w") as f:
        json.dump(data, f, indent=2)

    print("âœ… Dataset size:", len(data))


# ðŸ”¥ CHANGE n TO 300_000 IF YOU WANT
generate_dataset(n=80_000)


Generated 20000 examples
Generated 40000 examples
Generated 60000 examples
Generated 80000 examples
âœ… Dataset size: 80000


In [19]:
# import json

# with open("nl2sql_varied.json", "r") as f:
#     dataset = json.load(f)

# print("Loaded dataset size:", len(dataset))

# from collections import Counter

# ENC_VOCAB = {"<PAD>":0, "<UNK>":1, "<SEP>":2}
# DEC_VOCAB = {"<PAD>":0, "<UNK>":1, "<BOS>":2, "<EOS>":3}


# def add(vocab, tok):
#     if tok not in vocab:
#         vocab[tok] = len(vocab)

# for ex in dataset:
#     # NL tokens
#     for t in ex["question"].lower().split():
#         add(ENC_VOCAB, t)

#     # Schema tokens
#     for t, cols in ex["schema"].items():
#         add(ENC_VOCAB, t)
#         for c in cols:
#             add(ENC_VOCAB, f"{t}.{c}")

#     # SQL tokens
#     for ex in dataset:
#       for t in ex["sql"].lower().split():
#           if t not in DEC_VOCAB:
#               DEC_VOCAB[t] = len(DEC_VOCAB)


# print("Encoder vocab:", len(ENC_VOCAB))
# print("Decoder vocab:", len(DEC_VOCAB))
import json

with open("nl2sql_varied.json", "r") as f:
    dataset = json.load(f)

print("Loaded dataset size:", len(dataset))

ENC_VOCAB = {"<PAD>":0, "<UNK>":1, "<SEP>":2}
DEC_VOCAB = {"<PAD>":0, "<UNK>":1, "<BOS>":2, "<EOS>":3}

def add(vocab, tok):
    if tok not in vocab:
        vocab[tok] = len(vocab)

for ex in dataset:
    # -------- Encoder vocab --------
    for t in ex["question"].lower().split():
        add(ENC_VOCAB, t)

    for table, cols in ex["schema"].items():
        add(ENC_VOCAB, table)
        for col in cols:
            add(ENC_VOCAB, f"{table}.{col}")

    # -------- Decoder vocab --------
    for t in ex["sql"].lower().split():
        add(DEC_VOCAB, t)

print("Encoder vocab:", len(ENC_VOCAB))
print("Decoder vocab:", len(DEC_VOCAB))


Loaded dataset size: 80000
Encoder vocab: 62
Decoder vocab: 62


In [20]:
import torch
from torch.utils.data import Dataset

class NL2SQLDataset(Dataset):
    def __init__(self, data, enc_vocab, dec_vocab,
                 max_src_len=80, max_tgt_len=60):
        self.data = data
        self.enc_vocab = enc_vocab
        self.dec_vocab = dec_vocab
        self.max_src_len = max_src_len
        self.max_tgt_len = max_tgt_len

    def __len__(self):
        return len(self.data)

    def encode_src(self, tokens):
        ids = [self.enc_vocab.get(t, self.enc_vocab["<UNK>"]) for t in tokens]
        ids = ids[:self.max_src_len]
        pad_len = self.max_src_len - len(ids)
        return ids + [self.enc_vocab["<PAD>"]] * pad_len

    def encode_tgt(self, tokens):
        ids = [self.dec_vocab["<BOS>"]] + \
              [self.dec_vocab.get(t, self.dec_vocab["<UNK>"]) for t in tokens] + \
              [self.dec_vocab["<EOS>"]]

        ids = ids[:self.max_tgt_len]
        pad_len = self.max_tgt_len - len(ids)
        return ids + [self.dec_vocab["<PAD>"]] * pad_len

    def __getitem__(self, idx):
        ex = self.data[idx]

        # -------- Encoder input (NL + schema) --------
        question_tokens = ex["question"].lower().split()

        schema_tokens = [
            f"{table}.{col}"
            for table, cols in ex["schema"].items()
            for col in cols
        ]

        src_tokens = question_tokens + ["<SEP>"] + schema_tokens
        src_ids = self.encode_src(src_tokens)

        # -------- Decoder target (SQL) --------
        sql_tokens = ex["sql"].lower().split()
        tgt_ids = self.encode_tgt(sql_tokens)

        return (
            torch.tensor(src_ids, dtype=torch.long),
            torch.tensor(tgt_ids, dtype=torch.long)
        )


In [21]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, dim_ff=1024, num_layers=4):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers)

    def forward(self, x):
        # x: [B, T]
        pad_mask = (x == 0)        # True where PAD
        emb = self.emb(x)
        return self.encoder(emb, src_key_padding_mask=pad_mask)

class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=256,
        nhead=8,
        dim_ff=1024,
        num_layers=4,
        dropout=0.2
    ):
        super().__init__()

        # Embedding with dropout
        self.emb = nn.Embedding(vocab_size, d_model)
        self.emb_dropout = nn.Dropout(dropout)

        layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,          # ðŸ”¥ IMPORTANT
            batch_first=True
        )

        self.decoder = nn.TransformerDecoder(layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, y, memory, memory_pad_mask):
        tgt_len = y.size(1)

        # Causal mask
        causal_mask = torch.triu(
            torch.ones(tgt_len, tgt_len, device=y.device),
            diagonal=1
        ).bool()

        emb = self.emb(y)
        emb = self.emb_dropout(emb)   # ðŸ”¥ Dropout applied

        out = self.decoder(
            emb,
            memory,
            tgt_mask=causal_mask,
            memory_key_padding_mask=memory_pad_mask
        )

        return self.fc(out)


In [23]:
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize models (PASS vocab sizes)
enc = Encoder(vocab_size=len(ENC_VOCAB)).to(device)
dec = Decoder(vocab_size=len(DEC_VOCAB)).to(device)

# DataLoader (NO collate_fn)
loader = DataLoader(
    NL2SQLDataset(
        data=dataset,
        enc_vocab=ENC_VOCAB,
        dec_vocab=DEC_VOCAB,
        max_src_len=80,
        max_tgt_len=60
    ),
    batch_size=128,
    shuffle=True
)

# Optimizer
opt = optim.Adam(
    list(enc.parameters()) + list(dec.parameters()),
    lr=1e-4
)

# Loss (ignore PAD in decoder vocab)
loss_fn = nn.CrossEntropyLoss(ignore_index=DEC_VOCAB["<PAD>"])

print("ðŸš€ Training...")
for epoch in range(3):
    enc.train()
    dec.train()
    total = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()

        # Encoder
        mem = enc(x)
        src_pad_mask = (x == 0)

        # Decoder (teacher forcing)
        out = dec(
            y=y[:, :-1],
            memory=mem,
            memory_pad_mask=src_pad_mask
        )

        # Loss
        loss = loss_fn(
            out.reshape(-1, len(DEC_VOCAB)),
            y[:, 1:].reshape(-1)
        )

        loss.backward()
        opt.step()

        total += loss.item()

    print(f"Epoch {epoch+1} | Loss {total / len(loader):.4f}")


ðŸš€ Training...
Epoch 1 | Loss 0.2043
Epoch 2 | Loss 0.0037
Epoch 3 | Loss 0.0016


In [1]:
import torch

SAVE_PATH = "nl2sql_transformer.pt"

torch.save({
    "encoder_state_dict": enc.state_dict(),
    "decoder_state_dict": dec.state_dict(),
    "ENC_VOCAB": ENC_VOCAB,
    "DEC_VOCAB": DEC_VOCAB,
    "model_config": {
        "d_model": 256,
        "nhead": 8,
        "dim_ff": 1024,
        "num_layers": 4,
        "dropout": 0.2
    }
}, SAVE_PATH)

print(f"âœ… Model successfully saved to {SAVE_PATH}")


NameError: name 'enc' is not defined

In [29]:
SQL_KEYWORDS = {"select", "from", "where", "join", "group", "by"}

def is_valid_next_token(prev_tokens, next_token):
    prev_tokens = [t.lower() for t in prev_tokens]
    nt = next_token.lower()

    # Must start with SELECT
    if len(prev_tokens) == 0:
        return nt == "select"

    # FROM cannot come before SELECT
    if nt == "from" and "select" not in prev_tokens:
        return False

    # WHERE / GROUP cannot come before FROM
    if nt in {"where", "group"} and "from" not in prev_tokens:
        return False

    # JOIN cannot come before FROM
    if nt == "join" and "from" not in prev_tokens:
        return False

    return True

inv_dec_vocab = {v: k for k, v in DEC_VOCAB.items()}

def infer_sql(question, schema, max_len=50):
    enc.eval()
    dec.eval()

    tokens = question.lower().split() + ["<SEP>"] + [
        f"{t}.{c}" for t, cols in schema.items() for c in cols
    ]

    x = torch.tensor([
        ENC_VOCAB.get(t, ENC_VOCAB["<UNK>"]) for t in tokens
    ]).unsqueeze(0).to(device)

    with torch.no_grad():
        memory = enc(x)
        memory_pad_mask = (x == ENC_VOCAB["<PAD>"])

        y = torch.tensor([[DEC_VOCAB["<BOS>"]]], device=device)
        generated_tokens = []

        for _ in range(max_len):
            logits = dec(
                y=y,
                memory=memory,
                memory_pad_mask=memory_pad_mask
            )

            probs = logits[:, -1].softmax(dim=-1)
            sorted_ids = torch.argsort(probs, descending=True)

            next_token_id = None

            # ðŸ”¥ PICARD-style filtering
            for tok_id in sorted_ids[0]:
                tok = inv_dec_vocab[tok_id.item()]

                if tok in {"<PAD>", "<BOS>"}:
                    continue

                if is_valid_next_token(generated_tokens, tok):
                    next_token_id = tok_id
                    break

            # Fallback (should rarely happen)
            if next_token_id is None:
                next_token_id = sorted_ids[0][0]

            y = torch.cat([y, next_token_id.unsqueeze(0).unsqueeze(0)], dim=1)
            generated_tokens.append(inv_dec_vocab[next_token_id.item()])

            if next_token_id.item() == DEC_VOCAB["<EOS>"]:
                break

    return " ".join(
        t for t in generated_tokens
        if t not in {"<EOS>", "<PAD>"}
    )


In [34]:
schema = {
    "employees": ["id","name","salary","dept_id"],
    "departments": ["id","name"]
}

print(infer_sql("give names of employees", schema))
print(infer_sql("get count salary grouped by name", schema))
print(infer_sql("get marks from students having marks above 20", schema))
print(infer_sql("show students name and classes name", schema))

select name from employees
select name, count(salary) from employees group by name
select marks from employees where marks > 20
select students.name, classes.name from employees join classes on students.class_id = departments.id
