<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_TO_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================
# NL â†’ SQL Transformer (FINAL TOKEN-BASED SYSTEM)
# Perfect SELECT + AGG + WHERE + JOIN + PICARD
# ============================================================

import random, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

# ============================================================
# SCHEMAS
# ============================================================

SCHEMAS = [
    {
        "tables": {
            "employees": ["id","name","salary","dept_id"],
            "departments": ["id","name"]
        },
        "join": ("employees","departments","dept_id","id")
    },
    {
        "tables": {
            "students": ["id","name","marks","class_id"],
            "classes": ["id","name"]
        },
        "join": ("students","classes","class_id","id")
    }
]

AGGS = ["sum","avg","count","max","min"]

In [2]:
# ============================================================
# DATA GENERATION
# ============================================================

def generate_example():
    db = random.choice(SCHEMAS)
    schema = db["tables"]
    main = list(schema.keys())[0]
    cols = schema[main]

    intent = random.choices(
        ["SELECT","AGG","WHERE","JOIN","AGG_WHERE","JOIN_WHERE"],
        weights=[0.25,0.2,0.2,0.15,0.1,0.1]
    )[0]

    if intent=="SELECT":
        col=random.choice(cols)
        q=f"show {col} of {main}"
        sql=f"SELECT {col} FROM {main}"

    elif intent=="AGG":
        agg=random.choice(AGGS)
        col=random.choice(cols)
        q=f"show {agg} of {col} from {main}"
        sql=f"SELECT {agg.upper()}({col}) FROM {main}"

    elif intent=="WHERE":
        col=random.choice(cols)
        val=random.choice([10,20,50,100])
        q=f"get {col} from {main} where {col} > {val}"
        sql=f"SELECT {col} FROM {main} WHERE {col} > {val}"

    elif intent=="JOIN":
        t1,t2,c1,c2=db["join"]
        q=f"show {t1} and {t2} names"
        sql=f"SELECT {t1}.name, {t2}.name FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"

    elif intent=="AGG_WHERE":
        agg=random.choice(AGGS)
        col=random.choice(cols)
        val=random.choice([10,20,50])
        q=f"show {agg} of {col} from {main} where {col} > {val}"
        sql=f"SELECT {agg.upper()}({col}) FROM {main} WHERE {col} > {val}"

    else:
        t1,t2,c1,c2=db["join"]
        val=random.choice([10,20,50])
        q=f"show {t1} and {t2} names where {t1}.{c1} > {val}"
        sql=f"SELECT {t1}.name, {t2}.name FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2} WHERE {t1}.{c1} > {val}"

    return {"question":q,"schema":schema,"sql":sql}

DATA=[generate_example() for _ in range(60000)]

In [3]:
# ============================================================
# VOCAB
# ============================================================

ENC_VOCAB={"<PAD>":0,"<UNK>":1,"<SEP>":2}
DEC_VOCAB={"<PAD>":0,"<UNK>":1,"<BOS>":2,"<EOS>":3}

def add(v,t):
    if t not in v: v[t]=len(v)

for ex in DATA:
    for t in ex["question"].lower().split(): add(ENC_VOCAB,t)
    for tb,cs in ex["schema"].items():
        add(ENC_VOCAB,tb)
        for c in cs: add(ENC_VOCAB,f"{tb}.{c}")
    for t in ex["sql"].lower().split(): add(DEC_VOCAB,t)

# ============================================================
# DATASET
# ============================================================

class NL2SQLDataset(Dataset):
    def __init__(self,data): self.data=data
    def __len__(self): return len(self.data)

    def __getitem__(self,i):
        ex=self.data[i]
        q=ex["question"].lower().split()
        schema_tokens=[f"{t}.{c}" for t,cs in ex["schema"].items() for c in cs]
        random.shuffle(schema_tokens)

        src=q+["<SEP>"]+schema_tokens
        src_ids=[ENC_VOCAB.get(t,1) for t in src][:80]
        src_ids+=[0]*(80-len(src_ids))

        tgt=[DEC_VOCAB["<BOS>"]]+[DEC_VOCAB.get(t,1) for t in ex["sql"].lower().split()]+[DEC_VOCAB["<EOS>"]]
        tgt=tgt[:70]
        tgt+=[0]*(70-len(tgt))

        return torch.tensor(src_ids),torch.tensor(tgt)

train,val=train_test_split(DATA,test_size=0.1)
train_loader=DataLoader(NL2SQLDataset(train),batch_size=128,shuffle=True)
val_loader=DataLoader(NL2SQLDataset(val),batch_size=128)

In [4]:
# ============================================================
# MODEL
# ============================================================

