<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir -p data
!cd data

!wget https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2

!tar -xvf data.tar.bz2

!ls



--2026-01-27 14:16:35--  https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2 [following]
--2026-01-27 14:16:36--  https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26164664 (25M) [application/octet-stream]
Saving to: ‘data.tar.bz2’


2026-01-27 14:16:37 (163 MB/s) - ‘data.tar.bz2’ saved [26164664/26164664]

data/
data/train.jsonl
data/test.tables.jsonl
data/test.db
data/dev.tables.jsonl
data/dev.db
data/test.jsonl
data/train.t

In [3]:
!unzip Spider_kaggle.zip

Archive:  Spider_kaggle.zip
  inflating: spider/README.txt       
  inflating: spider/database/academic/academic.sqlite  
  inflating: spider/database/academic/schema.sql  
  inflating: spider/database/activity_1/activity_1.sqlite  
  inflating: spider/database/activity_1/schema.sql  
  inflating: spider/database/aircraft/aircraft.sqlite  
  inflating: spider/database/aircraft/schema.sql  
  inflating: spider/database/allergy_1/allergy_1.sqlite  
  inflating: spider/database/allergy_1/schema.sql  
  inflating: spider/database/apartment_rentals/apartment_rentals.sqlite  
  inflating: spider/database/apartment_rentals/schema.sql  
  inflating: spider/database/architecture/architecture.sqlite  
  inflating: spider/database/architecture/schema.sql  
  inflating: spider/database/assets_maintenance/assets_maintenance.sqlite  
  inflating: spider/database/assets_maintenance/schema.sql  
  inflating: spider/database/baseball_1/baseball_1.sqlite  
  inflating: spider/database/baseball_1/schema.

In [73]:
!pip install sqlparse transformers datasets torch nltk





In [74]:
import os
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import sqlparse
from tqdm import tqdm


In [127]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [128]:
def load_jsonl(path):
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))
    return data


wiki_train = load_jsonl("/content/data/train.jsonl")
wiki_val   = load_jsonl("/content/data/dev.jsonl")

wiki_train_tables = load_jsonl("/content/data/train.tables.jsonl")
wiki_dev_tables   = load_jsonl("/content/data/dev.tables.jsonl")

wiki_tables = wiki_train_tables + wiki_dev_tables

wiki_table_map = {t["id"]: t for t in wiki_tables}

print("Wiki train:", len(wiki_train))
print("Wiki val:", len(wiki_val))
print("Wiki tables:", len(wiki_table_map))


Wiki train: 56355
Wiki val: 8421
Wiki tables: 21301


In [129]:
with open("/content/spider/train_spider.json") as f:
    spider_train = json.load(f)

with open("/content/spider/dev.json") as f:
    spider_val = json.load(f)

with open("/content/spider/tables.json") as f:
    spider_tables = json.load(f)

spider_table_map = {t["db_id"]: t for t in spider_tables}

print("Spider train:", len(spider_train))
print("Spider val:", len(spider_val))


Spider train: 7000
Spider val: 1034


In [130]:
import re

def sql_tokenize(sql):
    sql = sql.lower()
    sql = re.sub(r"([(),=><])", r" \1 ", sql)
    return sql.split()


In [131]:
def sql_to_ast_tokens(sql):
    sql = sql.lower()

    tokens = ["<QUERY>"]
    for kw in ["select","from","join","where","group by","having",
               "union","intersect","except"]:
        if kw in sql:
            tokens.append(f"<{kw.replace(' ','_').upper()}>")

    if sql.count("select") > 1:
        tokens.append("<SUBQUERY>")

    tokens.append("</QUERY>")
    return tokens


In [132]:
AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
COND_OPS = ["=", ">", "<", "!="]

def wiksql_to_sql(item, table):
    sel = item["sql"]["sel"]
    agg = item["sql"]["agg"]
    conds = item["sql"]["conds"]

    col = f"table.{table['header'][sel]}"

    if agg == 0:
        select_clause = f"SELECT {col}"
    else:
        select_clause = f"SELECT {AGG_OPS[agg]}({col})"

    from_clause = "FROM table"

    if not conds:
        return f"{select_clause} {from_clause}"

    where = []
    for c, o, v in conds:
        op = COND_OPS[o]
        col_name = f"table.{table['header'][c]}"

        if isinstance(v, str):
            v = f"'{v}'"
        else:
            v = str(v)

        where.append(f"{col_name} {op} {v}")

    return f"{select_clause} {from_clause} WHERE " + " AND ".join(where)



In [133]:
def build_sql_vocab(wiki, spider):
    vocab = {"<PAD>":0, "<BOS>":1, "<EOS>":2, "<UNK>":3}
    idx = 4

    def add(tokens):
        nonlocal idx
        for t in tokens:
            if t not in vocab:
                vocab[t] = idx
                idx += 1

    for x in wiki[:50000]:
        table = wiki_table_map[x["table_id"]]
        sql = wiksql_to_sql(x, table)
        add(sql_tokenize(sql))

    for x in spider[:20000]:
        add(sql_tokenize(x["query"]))

    return vocab


In [134]:
SQL_VOCAB = build_sql_vocab(wiki_train, spider_train)
INV_SQL_VOCAB = {v:k for k,v in SQL_VOCAB.items()}

print("SQL vocab size:", len(SQL_VOCAB))


SQL vocab size: 49205


