<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_TO_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import random, json

# =====================================================
# SCHEMAS
# =====================================================

SCHEMAS = [

    {
        "tables":{
            "employees":["id","name","salary","dept_id","age","experience"],
            "departments":["id","name","location"]
        },
        "join":("employees","departments","dept_id","id")
    },

    {
        "tables":{
            "students":["id","name","marks","class_id","age"],
            "classes":["id","name","teacher"]
        },
        "join":("students","classes","class_id","id")
    }
]


# =====================================================
# LANGUAGE POOLS (VOCAB BOOSTERS üöÄ)
# =====================================================

PREFIX = [
    "show","list","display","fetch","give","retrieve",
    "provide","return","output","present","find",
    "identify","tell me","can you show","i want",
    "help me find","get me","let me know"
]

QUESTION_STYLE = [
    "{} {} from {}",
    "{} all {} from {}",
    "{} the {} available in {}",
    "what are the {} in {}",
    "which {} exist in {}"
]

COLUMN_SYNONYMS = {
    "salary":["salary","income","pay","earnings","compensation"],
    "name":["name","full name","employee name","student name"],
    "age":["age","years","age value"],
    "marks":["marks","score","grades","result"],
    "experience":["experience","years of experience"]
}

COMPARE = [
    (">","greater than"),
    ("<","less than"),
    (">=","at least"),
    ("<=","at most"),
    ("!=","not equal to")
]

AGGS = ["SUM","AVG","COUNT","MAX","MIN"]

ORDER_WORDS = [
    "ordered by","sorted by","arranged by",
    "ranked by","organized by"
]

# ‚≠ê Names for nested / equality queries
EMP_NAMES = [
    "ravi","rahul","amit","neha",
    "arjun","vikram","john","sara"
]


# =====================================================
# HELPER FUNCTIONS
# =====================================================

def col_word(col):
    if col in COLUMN_SYNONYMS:
        return random.choice(COLUMN_SYNONYMS[col])
    return col


def question(template, prefix, col, table):
    return template.format(prefix, col_word(col), table)


# =====================================================
# GENERATOR
# =====================================================

