<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_2_SQL_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import random, json, re, torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

SEED = 42

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cpu


In [None]:
import random, json

SCHEMAS = [

    {
        "tables":{
            "employees":["id","name","salary","dept_id"],
            "departments":["id","name"]
        },
        "join":("employees","departments","dept_id","id")
    },

    {
        "tables":{
            "students":["id","name","marks","class_id"],
            "classes":["id","name"]
        },
        "join":("students","classes","class_id","id")
    }
]

AGGS = ["SUM","AVG","COUNT","MAX","MIN"]

PREFIX = [
    "show","list","display","fetch","give","find","retrieve",
    "return","provide","get me","can you show"
]

COMPARE = [
    (">","greater than"),
    ("<","less than"),
    (">=","at least"),
    ("<=","at most"),
    ("!=","not equal to")
]


############################################
# Helper Functions
############################################

def rand_val():
    """Generate realistic numeric values"""
    return random.randint(1,10000)


def pick_table(schema):
    return random.choice(list(schema.keys()))


def pick_col(schema, table):
    return random.choice(schema[table])


############################################
# Generator
############################################

def generate_example():

    db = random.choice(SCHEMAS)
    schema = db["tables"]

    main = pick_table(schema)
    cols = schema[main]

    intent = random.choice([
        "SELECT","MULTI","WHERE","OR",
        "AGG","GROUP",
        "ORDER","LIMIT",
        "JOIN","JOIN_WHERE",
        "NESTED"
    ])

    prefix = random.choice(PREFIX)

    ##################################
    # SELECT
    ##################################

    if intent=="SELECT":

        col = pick_col(schema, main)

        q = f"{prefix} {col} from {main}"
        sql = f"SELECT {col} FROM {main}"


    ##################################
    # MULTI
    ##################################

    elif intent=="MULTI":

        c1, c2 = random.sample(cols,2)

        q = f"{prefix} {c1} and {c2} from {main}"
        sql = f"SELECT {c1}, {c2} FROM {main}"


    ##################################
    # WHERE
    ##################################

    elif intent=="WHERE":

        col = pick_col(schema, main)
        op, text = random.choice(COMPARE)
        val = rand_val()

        q = f"{prefix} {main} where {col} is {text} {val}"
        sql = f"SELECT * FROM {main} WHERE {col} {op} {val}"


    ##################################
    # OR
    ##################################

    elif intent=="OR":

        c1, c2 = random.sample(cols,2)
        v1, v2 = rand_val(), rand_val()

        q = f"{prefix} {main} where {c1} > {v1} or {c2} > {v2}"
        sql = f"SELECT * FROM {main} WHERE {c1} > {v1} OR {c2} > {v2}"


    ##################################
    # AGG
    ##################################

    elif intent=="AGG":

        col = pick_col(schema, main)
        agg = random.choice(AGGS)

        q = f"what is the {agg.lower()} of {col} in {main}"
        sql = f"SELECT {agg}({col}) FROM {main}"


    ##################################
    # GROUP BY
    ##################################

    elif intent=="GROUP":

        group_col = pick_col(schema, main)
        agg_col = pick_col(schema, main)
        agg = random.choice(AGGS)

        q = f"{prefix} {agg.lower()} {agg_col} grouped by {group_col}"
        sql = f"SELECT {group_col}, {agg}({agg_col}) FROM {main} GROUP BY {group_col}"


    ##################################
    # ORDER
    ##################################

    elif intent=="ORDER":

        col = pick_col(schema, main)

        q = f"{prefix} {main} ordered by {col}"
        sql = f"SELECT * FROM {main} ORDER BY {col} DESC"


    ##################################
    # LIMIT
    ##################################

    elif intent=="LIMIT":

        col = pick_col(schema, main)
        limit = random.randint(1,20)

        q = f"{prefix} top {limit} rows from {main} by {col}"
        sql = f"SELECT * FROM {main} ORDER BY {col} DESC LIMIT {limit}"


    ##################################
    # JOIN
    ##################################

    elif intent=="JOIN" and "join" in db:

        t1,t2,c1,c2 = db["join"]

        col1 = pick_col(schema, t1)
        col2 = pick_col(schema, t2)

        q = f"{prefix} {t1} with their {t2}"

        sql = f"""
        SELECT {t1}.{col1}, {t2}.{col2}
        FROM {t1}
        JOIN {t2}
        ON {t1}.{c1} = {t2}.{c2}
        """


    ##################################
    # JOIN WHERE
    ##################################

    elif intent=="JOIN_WHERE" and "join" in db:

        t1,t2,c1,c2 = db["join"]

        col = pick_col(schema, t2)
        val = rand_val()

        q = f"{prefix} {t1} where {t2} {col} is greater than {val}"

        sql = f"""
        SELECT *
        FROM {t1}
        JOIN {t2}
        ON {t1}.{c1} = {t2}.{c2}
        WHERE {t2}.{col} > {val}
        """


    ##################################
    # NESTED
    ##################################

    else:

        col = pick_col(schema, main)

        q = f"{prefix} {main} where {col} is greater than average"

        sql = f"""
        SELECT *
        FROM {main}
        WHERE {col} >
        (SELECT AVG({col}) FROM {main})
        """

    return {
        "question": q.lower(),
        "schema": schema,
        "sql": " ".join(sql.split())
    }


