<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/Welcome_To_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!unzip /content/archive.zip -d /content/

Archive:  /content/archive.zip
  inflating: /content/spider/README.txt  
  inflating: /content/spider/database/academic/academic.sqlite  
  inflating: /content/spider/database/academic/schema.sql  
  inflating: /content/spider/database/activity_1/activity_1.sqlite  
  inflating: /content/spider/database/activity_1/schema.sql  
  inflating: /content/spider/database/aircraft/aircraft.sqlite  
  inflating: /content/spider/database/aircraft/schema.sql  
  inflating: /content/spider/database/allergy_1/allergy_1.sqlite  
  inflating: /content/spider/database/allergy_1/schema.sql  
  inflating: /content/spider/database/apartment_rentals/apartment_rentals.sqlite  
  inflating: /content/spider/database/apartment_rentals/schema.sql  
  inflating: /content/spider/database/architecture/architecture.sqlite  
  inflating: /content/spider/database/architecture/schema.sql  
  inflating: /content/spider/database/assets_maintenance/assets_maintenance.sqlite  
  inflating: /content/spider/database/assets

In [2]:
# ============================================================
# SPIDER DATASET PREPROCESSING ‚Üí SCHEMA-AWARE FORMAT
# OUTPUT: spider_schema_aware.json (7000 examples)
# ============================================================

import json
import re
import os

# ============================================================
# 1Ô∏è‚É£ LOAD RAW SPIDER FILES
# ============================================================
TRAIN_PATH = "/content/spider/train_spider.json"
TABLES_PATH = "/content/spider/tables.json"

assert os.path.exists(TRAIN_PATH), "‚ùå train_spider.json not found"
assert os.path.exists(TABLES_PATH), "‚ùå tables.json not found"

with open(TRAIN_PATH, "r", encoding="utf-8") as f:
    train_data = json.load(f)

with open(TABLES_PATH, "r", encoding="utf-8") as f:
    tables_data = json.load(f)

print("‚úÖ Loaded questions:", len(train_data))
print("‚úÖ Loaded schemas:", len(tables_data))

# ============================================================
# 2Ô∏è‚É£ BUILD DB ‚Üí SCHEMA MAP
# ============================================================
db_schemas = {}

for db in tables_data:
    db_id = db["db_id"]
    tables = db["table_names_original"]
    columns = db["column_names_original"]

    schema = {}
    for tid, col in columns:
        if tid == -1:
            continue
        table = tables[tid]
        schema.setdefault(table, []).append(col)

    db_schemas[db_id] = schema

print("‚úÖ Schema map built")

# ============================================================
# 3Ô∏è‚É£ TOKENIZE QUESTION
# ============================================================
def tokenize_question(text):
    text = text.lower()
    text = re.sub(r"[^a-zA-Z0-9_ ]", " ", text)
    return text.split()

# ============================================================
# 4Ô∏è‚É£ SAFE SCHEMA MATCHING FROM SQL
# ============================================================
def find_schema_mentions(sql, schema):
    sql = sql.lower()
    used_tables = set()
    used_columns = set()

    for table, cols in schema.items():
        if re.search(rf"\b{re.escape(table.lower())}\b", sql):
            used_tables.add(table)

        for col in cols:
            if re.search(rf"\b{re.escape(col.lower())}\b", sql):
                used_columns.add(f"{table}.{col}")

    return used_tables, used_columns

# ============================================================
# 5Ô∏è‚É£ BUILD INPUT TOKENS + TOKEN TYPES
# ============================================================
def build_input_with_schema(question_tokens, schema):
    tokens = []
    token_types = []

    # Question tokens ‚Üí type 0
    for tok in question_tokens:
        tokens.append(tok)
        token_types.append(0)

    # Tables ‚Üí type 1, Columns ‚Üí type 2
    for table, cols in schema.items():
        tokens.append(table)
        token_types.append(1)
        for col in cols:
            tokens.append(col)
            token_types.append(2)

    return tokens, token_types

# ============================================================
# 6Ô∏è‚É£ BUILD SCHEMA LABELS
# ============================================================
def build_schema_labels(tokens, token_types, used_tables, used_columns):
    labels = []

    for tok, ttype in zip(tokens, token_types):
        if ttype == 1:  # table
            labels.append(1 if tok in used_tables else 0)
        elif ttype == 2:  # column
            labels.append(
                1 if any(tok == c.split(".")[1] for c in used_columns) else 0
            )
        else:
            labels.append(0)

    return labels

