<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/Welcome_To_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
!pip install sympy --upgrade

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import random
import string
import json
import re
import math
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [12]:
def random_token(min_len=4, max_len=10):
    return ''.join(random.choices(string.ascii_lowercase,
                                  k=random.randint(min_len, max_len)))

def generate_schema():

    schema = {}

    num_tables = random.randint(1,5)

    for _ in range(num_tables):

        table = random_token()

        cols = {random_token() for _ in range(random.randint(3,7))}
        cols.add("id")  # anchor column

        schema[table] = list(cols)

    return schema


In [13]:
def serialize_schema(schema):

    parts = []

    for t, cols in schema.items():

        random.shuffle(cols)

        parts.append(
            f"{t} : " + " , ".join(cols)
        )

    return " <SCHEMA> " + " | ".join(parts) + " </SCHEMA> "


In [14]:
AGGS = ["SUM","AVG","COUNT","MAX","MIN"]
OPS = [">","<",">=","<=","!="]
JOIN_TYPES = ["JOIN","LEFT JOIN","RIGHT JOIN"]
SORT = ["ASC","DESC"]

def generate_example():

    schema = generate_schema()

    tables = list(schema.keys())
    main = random.choice(tables)
    cols = schema[main]

    intent = random.choice([
        "SELECT","WHERE","HAVING",
        "ORDER","LIMIT","JOIN","NESTED"
    ])

    ################ SELECT ################

    if intent == "SELECT":

        chosen = random.sample(cols, random.randint(1,min(3,len(cols))))

        q = f"get {', '.join(chosen)} from {main}"
        sql = f"SELECT {', '.join(chosen)} FROM {main}"

    ################ WHERE ################

    elif intent == "WHERE":

        c1, c2 = random.sample(cols,2)
        op = random.choice(OPS)
        val = random.randint(1,1000)

        q = f"find {c1} from {main} where {c2} {op} {val}"

        sql = f"""
        SELECT {c1}
        FROM {main}
        WHERE {c2} {op} {val}
        """

    ################ HAVING ################

    elif intent == "HAVING":

        group = random.choice(cols)
        agg_col = random.choice(cols)

        agg = random.choice(AGGS)
        op = random.choice(OPS)
        val = random.randint(1,500)

        q = f"group {main} by {group} having {agg.lower()} {agg_col} {op} {val}"

        sql = f"""
        SELECT {group}, {agg}({agg_col})
        FROM {main}
        GROUP BY {group}
        HAVING {agg}({agg_col}) {op} {val}
        """

    ################ ORDER ################

    elif intent == "ORDER":

        col = random.choice(cols)
        direction = random.choice(SORT)

        q = f"order {main} by {col} {direction.lower()}"

        sql = f"""
        SELECT {col}
        FROM {main}
        ORDER BY {col} {direction}
        """

    ################ LIMIT ################

    elif intent == "LIMIT":

        col = random.choice(cols)
        limit = random.randint(1,20)

        q = f"top {limit} rows of {col} from {main}"

        sql = f"""
        SELECT {col}
        FROM {main}
        LIMIT {limit}
        """

    ################ JOIN ################

    elif intent == "JOIN" and len(tables) > 1:

        t2 = random.choice([t for t in tables if t!=main])

        c1 = random.choice(schema[main])
        c2 = random.choice(schema[t2])

        join = random.choice(JOIN_TYPES)

        q = f"join {main} with {t2}"

        sql = f"""
        SELECT {main}.{c1}, {t2}.{c2}
        FROM {main}
        {join} {t2}
        ON {main}.{c1} = {t2}.{c2}
        """

    ################ NESTED ################

    else:

        col = random.choice(cols)
        agg = random.choice(AGGS)

        q = f"find {col} from {main} greater than average"

        sql = f"""
        SELECT {col}
        FROM {main}
        WHERE {col} >
        (SELECT {agg}({col}) FROM {main})
        """

    full_question = q + serialize_schema(schema)

    return {
        "question": full_question.lower(),
        "sql": " ".join(sql.split())
    }


In [15]:
DATA = [generate_example() for _ in range(60000)]

with open("nl2sql.json","w") as f:
    json.dump(DATA,f)

print("Dataset Ready üöÄ")


Dataset Ready üöÄ


In [16]:
def sql_tokenize(sql):

    return re.findall(
        r"[A-Za-z_]+\.[A-Za-z_]+"
        r"|>=|<=|!=|=|>|<"
        r"|\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bon\b"
        r"|\bgroup\b|\bby\b|\bhaving\b|\border\b|\blimit\b"
        r"|\bavg\b|\bsum\b|\bcount\b|\bmax\b|\bmin\b"
        r"|\(|\)|,"
        r"|[A-Za-z_]+"
        r"|\d+",
        sql.lower()
    )


ENC_VOCAB={"<PAD>":0,"<UNK>":1}
DEC_VOCAB={"<PAD>":0,"<UNK>":1,"<BOS>":2,"<EOS>":3}

def add(vocab,t):
    if t not in vocab:
        vocab[t]=len(vocab)

for ex in DATA:

    for t in ex["question"].split():
        add(ENC_VOCAB,t)

    for t in sql_tokenize(ex["sql"]):
        add(DEC_VOCAB,t)