############################################
# Generate Dataset
############################################

DATA = [generate_example() for _ in range(150000)]

with open("nl2sql_data.json","w") as f:
    json.dump(DATA,f)

print("ðŸ”¥ ULTRA high-quality dataset generated!")


ðŸ”¥ ULTRA high-quality dataset generated!


In [None]:
with open("nl2sql_data.json") as f:
    DATA = json.load(f)


In [None]:
ENC_VOCAB={"<PAD>":0,"<UNK>":1,"<SEP>":2}
DEC_VOCAB={"<PAD>":0,"<UNK>":1,"<BOS>":2,"<EOS>":3}

def add(vocab,t):
    if t not in vocab:
        vocab[t]=len(vocab)


def sql_tokenize(sql):

    return re.findall(
        r"[A-Za-z_]+\.[A-Za-z_]+"     # table.column
        r"|>=|<=|!=|=|>|<"
        r"|\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bon\b"
        r"|\bgroup\b|\bby\b|\bhaving\b|\border\b|\blimit\b"
        r"|\bavg\b|\bsum\b|\bcount\b|\bmax\b|\bmin\b"
        r"|\*"
        r"|\(|\)|,"
        r"|[A-Za-z_]+"
        r"|\d+",
        sql.lower()
    )



for ex in DATA:

    for t in ex["question"].split():
        add(ENC_VOCAB,t)

    for t,cols in ex["schema"].items():
        add(ENC_VOCAB,t)
        for c in cols:
            add(ENC_VOCAB,f"{t}.{c}")

    for tok in sql_tokenize(ex["sql"]):
        add(DEC_VOCAB,tok)

print(len(ENC_VOCAB),len(DEC_VOCAB))


10028 10010


In [4]:
class NL2SQLDataset(Dataset):

    def __init__(self,data):
        self.data=data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,i):

        ex=self.data[i]

        src = ex["question"].split()+["<SEP>"]+[
            f"{t}.{c}"
            for t,cs in ex["schema"].items()
            for c in cs
        ]

        src_ids=[ENC_VOCAB.get(t,1) for t in src][:160]
        src_ids+=[0]*(160-len(src_ids))


        tgt=[DEC_VOCAB["<BOS>"]] + \
            [DEC_VOCAB.get(t,1) for t in sql_tokenize(ex["sql"])] + \
            [DEC_VOCAB["<EOS>"]]

        tgt=tgt[:80]
        tgt+=[0]*(80-len(tgt))

        return torch.tensor(src_ids),torch.tensor(tgt)


In [5]:
D_MODEL=384

class PositionalEncoding(nn.Module):
    def __init__(self,d_model,max_len=512):
        super().__init__()

        pe=torch.zeros(max_len,d_model)

        pos=torch.arange(0,max_len).unsqueeze(1)

        div=torch.exp(
            torch.arange(0,d_model,2) *
            (-torch.log(torch.tensor(10000.0))/d_model)
        )

        pe[:,0::2]=torch.sin(pos*div)
        pe[:,1::2]=torch.cos(pos*div)

        self.pe=pe.unsqueeze(0)

    def forward(self,x):
        return x+self.pe[:,:x.size(1)].to(x.device)