# ============================================================
# 7Ô∏è‚É£ CONVERT ONE EXAMPLE
# ============================================================
def convert_example(ex):
    schema = db_schemas[ex["db_id"]]

    q_tokens = tokenize_question(ex["question"])
    used_tables, used_columns = find_schema_mentions(ex["query"], schema)

    tokens, token_types = build_input_with_schema(q_tokens, schema)
    schema_labels = build_schema_labels(
        tokens, token_types, used_tables, used_columns
    )

    return {
        "tokens": tokens,
        "token_types": token_types,
        "schema_labels": schema_labels,
        "schema": schema,
        "sql": ex["query"]
    }

# ============================================================
# 8Ô∏è‚É£ PROCESS DATASET (7000 EXAMPLES)
# ============================================================
MAX_EXAMPLES = 7000
processed = []

for i, ex in enumerate(train_data[:MAX_EXAMPLES]):
    processed.append(convert_example(ex))
    if i > 0 and i % 500 == 0:
        print(f"‚è≥ Processed {i} examples")

print("‚úÖ Total processed:", len(processed))

# ============================================================
# 9Ô∏è‚É£ SAVE OUTPUT
# ============================================================
OUTPUT_PATH = "/content/spider_schema_aware.json"

with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
    json.dump(processed, f, indent=2)

print("‚úÖ Saved to:", OUTPUT_PATH)

# ============================================================
# üîü SANITY CHECK
# ============================================================
sample = processed[0]
print("\n--- SAMPLE ---")
print("Tokens:", sample["tokens"][:20])
print("Token types:", sample["token_types"][:20])
print("Schema labels:", sample["schema_labels"][:20])
print("SQL:", sample["sql"])


‚úÖ Loaded questions: 7000
‚úÖ Loaded schemas: 166
‚úÖ Schema map built
‚è≥ Processed 500 examples
‚è≥ Processed 1000 examples
‚è≥ Processed 1500 examples
‚è≥ Processed 2000 examples
‚è≥ Processed 2500 examples
‚è≥ Processed 3000 examples
‚è≥ Processed 3500 examples
‚è≥ Processed 4000 examples
‚è≥ Processed 4500 examples
‚è≥ Processed 5000 examples
‚è≥ Processed 5500 examples
‚è≥ Processed 6000 examples
‚è≥ Processed 6500 examples
‚úÖ Total processed: 7000
‚úÖ Saved to: /content/spider_schema_aware.json

--- SAMPLE ---
Tokens: ['how', 'many', 'heads', 'of', 'the', 'departments', 'are', 'older', 'than', '56', 'department', 'Department_ID', 'Name', 'Creation', 'Ranking', 'Budget_in_Billions', 'Num_Employees', 'head', 'head_ID', 'name']
Token types: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2]
Schema labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]
SQL: SELECT count(*) FROM head WHERE age  >  56


In [10]:
# ============================================================
# NL2SQL ‚Äì SCHEMA-AWARE ENCODER DECODER (7000 EXAMPLES)
# ============================================================

import torch
import torch.nn as nn
import torch.optim as optim
import math
import json
import re
import os
from torch.utils.data import Dataset, DataLoader

# ============================================================
# 1Ô∏è‚É£ DEVICE
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ============================================================
# 2Ô∏è‚É£ LOAD PREPROCESSED DATA
# ============================================================
DATA_PATH = "/content/spider_schema_aware.json"
assert os.path.exists(DATA_PATH), "‚ùå spider_schema_aware.json not found"

with open(DATA_PATH) as f:
    data = json.load(f)

print("Total examples in JSON:", len(data))

# ============================================================
# 3Ô∏è‚É£ BUILD ENCODER VOCAB
# ============================================================
def build_encoder_vocab(data):
    vocab = {"<PAD>": 0, "<UNK>": 1}
    idx = 2
    for ex in data:
        for t in ex["tokens"]:
            if t not in vocab:
                vocab[t] = idx
                idx += 1
    return vocab

encoder_vocab = build_encoder_vocab(data)
enc_vocab_size = len(encoder_vocab)
print("Encoder vocab size:", enc_vocab_size)

