<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_2_SQL_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random, json, re, torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

SEED = 42

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [2]:
SCHEMAS = [

    {
        "tables":{
            "employees":["id","name","salary","dept_id"],
            "departments":["id","name"]
        },
        "join":("employees","departments","dept_id","id")
    },

    {
        "tables":{
            "students":["id","name","marks","class_id"],
            "classes":["id","name"]
        },
        "join":("students","classes","class_id","id")
    }
]

AGGS = ["SUM","AVG","COUNT","MAX","MIN"]

PREFIX = [
    "show","list","display","fetch","give","find","retrieve"
]

COMPARE = [
    (">","greater than"),
    ("<","less than")
]


def generate_example():

    db = random.choice(SCHEMAS)
    schema = db["tables"]

    main = list(schema.keys())[0]
    cols = schema[main]

    intent = random.choice([
        "SELECT","MULTI","WHERE","OR",
        "AGG","GROUP","HAVING",
        "ORDER","LIMIT",
        "JOIN","JOIN_WHERE",
        "NESTED","DOUBLE_NESTED"
    ])

    prefix = random.choice(PREFIX)

    col = random.choice(cols)

    # ---------------- SELECT ----------------
    if intent=="SELECT":

        q = f"{prefix} {col} from {main}"
        sql = f"SELECT {col} FROM {main}"


    # ---------------- MULTI ----------------
    elif intent=="MULTI":

        c1,c2 = random.sample(cols,2)

        q = f"{prefix} {c1} and {c2} from {main}"
        sql = f"SELECT {c1}, {c2} FROM {main}"


    # ---------------- WHERE ----------------
    elif intent=="WHERE":

        op,text = random.choice(COMPARE)
        val = random.choice([10,50,100,500])

        q = f"{prefix} employees where {col} is {text} {val}"
        sql = f"SELECT name FROM {main} WHERE {col} {op} {val}"


    # ---------------- OR ----------------
    elif intent=="OR":

        c1,c2 = random.sample(cols,2)
        v1,v2 = random.choice([10,50]), random.choice([100,200])

        q = f"{prefix} employees where {c1} > {v1} or {c2} > {v2}"
        sql = f"SELECT name FROM {main} WHERE {c1} > {v1} OR {c2} > {v2}"


    # ---------------- AGG ----------------
    elif intent=="AGG":

        agg = random.choice(AGGS)

        q = f"what is the {agg.lower()} of {col} in {main}"
        sql = f"SELECT {agg}({col}) FROM {main}"


    # ---------------- GROUP ----------------
    elif intent=="GROUP":

        group = random.choice(cols)
        agg = random.choice(AGGS)

        q = f"{prefix} {agg.lower()} salary grouped by {group}"
        sql = f"SELECT {group}, {agg}(salary) FROM {main} GROUP BY {group}"


    # ---------------- HAVING ----------------
    elif intent=="HAVING":

        group = random.choice(cols)

        q = f"{prefix} departments having more than 5 employees"

        sql = f"""
        SELECT dept_id, COUNT(*)
        FROM employees
        GROUP BY dept_id
        HAVING COUNT(*) > 5
        """


    # ---------------- ORDER ----------------
    elif intent=="ORDER":

        q = f"{prefix} employees ordered by {col}"
        sql = f"SELECT name FROM {main} ORDER BY {col} DESC"


    # ---------------- LIMIT ----------------
    elif intent=="LIMIT":

        q = f"{prefix} top 5 employees by {col}"
        sql = f"SELECT name FROM {main} ORDER BY {col} DESC LIMIT 5"


    # ---------------- JOIN ----------------
    elif intent=="JOIN":

        t1,t2,c1,c2 = db["join"]

        q = f"{prefix} {t1} with their {t2}"

        sql = f"""
        SELECT {t1}.name, {t2}.name
        FROM {t1}
        JOIN {t2}
        ON {t1}.{c1} = {t2}.{c2}
        """


    # ---------------- JOIN WHERE ----------------
    elif intent=="JOIN_WHERE":

        t1,t2,c1,c2 = db["join"]

        q = f"{prefix} employees in departments where id > 3"

        sql = f"""
        SELECT {t1}.name, {t2}.name
        FROM {t1}
        JOIN {t2}
        ON {t1}.{c1} = {t2}.{c2}
        WHERE {t2}.id > 3
        """


    # ---------------- NESTED ----------------
    elif intent=="NESTED":

        q = f"{prefix} employees earning more than average salary"

        sql = f"""
        SELECT name
        FROM {main}
        WHERE salary >
        (SELECT AVG(salary) FROM {main})
        """


    # ---------------- DOUBLE NESTED ----------------
    else:

        q = f"{prefix} employees earning more than average salary in dept 2"

        sql = f"""
        SELECT name
        FROM employees
        WHERE salary >
        (
            SELECT AVG(salary)
            FROM employees
            WHERE dept_id = 2
        )
        """

    return {
        "question": q.lower(),
        "schema": schema,
        "sql": " ".join(sql.split())
    }


