<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_2_SQL_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!unzip Spider_kaggle.zip


Archive:  Spider_kaggle.zip
  inflating: spider/README.txt       
  inflating: spider/database/academic/academic.sqlite  
  inflating: spider/database/academic/schema.sql  
  inflating: spider/database/activity_1/activity_1.sqlite  
  inflating: spider/database/activity_1/schema.sql  
  inflating: spider/database/aircraft/aircraft.sqlite  
  inflating: spider/database/aircraft/schema.sql  
  inflating: spider/database/allergy_1/allergy_1.sqlite  
  inflating: spider/database/allergy_1/schema.sql  
  inflating: spider/database/apartment_rentals/apartment_rentals.sqlite  
  inflating: spider/database/apartment_rentals/schema.sql  
  inflating: spider/database/architecture/architecture.sqlite  
  inflating: spider/database/architecture/schema.sql  
  inflating: spider/database/assets_maintenance/assets_maintenance.sqlite  
  inflating: spider/database/assets_maintenance/schema.sql  
  inflating: spider/database/baseball_1/baseball_1.sqlite  
  inflating: spider/database/baseball_1/schema.

In [2]:
!mkdir -p data
!cd data

!wget https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2

!tar -xvf data.tar.bz2

!ls


--2026-01-26 15:49:38--  https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2 [following]
--2026-01-26 15:49:39--  https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 26164664 (25M) [application/octet-stream]
Saving to: â€˜data.tar.bz2â€™


2026-01-26 15:49:41 (52.7 MB/s) - â€˜data.tar.bz2â€™ saved [26164664/26164664]

data/
data/train.jsonl
data/test.tables.jsonl
data/test.db
data/dev.tables.jsonl
data/dev.db
data/test.jsonl
data/

In [3]:
!pip install torch transformers networkx tqdm sqlparse




In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import json, random, math, os
import numpy as np
import networkx as nx
from tqdm import tqdm


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


Using device: cpu


In [20]:
# Spider
with open("spider/train_spider.json") as f:
    spider_train = json.load(f)

with open("spider/dev.json") as f:
    spider_dev = json.load(f)

with open("spider/tables.json") as f:
    spider_tables = json.load(f)

print("Loaded spider schemas:", len(spider_tables))

# WikiSQL
with open("data/train.jsonl") as f:
    wikisql_train = [json.loads(x) for x in f]

with open("data/dev.jsonl") as f:
    wikisql_dev = [json.loads(x) for x in f]

print("Spider:", len(spider_train), len(spider_dev))
print("WikiSQL:", len(wikisql_train), len(wikisql_dev))


Loaded spider schemas: 166
Spider: 7000 1034
WikiSQL: 56355 8421


In [21]:
import sqlparse
def sql_tokenize(sql):
    sql = sql.lower()
    tokens = [t.value for t in sqlparse.parse(sql)[0].flatten()]
    tokens = [t for t in tokens if not t.isspace()]
    return tokens
print(sql_tokenize("SELECT name FROM student WHERE age > 18"))


['select', 'name', 'from', 'student', 'where', 'age', '>', '18']


In [22]:
SRC_VOCAB = {"<PAD>":0,"<UNK>":1}
TGT_VOCAB = {"<PAD>":0,"<BOS>":1,"<EOS>":2,"<UNK>":3}


In [23]:
def add_src(tok):
    if tok not in SRC_VOCAB:
        SRC_VOCAB[tok] = len(SRC_VOCAB)

def add_tgt(tok):
    if tok not in TGT_VOCAB:
        TGT_VOCAB[tok] = len(TGT_VOCAB)


In [24]:
def wikisql_to_sql(ex):
    """
    Convert WikiSQL structured SQL into SQL string
    """
    sql = ex["sql"]
    table = ex["table_id"]

    col = sql["sel"]
    agg = sql["agg"]

    agg_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]

    select = f"{agg_ops[agg]}(col{col})" if agg != 0 else f"col{col}"

    where = ""
    if len(sql["conds"]) > 0:
        conds = []
        for c in sql["conds"]:
            conds.append(f"col{c[0]} {['=','>','<','!='][c[1]]} '{c[2]}'")
        where = " WHERE " + " AND ".join(conds)

    return f"SELECT {select} FROM {table}{where}"


In [25]:
def spider_sql_to_string(sql, schema):
    try:
        select_clause = []
        for item in sql["select"][1]:
            agg, col = item
            if col[1][1] == 0:
                col_name = "*"
            else:
                t = schema["table_names_original"][
                    schema["column_names_original"][col[1][1]][0]
                ]
                c = schema["column_names_original"][col[1][1]][1]
                col_name = f"{t}.{c}"

            agg_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
            if agg > 0:
                select_clause.append(f"{agg_ops[agg]}({col_name})")
            else:
                select_clause.append(col_name)

        select_sql = "SELECT " + ", ".join(select_clause)

        from_sql = " FROM " + schema["table_names_original"][
            sql["from"]["table_units"][0][1]
        ]

        where_sql = ""
        if sql["where"]:
            conds = []
            for cond in sql["where"]:
                if isinstance(cond, list):
                    col_id = cond[2][1]
                    op = ["=", ">", "<", "!=", ">=", "<=", "LIKE", "IN", "BETWEEN"][cond[1]]
                    val = str(cond[3])
                    t = schema["table_names_original"][
                        schema["column_names_original"][col_id][0]
                    ]
                    c = schema["column_names_original"][col_id][1]
                    conds.append(f"{t}.{c} {op} {val}")
            where_sql = " WHERE " + " AND ".join(conds)

        return select_sql + from_sql + where_sql

    except:
        return None