In [135]:
MODEL_NAME = "microsoft/MiniLM-L12-H384-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


In [136]:
class NL2SQLDataset(Dataset):
    def __init__(self, data, table_map, is_spider=False):
        self.data = data
        self.table_map = table_map
        self.is_spider = is_spider

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        if self.is_spider:
            question = item["question"]
            table = self.table_map[item["db_id"]]
            schema = " | ".join([c for _,c in table["column_names_original"]])
            sql = item["query"]
        else:
            question = item["question"]
            table = self.table_map[item["table_id"]]
            schema = " | ".join(table["header"])
            sql = wiksql_to_sql(item, table)

        text = f"question: {question} schema: {schema}"

        enc = tokenizer(
            text, padding="max_length", truncation=True,
            max_length=256, return_tensors="pt"
        )

        tgt = [SQL_VOCAB["<BOS>"]] + \
              [SQL_VOCAB.get(t, SQL_VOCAB["<UNK>"]) for t in sql_tokenize(sql)] + \
              [SQL_VOCAB["<EOS>"]]

        return enc["input_ids"].squeeze(0), enc["attention_mask"].squeeze(0), torch.tensor(tgt)


In [137]:
def collate_fn(batch):
    ids, masks, tgts = zip(*batch)

    ids = torch.stack(ids)
    masks = torch.stack(masks)

    max_len = max(len(t) for t in tgts)
    padded = torch.zeros(len(tgts), max_len, dtype=torch.long)

    for i,t in enumerate(tgts):
        padded[i,:len(t)] = t

    return ids, masks, padded


In [138]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained(MODEL_NAME)

    def forward(self, ids, mask):
        return self.model(ids, attention_mask=mask).last_hidden_state


In [139]:
class Decoder(nn.Module):
    def __init__(self, hidden, vocab):
        super().__init__()
        self.emb = nn.Embedding(vocab, hidden)
        self.lstm = nn.LSTM(hidden, hidden, batch_first=True)
        self.attn = nn.MultiheadAttention(hidden, num_heads=8, batch_first=True)
        self.fc = nn.Linear(hidden, vocab)

    def forward(self, tgt, enc_out):
        emb = self.emb(tgt)
        out, _ = self.lstm(emb)

        # Cross attention (decoder attends to encoder)
        out, _ = self.attn(out, enc_out, enc_out)

        return self.fc(out)


In [140]:
class NL2SQL(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(384, vocab)

    def forward(self, ids, mask, tgt):
        enc_out = self.encoder(ids, mask)
        return self.decoder(tgt, enc_out)


In [141]:
train_ds = NL2SQLDataset(wiki_train, wiki_table_map, False)
val_ds   = NL2SQLDataset(wiki_val, wiki_table_map, False)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)


In [142]:
model = NL2SQL(len(SQL_VOCAB)).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss(ignore_index=0)


In [143]:
def train_epoch(loader):
    model.train()
    total_loss = 0

    for ids, mask, tgt in tqdm(loader):
        ids, mask, tgt = ids.to(device), mask.to(device), tgt.to(device)

        logits = model(ids, mask, tgt[:,:-1])

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt[:,1:].reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [144]:
def val_epoch(loader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for ids, mask, tgt in loader:
            ids, mask, tgt = ids.to(device), mask.to(device), tgt.to(device)
            logits = model(ids, mask, tgt[:,:-1])

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                tgt[:,1:].reshape(-1)
            )
            total_loss += loss.item()

    return total_loss / len(loader)


In [145]:
for epoch in range(5):
    tr = train_epoch(train_loader)
    val = val_epoch(val_loader)
    print(f"Epoch {epoch+1} | Train {tr:.4f} | Val {val:.4f}")


 49%|████▉     | 1720/3523 [08:02<08:25,  3.57it/s]


KeyboardInterrupt: 

In [None]:
spider_train_ds = NL2SQLDataset(spider_train, spider_table_map, True)
spider_loader = DataLoader(spider_train_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)


In [None]:
for epoch in range(4):
    loss = train_epoch(spider_loader)
    print(f"Spider Epoch {epoch+1} | Loss {loss:.4f}")


In [None]:
torch.save({
    "model": model.state_dict(),
    "SQL_VOCAB": SQL_VOCAB,
    "INV_SQL_VOCAB": INV_SQL_VOCAB
}, "/content/nl2sql_full_sql.pt")


In [None]:
def infer_sql(question, table):
    model.eval()

    schema = " | ".join(
        table["header"] if "header" in table
        else [c for _,c in table["column_names_original"]]
    )

    text = f"question: {question} schema: {schema}"
    enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)

    tgt = torch.tensor([[SQL_VOCAB["<BOS>"]]]).to(device)

    for _ in range(60):
        logits = model(enc["input_ids"], enc["attention_mask"], tgt)
        nxt = logits[:,-1].argmax(-1, keepdim=True)
        tgt = torch.cat([tgt, nxt], dim=1)

        if nxt.item() == SQL_VOCAB["<EOS>"]:
            break

    tokens = [INV_SQL_VOCAB[t.item()] for t in tgt[0]]
    return " ".join(tokens[1:-1])


In [None]:
db_id = list(spider_table_map.keys())[0]
table = spider_table_map[db_id]

print(infer_sql("Show employee names and their salaries", table))
print(infer_sql("Which department has highest salary", table))
print(infer_sql("List students whose marks > 70", table))