DATA = [generate_example() for _ in range(150000)]

with open("nl2sql_data.json","w") as f:
    json.dump(DATA,f,indent=2)

print("âœ… Massive high-quality dataset generated")


âœ… Massive high-quality dataset generated


In [3]:
ENC_VOCAB={"<PAD>":0,"<UNK>":1,"<SEP>":2}
DEC_VOCAB={"<PAD>":0,"<UNK>":1,"<BOS>":2,"<EOS>":3}

def add(vocab,t):
    if t not in vocab:
        vocab[t]=len(vocab)


def sql_tokenize(sql):
    return re.findall(
        r"[A-Za-z_]+\.[A-Za-z_]+|\w+|>=|<=|!=|=|>|<|\(|\)|,",
        sql.lower()
    )


for ex in DATA:

    for t in ex["question"].split():
        add(ENC_VOCAB,t)

    for t,cols in ex["schema"].items():
        add(ENC_VOCAB,t)
        for c in cols:
            add(ENC_VOCAB,f"{t}.{c}")

    for tok in sql_tokenize(ex["sql"]):
        add(DEC_VOCAB,tok)

print(len(ENC_VOCAB),len(DEC_VOCAB))


69 53


In [4]:
class NL2SQLDataset(Dataset):

    def __init__(self,data):
        self.data=data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,i):

        ex=self.data[i]

        src = ex["question"].split()+["<SEP>"]+[
            f"{t}.{c}"
            for t,cs in ex["schema"].items()
            for c in cs
        ]

        src_ids=[ENC_VOCAB.get(t,1) for t in src][:120]
        src_ids+=[0]*(120-len(src_ids))

        tgt=[DEC_VOCAB["<BOS>"]] + \
            [DEC_VOCAB.get(t,1) for t in sql_tokenize(ex["sql"])] + \
            [DEC_VOCAB["<EOS>"]]

        tgt=tgt[:80]
        tgt+=[0]*(80-len(tgt))

        return torch.tensor(src_ids),torch.tensor(tgt)


In [5]:
D_MODEL=384

class PositionalEncoding(nn.Module):
    def __init__(self,d_model,max_len=512):
        super().__init__()

        pe=torch.zeros(max_len,d_model)

        pos=torch.arange(0,max_len).unsqueeze(1)

        div=torch.exp(
            torch.arange(0,d_model,2) *
            (-torch.log(torch.tensor(10000.0))/d_model)
        )

        pe[:,0::2]=torch.sin(pos*div)
        pe[:,1::2]=torch.cos(pos*div)

        self.pe=pe.unsqueeze(0)

    def forward(self,x):
        return x+self.pe[:,:x.size(1)].to(x.device)


In [6]:
class Encoder(nn.Module):

    def __init__(self,vocab):
        super().__init__()

        self.emb=nn.Embedding(vocab,D_MODEL,padding_idx=0)
        self.pos=PositionalEncoding(D_MODEL)

        layer=nn.TransformerEncoderLayer(
            D_MODEL,8,1536,
            dropout=0.1,
            batch_first=True,
            norm_first=True
        )

        self.enc=nn.TransformerEncoder(layer,4)

    def forward(self,x):

        mask=(x==0)

        x=self.pos(self.emb(x))

        return self.enc(x,src_key_padding_mask=mask)



class Decoder(nn.Module):

    def __init__(self,vocab):
        super().__init__()

        self.emb=nn.Embedding(vocab,D_MODEL)
        self.pos=PositionalEncoding(D_MODEL)

        layer=nn.TransformerDecoderLayer(
            D_MODEL,8,1536,
            dropout=0.1,
            batch_first=True,
            norm_first=True
        )

        self.dec=nn.TransformerDecoder(layer,4)

        self.fc=nn.Linear(D_MODEL,vocab)

        self.fc.weight=self.emb.weight

    def forward(self,y,mem,mask):

        L=y.size(1)

        causal=torch.triu(
            torch.ones(L,L,device=y.device),1
        ).bool()

        y=self.pos(self.emb(y))

        return self.fc(
            self.dec(
                y,mem,
                tgt_mask=causal,
                memory_key_padding_mask=mask
            )
        )


In [7]:
train,val=train_test_split(DATA,test_size=0.1,random_state=42)

train_loader=DataLoader(
    NL2SQLDataset(train),
    batch_size=64,
    shuffle=True
)

val_loader=DataLoader(
    NL2SQLDataset(val),
    batch_size=64
)


In [8]:
enc=Encoder(len(ENC_VOCAB)).to(device)
dec=Decoder(len(DEC_VOCAB)).to(device)

optimizer=optim.AdamW(
    list(enc.parameters())+list(dec.parameters()),
    lr=1e-4
)

scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,T_max=15
)

loss_fn=nn.CrossEntropyLoss(
    ignore_index=0,
    label_smoothing=0.05
)

print("ðŸš€ Training...\n")