In [6]:
class Encoder(nn.Module):

    def __init__(self,vocab):
        super().__init__()

        self.emb=nn.Embedding(vocab,D_MODEL,padding_idx=0)
        self.pos=PositionalEncoding(D_MODEL)

        layer=nn.TransformerEncoderLayer(
            D_MODEL,8,1536,
            dropout=0.1,
            batch_first=True,
            norm_first=True
        )

        self.enc=nn.TransformerEncoder(layer,4)

    def forward(self,x):

        mask=(x==0)

        x=self.pos(self.emb(x))

        return self.enc(x,src_key_padding_mask=mask)



class Decoder(nn.Module):

    def __init__(self,vocab):
        super().__init__()

        self.emb=nn.Embedding(vocab,D_MODEL)
        self.pos=PositionalEncoding(D_MODEL)

        layer=nn.TransformerDecoderLayer(
            D_MODEL,8,1536,
            dropout=0.1,
            batch_first=True,
            norm_first=True
        )

        self.dec=nn.TransformerDecoder(layer,4)

        self.fc=nn.Linear(D_MODEL,vocab)

        self.fc.weight=self.emb.weight

    def forward(self,y,mem,mask):

        L=y.size(1)

        causal=torch.triu(
            torch.ones(L,L,device=y.device),1
        ).bool()

        y=self.pos(self.emb(y))

        return self.fc(
            self.dec(
                y,mem,
                tgt_mask=causal,
                memory_key_padding_mask=mask
            )
        )


In [None]:
train,val=train_test_split(DATA,test_size=0.1,random_state=42)

train_loader=DataLoader(
    NL2SQLDataset(train),
    batch_size=64,
    shuffle=True
)

val_loader=DataLoader(
    NL2SQLDataset(val),
    batch_size=64
)


In [None]:
enc=Encoder(len(ENC_VOCAB)).to(device)
dec=Decoder(len(DEC_VOCAB)).to(device)

optimizer=optim.AdamW(
    list(enc.parameters())+list(dec.parameters()),
    lr=1e-4
)

scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,T_max=15
)

loss_fn=nn.CrossEntropyLoss(
    ignore_index=0,
    label_smoothing=0.05
)

import os

CHECKPOINT_PATH = "checkpoint.pt"

start_epoch = 0

############################################
# âœ… RESUME IF CHECKPOINT EXISTS
############################################

if os.path.exists(CHECKPOINT_PATH):

    print("âœ… Resuming from checkpoint...\n")

    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

    enc.load_state_dict(checkpoint["encoder"])
    dec.load_state_dict(checkpoint["decoder"])

    optimizer.load_state_dict(checkpoint["optimizer"])
    scheduler.load_state_dict(checkpoint["scheduler"])

    start_epoch = checkpoint["epoch"] + 1

    print(f"Resuming from Epoch {start_epoch}\n")


print("ðŸš€ Training...\n")

EPOCHS = 15

for epoch in range(start_epoch, EPOCHS):

    #################################
    # TRAIN
    #################################
    enc.train()
    dec.train()

    train_loss = 0

    for x, y in train_loader:

        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        mem = enc(x)

        out = dec(y[:, :-1], mem, (x == 0))

        loss = loss_fn(
            out.reshape(-1, len(DEC_VOCAB)),
            y[:, 1:].reshape(-1)
        )

        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            list(enc.parameters()) + list(dec.parameters()),
            1.0
        )

        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    #################################
    # VALIDATION
    #################################
    enc.eval()
    dec.eval()

    val_loss = 0

    with torch.no_grad():

        for x, y in val_loader:

            x = x.to(device)
            y = y.to(device)

            mem = enc(x)

            out = dec(y[:, :-1], mem, (x == 0))

            val_loss += loss_fn(
                out.reshape(-1, len(DEC_VOCAB)),
                y[:, 1:].reshape(-1)
            ).item()

    val_loss /= len(val_loader)

    scheduler.step()

    #################################
    print(f"""
Epoch {epoch+1}

Train Loss: {train_loss:.3f}
Val Loss:   {val_loss:.3f}
""")
    #################################

    ################################################
    # âœ… SAVE CHECKPOINT AFTER EVERY EPOCH
    ################################################

    torch.save({
        "epoch": epoch,
        "encoder": enc.state_dict(),
        "decoder": dec.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "enc_vocab": ENC_VOCAB,
        "dec_vocab": DEC_VOCAB
    }, CHECKPOINT_PATH)

    print("âœ… Checkpoint Saved!\n")

