<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir -p data
!cd data

!wget https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2

!tar -xvf data.tar.bz2

!ls



--2026-01-27 10:01:46--  https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2 [following]
--2026-01-27 10:01:46--  https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26164664 (25M) [application/octet-stream]
Saving to: ‘data.tar.bz2’


2026-01-27 10:01:47 (304 MB/s) - ‘data.tar.bz2’ saved [26164664/26164664]

data/
data/train.jsonl
data/test.tables.jsonl
data/test.db
data/dev.tables.jsonl
data/dev.db
data/test.jsonl
data/train.table

In [2]:
!unzip Spider_dataset.zip

Archive:  Spider_dataset.zip
  inflating: spider/README.txt       
  inflating: spider/database/academic/academic.sqlite  
  inflating: spider/database/academic/schema.sql  
  inflating: spider/database/activity_1/activity_1.sqlite  
  inflating: spider/database/activity_1/schema.sql  
  inflating: spider/database/aircraft/aircraft.sqlite  
  inflating: spider/database/aircraft/schema.sql  
  inflating: spider/database/allergy_1/allergy_1.sqlite  
  inflating: spider/database/allergy_1/schema.sql  
  inflating: spider/database/apartment_rentals/apartment_rentals.sqlite  
  inflating: spider/database/apartment_rentals/schema.sql  
  inflating: spider/database/architecture/architecture.sqlite  
  inflating: spider/database/architecture/schema.sql  
  inflating: spider/database/assets_maintenance/assets_maintenance.sqlite  
  inflating: spider/database/assets_maintenance/schema.sql  
  inflating: spider/database/baseball_1/baseball_1.sqlite  
  inflating: spider/database/baseball_1/schema

In [80]:
!pip install sqlparse transformers datasets torch nltk
MAX_SCHEMA_COLS = 128




In [4]:
import os
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import sqlparse
from tqdm import tqdm


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [35]:
def load_jsonl(path):
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))
    return data


wiki_train = load_jsonl("/content/data/train.jsonl")
wiki_val   = load_jsonl("/content/data/dev.jsonl")

wiki_train_tables = load_jsonl("/content/data/train.tables.jsonl")
wiki_dev_tables   = load_jsonl("/content/data/dev.tables.jsonl")

wiki_tables = wiki_train_tables + wiki_dev_tables

wiki_table_map = {t["id"]: t for t in wiki_tables}

print("Wiki train:", len(wiki_train))
print("Wiki val:", len(wiki_val))
print("Wiki tables:", len(wiki_table_map))


Wiki train: 56355
Wiki val: 8421
Wiki tables: 21301


In [18]:
with open("/content/spider/train_spider.json") as f:
    spider_train = json.load(f)

with open("/content/spider/dev.json") as f:
    spider_val = json.load(f)

with open("/content/spider/tables.json") as f:
    spider_tables = json.load(f)

spider_table_map = {t["db_id"]: t for t in spider_tables}

print("Spider train:", len(spider_train))
print("Spider val:", len(spider_val))


Spider train: 7000
Spider val: 1034


In [36]:
AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
COND_OPS = ["=", ">", "<", "!="]

def wiksql_to_sql(item, table):
    sel = item["sql"]["sel"]
    agg = item["sql"]["agg"]
    conds = item["sql"]["conds"]

    col = table["header"][sel]

    if agg == 0:
        select_clause = f"SELECT {col}"
    else:
        select_clause = f"SELECT {AGG_OPS[agg]}({col})"

    from_clause = f"FROM {table['id']}"

    if not conds:
        return f"{select_clause} {from_clause}"

    where = []
    for c, o, v in conds:
        op = COND_OPS[o]
        col_name = table["header"][c]
        v = f"'{v}'" if isinstance(v, str) else str(v)
        where.append(f"{col_name} {op} {v}")

    return f"{select_clause} {from_clause} WHERE " + " AND ".join(where)


In [20]:
def sql_to_ast_tokens(sql):
    sql = sql.lower()

    tokens = ["<QUERY>"]

    if "select" in sql: tokens.append("<SELECT>")
    if "from" in sql: tokens.append("<FROM>")
    if "join" in sql: tokens.append("<JOIN>")
    if "where" in sql: tokens.append("<WHERE>")
    if "group by" in sql: tokens.append("<GROUP_BY>")
    if "having" in sql: tokens.append("<HAVING>")
    if "union" in sql: tokens.append("<UNION>")
    if "intersect" in sql: tokens.append("<INTERSECT>")
    if "except" in sql: tokens.append("<EXCEPT>")

    if sql.count("select") > 1:
        tokens.append("<SUBQUERY>")

    tokens.append("</QUERY>")

    return tokens


