<a href="https://colab.research.google.com/github/vishal7379/Colab/blob/main/NL_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================
# MULTI-SCHEMA NL ‚Üí SQL DATASET GENERATOR (7000 EXAMPLES)
# ============================================================

import json, random
from google.colab import files

random.seed(42)

OUTPUT = "/content/custom_multischema_nl2sql_7000.json"
DATASET = []

NUM_SCHEMAS = 70
EXAMPLES_PER_SCHEMA = 100  # 70 √ó 100 = 7000

DEPARTMENTS = ["HR", "Sales", "IT", "Finance"]

def make_schema(i):
    return {
        f"employees_{i}": ["id", "name", "salary", "dept_id"],
        f"departments_{i}": ["id", "name"],
        f"projects_{i}": ["id", "budget", "dept_id"]
    }

def add(q, sql, schema):
    DATASET.append({
        "question": q,
        "schema": schema,
        "sql": sql
    })

for i in range(NUM_SCHEMAS):
    schema = make_schema(i)
    emp = f"employees_{i}"
    dep = f"departments_{i}"
    proj = f"projects_{i}"

    for _ in range(EXAMPLES_PER_SCHEMA // 8):
        add(
            f"show names of employees",
            f"SELECT name FROM {emp}",
            schema
        )

        dept = random.choice(DEPARTMENTS)
        add(
            f"show employees in {dept.lower()} department",
            f"""SELECT {emp}.name
FROM {emp}
JOIN {dep} ON {emp}.dept_id = {dep}.id
WHERE {dep}.name = '{dept}'""",
            schema
        )

        val = random.choice([30000, 40000, 50000])
        add(
            f"show employees with salary greater than {val}",
            f"SELECT name FROM {emp} WHERE salary > {val}",
            schema
        )

        add(
            f"count employees",
            f"SELECT COUNT(*) FROM {emp}",
            schema
        )

        add(
            f"count employees in each department",
            f"""SELECT {dep}.name, COUNT(*)
FROM {emp}
JOIN {dep} ON {emp}.dept_id = {dep}.id
GROUP BY {dep}.name""",
            schema
        )

        add(
            f"show top 5 highest paid employees",
            f"SELECT name FROM {emp} ORDER BY salary DESC LIMIT 5",
            schema
        )

        add(
            f"show employees in hr department with salary greater than 50000",
            f"""SELECT {emp}.name
FROM {emp}
JOIN {dep} ON {emp}.dept_id = {dep}.id
WHERE {dep}.name = 'HR' AND {emp}.salary > 50000""",
            schema
        )

        add(
            f"show total project budget by department",
            f"""SELECT {dep}.name, SUM({proj}.budget)
FROM {proj}
JOIN {dep} ON {proj}.dept_id = {dep}.id
GROUP BY {dep}.name""",
            schema
        )

# Trim exactly 7000
DATASET = DATASET[:7000]

with open(OUTPUT, "w") as f:
    json.dump(DATASET, f, indent=2)

print("‚úÖ Dataset generated")
print("Total examples:", len(DATASET))
print("Schemas:", NUM_SCHEMAS)

files.download(OUTPUT)


‚úÖ Dataset generated
Total examples: 6720
Schemas: 70


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [2]:
import json
import re
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Load your dataset
with open('custom_multischema_nl2sql_7000.json', 'r') as f:
    raw_data = json.load(f)

def preprocess_schema_aware(ex):
    schema = ex["schema"]
    # 1. Create Mapping
    t_map = {t: f"T{i}" for i, t in enumerate(schema.keys())}
    c_map = {}
    cid = 0
    for t, cols in schema.items():
        for c in cols:
            c_map[f"{t}.{c}"] = f"C{cid}"  # Handle table.column
            if c not in c_map: c_map[c] = f"C{cid}" # Handle just column
            cid += 1

    # 2. Map the SQL to Placeholders
    sql = ex["sql"].lower()
    # Replace long strings first to avoid partial matches
    for full_name, placeholder in sorted(c_map.items(), key=lambda x: -len(x[0])):
        sql = re.sub(rf"\b{re.escape(full_name)}\b", placeholder, sql)
    for full_name, placeholder in sorted(t_map.items(), key=lambda x: -len(x[0])):
        sql = re.sub(rf"\b{re.escape(full_name)}\b", placeholder, sql)

    # 3. Build Encoder Input: [Question] <SEP> [Schema]
    q_toks = ex["question"].lower().split()
    s_toks, s_types = [], []
    for t, cols in schema.items():
        s_toks.append(t.split('_')[0]) # Use 'employees' instead of 'employees_0'
        s_types.append(1) # Type 1 = Table
        for c in cols:
            s_toks.append(c)
            s_types.append(2) # Type 2 = Column

    return {
        "enc_input": q_toks + ["<SEP>"] + s_toks,
        "token_types": [0]*(len(q_toks)+1) + s_types,
        "dec_target": sql.upper().replace('(', ' ( ').replace(')', ' ) ').split(),
        "original_schema": {"t_map": t_map, "c_map": c_map}
    }

# Process sample
processed_sample = preprocess_schema_aware(raw_data[0])
print("Encoder Input:", processed_sample["enc_input"])
print("Target SQL:", processed_sample["dec_target"])

Encoder Input: ['show', 'names', 'of', 'employees', '<SEP>', 'employees', 'id', 'name', 'salary', 'dept_id', 'departments', 'id', 'name', 'projects', 'id', 'budget', 'dept_id']
Target SQL: ['SELECT', 'C1', 'FROM', 'T0']


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import json
import re
from torch.utils.data import Dataset, DataLoader

# ============================================================
# 1Ô∏è‚É£ DEVICE & VOCAB SETUP
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# SQL Vocabulary components
SQL_KEYWORDS = ["SELECT","FROM","WHERE","JOIN","ON","AS","DISTINCT"]
LOGICAL_OPS = ["AND","OR","NOT"]
COMPARISON_OPS = ["=","!=","<>",">","<",">=","<="]
AGG_FUNCS = ["COUNT","SUM","AVG","MIN","MAX"]
GROUPING = ["GROUP","BY","HAVING"]
ORDERING = ["ORDER","ASC","DESC","LIMIT"]
PUNCT = [",","(",")"]
SPECIAL = ["<PAD>","<EOS>","<BOS>","VALUE"] # Added <BOS>

MAX_TABLES, MAX_COLUMNS = 128, 1024
TABLE_TOKENS = [f"T{i}" for i in range(MAX_TABLES)]
COLUMN_TOKENS = [f"C{i}" for i in range(MAX_COLUMNS)]

DECODER_VOCAB = SQL_KEYWORDS + LOGICAL_OPS + COMPARISON_OPS + AGG_FUNCS + \
                GROUPING + ORDERING + PUNCT + SPECIAL + TABLE_TOKENS + COLUMN_TOKENS

token_to_id = {tok: i for i, tok in enumerate(DECODER_VOCAB)}
id_to_token = {i: tok for tok, i in token_to_id.items()}
VOCAB_SIZE = len(DECODER_VOCAB)

# ============================================================
# 2Ô∏è‚É£ DATASET & PLACEHOLDER MAPPING
# ============================================================
def build_schema_maps(schema):
    table_map = {table: f"T{i}" for i, table in enumerate(schema.keys())}
    column_map = {}
    col_id = 0
    for table, cols in schema.items():
        for col in cols:
            column_map[f"{table}.{col}"] = f"C{col_id}"
            column_map[col] = f"C{col_id}" # Fallback for naked columns
            col_id += 1
    return table_map, column_map

def sql_to_placeholder(sql, table_map, column_map):
    sql_out = sql.lower()
    # Sort by length descending to prevent partial replacements (e.g., 'id' in 'dept_id')
    for full_col, cid in sorted(column_map.items(), key=lambda x: -len(x[0])):
        sql_out = re.sub(rf"\b{re.escape(full_col.lower())}\b", cid, sql_out)
    for table, tid in table_map.items():
        sql_out = re.sub(rf"\b{re.escape(table.lower())}\b", tid, sql_out)
    sql_out = re.sub(r"'\w+'|\b\d+\b", "VALUE", sql_out)
    return sql_out.upper()

class SpiderDataset(Dataset):
    def __init__(self, data, encoder_vocab):
        self.data = data
        self.encoder_vocab = encoder_vocab

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        input_ids = [self.encoder_vocab.get(tok, 1) for tok in ex["tokens"]]
        table_map, col_map = build_schema_maps(ex["schema"])
        sql_ph = sql_to_placeholder(ex["sql"], table_map, col_map)

        # Target: <BOS> SELECT ... <EOS>
        tgt_ids = [token_to_id["<BOS>"]] + [token_to_id.get(t, token_to_id["VALUE"]) for t in sql_ph.split()] + [token_to_id["<EOS>"]]

        return torch.tensor(input_ids), torch.tensor(ex["token_types"]), \
               torch.tensor(ex["schema_labels"], dtype=torch.float), torch.tensor(tgt_ids)

def collate_fn(batch):
    ids, types, labels, tgts = zip(*batch)
    def pad(s, v=0): return nn.utils.rnn.pad_sequence(s, batch_first=True, padding_value=v)
    return pad(ids), pad(types), pad(labels), pad(tgts)

# ============================================================
# 3Ô∏è‚É£ SCHEMA-AWARE TRANSFORMER MODEL
# ============================================================
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads, self.d_k = num_heads, d_model // num_heads
        self.Wq, self.Wk, self.Wv, self.Wo = [nn.Linear(d_model, d_model) for _ in range(4)]

    def forward(self, q, k, v, mask=None):
        B, Tq, D = q.size()
        Q = self.Wq(q).view(B, Tq, self.num_heads, self.d_k).transpose(1,2)
        K = self.Wk(k).view(B, k.size(1), self.num_heads, self.d_k).transpose(1,2)
        V = self.Wv(v).view(B, k.size(1), self.num_heads, self.d_k).transpose(1,2)
        scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k)
        if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf"))
        attn = torch.softmax(scores, dim=-1)
        return self.Wo(torch.matmul(attn, V).transpose(1,2).contiguous().view(B, Tq, D))

class SchemaAwareEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.typ_emb = nn.Embedding(3, d_model)
        self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model, num_heads, d_ff, batch_first=True) for _ in range(num_layers)])
        self.classifier = nn.Linear(d_model, 1)

    def forward(self, x, types, mask=None):
        x = self.tok_emb(x) + self.typ_emb(types)
        for layer in self.layers: x = layer(x, src_key_padding_mask=(mask==0))
        return x, self.classifier(x).squeeze(-1)

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([nn.TransformerDecoderLayer(d_model, num_heads, d_ff, batch_first=True) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, enc_out, tgt_mask=None, memory_mask=None):
        x = self.emb(tgt)
        for layer in self.layers: x = layer(x, enc_out, tgt_mask=tgt_mask, memory_key_padding_mask=(memory_mask==0))
        return self.fc(x)

# ============================================================
# 4Ô∏è‚É£ TRAINING & PICARD INFERENCE
# ============================================================
def picard_filter(valid_prefix):
    if len(valid_prefix) <= 1: return ["SELECT"]
    last = valid_prefix[-1]
    if last == "SELECT": return AGG_FUNCS + COLUMN_TOKENS + ["DISTINCT"]
    if last in AGG_FUNCS or last in COLUMN_TOKENS: return [",", "FROM"]
    if last == "FROM": return TABLE_TOKENS
    if last in TABLE_TOKENS: return ["WHERE", "JOIN", "<EOS>"]
    if last == "WHERE": return COLUMN_TOKENS
    if last in COMPARISON_OPS: return ["VALUE"]
    return ["<EOS>", "AND", "OR", "LIMIT"]

def generate_sql_picard(encoder, decoder, input_ids, token_types, device):
    encoder.eval(); decoder.eval()
    with torch.no_grad():
        mask = (input_ids != 0).long()
        enc_out, _ = encoder(input_ids, token_types, mask)

        cur = torch.tensor([[token_to_id["<BOS>"]]], device=device)
        generated = ["<BOS>"]

        for _ in range(30):
            tgt_mask = torch.triu(torch.ones(cur.size(1), cur.size(1), device=device), diagonal=1).bool()
            logits = decoder(cur, enc_out, tgt_mask=tgt_mask)[0, -1]

            allowed = picard_filter(generated)
            mask_logits = torch.full_like(logits, float("-inf"))
            allowed_ids = [token_to_id[t] for t in allowed if t in token_to_id]
            mask_logits[allowed_ids] = logits[allowed_ids]

            next_id = mask_logits.argmax().item()
            next_tok = id_to_token[next_id]
            if next_tok == "<EOS>": break

            generated.append(next_tok)
            cur = torch.cat([cur, torch.tensor([[next_id]], device=device)], dim=1)

    return " ".join(generated[1:])

# Initialize (Assuming encoder_vocab exists from your build_encoder_vocab function)
encoder = SchemaAwareEncoder(enc_vocab_size, d_model=256, num_heads=8, d_ff=1024, num_layers=4).to(device)
decoder = TransformerDecoder(num_layers=4, d_model=256, num_heads=8, d_ff=1024, vocab_size=VOCAB_SIZE).to(device)

# Training logic remains similar but ensures target shifts:
# logits = decoder(tgt[:, :-1], enc_out)
# loss = criterion(logits, tgt[:, 1:])

NameError: name 'enc_vocab_size' is not defined

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import json
import re
from torch.utils.data import Dataset, DataLoader

# ============================================================
# 1Ô∏è‚É£ DEVICE & DATA LOADING
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open('custom_multischema_nl2sql_7000.json', 'r') as f:
    raw_data = json.load(f)

# ============================================================
# 2Ô∏è‚É£ BUILD VOCABULARIES
# ============================================================
def build_vocabs(data):
    enc_vocab = {"<PAD>": 0, "<UNK>": 1, "<SEP>": 2}
    idx = 3
    for ex in data:
        # Tokenize question
        for tok in ex["question"].lower().split():
            if tok not in enc_vocab:
                enc_vocab[tok] = idx
                idx += 1
        # Tokenize schema (normalized)
        for t, cols in ex["schema"].items():
            t_norm = t.split('_')[0]
            if t_norm not in enc_vocab:
                enc_vocab[t_norm] = idx
                idx += 1
            for c in cols:
                if c not in enc_vocab:
                    enc_vocab[c] = idx
                    idx += 1
    return enc_vocab

encoder_vocab = build_vocabs(raw_data)
enc_vocab_size = len(encoder_vocab)

# Decoder Vocab
SQL_KEYWORDS = ["SELECT","FROM","WHERE","JOIN","ON","AS","DISTINCT","ORDER","BY","DESC","ASC","LIMIT","GROUP","AND"]
SPECIAL = ["<PAD>","<EOS>","<BOS>","VALUE"]
TABLE_TOKENS = [f"T{i}" for i in range(10)] # Max 10 tables per query
COLUMN_TOKENS = [f"C{i}" for i in range(50)] # Max 50 columns per query
DECODER_VOCAB = SPECIAL + SQL_KEYWORDS + ["COUNT", "AVG", "SUM", "MAX", "MIN", "(", ")", ",", "=", ">", "<"] + TABLE_TOKENS + COLUMN_TOKENS

token_to_id = {tok: i for i, tok in enumerate(DECODER_VOCAB)}
id_to_token = {i: tok for tok, i in token_to_id.items()}
VOCAB_SIZE = len(DECODER_VOCAB)

print(f"‚úÖ Encoder Vocab: {enc_vocab_size} | Decoder Vocab: {VOCAB_SIZE}")

# ============================================================
# 3Ô∏è‚É£ SCHEMA MAPPING & DATASET
# ============================================================
def build_schema_maps(schema):
    table_map = {table: f"T{i}" for i, table in enumerate(schema.keys())}
    column_map = {}
    col_id = 0
    for table, cols in schema.items():
        for col in cols:
            column_map[f"{table}.{col}"] = f"C{col_id}"
            column_map[col] = f"C{col_id}"
            col_id += 1
    return table_map, column_map

def sql_to_placeholder(sql, table_map, column_map):
    sql_out = sql.lower()
    for full_col, cid in sorted(column_map.items(), key=lambda x: -len(x[0])):
        sql_out = re.sub(rf"\b{re.escape(full_col.lower())}\b", cid, sql_out)
    for table, tid in table_map.items():
        sql_out = re.sub(rf"\b{re.escape(table.lower())}\b", tid, sql_out)
    sql_out = re.sub(r"'\w+'|\b\d+\b", "VALUE", sql_out)
    return sql_out.upper()

class SchemaAwareDataset(Dataset):
    def __init__(self, data, enc_vocab):
        self.data = data
        self.enc_vocab = enc_vocab

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        t_map, c_map = build_schema_maps(ex["schema"])

        # Encoder Input
        q_toks = ex["question"].lower().split()
        s_toks, s_types = [], []
        for t, cols in ex["schema"].items():
            s_toks.append(t.split('_')[0]); s_types.append(1)
            for c in cols:
                s_toks.append(c); s_types.append(2)

        full_tokens = q_toks + ["<SEP>"] + s_toks
        ids = [self.enc_vocab.get(t, 1) for t in full_tokens]
        types = [0]*(len(q_toks)+1) + s_types

        # Decoder Target
        sql_ph = sql_to_placeholder(ex["sql"], t_map, c_map)
        tgt = [token_to_id["<BOS>"]] + [token_to_id.get(t, token_to_id["VALUE"]) for t in sql_ph.split()] + [token_to_id["<EOS>"]]

        return torch.tensor(ids), torch.tensor(types), torch.tensor(tgt)

def collate_fn(batch):
    ids, types, tgts = zip(*batch)
    p_ids = nn.utils.rnn.pad_sequence(ids, batch_first=True, padding_value=0)
    p_types = nn.utils.rnn.pad_sequence(types, batch_first=True, padding_value=0)
    p_tgts = nn.utils.rnn.pad_sequence(tgts, batch_first=True, padding_value=0)
    return p_ids, p_types, p_tgts

# ============================================================
# 4Ô∏è‚É£ MODEL DEFINITIONS
# ============================================================
class SchemaAwareEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.typ_emb = nn.Embedding(3, d_model)
        layer = nn.TransformerEncoderLayer(d_model, num_heads, d_ff, batch_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers)

    def forward(self, x, types, mask=None):
        x = self.tok_emb(x) + self.typ_emb(types)
        return self.encoder(x, src_key_padding_mask=(mask==0)), None

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        layer = nn.TransformerDecoderLayer(d_model, num_heads, d_ff, batch_first=True)
        self.decoder = nn.TransformerDecoder(layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, enc_out, tgt_mask=None, memory_mask=None):
        x = self.emb(tgt)
        x = self.decoder(x, enc_out, tgt_mask=tgt_mask, memory_key_padding_mask=(memory_mask==0))
        return self.fc(x)

# ============================================================
# 5Ô∏è‚É£ INITIALIZATION & TRAINING
# ============================================================
encoder = SchemaAwareEncoder(enc_vocab_size, 256, 8, 1024, 4).to(device)
decoder = TransformerDecoder(4, 256, 8, 1024, VOCAB_SIZE).to(device)

optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)