def generate_example():

    db=random.choice(SCHEMAS)

    schema=db["tables"]
    t1,t2,c1,c2=db["join"]

    main=t1
    cols=schema[main]

    prefix=random.choice(PREFIX)

    intent=random.choice([
        "SELECT","MULTI","WHERE","BETWEEN",

        # ‚≠ê NEW
        "EQUALITY_VALUE",
        "EQUALITY_NAME",

        "AGG","GROUP","ORDER","LIMIT",

        "JOIN","LEFT_JOIN","JOIN_WHERE",
        "JOIN_GROUP",

        "NESTED",

        # ‚≠ê VERY IMPORTANT
        "NESTED_PERSON_COMPARE",

        "CORRELATED",
        "EXISTS"
    ])

    col=random.choice(cols)

    # =====================================================
    # SIMPLE SELECT
    # =====================================================

    if intent=="SELECT":

        q=question(
            random.choice(QUESTION_STYLE),
            prefix,col,main
        )

        sql=f"SELECT {col} FROM {main}"


    # =====================================================
    # MULTI COLUMN
    # =====================================================

    elif intent=="MULTI":

        c1_,c2_=random.sample(cols,2)

        q=f"{prefix} {col_word(c1_)} and {col_word(c2_)} from {main}"

        sql=f"SELECT {c1_}, {c2_} FROM {main}"


    # =====================================================
    # WHERE
    # =====================================================

    elif intent=="WHERE":

        op,text=random.choice(COMPARE)
        val=random.randint(10,100)

        q=f"{prefix} {main} where {col_word(col)} is {text} {val}"

        sql=f"SELECT name FROM {main} WHERE {col} {op} {val}"


    # =====================================================
    # BETWEEN
    # =====================================================

    elif intent=="BETWEEN":

        low=random.randint(10,40)
        high=random.randint(50,100)

        q=f"{prefix} {main} with {col_word(col)} between {low} and {high}"

        sql=f"SELECT name FROM {main} WHERE {col} BETWEEN {low} AND {high}"


    # =====================================================
    # ‚≠ê EQUALITY VALUE
    # =====================================================

    elif intent=="EQUALITY_VALUE":

        val=random.randint(10,100)

        q=f"{prefix} {main} where {col_word(col)} equals {val}"

        sql=f"SELECT name FROM {main} WHERE {col} = {val}"


    # =====================================================
    # ‚≠ê EQUALITY NAME
    # =====================================================

    elif intent=="EQUALITY_NAME":

        person=random.choice(EMP_NAMES)

        q=f"{prefix} {main} named {person}"

        sql=f"SELECT * FROM {main} WHERE name = '{person.capitalize()}'"


    # =====================================================
    # AGG
    # =====================================================

    elif intent=="AGG":

        agg=random.choice(AGGS)

        q=f"what is the {agg.lower()} of {col_word(col)} in {main}"

        sql=f"SELECT {agg}({col}) FROM {main}"


    # =====================================================
    # GROUP
    # =====================================================

    elif intent=="GROUP":

        agg=random.choice(AGGS)

        q=f"{prefix} {agg.lower()} salary grouped by dept"

        sql=f"SELECT dept_id, {agg}(salary) FROM {main} GROUP BY dept_id"


    # =====================================================
    # ORDER
    # =====================================================

    elif intent=="ORDER":

        q=f"{prefix} {main} {random.choice(ORDER_WORDS)} {col_word(col)}"

        sql=f"SELECT name FROM {main} ORDER BY {col} DESC"


    # =====================================================
    # LIMIT
    # =====================================================

    elif intent=="LIMIT":

        n=random.randint(3,10)

        q=f"{prefix} top {n} {main} by {col_word(col)}"

        sql=f"SELECT name FROM {main} ORDER BY {col} DESC LIMIT {n}"


    # =====================================================
    # JOIN
    # =====================================================

    elif intent=="JOIN":

        pattern=random.randint(1,4)

        if pattern==1:
            q=f"{prefix} {t1} with their {t2}"
            sql=f"SELECT {t1}.name, {t2}.name FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"

        elif pattern==2:
            q=f"{prefix} {t2} along with related {t1}"
            sql=f"SELECT {t2}.name, {t1}.name FROM {t2} JOIN {t1} ON {t2}.{c2} = {t1}.{c1}"

        elif pattern==3:
            q=f"{prefix} records combining {t1} and {t2}"
            sql=f"SELECT * FROM {t1} INNER JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"

        else:
            q=f"{prefix} {t1} mapped to their departments"
            sql=f"SELECT {t1}.name, {t2}.location FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"


    # =====================================================
    # LEFT JOIN
    # =====================================================

    elif intent=="LEFT_JOIN":

        q=f"{prefix} all {t1} even if department missing"

        sql=f"SELECT {t1}.name FROM {t1} LEFT JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"


    # =====================================================
    # JOIN WHERE
    # =====================================================

    elif intent=="JOIN_WHERE":

        val=random.randint(1,5)

        q=f"{prefix} {t1} working in departments where id > {val}"

        sql=f"""
        SELECT {t1}.name, {t2}.name
        FROM {t1}
        JOIN {t2}
        ON {t1}.{c1} = {t2}.{c2}
        WHERE {t2}.id > {val}
        """.replace("\n"," ")


    # =====================================================
    # JOIN GROUP
    # =====================================================

    elif intent=="JOIN_GROUP":

        q=f"{prefix} number of {t1} per {t2}"

        sql=f"""
        SELECT {t2}.name, COUNT(*)
        FROM {t1}
        JOIN {t2}
        ON {t1}.{c1} = {t2}.{c2}
        GROUP BY {t2}.name
        """.replace("\n"," ")


    # =====================================================
    # NESTED AVG/MAX/MIN
    # =====================================================

    elif intent=="NESTED":

        agg=random.choice(["AVG","MAX","MIN"])

        q=f"{prefix} {main} earning above {agg.lower()} salary"

        sql=f"""
        SELECT name
        FROM {main}
        WHERE salary >
        (SELECT {agg}(salary) FROM {main})
        """.replace("\n"," ")


    # =====================================================
    # ‚≠ê NESTED PERSON COMPARISON
    # =====================================================

    elif intent=="NESTED_PERSON_COMPARE":

        person=random.choice(EMP_NAMES)

        q=f"{prefix} employees earning more than {person}"

        sql=f"""
        SELECT name
        FROM employees
        WHERE salary >
        (
            SELECT salary
            FROM employees
            WHERE name = '{person.capitalize()}'
        )
        """.replace("\n"," ")


    # =====================================================
    # CORRELATED
    # =====================================================

    elif intent=="CORRELATED":

        q=f"{prefix} employees earning above their department average"

        sql=f"""
        SELECT name
        FROM employees e
        WHERE salary >
        (
            SELECT AVG(salary)
            FROM employees
            WHERE dept_id = e.dept_id
        )
        """.replace("\n"," ")


    # =====================================================
    # EXISTS
    # =====================================================

    else:

        q=f"{prefix} employees that belong to a department"

        sql=f"""
        SELECT name
        FROM employees e
        WHERE EXISTS
        (
            SELECT 1
            FROM departments d
            WHERE e.dept_id = d.id
        )
        """.replace("\n"," ")


    return {
        "question":q.lower(),
        "schema":schema,
        "sql":" ".join(sql.split())
    }