In [37]:
def build_action_vocab(wiki, spider):
    vocab = {"<PAD>":0, "<BOS>":1, "<EOS>":2}
    idx = 3

    def add(sql):
        nonlocal idx
        tokens = sql_to_ast_tokens(sql)
        for t in tokens:
            if t not in vocab:
                vocab[t] = idx
                idx += 1

    for x in wiki[:50000]:
        table = wiki_table_map[x["table_id"]]
        add(wiksql_to_sql(x, table))

    for x in spider[:20000]:
        add(x["query"])

    return vocab

ACTION_VOCAB = build_action_vocab(wiki_train, spider_train)
INV_ACTION_VOCAB = {v:k for k,v in ACTION_VOCAB.items()}

print("Action vocab size:", len(ACTION_VOCAB))


Action vocab size: 15


In [39]:
MODEL_NAME = "microsoft/MiniLM-L12-H384-uncased"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


In [51]:
class NL2SQLDataset(Dataset):
    def __init__(self, data, table_map, is_spider=False):
        self.data = data
        self.table_map = table_map
        self.is_spider = is_spider

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        if self.is_spider:
            question = item["question"]
            table = self.table_map[item["db_id"]]
            schema = " | ".join([col for _, col in table["column_names_original"]])
            sql = item["query"]

        else:
            question = item["question"]
            table = self.table_map[item["db_id"]]
            schema = " | ".join(sum(table["column_names_original"], []))
            sql = item["query"]

        text = f"question: {question} schema: {schema}"

        enc = tokenizer(
            text, padding="max_length", truncation=True,
            max_length=256, return_tensors="pt"
        )

        tgt = [ACTION_VOCAB["<BOS>"]] + \
              [ACTION_VOCAB[t] for t in sql_to_ast_tokens(sql)] + \
              [ACTION_VOCAB["<EOS>"]]

        return enc["input_ids"].squeeze(0), enc["attention_mask"].squeeze(0), torch.tensor(tgt)


In [52]:
def collate_fn(batch):
    ids, masks, tgts = zip(*batch)

    ids = torch.stack(ids)
    masks = torch.stack(masks)

    max_len = max(len(t) for t in tgts)
    padded = torch.zeros(len(tgts), max_len, dtype=torch.long)

    for i,t in enumerate(tgts):
        padded[i,:len(t)] = t

    return ids, masks, padded


In [53]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained(MODEL_NAME)

    def forward(self, ids, mask):
        return self.model(ids, attention_mask=mask).last_hidden_state


In [54]:
class Decoder(nn.Module):
    def __init__(self, hidden, vocab):
        super().__init__()
        self.emb = nn.Embedding(vocab, hidden)
        self.lstm = nn.LSTM(hidden, hidden, batch_first=True)
        self.fc = nn.Linear(hidden, vocab)

    def forward(self, tgt):
        emb = self.emb(tgt)
        out,_ = self.lstm(emb)
        return self.fc(out)


In [55]:
class NL2SQL(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(384, vocab)

    def forward(self, ids, mask, tgt):
        enc = self.encoder(ids, mask)
        return self.decoder(tgt)


In [56]:
train_ds = NL2SQLDataset(wiki_train, wiki_table_map, False)
val_ds   = NL2SQLDataset(wiki_val, wiki_table_map, False)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate_fn)


In [76]:
model = NL2SQL(len(ACTION_VOCAB)).to(device)


criterion = nn.CrossEntropyLoss(ignore_index=0)


