<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_TO_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random, json

SCHEMAS = [

    {
        "tables":{
            "employees":["id","name","salary","dept_id"],
            "departments":["id","name"]
        },
        "join":("employees","departments","dept_id","id")
    },

    {
        "tables":{
            "students":["id","name","marks","class_id"],
            "classes":["id","name"]
        },
        "join":("students","classes","class_id","id")
    }
]


PREFIX = [
    "show","list","display","fetch","give",
    "find","retrieve","get","tell me"
]

AGGS = ["SUM","AVG","COUNT","MAX","MIN"]

COMPARE = [
    (">","greater than"),
    ("<","less than"),
    (">=","at least"),
    ("<=","at most")
]


def generate_example():

    db = random.choice(SCHEMAS)
    schema = db["tables"]

    t1, t2, c1, c2 = db["join"]

    main = t1
    cols = schema[main]

    prefix = random.choice(PREFIX)

    intent = random.choice([
        "SELECT","MULTI","WHERE","AGG",
        "ORDER","GROUP",
        "JOIN","JOIN_WHERE","JOIN_AGG",
        "NESTED","CORRELATED"
    ])

    col = random.choice(cols)

    # ---------- SELECT ----------
    if intent=="SELECT":

        q=f"{prefix} {col} from {main}"
        sql=f"SELECT {col} FROM {main}"


    # ---------- MULTI ----------
    elif intent=="MULTI":

        c1_,c2_=random.sample(cols,2)

        q=f"{prefix} {c1_} and {c2_} from {main}"
        sql=f"SELECT {c1_}, {c2_} FROM {main}"


    # ---------- WHERE ----------
    elif intent=="WHERE":

        op,text=random.choice(COMPARE)
        val=random.choice([10,50,100,500])

        q=f"{prefix} {main} where {col} is {text} {val}"
        sql=f"SELECT name FROM {main} WHERE {col} {op} {val}"


    # ---------- AGG ----------
    elif intent=="AGG":

        agg=random.choice(AGGS)

        q=f"what is the {agg.lower()} of {col} in {main}"
        sql=f"SELECT {agg}({col}) FROM {main}"


    # ---------- ORDER ----------
    elif intent=="ORDER":

        q=f"{prefix} {main} ordered by {col}"
        sql=f"SELECT name FROM {main} ORDER BY {col} DESC"


    # ---------- GROUP ----------
    elif intent=="GROUP":

        agg=random.choice(AGGS)

        q=f"{prefix} {agg.lower()} salary grouped by dept"
        sql=f"SELECT dept_id, {agg}(salary) FROM {main} GROUP BY dept_id"


    # ==================================================
    # ðŸ”¥ MASSIVE JOIN VARIETY
    # ==================================================

    elif intent=="JOIN":

        pattern=random.randint(1,4)

        if pattern==1:
            q=f"{prefix} {t1} with their {t2}"
            sql=f"SELECT {t1}.name, {t2}.name FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"

        elif pattern==2:
            q=f"{prefix} {t2} and corresponding {t1}"
            sql=f"SELECT {t2}.name, {t1}.name FROM {t2} JOIN {t1} ON {t2}.{c2} = {t1}.{c1}"

        elif pattern==3:
            q=f"{prefix} names from both {t1} and {t2}"
            sql=f"SELECT {t1}.name, {t2}.name FROM {t1} INNER JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"

        else:
            q=f"{prefix} all {t1} mapped to {t2}"
            sql=f"SELECT * FROM {t1} JOIN {t2} ON {t1}.{c1} = {t2}.{c2}"


    # ---------- JOIN + WHERE ----------
    elif intent=="JOIN_WHERE":

        val=random.randint(1,5)

        q=f"{prefix} {t1} working in {t2} where id > {val}"

        sql=f"""SELECT {t1}.name, {t2}.name
                FROM {t1}
                JOIN {t2}
                ON {t1}.{c1} = {t2}.{c2}
                WHERE {t2}.id > {val}""".replace("\n"," ")


    # ---------- JOIN + AGG ----------
    elif intent=="JOIN_AGG":

        q=f"{prefix} number of {t1} per {t2}"

        sql=f"""SELECT {t2}.name, COUNT(*)
                FROM {t1}
                JOIN {t2}
                ON {t1}.{c1} = {t2}.{c2}
                GROUP BY {t2}.name""".replace("\n"," ")


    # ==================================================
    # ðŸ”¥ NESTED VARIETY (VERY IMPORTANT)
    # ==================================================

    elif intent=="NESTED":

        agg=random.choice(["AVG","MAX","MIN"])

        q=f"{prefix} {main} earning above {agg.lower()} salary"

        sql=f"""SELECT name FROM {main}
                WHERE salary >
                (SELECT {agg}(salary) FROM {main})""".replace("\n"," ")


    # ---------- CORRELATED ----------
    else:

        q=f"{prefix} employees earning above their department average"

        sql=f"""
        SELECT name
        FROM employees e
        WHERE salary >
        (
            SELECT AVG(salary)
            FROM employees
            WHERE dept_id = e.dept_id
        )
        """.replace("\n"," ")


    return {
        "question":q.lower(),
        "schema":schema,
        "sql":" ".join(sql.split())
    }


DATA=[generate_example() for _ in range(200000)]

with open("nl2sql_data.json","w") as f:
    json.dump(DATA,f)

print("âœ… NEXT-LEVEL dataset generated")


In [None]:
import re, torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


In [None]:
ENC_VOCAB = {"<PAD>":0,"<UNK>":1,"<SEP>":2}
DEC_VOCAB = {"<PAD>":0,"<UNK>":1,"<BOS>":2,"<EOS>":3}

def add(vocab,token):
    if token not in vocab:
        vocab[token]=len(vocab)


def sql_tokenize(sql):
    return re.findall(
        r"[A-Za-z_]+\.[A-Za-z_]+|\w+|>=|<=|!=|=|>|<|\(|\)|,|\*",
        sql.lower()
    )


for ex in DATA:

    for t in ex["question"].split():
        add(ENC_VOCAB,t)

    for t,cols in ex["schema"].items():
        add(ENC_VOCAB,t)
        for c in cols:
            add(ENC_VOCAB,f"{t}.{c}")

    for tok in sql_tokenize(ex["sql"]):
        add(DEC_VOCAB,tok)

print("Encoder vocab:",len(ENC_VOCAB))
print("Decoder vocab:",len(DEC_VOCAB))