# =====================================================
# GENERATE DATASET
# =====================================================

DATA=[generate_example() for _ in range(150000)]

with open("nl2sql_data.json","w") as f:
    json.dump(DATA,f)

print("üî• ULTRA dataset generated successfully!")


üî• ULTRA dataset generated successfully!


In [13]:
import re, torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [14]:
ENC_VOCAB = {"<PAD>":0,"<UNK>":1,"<SEP>":2}
DEC_VOCAB = {"<PAD>":0,"<UNK>":1,"<BOS>":2,"<EOS>":3}

def add(vocab,token):
    if token not in vocab:
        vocab[token]=len(vocab)



# ‚≠ê VERY IMPORTANT TOKENIZER
def sql_tokenize(sql):

    return re.findall(
        r"[A-Za-z_]+\.[A-Za-z_]+"      # table.column
        r"|>=|<=|!=|=|>|<"
        r"|\bselect\b|\bfrom\b|\bjoin\b|\bon\b|\bwhere\b"
        r"|\bgroup\b|\bby\b|\border\b|\blimit\b"
        r"|\binner\b|\bleft\b|\bright\b|\bhaving\b"
        r"|\bavg\b|\bsum\b|\bcount\b|\bmax\b|\bmin\b"
        r"|\*"
        r"|\(|\)|,"
        r"|[A-Za-z_]+"
        r"|\d+",
        sql.lower()
    )

for ex in DATA:

    for t in ex["question"].split():
        add(ENC_VOCAB,t)

    for t,cols in ex["schema"].items():
        add(ENC_VOCAB,t)
        for c in cols:
            add(ENC_VOCAB,f"{t}.{c}")

    for tok in sql_tokenize(ex["sql"]):
        add(DEC_VOCAB,tok)

print("Encoder vocab:",len(ENC_VOCAB))
print("Decoder vocab:",len(DEC_VOCAB))


Encoder vocab: 235
Decoder vocab: 168


In [15]:
import re
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

SEED = 42

torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [16]:


def add(vocab,token):
    if token not in vocab:
        vocab[token]=len(vocab)





for ex in DATA:

    for t in ex["question"].split():
        add(ENC_VOCAB,t)

    for t,cols in ex["schema"].items():
        add(ENC_VOCAB,t)
        for c in cols:
            add(ENC_VOCAB,f"{t}.{c}")

    for tok in sql_tokenize(ex["sql"]):
        add(DEC_VOCAB,tok)