print(len(ENC_VOCAB), len(DEC_VOCAB))


1066647 144923


In [17]:
SRC_LEN = 96
TGT_LEN = 48

class NL2SQLDataset(Dataset):

    def __init__(self,data):
        self.data=data

    def __len__(self):
        return len(self.data)

    def __getitem__(self,i):

        ex=self.data[i]

        src=[ENC_VOCAB.get(t,1) for t in ex["question"].split()][:SRC_LEN]
        src+=[0]*(SRC_LEN-len(src))

        tgt=[DEC_VOCAB["<BOS>"]] + \
            [DEC_VOCAB.get(t,1) for t in sql_tokenize(ex["sql"])] + \
            [DEC_VOCAB["<EOS>"]]

        tgt=tgt[:TGT_LEN]
        tgt+=[0]*(TGT_LEN-len(tgt))

        return torch.tensor(src),torch.tensor(tgt)


In [18]:
############################################
# MODEL CONFIG  (VERY IMPORTANT ‚Äî DO NOT INCREASE)
############################################

D_MODEL = 256        # Sweet spot (fast + stable)
N_HEADS = 4
NUM_LAYERS = 2
FF_DIM = 512
DROPOUT = 0.1


############################################
# POSITIONAL ENCODING
############################################

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=512):
        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(
            torch.arange(0, d_model, 2) *
            (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):

        return x + self.pe[:, :x.size(1)]


############################################
# ENCODER
############################################

class Encoder(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        self.embedding = nn.Embedding(
            vocab_size,
            D_MODEL,
            padding_idx=0
        )

        self.pos = PositionalEncoding(D_MODEL)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=D_MODEL,
            nhead=N_HEADS,
            dim_feedforward=FF_DIM,
            dropout=DROPOUT,
            batch_first=True,
            norm_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=NUM_LAYERS
        )

    def forward(self, x):

        padding_mask = (x == 0)

        x = self.embedding(x)
        x = self.pos(x)

        return self.encoder(
            x,
            src_key_padding_mask=padding_mask
        )


############################################
# DECODER
############################################

class Decoder(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        self.embedding = nn.Embedding(
            vocab_size,
            D_MODEL
        )

        self.pos = PositionalEncoding(D_MODEL)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=D_MODEL,
            nhead=N_HEADS,
            dim_feedforward=FF_DIM,
            dropout=DROPOUT,
            batch_first=True,
            norm_first=True
        )

        self.decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=NUM_LAYERS
        )

        self.fc = nn.Linear(D_MODEL, vocab_size)

        # ‚≠ê weight tying (VERY IMPORTANT for NL‚ÜíSQL)
        self.fc.weight = self.embedding.weight


    def forward(self, y, memory, src_padding_mask):

        L = y.size(1)

        # causal mask prevents cheating
        causal_mask = torch.triu(
            torch.ones(L, L, device=y.device),
            diagonal=1
        ).bool()

        y = self.embedding(y)
        y = self.pos(y)

        output = self.decoder(
            y,
            memory,
            tgt_mask=causal_mask,
            memory_key_padding_mask=src_padding_mask
        )

        return self.fc(output)


In [19]:
train_data, val_data = train_test_split(DATA, test_size=0.1)

train_loader = DataLoader(
    NL2SQLDataset(train_data),
    batch_size=16,
    shuffle=True
)

val_loader = DataLoader(
    NL2SQLDataset(val_data),
    batch_size=16
)


In [20]:
enc = Encoder(len(ENC_VOCAB)).to(device)
dec = Decoder(len(DEC_VOCAB)).to(device)

optimizer = optim.AdamW(
    list(enc.parameters()) + list(dec.parameters()),
    lr=1e-4
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,T_max=12
)

loss_fn = nn.CrossEntropyLoss(ignore_index=0,label_smoothing=0.05)

scaler = torch.amp.GradScaler("cuda")




In [None]:
CHECKPOINT_PATH = "nl2sql_checkpoint.pt"
start_epoch = 0

if os.path.exists(CHECKPOINT_PATH):

    checkpoint = torch.load(CHECKPOINT_PATH,map_location=device)

    enc.load_state_dict(checkpoint["encoder"])
    dec.load_state_dict(checkpoint["decoder"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    start_epoch = checkpoint["epoch"]+1

EPOCHS = 12

for epoch in range(start_epoch,EPOCHS):

    enc.train()
    dec.train()

    total=0

    for x,y in train_loader:

        x,y=x.to(device),y.to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda"):

            mem=enc(x)
            out=dec(y[:,:-1],mem,(x==0))

            loss=loss_fn(
                out.reshape(-1,len(DEC_VOCAB)),
                y[:,1:].reshape(-1)
            )

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total+=loss.item()

    print(f"\nEpoch {epoch+1} Train Loss:",total/len(train_loader))

    torch.save({
        "epoch":epoch,
        "encoder":enc.state_dict(),
        "decoder":dec.state_dict(),
        "optimizer":optimizer.state_dict(),
        "enc_vocab":ENC_VOCAB,
        "dec_vocab":DEC_VOCAB
    },CHECKPOINT_PATH)

    torch.cuda.empty_cache()



Epoch 1 Train Loss: 18.812972576282643