print("ðŸŽ‰ TRAINING COMPLETE")




âœ… Resuming from checkpoint...

Resuming from Epoch 8

ðŸš€ Training...


Epoch 9

Train Loss: 0.750
Val Loss:   0.733

âœ… Checkpoint Saved!


Epoch 10

Train Loss: 0.739
Val Loss:   0.728

âœ… Checkpoint Saved!


Epoch 11

Train Loss: 0.731
Val Loss:   0.728

âœ… Checkpoint Saved!


Epoch 12

Train Loss: 0.725
Val Loss:   0.722

âœ… Checkpoint Saved!


Epoch 13

Train Loss: 0.721
Val Loss:   0.720

âœ… Checkpoint Saved!


Epoch 14

Train Loss: 0.717
Val Loss:   0.720

âœ… Checkpoint Saved!



KeyboardInterrupt: 

In [7]:
import torch

CHECKPOINT_PATH = "checkpoint.pt"

checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

# Restore vocab
ENC_VOCAB = checkpoint["enc_vocab"]
DEC_VOCAB = checkpoint["dec_vocab"]

inv_dec_vocab = {v:k for k,v in DEC_VOCAB.items()}

# Rebuild models
enc = Encoder(len(ENC_VOCAB)).to(device)
dec = Decoder(len(DEC_VOCAB)).to(device)

enc.load_state_dict(checkpoint["encoder"])
dec.load_state_dict(checkpoint["decoder"])

enc.eval()
dec.eval()

print("âœ… Model restored. Ready for inference.")


âœ… Model restored. Ready for inference.




In [8]:
def infer_sql(question, schema, max_len=80):

    enc.eval()
    dec.eval()

    tokens = question.lower().split() + ["<SEP>"] + [
        f"{t}.{c}" for t,cs in schema.items() for c in cs
    ]

    ids = [ENC_VOCAB.get(t,1) for t in tokens][:120]
    ids += [0]*(120-len(ids))

    x = torch.tensor(ids).unsqueeze(0).to(device)

    with torch.no_grad():
        memory = enc(x)

        y = torch.tensor([[DEC_VOCAB["<BOS>"]]], device=device)

        generated = []

        for _ in range(max_len):

            out = dec(y, memory, (x==0))

            logits = out[:,-1,:]

            idx = torch.argmax(logits, dim=-1)

            token = inv_dec_vocab[idx.item()]

            if token == "<EOS>":
                break

            if token not in ["<PAD>", "<BOS>"]:
                generated.append(token)

            y = torch.cat([y, idx.view(1,1)], dim=1)

    return " ".join(generated)


In [12]:
schema1 = {
    "employees":["id","name","salary","dept_id"],
    "departments":["id","name"]
}



print(infer_sql(
    "list name and salary of employees",
    schema1
))
print(infer_sql(
    "show employees earning more than average salary",
    schema1
))
print(infer_sql(
    "what is count of dept_id of employees grouped by dept_id",
    schema1
))
print(infer_sql(
    "what is max of salary in employees",
    schema1
))
print(infer_sql(
    "show employees in departments where id > 8",
    schema1
))
print(infer_sql("show employees where salary is greater than 5000", schema1))

print(infer_sql("list employees where id is less than 30", schema1))

print(infer_sql("get employees ordered by salary", schema1))

print(infer_sql("show top 3 employees by salary", schema1))




select name , salary from employees
select * from employees where salary > ( select avg ( salary ) from employees )
select dept_id , count ( dept_id ) from employees group by dept_id
select max ( salary ) from employees
select * from employees join departments on employees.dept_id = departments.id where departments.id > 8
select * from employees where salary > 5000
select * from employees where id < 30
select * from employees order by salary desc
select * from employees order by salary desc limit 3