print("Encoder vocab:",len(ENC_VOCAB))
print("Decoder vocab:",len(DEC_VOCAB))


Encoder vocab: 235
Decoder vocab: 168


In [17]:
class NL2SQLDataset(Dataset):

    def __init__(self,data):
        self.data=data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,i):

        ex=self.data[i]

        src = ex["question"].split()+["<SEP>"]+[
            f"{t}.{c}"
            for t,cs in ex["schema"].items()
            for c in cs
        ]

        src_ids=[ENC_VOCAB.get(t,1) for t in src][:140]
        src_ids+=[0]*(140-len(src_ids))

        tgt=[DEC_VOCAB["<BOS>"]] + \
            [DEC_VOCAB.get(t,1) for t in sql_tokenize(ex["sql"])] + \
            [DEC_VOCAB["<EOS>"]]

        tgt=tgt[:100]
        tgt+=[0]*(100-len(tgt))

        return torch.tensor(src_ids),torch.tensor(tgt)


In [18]:
D_MODEL = 512   # üî• bigger = better joins + nested

class PositionalEncoding(nn.Module):

    def __init__(self,d_model,max_len=600):
        super().__init__()

        pe=torch.zeros(max_len,d_model)

        pos=torch.arange(0,max_len).unsqueeze(1)

        div=torch.exp(
            torch.arange(0,d_model,2) *
            (-torch.log(torch.tensor(10000.0))/d_model)
        )

        pe[:,0::2]=torch.sin(pos*div)
        pe[:,1::2]=torch.cos(pos*div)

        self.pe=pe.unsqueeze(0)

    def forward(self,x):
        return x+self.pe[:,:x.size(1)].to(x.device)


In [19]:
class Encoder(nn.Module):

    def __init__(self,vocab):
        super().__init__()

        self.emb=nn.Embedding(vocab,D_MODEL,padding_idx=0)
        self.pos=PositionalEncoding(D_MODEL)

        layer=nn.TransformerEncoderLayer(
            D_MODEL,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True,
            norm_first=True
        )

        self.enc=nn.TransformerEncoder(layer,4)

    def forward(self,x):

        mask=(x==0)

        x=self.pos(self.emb(x))

        return self.enc(x,src_key_padding_mask=mask)



class Decoder(nn.Module):

    def __init__(self,vocab):
        super().__init__()

        self.emb=nn.Embedding(vocab,D_MODEL)
        self.pos=PositionalEncoding(D_MODEL)

        layer=nn.TransformerDecoderLayer(
            D_MODEL,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True,
            norm_first=True
        )

        self.dec=nn.TransformerDecoder(layer,4)

        self.fc=nn.Linear(D_MODEL,vocab)

        # ‚≠ê improves generation quality massively
        self.fc.weight=self.emb.weight

    def forward(self,y,mem,mask):

        L=y.size(1)

        causal=torch.triu(
            torch.ones(L,L,device=y.device),1
        ).bool()

        y=self.pos(self.emb(y))

        return self.fc(
            self.dec(
                y,mem,
                tgt_mask=causal,
                memory_key_padding_mask=mask
            )
        )


In [20]:
train,val=train_test_split(DATA,test_size=0.1,random_state=42)

train_loader=DataLoader(
    NL2SQLDataset(train),
    batch_size=48,   # ‚ö†Ô∏è 512 model needs smaller batch
    shuffle=True
)

val_loader=DataLoader(
    NL2SQLDataset(val),
    batch_size=48
)


In [21]:
enc=Encoder(len(ENC_VOCAB)).to(device)
dec=Decoder(len(DEC_VOCAB)).to(device)

optimizer=optim.AdamW(
    list(enc.parameters())+list(dec.parameters()),
    lr=8e-5,           # ‚≠ê SWEET SPOT
    weight_decay=1e-4
)

scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=12
)

loss_fn=nn.CrossEntropyLoss(
    ignore_index=0,
    label_smoothing=0.07
)




In [None]:
import os