In [26]:
# Build vocab
for ex in wikisql_train:
    for t in ex["question"].lower().split():
        add_src(t)

    sql = wikisql_to_sql(ex)
    for t in sql_tokenize(sql):
        add_tgt(t)

for ex in spider_train:
    for t in ex["query"].lower().split():
        add_src(t)

    for t in sql_tokenize(ex["query"]):
        add_tgt(t)

print("SRC vocab:", len(SRC_VOCAB))
print("TGT vocab:", len(TGT_VOCAB))


SRC vocab: 57658
TGT vocab: 52432


In [27]:
# Special tokens
PAD = "<PAD>"
BOS = "<BOS>"
EOS = "<EOS>"
UNK = "<UNK>"

# Add to target vocab
for t in [PAD, BOS, EOS, UNK]:
    if t not in TGT_VOCAB:
        TGT_VOCAB[t] = len(TGT_VOCAB)

INV_TGT_VOCAB = {v:k for k,v in TGT_VOCAB.items()}

PAD_ID = TGT_VOCAB[PAD]
BOS_ID = TGT_VOCAB[BOS]
EOS_ID = TGT_VOCAB[EOS]
UNK_ID = TGT_VOCAB[UNK]

# Encode source sentence
def encode_src(text, max_len=64):
    ids = [SRC_VOCAB.get(t, SRC_VOCAB[UNK]) for t in text.lower().split()]
    ids = ids[:max_len]
    return ids + [SRC_VOCAB[PAD]] * (max_len - len(ids))

# Encode SQL string
def encode_tgt(sql, max_len=128):
    tokens = sql_tokenize(sql)
    ids = [TGT_VOCAB.get(t, UNK_ID) for t in tokens]
    ids = [BOS_ID] + ids[:max_len-2] + [EOS_ID]
    return ids + [PAD_ID] * (max_len - len(ids))


class Seq2SeqDataset(Dataset):
    def __init__(self, spider_data, wiki_data):
        self.samples = []

        # WikiSQL
        for ex in wiki_data:
            nl = ex["question"]
            sql = wikisql_to_sql(ex)
            self.samples.append((nl, sql))

        # Spider
        # Spider
        for ex in spider_data:
            schema = next(db for db in spider_tables if db["db_id"] == ex["db_id"])
            sql = spider_sql_to_string(ex["sql"], schema)
            if sql:
                self.samples.append((ex["query"], sql))

        print("Total training samples:", len(self.samples))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        nl, sql = self.samples[idx]

        src = torch.tensor(encode_src(nl))
        tgt = torch.tensor(encode_tgt(sql))

        return src, tgt


train_loader = DataLoader(
    Seq2SeqDataset(spider_train, wikisql_train),
    batch_size=32,
    shuffle=True
)


Total training samples: 59937


In [28]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)

        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)

        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class TransformerSeq2Seq(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, hidden=512, layers=6, heads=8, dropout=0.1):
        super().__init__()

        self.src_emb = nn.Embedding(src_vocab, hidden)
        self.tgt_emb = nn.Embedding(tgt_vocab, hidden)

        self.pos_enc = PositionalEncoding(hidden)

        self.transformer = nn.Transformer(
            d_model=hidden,
            nhead=heads,
            num_encoder_layers=layers,
            num_decoder_layers=layers,
            dim_feedforward=hidden * 4,
            dropout=dropout,
            batch_first=True
        )

        self.fc_out = nn.Linear(hidden, tgt_vocab)

    def forward(self, src, tgt):
        src = self.pos_enc(self.src_emb(src))
        tgt = self.pos_enc(self.tgt_emb(tgt))

        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        out = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return self.fc_out(out)

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TransformerSeq2Seq(
    src_vocab=len(SRC_VOCAB),
    tgt_vocab=len(TGT_VOCAB),
    hidden=512,
    layers=6,
    heads=8,
    dropout=0.2
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)

print("Model parameters:", sum(p.numel() for p in model.parameters()) // 1e6, "M")


def train_epoch():
    model.train()
    total_loss = 0

    for src, tgt in tqdm(train_loader):
        src = src.to(device)
        tgt = tgt.to(device)

        out = model(src, tgt[:, :-1])
        loss = criterion(out.reshape(-1, out.size(-1)), tgt[:, 1:].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


print("\nðŸ”¥ Training started\n")

best_loss = float("inf")

for epoch in range(1, 31):
    loss = train_epoch()
    print(f"Epoch {epoch:02d} | Train Loss: {loss:.4f}")

    if loss < best_loss:
        best_loss = loss
        torch.save(model.state_dict(), "nl2sql_seq2seq_best.pt")
        print("âœ… Saved BEST model")


Model parameters: 127.0 M

ðŸ”¥ Training started



  0%|          | 1/1874 [00:39<20:22:32, 39.16s/it]