<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_2_SQL_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!unzip Spider_dataset.zip


Archive:  Spider_dataset.zip
  inflating: spider/README.txt       
  inflating: spider/database/academic/academic.sqlite  
  inflating: spider/database/academic/schema.sql  
  inflating: spider/database/activity_1/activity_1.sqlite  
  inflating: spider/database/activity_1/schema.sql  
  inflating: spider/database/aircraft/aircraft.sqlite  
  inflating: spider/database/aircraft/schema.sql  
  inflating: spider/database/allergy_1/allergy_1.sqlite  
  inflating: spider/database/allergy_1/schema.sql  
  inflating: spider/database/apartment_rentals/apartment_rentals.sqlite  
  inflating: spider/database/apartment_rentals/schema.sql  
  inflating: spider/database/architecture/architecture.sqlite  
  inflating: spider/database/architecture/schema.sql  
  inflating: spider/database/assets_maintenance/assets_maintenance.sqlite  
  inflating: spider/database/assets_maintenance/schema.sql  
  inflating: spider/database/baseball_1/baseball_1.sqlite  
  inflating: spider/database/baseball_1/schema

In [2]:
!mkdir -p data
!cd data

!wget https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2

!tar -xvf data.tar.bz2

!ls


--2026-01-27 04:25:15--  https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2
Resolving github.com (github.com)... 140.82.121.4
Connecting to github.com (github.com)|140.82.121.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2 [following]
--2026-01-27 04:25:15--  https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26164664 (25M) [application/octet-stream]
Saving to: â€˜data.tar.bz2â€™


2026-01-27 04:25:16 (331 MB/s) - â€˜data.tar.bz2â€™ saved [26164664/26164664]

data/
data/train.jsonl
data/test.tables.jsonl
data/test.db
data/dev.tables.jsonl
data/dev.db
data/test.jsonl
data/tra

In [3]:
!pip install torch transformers networkx tqdm sqlparse




In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import json, random, math, os
import numpy as np
import networkx as nx
from tqdm import tqdm


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


Using device: cuda


In [6]:
# Spider
with open("spider/train_spider.json") as f:
    spider_train = json.load(f)

with open("spider/dev.json") as f:
    spider_dev = json.load(f)

with open("spider/tables.json") as f:
    spider_tables = json.load(f)

print("Loaded spider schemas:", len(spider_tables))

# WikiSQL
with open("data/train.jsonl") as f:
    wikisql_train = [json.loads(x) for x in f]

with open("data/dev.jsonl") as f:
    wikisql_dev = [json.loads(x) for x in f]

print("Spider:", len(spider_train), len(spider_dev))
print("WikiSQL:", len(wikisql_train), len(wikisql_dev))


Loaded spider schemas: 166
Spider: 7000 1034
WikiSQL: 56355 8421


In [7]:
import sqlparse
def sql_tokenize(sql):
    sql = sql.lower()
    tokens = [t.value for t in sqlparse.parse(sql)[0].flatten()]
    tokens = [t for t in tokens if not t.isspace()]
    return tokens
print(sql_tokenize("SELECT name FROM student WHERE age > 18"))


['select', 'name', 'from', 'student', 'where', 'age', '>', '18']


In [8]:
SRC_VOCAB = {"<PAD>":0,"<UNK>":1}
TGT_VOCAB = {"<PAD>":0,"<BOS>":1,"<EOS>":2,"<UNK>":3}


In [9]:
def add_src(tok):
    if tok not in SRC_VOCAB:
        SRC_VOCAB[tok] = len(SRC_VOCAB)

def add_tgt(tok):
    if tok not in TGT_VOCAB:
        TGT_VOCAB[tok] = len(TGT_VOCAB)


In [10]:
def wikisql_to_sql(ex):
    """
    Convert WikiSQL structured SQL into SQL string
    """
    sql = ex["sql"]
    table = ex["table_id"]

    col = sql["sel"]
    agg = sql["agg"]

    agg_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]

    select = f"{agg_ops[agg]}(col{col})" if agg != 0 else f"col{col}"

    where = ""
    if len(sql["conds"]) > 0:
        conds = []
        for c in sql["conds"]:
            conds.append(f"col{c[0]} {['=','>','<','!='][c[1]]} '{c[2]}'")
        where = " WHERE " + " AND ".join(conds)

    return f"SELECT {select} FROM {table}{where}"