class Encoder(nn.Module):
    def __init__(self,v):
        super().__init__()
        self.emb=nn.Embedding(v,256,padding_idx=0)
        layer=nn.TransformerEncoderLayer(256,8,1024,batch_first=True)
        self.enc=nn.TransformerEncoder(layer,4)

    def forward(self,x):
        return self.enc(self.emb(x),src_key_padding_mask=(x==0))

class Decoder(nn.Module):
    def __init__(self,v):
        super().__init__()
        self.emb=nn.Embedding(v,256)
        layer=nn.TransformerDecoderLayer(256,8,1024,batch_first=True)
        self.dec=nn.TransformerDecoder(layer,4)
        self.fc=nn.Linear(256,v)

    def forward(self,y,mem,mask):
        L=y.size(1)
        causal=torch.triu(torch.ones(L,L,device=y.device),1).bool()
        out=self.dec(self.emb(y),mem,tgt_mask=causal,memory_key_padding_mask=mask)
        return self.fc(out)

device="cuda" if torch.cuda.is_available() else "cpu"
enc,dec=Encoder(len(ENC_VOCAB)).to(device),Decoder(len(DEC_VOCAB)).to(device)

opt=optim.AdamW(list(enc.parameters())+list(dec.parameters()),lr=3e-4)
loss_fn=nn.CrossEntropyLoss(ignore_index=0,label_smoothing=0.1)


In [5]:
# ============================================================
# TRAIN
# ============================================================

print("ðŸš€ Training")

for epoch in range(5):
    enc.train(); dec.train(); tot=0
    for x,y in train_loader:
        x,y=x.to(device),y.to(device)
        opt.zero_grad()
        mem=enc(x)
        out=dec(y[:,:-1],mem,(x==0))
        loss=loss_fn(out.reshape(-1,len(DEC_VOCAB)),y[:,1:].reshape(-1))
        loss.backward(); opt.step()
        tot+=loss.item()
    print(f"Epoch {epoch+1} | Loss {tot/len(train_loader):.4f}")

# ============================================================
# SAVE MODEL
# ============================================================

torch.save({
    "enc":enc.state_dict(),
    "dec":dec.state_dict(),
    "ENC_VOCAB":ENC_VOCAB,
    "DEC_VOCAB":DEC_VOCAB
},"nl2sql_token_final.pt")

print("âœ… Model saved: nl2sql_token_final.pt")


ðŸš€ Training
Epoch 1 | Loss 0.8214
Epoch 2 | Loss 0.7302
Epoch 3 | Loss 0.7293
Epoch 4 | Loss 0.7289
Epoch 5 | Loss 0.7287
âœ… Model saved: nl2sql_token_final.pt


In [12]:
# ============================================================
# INFERENCE WITH PICARD
# ============================================================

inv_dec_vocab={v:k for k,v in DEC_VOCAB.items()}

def valid_sql_step(prev, tok):
    p=[t.lower() for t in prev]
    t=tok.lower()

    if len(p)==0: return t=="select"
    if t=="from" and "select" not in p: return False
    if t=="join" and "from" not in p: return False
    if t=="where" and "from" not in p: return False
    if t=="on" and "join" not in p: return False
    return True

def infer_sql(question,schema,max_len=70):
    tokens=question.lower().split()+["<SEP>"]+[f"{t}.{c}" for t,cs in schema.items() for c in cs]
    ids=[ENC_VOCAB.get(t,1) for t in tokens][:80]
    ids+=[0]*(80-len(ids))

    x=torch.tensor(ids).unsqueeze(0).to(device)
    mem=enc(x)

    y=torch.tensor([[DEC_VOCAB["<BOS>"]]],device=device)
    gen=[]

    for _ in range(max_len):
        out=dec(y,mem,(x==0))
        probs=out[:,-1].softmax(-1)
        for idx in torch.argsort(probs,descending=True)[0]:
            tok=inv_dec_vocab[idx.item()]
            if tok in {"<PAD>","<BOS>"}: continue
            if valid_sql_step(gen,tok):
                break
        if tok=="<EOS>": break
        gen.append(tok)
        y=torch.cat([y,idx.view(1,1)],dim=1)

    return " ".join(gen)

# ============================================================
# TEST
# ============================================================

schema1={
    "employees":["id","name","salary","dept_id"],
    "departments":["id","name"]
}

print(infer_sql("min salary from employees",schema1))
print(infer_sql("get salary from employees where salary > 20",schema1))
print(infer_sql("show employees and departments names",schema1))

select min(salary) from employees
select salary from employees where salary > 20
select employees.name, departments.name from employees join departments on employees.dept_id = departments.id