# ============================================================
# 4Ô∏è‚É£ DECODER SQL VOCAB (PICARD-STYLE)
# ============================================================
SQL_KEYWORDS = ["SELECT","FROM","WHERE","JOIN","ON","AS","DISTINCT"]
LOGICAL_OPS = ["AND","OR","NOT"]
COMPARISON_OPS = ["=","!=","<>",">","<",">=","<="]
AGG_FUNCS = ["COUNT","SUM","AVG","MIN","MAX"]
GROUPING = ["GROUP","BY","HAVING"]
ORDERING = ["ORDER","ASC","DESC","LIMIT"]
PUNCT = [",","(",")"]
SPECIAL = ["<PAD>","<EOS>","VALUE"]

MAX_TABLES = 128
MAX_COLUMNS = 1024
TABLE_TOKENS = [f"T{i}" for i in range(MAX_TABLES)]
COLUMN_TOKENS = [f"C{i}" for i in range(MAX_COLUMNS)]

DECODER_VOCAB = (
    SQL_KEYWORDS + LOGICAL_OPS + COMPARISON_OPS + AGG_FUNCS +
    GROUPING + ORDERING + PUNCT + SPECIAL +
    TABLE_TOKENS + COLUMN_TOKENS
)

token_to_id = {t:i for i,t in enumerate(DECODER_VOCAB)}
VOCAB_SIZE = len(token_to_id)
print("Decoder vocab size:", VOCAB_SIZE)

# ============================================================
# 5Ô∏è‚É£ SCHEMA ‚Üí PLACEHOLDER MAPPING
# ============================================================
def build_schema_maps(schema):
    table_map = {t: f"T{i}" for i, t in enumerate(schema.keys())}
    col_map = {}
    cid = 0
    for t, cols in schema.items():
        for c in cols:
            col_map[f"{t}.{c}"] = f"C{cid}"
            cid += 1
    return table_map, col_map

def sql_to_placeholder(sql, tmap, cmap):
    sql = sql.lower()
    for k, v in sorted(cmap.items(), key=lambda x: -len(x[0])):
        sql = re.sub(rf"\b{re.escape(k)}\b", v, sql)
    for k, v in tmap.items():
        sql = re.sub(rf"\b{re.escape(k)}\b", v, sql)
    sql = re.sub(r"\b\d+\b", "VALUE", sql)
    return sql.upper()

# ============================================================
# 6Ô∏è‚É£ DATASET
# ============================================================
class SpiderDataset(Dataset):
    def __init__(self, data, vocab, limit=7000):
        self.data = data[:limit]
        self.vocab = vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        tmap, cmap = build_schema_maps(ex["schema"])
        sql = sql_to_placeholder(ex["sql"], tmap, cmap)

        x = [self.vocab.get(t, 1) for t in ex["tokens"]]
        y = [token_to_id.get(t, token_to_id["<PAD>"]) for t in sql.split()]
        y.append(token_to_id["<EOS>"])

        return (
            torch.tensor(x),
            torch.tensor(ex["token_types"]),
            torch.tensor(ex["schema_labels"], dtype=torch.float),
            torch.tensor(y)
        )

def collate_fn(batch):
    def pad(seqs, val=0):
        m = max(len(s) for s in seqs)
        return torch.stack([
            torch.cat([s, torch.full((m - len(s),), val)]) for s in seqs
        ])
    x, t, s, y = zip(*batch)
    return pad(x), pad(t), pad(s), pad(y)

dataset = SpiderDataset(data, encoder_vocab, limit=7000)
loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

print("Training examples:", len(dataset))