In [22]:
def spider_sql_to_full_string(sql, schema):

    def col_to_str(cid):
        if cid == 0:
            return "*"
        t = schema["table_names_original"][
            schema["column_names_original"][cid][0]
        ]
        c = schema["column_names_original"][cid][1]
        return f"{t}.{c}"

    AGG = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
    OP  = ["=", ">", "<", "!=", ">=", "<=", "LIKE", "IN", "BETWEEN"]

    def parse_select(sel):
        cols = []
        for agg, col in sel[1]:
            cid = col[1][1]
            col_str = col_to_str(cid)
            if agg > 0:
                cols.append(f"{AGG[agg]}({col_str})")
            else:
                cols.append(col_str)
        return "SELECT " + ", ".join(cols)

    def parse_from(fr):
        base = schema["table_names_original"][fr["table_units"][0][1]]
        sql = f" FROM {base}"

        for cond in fr["conds"]:
            if cond[1] == 2:  # join
                c1 = col_to_str(cond[2][1])
                c2 = col_to_str(cond[3][1])
                t2 = c2.split(".")[0]
                sql += f" JOIN {t2} ON {c1} = {c2}"

        return sql

    def parse_where(w):
        if not w:
            return ""
        conds = []
        for cond in w:
            if isinstance(cond, list):
                col = col_to_str(cond[2][1])
                op  = OP[cond[1]]
                val = str(cond[3])
                conds.append(f"{col} {op} {val}")
        return " WHERE " + " AND ".join(conds)

    def parse_group(g):
        if not g:
            return ""
        cols = [col_to_str(c[1]) for c in g]
        return " GROUP BY " + ", ".join(cols)

    def parse_having(h):
        if not h:
            return ""
        conds = []
        for cond in h:
            col = col_to_str(cond[2][1])
            op  = OP[cond[1]]
            val = str(cond[3])
            conds.append(f"{col} {op} {val}")
        return " HAVING " + " AND ".join(conds)

    def parse_order(o):
        if not o:
            return ""
        cols = [col_to_str(c[1]) for c in o[1]]
        order = "DESC" if o[0] == "desc" else "ASC"
        return " ORDER BY " + ", ".join(cols) + " " + order

    def parse_sql(s):
        q = ""
        q += parse_select(s["select"])
        q += parse_from(s["from"])
        q += parse_where(s["where"])
        q += parse_group(s["groupBy"])
        q += parse_having(s["having"])
        q += parse_order(s["orderBy"])

        if s["intersect"]:
            q += " INTERSECT " + parse_sql(s["intersect"])
        if s["union"]:
            q += " UNION " + parse_sql(s["union"])
        if s["except"]:
            q += " EXCEPT " + parse_sql(s["except"])

        return q

    try:
        return parse_sql(sql)
    except:
        return None


In [23]:
# Build vocab
for ex in wikisql_train:
    for t in ex["question"].lower().split():
        add_src(t)

    sql = wikisql_to_sql(ex)
    for t in sql_tokenize(sql):
        add_tgt(t)

for ex in spider_train:
    for t in ex["query"].lower().split():
        add_src(t)

    for t in sql_tokenize(ex["query"]):
        add_tgt(t)

print("SRC vocab:", len(SRC_VOCAB))
print("TGT vocab:", len(TGT_VOCAB))


SRC vocab: 57658
TGT vocab: 52432


In [24]:
# Special tokens
PAD = "<PAD>"
BOS = "<BOS>"
EOS = "<EOS>"
UNK = "<UNK>"

# Add to target vocab
for t in [PAD, BOS, EOS, UNK]:
    if t not in TGT_VOCAB:
        TGT_VOCAB[t] = len(TGT_VOCAB)

INV_TGT_VOCAB = {v:k for k,v in TGT_VOCAB.items()}

PAD_ID = TGT_VOCAB[PAD]
BOS_ID = TGT_VOCAB[BOS]
EOS_ID = TGT_VOCAB[EOS]
UNK_ID = TGT_VOCAB[UNK]