for epoch in range(15):

    enc.train(); dec.train()
    train_loss=0

    for x,y in train_loader:

        x,y=x.to(device),y.to(device)

        optimizer.zero_grad()

        mem=enc(x)

        out=dec(y[:,:-1],mem,(x==0))

        loss=loss_fn(
            out.reshape(-1,len(DEC_VOCAB)),
            y[:,1:].reshape(-1)
        )

        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            list(enc.parameters())+list(dec.parameters()),
            1.0
        )

        optimizer.step()

        train_loss+=loss.item()


    enc.eval(); dec.eval()
    val_loss=0

    with torch.no_grad():
        for x,y in val_loader:

            x,y=x.to(device),y.to(device)

            mem=enc(x)
            out=dec(y[:,:-1],mem,(x==0))

            val_loss+=loss_fn(
                out.reshape(-1,len(DEC_VOCAB)),
                y[:,1:].reshape(-1)
            ).item()

    scheduler.step()

    print(f"""
Epoch {epoch+1}

Train Loss: {train_loss/len(train_loader):.3f}
Val Loss:   {val_loss/len(val_loader):.3f}
""")




ðŸš€ Training...


Epoch 1

Train Loss: 3.279
Val Loss:   0.446


Epoch 2

Train Loss: 0.449
Val Loss:   0.400


Epoch 3

Train Loss: 0.410
Val Loss:   0.395


Epoch 4

Train Loss: 0.400
Val Loss:   0.394


Epoch 5

Train Loss: 0.396
Val Loss:   0.392


Epoch 6

Train Loss: 0.394
Val Loss:   0.391


Epoch 7

Train Loss: 0.393
Val Loss:   0.391


Epoch 8

Train Loss: 0.392
Val Loss:   0.391


Epoch 9

Train Loss: 0.391
Val Loss:   0.390


Epoch 10

Train Loss: 0.391
Val Loss:   0.390


Epoch 11

Train Loss: 0.391
Val Loss:   0.390


Epoch 12

Train Loss: 0.391
Val Loss:   0.390


Epoch 13

Train Loss: 0.390
Val Loss:   0.390


Epoch 14

Train Loss: 0.390
Val Loss:   0.390


Epoch 15

Train Loss: 0.390
Val Loss:   0.390



In [9]:
torch.save({
    "encoder":enc.state_dict(),
    "decoder":dec.state_dict(),
    "enc_vocab":ENC_VOCAB,
    "dec_vocab":DEC_VOCAB
},"nl2sql_model.pt")

print("âœ… MODEL SAVED")


âœ… MODEL SAVED


In [15]:
checkpoint = torch.load("nl2sql_model.pt", map_location=device)

ENC_VOCAB = checkpoint["enc_vocab"]
DEC_VOCAB = checkpoint["dec_vocab"]

inv_dec_vocab = {v:k for k,v in DEC_VOCAB.items()}

enc = Encoder(len(ENC_VOCAB)).to(device)
dec = Decoder(len(DEC_VOCAB)).to(device)

enc.load_state_dict(checkpoint["encoder"])
dec.load_state_dict(checkpoint["decoder"])

enc.eval()
dec.eval()

print("âœ… Model Loaded")


âœ… Model Loaded


In [16]:
def top_k_sampling(logits, k=5, temperature=0.7):

    logits = logits / temperature

    vals, indices = torch.topk(logits, k)

    probs = torch.softmax(vals, dim=-1)

    sampled = torch.multinomial(probs, 1)

    return indices[0, sampled.item()]


In [17]:
def infer_sql(question, schema, max_len=80):

    enc.eval()
    dec.eval()

    tokens = question.lower().split() + ["<SEP>"] + [
        f"{t}.{c}" for t,cs in schema.items() for c in cs
    ]

    ids = [ENC_VOCAB.get(t,1) for t in tokens][:120]
    ids += [0]*(120-len(ids))

    x = torch.tensor(ids).unsqueeze(0).to(device)

    with torch.no_grad():
        mem = enc(x)

    y = torch.tensor([[DEC_VOCAB["<BOS>"]]], device=device)

    generated=[]

    for _ in range(max_len):

        with torch.no_grad():

            out = dec(y, mem, (x==0))

            logits = out[:,-1,:]

        idx = top_k_sampling(logits, k=5)

        token = inv_dec_vocab[idx.item()]

        if token == "<EOS>":
            break

        if token not in ["<PAD>","<BOS>"]:
            generated.append(token)

        y = torch.cat([y, idx.view(1,1)], dim=1)

    return " ".join(generated)


In [20]:
schema1 = {
    "employees":["id","name","salary","dept_id"],
    "departments":["id","name"]
}

print(infer_sql(
    "show salary of employees earning more than average salary",
    schema1
))

print(infer_sql(
    "list name and salary of employees",
    schema1
))

print(infer_sql(
    "show employees and their departments",
    schema1
))


select name from employees where salary > ( select avg ( salary ) from employees )
select name , salary from employees
select employees.name , departments.name from employees