In [46]:
def train_epoch(loader):
    model.train()
    total_loss = 0

    for ids, mask, tgt in tqdm(loader):
        ids, mask, tgt = ids.to(device), mask.to(device), tgt.to(device)

        logits = model(ids, mask, tgt[:,:-1])

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt[:,1:].reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [47]:
def val_epoch(loader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for ids, mask, tgt in loader:
            ids, mask, tgt = ids.to(device), mask.to(device), tgt.to(device)
            logits = model(ids, mask, tgt[:,:-1])

            loss = criterion(
                logits.reshape(-1, logits.size(-1)),
                tgt[:,1:].reshape(-1)
            )

            total_loss += loss.item()

    return total_loss / len(loader)


In [72]:
class FullNL2SQL(nn.Module):
    def __init__(self, base_model, hidden):
        super().__init__()
        self.encoder = base_model.encoder
        self.structure_decoder = base_model.decoder

        self.column_predictor = ColumnPredictor(hidden)
        self.table_predictor = TablePredictor(hidden)

    def forward(self, ids, mask, tgt):
        enc = self.encoder(ids, mask)
        struct_logits = self.structure_decoder(tgt)

        return enc, struct_logits


In [48]:
for epoch in range(5):
    train_loss = train_epoch(train_loader)
    val_loss = val_epoch(val_loader)

    print(f"Epoch {epoch+1} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")


100%|██████████| 3523/3523 [04:30<00:00, 13.04it/s]


Epoch 1 | Train: 0.0116 | Val: 0.0120


100%|██████████| 3523/3523 [04:30<00:00, 13.04it/s]


Epoch 2 | Train: 0.0115 | Val: 0.0119


100%|██████████| 3523/3523 [04:29<00:00, 13.05it/s]


Epoch 3 | Train: 0.0115 | Val: 0.0119


100%|██████████| 3523/3523 [04:33<00:00, 12.89it/s]


Epoch 4 | Train: 0.0114 | Val: 0.0119


100%|██████████| 3523/3523 [04:30<00:00, 13.01it/s]


Epoch 5 | Train: 0.0114 | Val: 0.0118


In [57]:
spider_train_ds = NL2SQLDataset(spider_train, spider_table_map, True)
spider_loader   = DataLoader(spider_train_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)

for epoch in range(3):
    loss = train_epoch(spider_loader)
    print(f"Spider Epoch {epoch+1} | Loss {loss:.4f}")


100%|██████████| 875/875 [00:36<00:00, 24.25it/s]


Spider Epoch 1 | Loss 0.4555


100%|██████████| 875/875 [00:36<00:00, 24.21it/s]


Spider Epoch 2 | Loss 0.3843


100%|██████████| 875/875 [00:35<00:00, 24.37it/s]

Spider Epoch 3 | Loss 0.3817





In [77]:
full_model = FullNL2SQL(model, hidden=384).to(device)
optimizer = torch.optim.AdamW(full_model.parameters(), lr=2e-5)

In [58]:
class ColumnPredictor(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.linear = nn.Linear(hidden, 1)

    def forward(self, col_emb):
        # col_emb → (batch, num_cols, hidden)
        return self.linear(col_emb).squeeze(-1)


In [59]:
class TablePredictor(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.linear = nn.Linear(hidden, 1)

    def forward(self, table_emb):
        return self.linear(table_emb).squeeze(-1)


In [60]:
import re

def extract_values(question):
    numbers = re.findall(r'\d+', question)
    strings = re.findall(r"'(.*?)'", question)
    return numbers + strings


In [62]:
def extract_schema_embeddings(enc_out, schema_token_positions):
    return enc_out[:, schema_token_positions, :]


In [63]:
def get_wiki_column_labels(item, table):
    labels = [0] * len(table["header"])

    labels[item["sql"]["sel"]] = 1

    for col, _, _ in item["sql"]["conds"]:
        labels[col] = 1

    return labels


In [81]:
def get_spider_column_labels(item, table):
    cols = [c for _, c in table["column_names_original"]][:MAX_SCHEMA_COLS]
    labels = [0] * len(cols)

    used_cols = item["query"].lower()

    for i, col in enumerate(cols):
        if col.lower() in used_cols:
            labels[i] = 1

    return labels


In [83]:
column_loss_fn = nn.BCEWithLogitsLoss()
def train_column_epoch(loader):
    model.train()
    total_loss = 0

    for ids, mask, tgt, col_labels in tqdm(loader):
        ids, mask = ids.to(device), mask.to(device)
        col_labels = col_labels.to(device)

        enc = model.encoder(ids, mask)

        num_cols = col_labels.shape[1]
        col_emb = enc[:, -num_cols:, :]


        logits = model.column_predictor(col_emb)

        loss = column_loss_fn(logits, col_labels.float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [84]:
def generate_full_sql(question, table):
    model.eval()

    schema = " | ".join(table["header"] if "header" in table else
                         [c for _,c in table["column_names_original"]])

    text = f"question: {question} schema: {schema}"
    enc = tokenizer(text, return_tensors="pt").to(device)

    # Structure
    tgt = torch.tensor([[ACTION_VOCAB["<BOS>"]]]).to(device)
    for _ in range(50):
        logits = model(enc["input_ids"], enc["attention_mask"], tgt)
        next_tok = logits[:,-1].argmax(-1, keepdim=True)
        tgt = torch.cat([tgt, next_tok], dim=1)
        if next_tok.item() == ACTION_VOCAB["<EOS>"]:
            break

    struct_tokens = [INV_ACTION_VOCAB[t.item()] for t in tgt[0]]

    # Column Prediction
    enc_out = model.encoder(enc["input_ids"], enc["attention_mask"])

    col_emb = enc_out[:, -len(schema.split(" | ")): , :]
    col_scores = model.column_predictor(col_emb)
    col_ids = (torch.sigmoid(col_scores) > 0.5).squeeze(0)

    columns = [schema.split(" | ")[i] for i in range(len(col_ids)) if col_ids[i]]

    values = extract_values(question)

    return struct_tokens, columns, values


In [85]:
class NL2SQLSlotDataset(Dataset):
    def __init__(self, data, table_map, is_spider=False):
        self.data = data
        self.table_map = table_map
        self.is_spider = is_spider

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        if not self.is_spider:
            question = item["question"]
            table = self.table_map[item["table_id"]]
            schema_cols = table["header"]
            sql = wiksql_to_sql(item, table)

            col_labels = get_wiki_column_labels(item, table)
            table_labels = [1]  # single table always

        else:
            question = item["question"]
            table = self.table_map[item["db_id"]]
            schema_cols = [c for _, c in table["column_names_original"]][:MAX_SCHEMA_COLS]

            sql = item["query"]

            col_labels = get_spider_column_labels(item, table)
            table_labels = [1] * len(set([t for t,_ in table["column_names_original"]]))

        schema = " | ".join(schema_cols)
        text = f"question: {question} schema: {schema}"

        enc = tokenizer(text, padding="max_length", truncation=True,
                        max_length=256, return_tensors="pt")

        tgt = [ACTION_VOCAB["<BOS>"]] + \
              [ACTION_VOCAB[t] for t in sql_to_ast_tokens(sql)] + \
              [ACTION_VOCAB["<EOS>"]]

        return (
            enc["input_ids"].squeeze(0),
            enc["attention_mask"].squeeze(0),
            torch.tensor(tgt),
            torch.tensor(col_labels, dtype=torch.float),
            torch.tensor(table_labels, dtype=torch.float)
        )


In [86]:
def collate_slot_fn(batch):
    ids, masks, tgts, col_labels, table_labels = zip(*batch)

    ids = torch.stack(ids)
    masks = torch.stack(masks)

    max_len = max(len(t) for t in tgts)
    tgt_pad = torch.zeros(len(tgts), max_len, dtype=torch.long)

    for i,t in enumerate(tgts):
        tgt_pad[i,:len(t)] = t

    col_labels = torch.nn.utils.rnn.pad_sequence(col_labels, batch_first=True)
    table_labels = torch.nn.utils.rnn.pad_sequence(table_labels, batch_first=True)

    return ids, masks, tgt_pad, col_labels, table_labels


In [69]:
slot_train_ds = NL2SQLSlotDataset(spider_train, spider_table_map, True)
slot_train_loader = DataLoader(
    slot_train_ds, batch_size=8, shuffle=True, collate_fn=collate_slot_fn
)


In [87]:
col_loss_fn = nn.BCEWithLogitsLoss()
table_loss_fn = nn.BCEWithLogitsLoss()

def train_slot_epoch(loader):
    full_model.train()
    total_loss = 0

    for ids, mask, tgt, col_labels, table_labels in tqdm(loader):
        ids = ids.to(device)
        mask = mask.to(device)
        col_labels = col_labels.to(device)

        enc = full_model.encoder(ids, mask)

        col_emb = enc[:, -col_labels.shape[1]:, :]
        col_logits = full_model.column_predictor(col_emb)

        loss = col_loss_fn(col_logits, col_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [None]:
for epoch in range(5):
    loss = train_slot_epoch(slot_train_loader)
    print(f"SLOT Epoch {epoch+1} | Loss {loss:.4f}")


100%|██████████| 875/875 [01:54<00:00,  7.65it/s]


SLOT Epoch 1 | Loss 0.2666


100%|██████████| 875/875 [01:52<00:00,  7.76it/s]


SLOT Epoch 2 | Loss 0.2580


100%|██████████| 875/875 [01:52<00:00,  7.77it/s]


SLOT Epoch 3 | Loss 0.2535


100%|██████████| 875/875 [01:52<00:00,  7.76it/s]


SLOT Epoch 4 | Loss 0.2503


  6%|▌         | 52/875 [00:06<01:45,  7.80it/s]

In [None]:
q = "Which employees earn more than 50000?"
table = spider_table_map["employee_db"]

struct, cols, values = generate_full_sql(q, table)

print("Structure:", struct)
print("Columns:", cols)
print("Values:", values)