# Encode source sentence
def encode_src(text, max_len=64):
    ids = [SRC_VOCAB.get(t, SRC_VOCAB[UNK]) for t in text.lower().split()]
    ids = ids[:max_len]
    return ids + [SRC_VOCAB[PAD]] * (max_len - len(ids))

# Encode SQL string
def encode_tgt(sql, max_len=128):
    tokens = sql_tokenize(sql)
    ids = [TGT_VOCAB.get(t, UNK_ID) for t in tokens]
    ids = [BOS_ID] + ids[:max_len-2] + [EOS_ID]
    return ids + [PAD_ID] * (max_len - len(ids))


class Seq2SeqDataset(Dataset):
    def __init__(self, spider_data, wiki_data):
        self.samples = []

        # WikiSQL
        for ex in wiki_data:
            nl = ex["question"]
            sql = wikisql_to_sql(ex)
            self.samples.append((nl, sql))

        # Spider
        # Spider
        for ex in spider_data:
            schema = next(db for db in spider_tables if db["db_id"] == ex["db_id"])
            sql = spider_sql_to_full_string(ex["sql"], schema)
            if sql:
                self.samples.append((ex["query"], sql))

        print("Total training samples:", len(self.samples))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        nl, sql = self.samples[idx]

        src = torch.tensor(encode_src(nl))
        tgt = torch.tensor(encode_tgt(sql))

        return src, tgt


train_loader = DataLoader(
    Seq2SeqDataset(spider_train, wikisql_train),
    batch_size=32,
    shuffle=True
)

val_loader = DataLoader(
    Seq2SeqDataset(spider_dev, wikisql_dev),
    batch_size=32,
    shuffle=False
)


Total training samples: 57658
Total training samples: 8619


In [25]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)

        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)

        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class TransformerSeq2Seq(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, hidden=512, layers=6, heads=8, dropout=0.1):
        super().__init__()

        self.src_emb = nn.Embedding(src_vocab, hidden)
        self.tgt_emb = nn.Embedding(tgt_vocab, hidden)

        self.pos_enc = PositionalEncoding(hidden)

        self.transformer = nn.Transformer(
            d_model=hidden,
            nhead=heads,
            num_encoder_layers=layers,
            num_decoder_layers=layers,
            dim_feedforward=hidden * 4,
            dropout=dropout,
            batch_first=True
        )

        self.fc_out = nn.Linear(hidden, tgt_vocab)

    def forward(self, src, tgt):
        src = self.pos_enc(self.src_emb(src))
        tgt = self.pos_enc(self.tgt_emb(tgt))

        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        out = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return self.fc_out(out)

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)


In [29]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TransformerSeq2Seq(
    src_vocab=len(SRC_VOCAB),
    tgt_vocab=len(TGT_VOCAB),
    hidden=512,
    layers=6,
    heads=8,
    dropout=0.2
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)

print("Model parameters:", sum(p.numel() for p in model.parameters()) // 1e6, "M")


def train_epoch():
    model.train()
    total_loss = 0

    for src, tgt in tqdm(train_loader):
        src = src.to(device)
        tgt = tgt.to(device)

        out = model(src, tgt[:, :-1])
        loss = criterion(out.reshape(-1, out.size(-1)), tgt[:, 1:].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def eval_epoch():
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for src, tgt in tqdm(val_loader):
            src = src.to(device)
            tgt = tgt.to(device)

            out = model(src, tgt[:, :-1])
            loss = criterion(
                out.reshape(-1, out.size(-1)),
                tgt[:, 1:].reshape(-1)
            )

            total_loss += loss.item()

    return total_loss / len(val_loader)

print("\nðŸ”¥ Training started\n")




best_val = float("inf")


for epoch in range(1, 31):
    train_loss = train_epoch()
    val_loss = eval_epoch()

    print(f"Epoch {epoch:02d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")

    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), "nl2sql_seq2seq_best.pt")
        print("âœ… Saved BEST model")


Model parameters: 127.0 M

ðŸ”¥ Training started



  0%|          | 1/1802 [00:00<27:21,  1.10it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 814.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 242.12 MiB is free. Process 4052 has 14.50 GiB memory in use. Of the allocated memory 13.49 GiB is allocated by PyTorch, and 900.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)