CHECKPOINT_PATH = "checkpoint.pt"
BEST_MODEL_PATH = "best_model.pt"

start_epoch = 0
best_val_loss = float("inf")

# ‚≠ê Resume automatically if checkpoint exists
if os.path.exists(CHECKPOINT_PATH):

    print("‚úÖ Resuming from checkpoint...\n")

    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

    enc.load_state_dict(checkpoint["encoder"])
    dec.load_state_dict(checkpoint["decoder"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scheduler.load_state_dict(checkpoint["scheduler"])

    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]

    print(f"Resuming from Epoch {start_epoch}\n")


print("üöÄ Training...\n")

for epoch in range(start_epoch, 14):

    # ================= TRAIN =================
    enc.train(); dec.train()
    train_loss = 0

    for x,y in train_loader:

        x,y = x.to(device),y.to(device)

        optimizer.zero_grad()

        mem = enc(x)

        out = dec(y[:,:-1],mem,(x==0))

        loss = loss_fn(
            out.reshape(-1,len(DEC_VOCAB)),
            y[:,1:].reshape(-1)
        )

        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            list(enc.parameters())+list(dec.parameters()),
            1.0
        )

        optimizer.step()

        train_loss += loss.item()


    # ================= VALIDATION =================
    enc.eval(); dec.eval()
    val_loss = 0

    with torch.no_grad():

        for x,y in val_loader:

            x,y = x.to(device),y.to(device)

            mem = enc(x)

            out = dec(y[:,:-1],mem,(x==0))

            val_loss += loss_fn(
                out.reshape(-1,len(DEC_VOCAB)),
                y[:,1:].reshape(-1)
            ).item()


    scheduler.step()

    train_loss /= len(train_loader)
    val_loss /= len(val_loader)

    print(f"""
Epoch {epoch+1}

Train Loss: {train_loss:.3f}
Val Loss:   {val_loss:.3f}
""")


    # ================= SAVE CHECKPOINT =================

    torch.save({
        "epoch": epoch,
        "encoder": enc.state_dict(),
        "decoder": dec.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "best_val_loss": best_val_loss
    }, CHECKPOINT_PATH)


    # ================= SAVE BEST MODEL =================

    if val_loss < best_val_loss:

        best_val_loss = val_loss

        torch.save({
            "encoder": enc.state_dict(),
            "decoder": dec.state_dict(),
            "enc_vocab": ENC_VOCAB,
            "dec_vocab": DEC_VOCAB
        }, BEST_MODEL_PATH)

        print("üî• BEST MODEL SAVED!\n")


üöÄ Training...


Epoch 1

Train Loss: 5.543
Val Loss:   0.856

üî• BEST MODEL SAVED!


Epoch 2

Train Loss: 0.799
Val Loss:   0.644

üî• BEST MODEL SAVED!


Epoch 3

Train Loss: 0.665
Val Loss:   0.625

üî• BEST MODEL SAVED!


Epoch 4

Train Loss: 0.635
Val Loss:   0.617

üî• BEST MODEL SAVED!


Epoch 5

Train Loss: 0.625
Val Loss:   0.616

üî• BEST MODEL SAVED!


Epoch 6

Train Loss: 0.620
Val Loss:   0.615

üî• BEST MODEL SAVED!


Epoch 7

Train Loss: 0.617
Val Loss:   0.613

üî• BEST MODEL SAVED!


Epoch 8

Train Loss: 0.615
Val Loss:   0.612

üî• BEST MODEL SAVED!


Epoch 9

Train Loss: 0.614
Val Loss:   0.611

üî• BEST MODEL SAVED!


Epoch 10

Train Loss: 0.613
Val Loss:   0.610

üî• BEST MODEL SAVED!


Epoch 11

Train Loss: 0.612
Val Loss:   0.610

üî• BEST MODEL SAVED!


Epoch 12

Train Loss: 0.612
Val Loss:   0.610


Epoch 13

Train Loss: 0.612
Val Loss:   0.610