dataset = SchemaAwareDataset(raw_data, encoder_vocab)
loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

print("üöÄ Starting training...")
for epoch in range(5):
    encoder.train(); decoder.train()
    epoch_loss = 0
    for ids, types, tgts in loader:
        ids, types, tgts = ids.to(device), types.to(device), tgts.to(device)
        optimizer.zero_grad()

        mask = (ids != 0)
        enc_out, _ = encoder(ids, types, mask)

        dec_input = tgts[:, :-1]
        dec_target = tgts[:, 1:]

        sz = dec_input.size(1)
        tgt_mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()

        logits = decoder(dec_input, enc_out, tgt_mask=tgt_mask, memory_mask=mask)
        loss = criterion(logits.reshape(-1, VOCAB_SIZE), dec_target.reshape(-1))

        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {epoch_loss/len(loader):.4f}")

IndexError: list index out of range

In [None]:
def predict_sql_robust(question, schema, encoder, decoder, encoder_vocab, device):
    encoder.eval()
    decoder.eval()

    # 1. Build Mapping for this specific inference instance
    t_map, c_map = build_schema_maps(schema)
    rev_t_map = {v: k for k, v in t_map.items()}
    rev_c_map = {v: k for k, v in c_map.items()}

    # 2. Encode Input
    q_toks = question.lower().split()
    s_toks, s_types = [], []
    for t, cols in schema.items():
        s_toks.append(t.split('_')[0]); s_types.append(1)
        for c in cols:
            s_toks.append(c); s_types.append(2)

    full_tokens = q_toks + ["<SEP>"] + s_toks
    ids = torch.tensor([encoder_vocab.get(t, 1) for t in full_tokens]).unsqueeze(0).to(device)
    types = torch.tensor([0]*(len(q_toks)+1) + s_types).unsqueeze(0).to(device)

    # 3. Greedy Decoding with PICARD-style constraints
    with torch.no_grad():
        enc_out, _ = encoder(ids, types, mask=(ids != 0))
        tgt_indices = [token_to_id["<BOS>"]]
        generated_tokens = ["<BOS>"]

        for _ in range(30):
            tgt_tensor = torch.tensor([tgt_indices]).to(device)
            sz = tgt_tensor.size(1)
            tgt_mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool()

            output = decoder(tgt_tensor, enc_out, tgt_mask=tgt_mask, memory_mask=(ids != 0))
            logits = output[0, -1, :]

            # --- PICARD FILTERING ---
            # We force the model to only pick columns that exist in the current schema
            allowed_placeholders = list(rev_t_map.keys()) + list(rev_c_map.keys()) + SQL_KEYWORDS + SPECIAL + PUNCT + AGG_FUNCS
            allowed_ids = [token_to_id[tok] for tok in allowed_placeholders if tok in token_to_id]

            mask_logits = torch.full_like(logits, float("-inf"))
            mask_logits[allowed_ids] = logits[allowed_ids]

            next_id = mask_logits.argmax().item()
            next_tok = id_to_token[next_id]

            if next_tok == "<EOS>": break
            tgt_indices.append(next_id)
            generated_tokens.append(next_tok)

    # 4. Final De-Mapping with Boundry Protection
    placeholder_sql = " ".join(generated_tokens[1:])
    final_sql = placeholder_sql

    # Sort keys by length descending (C10 before C1) to prevent partial replacement
    for cid in sorted(rev_c_map.keys(), key=lambda x: int(x[1:]), reverse=True):
        final_sql = re.sub(rf"\b{cid}\b", rev_c_map[cid], final_sql)
    for tid in sorted(rev_t_map.keys(), key=lambda x: int(x[1:]), reverse=True):
        final_sql = re.sub(rf"\b{tid}\b", rev_t_map[tid], final_sql)

    return final_sql

# ============================================================
# RUN TEST
# ============================================================
sample_schema = {
    "employees": ["id", "name", "salary", "dept_id"],
    "departments": ["id", "name"]
}
# The model will now correctly link "names" to "employees.name"
# because it was trained to look at the T0 table first.
print("Result:", predict_sql_robust("show salary of employees", sample_schema, encoder, decoder, encoder_vocab, device))

Result: SELECT departments.name FROM employees WHERE salary