# ============================================================
# 7Ô∏è‚É£ TRANSFORMER MODULES
# ============================================================
class MultiHeadAttention(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.h = h
        self.dk = d // h
        self.qkv = nn.Linear(d, d * 3)
        self.o = nn.Linear(d, d)

    def forward(self, q, k, v, mask=None):
        B, T, D = q.size()
        q, k, v = self.qkv(q).chunk(3, dim=-1)
        q = q.view(B, T, self.h, self.dk).transpose(1,2)
        k = k.view(B, -1, self.h, self.dk).transpose(1,2)
        v = v.view(B, -1, self.h, self.dk).transpose(1,2)
        scores = q @ k.transpose(-2,-1) / math.sqrt(self.dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        out = (scores.softmax(-1) @ v).transpose(1,2).contiguous().view(B,T,D)
        return self.o(out)

class Encoder(nn.Module):
    def __init__(self, vocab, d, h, ff, layers):
        super().__init__()
        self.emb = nn.Embedding(vocab, d, padding_idx=0)
        self.type_emb = nn.Embedding(3, d)
        self.layers = nn.ModuleList([
            nn.ModuleList([
                MultiHeadAttention(d,h),
                nn.Sequential(nn.Linear(d,ff), nn.ReLU(), nn.Linear(ff,d)),
                nn.LayerNorm(d),
                nn.LayerNorm(d)
            ]) for _ in range(layers)
        ])
        self.schema_head = nn.Linear(d,1)

    def forward(self, x, t, mask):
        x = self.emb(x) + self.type_emb(t)
        for attn, ff, n1, n2 in self.layers:
            x = n1(x + attn(x,x,x,mask))
            x = n2(x + ff(x))
        return x, self.schema_head(x).squeeze(-1)

class Decoder(nn.Module):
    def __init__(self, vocab, d, h, ff, layers):
        super().__init__()
        self.emb = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(512, d)
        self.layers = nn.ModuleList([
            nn.ModuleList([
                MultiHeadAttention(d,h),
                MultiHeadAttention(d,h),
                nn.Sequential(nn.Linear(d,ff), nn.ReLU(), nn.Linear(ff,d)),
                nn.LayerNorm(d),
                nn.LayerNorm(d),
                nn.LayerNorm(d)
            ]) for _ in range(layers)
        ])
        self.out = nn.Linear(d, vocab)

    def forward(self, y, enc, mask):
        B, T = y.size()
        x = self.emb(y) + self.pos(torch.arange(T, device=y.device))
        for sa, ca, ff, n1, n2, n3 in self.layers:
            x = n1(x + sa(x,x,x,mask))
            x = n2(x + ca(x,enc,enc,None))
            x = n3(x + ff(x))
        return self.out(x)

# ============================================================
# 8Ô∏è‚É£ TRAINING
# ============================================================
encoder = Encoder(enc_vocab_size, 256, 8, 1024, 4).to(device)
decoder = Decoder(VOCAB_SIZE, 256, 8, 1024, 4).to(device)

opt = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=3e-4)
sql_loss = nn.CrossEntropyLoss(ignore_index=token_to_id["<PAD>"])
schema_loss = nn.BCEWithLogitsLoss()

EPOCHS = 25

for ep in range(EPOCHS):
    total = 0
    for x, t, s, y in loader:
        x, t, s, y = x.to(device), t.to(device), s.to(device), y.to(device)
        mask = (x != 0).unsqueeze(1).unsqueeze(2)

        enc, sch = encoder(x, t, mask)
        dec_in, dec_tgt = y[:, :-1], y[:, 1:]

        T = dec_in.size(1)
        causal = torch.tril(torch.ones(T, T, device=device)).unsqueeze(0).unsqueeze(0)

        logits = decoder(dec_in, enc, causal)

        loss = sql_loss(logits.reshape(-1, VOCAB_SIZE), dec_tgt.reshape(-1))
        loss += 0.7 * schema_loss(sch*(t!=0), s*(t!=0))

        opt.zero_grad()
        loss.backward()
        opt.step()
        total += loss.item()

    print(f"Epoch {ep+1} | Loss: {total/len(loader):.4f}")

# ============================================================
# 9Ô∏è‚É£ SAVE MODEL
# ============================================================
torch.save({
    "encoder": encoder.state_dict(),
    "decoder": decoder.state_dict(),
    "encoder_vocab": encoder_vocab,
    "decoder_vocab": token_to_id
}, "/content/nl2sql_schema_aware.pt")

print("‚úÖ Model saved successfully")


Using device: cpu
Total examples in JSON: 7000
Encoder vocab size: 5312
Decoder vocab size: 1187
Training examples: 7000
Epoch 1 | Loss: 0.5832
Epoch 2 | Loss: 0.3839
Epoch 3 | Loss: 0.3574
Epoch 4 | Loss: 0.3396
Epoch 5 | Loss: 0.3294
Epoch 6 | Loss: 0.3204
Epoch 7 | Loss: 0.3177
Epoch 8 | Loss: 0.3144
Epoch 9 | Loss: 0.3091
Epoch 10 | Loss: 0.3114
Epoch 11 | Loss: 0.3126
Epoch 12 | Loss: 0.3081
Epoch 13 | Loss: 0.3104
Epoch 14 | Loss: 0.3066
Epoch 15 | Loss: 0.3084
Epoch 16 | Loss: 0.3053
Epoch 17 | Loss: 0.3391
Epoch 18 | Loss: 0.3115
Epoch 19 | Loss: 0.3052
Epoch 20 | Loss: 0.3063
Epoch 21 | Loss: 0.3067
Epoch 22 | Loss: 0.3072
Epoch 23 | Loss: 0.3064
Epoch 24 | Loss: 0.3053
Epoch 25 | Loss: 0.3060
‚úÖ Model saved successfully


In [14]:
# ============================================================
# NL2SQL INFERENCE ‚Äì FINAL SINGLE CELL (WORKING & STABLE)
# ============================================================

import torch
import torch.nn as nn
import math
import re

# ------------------------------------------------------------
# DEVICE
# ------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------------------------------------------------
# LOAD CHECKPOINT
# ------------------------------------------------------------
ckpt = torch.load("/content/nl2sql_schema_aware.pt", map_location=device)

encoder_vocab = ckpt["encoder_vocab"]
decoder_vocab = ckpt["decoder_vocab"]
id_to_token = {v: k for k, v in decoder_vocab.items()}

# ------------------------------------------------------------
# SQL CONSTANTS (MUST MATCH TRAINING)
# ------------------------------------------------------------
AGG_FUNCS = ["COUNT","SUM","AVG","MIN","MAX"]
COMPARISON_OPS = ["=","!=","<>",">","<",">=","<="]

TABLE_TOKENS = [t for t in decoder_vocab if t.startswith("T")]
COLUMN_TOKENS = [t for t in decoder_vocab if t.startswith("C")]

# ------------------------------------------------------------
# MODEL DEFINITIONS (MATCH TRAINING)
# ------------------------------------------------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, d, h):
        super().__init__()
        self.h = h
        self.dk = d // h
        self.qkv = nn.Linear(d, d * 3)
        self.o = nn.Linear(d, d)

    def forward(self, q, k, v, mask=None):
        B, T, D = q.size()
        q, k, v = self.qkv(q).chunk(3, dim=-1)

        q = q.view(B, T, self.h, self.dk).transpose(1, 2)
        k = k.view(B, -1, self.h, self.dk).transpose(1, 2)
        v = v.view(B, -1, self.h, self.dk).transpose(1, 2)

        scores = q @ k.transpose(-2, -1) / math.sqrt(self.dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        out = (scores.softmax(-1) @ v)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.o(out)

class Encoder(nn.Module):
    def __init__(self, vocab, d=256, h=8, ff=1024, layers=4):
        super().__init__()
        self.emb = nn.Embedding(vocab, d, padding_idx=0)
        self.type_emb = nn.Embedding(3, d)

        self.layers = nn.ModuleList([
            nn.ModuleList([
                MultiHeadAttention(d, h),
                nn.Sequential(nn.Linear(d, ff), nn.ReLU(), nn.Linear(ff, d)),
                nn.LayerNorm(d),
                nn.LayerNorm(d)
            ]) for _ in range(layers)
        ])

        # schema_head existed during training but NOT needed at inference

    def forward(self, x, t, mask):
        x = self.emb(x) + self.type_emb(t)
        for attn, ff, n1, n2 in self.layers:
            x = n1(x + attn(x, x, x, mask))
            x = n2(x + ff(x))
        return x

class Decoder(nn.Module):
    def __init__(self, vocab, d=256, h=8, ff=1024, layers=4):
        super().__init__()
        self.emb = nn.Embedding(vocab, d)
        self.pos = nn.Embedding(512, d)

        self.layers = nn.ModuleList([
            nn.ModuleList([
                MultiHeadAttention(d, h),
                MultiHeadAttention(d, h),
                nn.Sequential(nn.Linear(d, ff), nn.ReLU(), nn.Linear(ff, d)),
                nn.LayerNorm(d),
                nn.LayerNorm(d),
                nn.LayerNorm(d)
            ]) for _ in range(layers)
        ])

        self.out = nn.Linear(d, vocab)

    def forward(self, y, enc, mask):
        B, T = y.size()
        x = self.emb(y) + self.pos(torch.arange(T, device=y.device))

        for sa, ca, ff, n1, n2, n3 in self.layers:
            x = n1(x + sa(x, x, x, mask))
            x = n2(x + ca(x, enc, enc, None))
            x = n3(x + ff(x))

        return self.out(x)

# ------------------------------------------------------------
# LOAD MODEL WEIGHTS (KEY FIX HERE)
# ------------------------------------------------------------
encoder = Encoder(len(encoder_vocab)).to(device)
decoder = Decoder(len(decoder_vocab)).to(device)

encoder.load_state_dict(ckpt["encoder"], strict=False)  # ‚úÖ FIX
decoder.load_state_dict(ckpt["decoder"])

encoder.eval()
decoder.eval()

print("‚úÖ Model loaded successfully")

# ------------------------------------------------------------
# HELPERS
# ------------------------------------------------------------
def tokenize_question(text):
    text = text.lower()
    text = re.sub(r"[^a-zA-Z0-9_ ]", " ", text)
    return text.split()

def build_encoder_input(question, schema):
    tokens, types = [], []

    for t in tokenize_question(question):
        tokens.append(t)
        types.append(0)

    for table, cols in schema.items():
        tokens.append(table)
        types.append(1)
        for col in cols:
            tokens.append(col)
            types.append(2)

    ids = [encoder_vocab.get(t, encoder_vocab["<UNK>"]) for t in tokens]

    return (
        torch.tensor(ids).unsqueeze(0).to(device),
        torch.tensor(types).unsqueeze(0).to(device)
    )

# ------------------------------------------------------------
# SIMPLE PICARD GRAMMAR
# ------------------------------------------------------------
def picard_filter(prefix):
    if len(prefix) == 0:
        return ["SELECT"]

    last = prefix[-1]

    if last == "SELECT":
        return AGG_FUNCS + COLUMN_TOKENS
    if last in AGG_FUNCS or last in COLUMN_TOKENS:
        return ["FROM", ","]
    if last == ",":
        return COLUMN_TOKENS
    if last == "FROM":
        return TABLE_TOKENS
    if last in TABLE_TOKENS:
        return ["WHERE", "<EOS>"]
    if last == "WHERE":
        return COLUMN_TOKENS
    if last in COLUMN_TOKENS:
        return COMPARISON_OPS
    if last in COMPARISON_OPS:
        return ["VALUE"]
    if last == "VALUE":
        return ["AND", "<EOS>"]
    if last == "AND":
        return COLUMN_TOKENS

    return ["<EOS>"]

# ------------------------------------------------------------
# SQL GENERATION
# ------------------------------------------------------------
def generate_sql(question, schema, max_len=40):
    x, t = build_encoder_input(question, schema)
    mask = (x != 0).unsqueeze(1).unsqueeze(2)

    with torch.no_grad():
        enc = encoder(x, t, mask)

    cur = torch.tensor([[decoder_vocab["SELECT"]]], device=device)
    generated = ["SELECT"]

    for _ in range(max_len):
        T = cur.size(1)
        causal = torch.tril(torch.ones(T, T, device=device)).unsqueeze(0).unsqueeze(0)

        with torch.no_grad():
            logits = decoder(cur, enc, causal)[0, -1]

        allowed = picard_filter(generated)
        allowed_ids = [decoder_vocab[t] for t in allowed if t in decoder_vocab]

        masked = torch.full_like(logits, -1e9)
        masked[allowed_ids] = logits[allowed_ids]

        next_id = masked.argmax().item()
        token = id_to_token[next_id]

        if token == "<EOS>":
            break

        generated.append(token)
        cur = torch.cat([cur, torch.tensor([[next_id]], device=device)], dim=1)

    return " ".join(generated)

# ------------------------------------------------------------
# TEST
# ------------------------------------------------------------
schema = {
    "employees": ["id", "name", "salary", "department_id"],
    "departments": ["id", "department_name"]
}

question = "show salary of employees"

print("\nüü¢ Generated SQL:")
print(generate_sql(question, schema))


Using device: cpu
‚úÖ Model loaded successfully

üü¢ Generated SQL:
SELECT COUNT FROM T0
