# **TEXT2SQL USING LLAMA 3.1 8B**

In [1]:
!pip install transformers sentence-transformers faiss-cpu accelerate numpy tqdm groq

Collecting faiss-cpu
  Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Collecting groq
  Downloading groq-0.37.0-py3-none-any.whl.metadata (16 kB)
Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (23.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m108.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading groq-0.37.0-py3-none-any.whl (137 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu, groq
Successfully installed faiss-cpu-1.13.0 groq-0.37.0


In [2]:
import json
import os
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
from tqdm import tqdm

# Embeddings / LLM
from sentence_transformers import SentenceTransformer
from groq import Groq # Import Groq

# Vector DB
import faiss  # pip install faiss-cpu

# Neo4j
#from neo4j import GraphDatabase  # pip install neo4j

In [3]:
!pwd

/content


In [4]:
# ============================================================
# CONFIG
# ============================================================

TABLES_JSON_PATH = "./tables.json"  # path to Spider/SParC-style tables.json
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "password"

EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
# Flan-T5 model fine-tuned on Spider / SParC / CoSQL (SQL task)
# We'll reuse it with a TRC-style prompt.
FLAN_T5_MODEL_NAME = "alpecevit/flan-t5-base-text2sql"

In [5]:
# ============================================================
# SCHEMA DATA CLASSES
# ============================================================

@dataclass
class ColumnSchema:
    col_id: int
    table_id: int
    name: str
    orig_name: str
    col_type: str

@dataclass
class TableSchema:
    table_id: int
    name: str
    orig_name: str
    columns: List[ColumnSchema] = field(default_factory=list)
    primary_keys: List[int] = field(default_factory=list)      # list of col_ids
    foreign_keys: List[Tuple[int, int]] = field(default_factory=list)  # (src_col_id, dst_col_id)

@dataclass
class DatabaseSchema:
    db_id: str
    tables: Dict[int, TableSchema] = field(default_factory=dict)        # table_id -> TableSchema
    columns: Dict[int, ColumnSchema] = field(default_factory=dict)      # col_id   -> ColumnSchema

In [6]:
# ============================================================
# STEP 1: LOAD SCHEMA FROM SPARC/SPIDER TABLES.JSON
# ============================================================

def load_spider_like_tables_json(path: str) -> Dict[str, DatabaseSchema]:
    """
    Load Spider/SParC 'tables.json' into a structured dict of DatabaseSchema.

    tables.json format (per Spider/SParC docs):
      - db_id
      - table_names, table_names_original
      - column_names, column_names_original
      - column_types
      - primary_keys  (list of column indices)
      - foreign_keys  (list of [src_col_idx, dst_col_idx])
    """
    with open(path, "r", encoding="utf-8") as f:
        raw = json.load(f)

    db_schemas: Dict[str, DatabaseSchema] = {}

    for db_meta in raw:
        db_id = db_meta["db_id"]
        table_names = db_meta["table_names"]
        table_names_original = db_meta["table_names_original"]
        column_names = db_meta["column_names"]
        column_names_original = db_meta["column_names_original"]
        column_types = db_meta["column_types"]
        primary_keys = db_meta["primary_keys"]
        foreign_keys = db_meta["foreign_keys"]

        db_schema = DatabaseSchema(db_id=db_id)

        # Create table schemas
        for table_id, (name, orig_name) in enumerate(zip(table_names, table_names_original)):
            db_schema.tables[table_id] = TableSchema(
                table_id=table_id,
                name=name,
                orig_name=orig_name,
            )

        # Create column schemas
        for col_id, ((table_id, col_name), (_, col_orig_name), col_type) in enumerate(
            zip(column_names, column_names_original, column_types)
        ):
            if table_id == -1:
                # This is the "dummy" column (for * in some preprocessing setups).
                continue
            col_schema = ColumnSchema(
                col_id=col_id,
                table_id=table_id,
                name=col_name,
                orig_name=col_orig_name,
                col_type=col_type,
            )
            db_schema.columns[col_id] = col_schema
            db_schema.tables[table_id].columns.append(col_schema)

        # Attach primary keys
        for pk_col_id in primary_keys:
            if pk_col_id in db_schema.columns:
                col = db_schema.columns[pk_col_id]
                db_schema.tables[col.table_id].primary_keys.append(pk_col_id)

        # Attach foreign keys
        for src_col_id, dst_col_id in foreign_keys:
            if src_col_id in db_schema.columns and dst_col_id in db_schema.columns:
                src_col = db_schema.columns[src_col_id]
                db_schema.tables[src_col.table_id].foreign_keys.append((src_col_id, dst_col_id))

        db_schemas[db_id] = db_schema

    return db_schemas


In [7]:
# ============================================================
# STEP 2: BUILD NEO4J GRAPH (TABLE, COLUMN, EDGES)
# ============================================================
'''
class SchemaGraphBuilder:
    def __init__(self, uri: str, user: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def clear_database(self):
        with self.driver.session() as session:
            session.run("MATCH (n) DETACH DELETE n")

    def build_graph(self, db_schemas: Dict[str, DatabaseSchema]):
        """
        Build:
          (:TABLE {db_id, table_id, name})
          (:COLUMN {db_id, col_id, table_id, name, col_type})
          (:TABLE)-[:HAS_COLUMN]->(:COLUMN)
          (:COLUMN)-[:PRIMARY_KEY_OF]->(:TABLE)
          (:COLUMN)-[:FOREIGN_KEY_TO]->(:COLUMN)
          (:TABLE)-[:JOINS_WITH]->(:TABLE)
        """
        with self.driver.session() as session:
            for db_id, db_schema in tqdm(db_schemas.items(), desc="Building Neo4j schema graph"):
                # TABLE and COLUMN nodes + HAS_COLUMN
                for table in db_schema.tables.values():
                    session.run(
                        """
                        MERGE (t:TABLE {db_id:$db_id, table_id:$table_id})
                        SET t.name = $name, t.orig_name = $orig_name
                        """,
                        db_id=db_id,
                        table_id=table.table_id,
                        name=table.name,
                        orig_name=table.orig_name,
                    )
                    for col in table.columns:
                        session.run(
                            """
                            MERGE (c:COLUMN {db_id:$db_id, col_id:$col_id})
                            SET c.name = $name,
                                c.orig_name = $orig_name,
                                c.col_type = $col_type,
                                c.table_id = $table_id
                            WITH c
                            MATCH (t:TABLE {db_id:$db_id, table_id:$table_id})
                            MERGE (t)-[:HAS_COLUMN]->(c)
                            """,
                            db_id=db_id,
                            table_id=table.table_id,
                            col_id=col.col_id,
                            name=col.name,
                            orig_name=col.orig_name,
                            col_type=col.col_type,
                        )

                # PRIMARY_KEY_OF edges
                for table in db_schema.tables.values():
                    for pk_col_id in table.primary_keys:
                        if pk_col_id not in db_schema.columns:
                            continue
                        session.run(
                            """
                            MATCH (c:COLUMN {db_id:$db_id, col_id:$col_id}),
                                  (t:TABLE {db_id:$db_id, table_id:$table_id})
                            MERGE (c)-[:PRIMARY_KEY_OF]->(t)
                            """,
                            db_id=db_id,
                            col_id=pk_col_id,
                            table_id=table.table_id,
                        )

                # FOREIGN_KEY_TO and JOINS_WITH edges
                for table in db_schema.tables.values():
                    for src_col_id, dst_col_id in table.foreign_keys:
                        if src_col_id not in db_schema.columns or dst_col_id not in db_schema.columns:
                            continue
                        src_col = db_schema.columns[src_col_id]
                        dst_col = db_schema.columns[dst_col_id]
                        src_table_id = src_col.table_id
                        dst_table_id = dst_col.table_id
                        session.run(
                            """
                            MATCH (src:COLUMN {db_id:$db_id, col_id:$src_col_id}),
                                  (dst:COLUMN {db_id:$db_id, col_id:$dst_col_id})
                            MERGE (src)-[:FOREIGN_KEY_TO]->(dst)
                            """,
                            db_id=db_id,
                            src_col_id=src_col_id,
                            dst_col_id=dst_col_id,
                        )
                        # JOINS_WITH in both directions
                        session.run(
                            """
                            MATCH (t1:TABLE {db_id:$db_id, table_id:$t1_id}),
                                  (t2:TABLE {db_id:$db_id, table_id:$t2_id})
                            MERGE (t1)-[:JOINS_WITH]->(t2)
                            MERGE (t2)-[:JOINS_WITH]->(t1)
                            """,
                            db_id=db_id,
                            t1_id=src_table_id,
                            t2_id=dst_table_id,
                        )
'''

'\nclass SchemaGraphBuilder:\n    def __init__(self, uri: str, user: str, password: str):\n        self.driver = GraphDatabase.driver(uri, auth=(user, password))\n\n    def close(self):\n        self.driver.close()\n\n    def clear_database(self):\n        with self.driver.session() as session:\n            session.run("MATCH (n) DETACH DELETE n")\n\n    def build_graph(self, db_schemas: Dict[str, DatabaseSchema]):\n        """\n        Build:\n          (:TABLE {db_id, table_id, name})\n          (:COLUMN {db_id, col_id, table_id, name, col_type})\n          (:TABLE)-[:HAS_COLUMN]->(:COLUMN)\n          (:COLUMN)-[:PRIMARY_KEY_OF]->(:TABLE)\n          (:COLUMN)-[:FOREIGN_KEY_TO]->(:COLUMN)\n          (:TABLE)-[:JOINS_WITH]->(:TABLE)\n        """\n        with self.driver.session() as session:\n            for db_id, db_schema in tqdm(db_schemas.items(), desc="Building Neo4j schema graph"):\n                # TABLE and COLUMN nodes + HAS_COLUMN\n                for table in db_schema.ta

In [8]:
# ============================================================
# STEP 3: BUILD VECTOR INDEX OVER SCHEMA TEXT
# ============================================================

@dataclass
class SchemaNodeMeta:
    node_type: str  # "TABLE" or "COLUMN"
    db_id: str
    table_id: Optional[int] = None
    col_id: Optional[int] = None
    table_name: Optional[str] = None
    column_name: Optional[str] = None

class SchemaVectorIndex:
    """
    Global vector index over:
        - table-level texts (TABLE nodes)
        - column-level texts (COLUMN nodes)

    For query-time, we can filter hits by db_id if needed.
    """

    def __init__(self, embedding_model_name: str = EMBEDDING_MODEL_NAME):
        self.model = SentenceTransformer(embedding_model_name)
        self.index: Optional[faiss.IndexFlatIP] = None
        self.metas: List[SchemaNodeMeta] = []
        self.embeddings: Optional[np.ndarray] = None

    def _build_node_texts(self, db_schemas: Dict[str, DatabaseSchema]) -> Tuple[List[str], List[SchemaNodeMeta]]:
        texts: List[str] = []
        metas: List[SchemaNodeMeta] = []

        for db_id, db_schema in db_schemas.items():
            for table in db_schema.tables.values():
                # TABLE text representation
                table_text = f"table {table.name}"
                texts.append(table_text)
                metas.append(
                    SchemaNodeMeta(
                        node_type="TABLE",
                        db_id=db_id,
                        table_id=table.table_id,
                        table_name=table.name,
                    )
                )

                # COLUMN text representation
                for col in table.columns:
                    col_text = f"column {table.name}.{col.name} of type {col.col_type}"
                    texts.append(col_text)
                    metas.append(
                        SchemaNodeMeta(
                            node_type="COLUMN",
                            db_id=db_id,
                            table_id=table.table_id,
                            col_id=col.col_id,
                            table_name=table.name,
                            column_name=col.name,
                        )
                    )

        return texts, metas

    def build_from_schemas(self, db_schemas: Dict[str, DatabaseSchema]):
        texts, metas = self._build_node_texts(db_schemas)
        print(f"Building embeddings for {len(texts)} schema nodes...")
        emb = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
        # Normalize embeddings for cosine similarity via inner product
        faiss.normalize_L2(emb)
        dim = emb.shape[1]
        index = faiss.IndexFlatIP(dim)
        index.add(emb.astype(np.float32))

        self.index = index
        self.embeddings = emb
        self.metas = metas

    def query(
        self,
        query_text: str,
        top_k: int = 20,
        db_id: Optional[str] = None,
        search_k: int = 100,
    ) -> List[Tuple[SchemaNodeMeta, float]]:
        """
        Perform semantic search over schema nodes.
        If db_id is given, we filter hits to that db.
        """
        if self.index is None:
            raise RuntimeError("Index not built. Call build_from_schemas() first.")

        q_emb = self.model.encode([query_text], convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        # Search more than top_k so we can filter by db_id
        search_k = max(search_k, top_k)
        D, I = self.index.search(q_emb.astype(np.float32), search_k)
        scores = D[0]
        indices = I[0]

        results: List[Tuple[SchemaNodeMeta, float]] = []
        for idx, score in zip(indices, scores):
            if idx < 0:
                continue
            meta = self.metas[idx]
            if db_id is not None and meta.db_id != db_id:
                continue
            results.append((meta, float(score)))
            if len(results) >= top_k:
                break

        return results

In [9]:
def build_schema_context(
    db_schema: DatabaseSchema,
    matches: List[Tuple[SchemaNodeMeta, float]],
    max_tables: int = 4,
) -> str:
    """
    Build textual schema context for LLM from:
      - relevant tables/columns (from vector hits)
      - PK/FK info
      - join relations

    Uses ONLY canonical SQL names via orig_name.
    This ensures TRC generation is safe, consistent, and compiler-friendly.
    """

    # 1) Collect relevant table IDs
    table_ids = []
    for meta, score in matches:
        if meta.table_id is None:
            continue
        if meta.table_id not in table_ids:
            table_ids.append(meta.table_id)
        if len(table_ids) >= max_tables:
            break

    lines = []
    lines.append("SCHEMA CONTEXT START")
    lines.append(f"Database: {db_schema.db_id}")
    lines.append("")

    # 2) Build canonical schema context
    for table_id in table_ids:
        table = db_schema.tables[table_id]

        # Use canonical name
        table_name = table.orig_name
        lines.append(f"Table: {table_name}")

        # Columns (canonical names)
        col_parts = []
        for col in table.columns:
            col_parts.append(f"{col.orig_name} ({col.col_type})")
        lines.append("  Columns: " + ", ".join(col_parts))

        # Primary keys
        pk_cols = [db_schema.columns[pk_id].orig_name for pk_id in table.primary_keys]
        if pk_cols:
            lines.append("  Primary keys: " + ", ".join(pk_cols))

        # Foreign keys (canonical)
        fk_descs = []
        for src_col_id, dst_col_id in table.foreign_keys:
            src_col = db_schema.columns[src_col_id]
            dst_col = db_schema.columns[dst_col_id]
            src_table = db_schema.tables[src_col.table_id]
            dst_table = db_schema.tables[dst_col.table_id]

            fk_descs.append(
            f"[Foreign Key REFERENCES: {dst_table.orig_name}.{dst_col.orig_name}]"
            )

        if fk_descs:
            lines.append("  Foreign keys:")
            for fk in fk_descs:
                lines.append(f"    - {fk}")

        # Join partners (canonical)
        join_partners = set()

        # Outgoing FKs
        for src_col_id, dst_col_id in table.foreign_keys:
            dst_col = db_schema.columns[dst_col_id]
            dst_table = db_schema.tables[dst_col.table_id]
            join_partners.add(dst_table.orig_name)

        # Incoming FKs
        for other_table in db_schema.tables.values():
            if other_table.table_id == table.table_id:
                continue

            for src_col_id, dst_col_id in other_table.foreign_keys:
                dst_col = db_schema.columns[dst_col_id]
                dst_table = db_schema.tables[dst_col.table_id]

                if dst_table.table_id == table.table_id:
                    join_partners.add(other_table.orig_name)

        lines.append("")

    lines.append("SCHEMA CONTEXT END")
    return "\n".join(lines)

In [44]:
import os
import re
from google.colab import userdata
from groq import Groq

class TextToTRCGenerator:
    def __init__(self, model_name: str = "llama-3.1-8b-instant"):
        self.model_name = model_name

        groq_api_key = userdata.get('GROQ_API_KEY3')
        if not groq_api_key:
            raise ValueError("GROQ_API_KEY not found in Colab secrets. Please add it.")

        self.client = Groq(api_key=groq_api_key)


    # ============================================================
    # Few Shot Examples (unchanged)
    # ============================================================
    @staticmethod
    def few_shot_examples() -> str:
        examples = []

        examples.append(
        """Example 1

    == Vector Matches in Schema ==
    VECTOR MATCHES BASED ON SEMANTICS:
    - table: stadium (db_id: concert_singer)
    - column: capacity (table: stadium, db_id: concert_singer)
    - column: highest (table: stadium, db_id: concert_singer)
    - column: lowest (table: stadium, db_id: concert_singer)

    SCHEMA CONTEXT START
    Database: concert_singer

    Table: stadium
      Columns: stadium_id (number), name (text), capacity (number), highest (number), lowest (number), average (number)
      Primary keys: stadium_id

    Table: singer
      Columns: singer_id (number), name (text), country (text), song_name (text), song_release_year (number), age (number), is_male (others)
      Primary keys: singer_id

    SCHEMA CONTEXT END

    Question:
    "List all stadium names with highest capacity above 35000 and lowest capacity below 8000."

    TRC:
    { stadium.name | stadium AND stadium.highest > 35000 AND stadium.lowest < 8000 }
    """
        )

        examples.append(
    """Example 2

== Vector Matches in Schema ==
VECTOR MATCHES BASED ON SEMANTICS:
- table: singer (db_id: concert_singer)
- column: name (table: singer, db_id: concert_singer)
- table: stadium (db_id: concert_singer)
- column: name (table: stadium, db_id: concert_singer)

SCHEMA CONTEXT START
Database: concert_singer

Table: singer
  Columns: singer_id (number), name (text), country (text), song_name (text), song_release_year (number), age (number), is_male (others)
  Primary keys: singer_id

Table: stadium
  Columns: stadium_id (number), name (text), capacity (number), highest (number), lowest (number), average (number)
  Primary keys: stadium_id

SCHEMA CONTEXT END

Question:
"List the names of singers and the names of stadiums."

TRC:
{ singer.name, stadium.name | singer AND stadium }
"""
        )

        return "\n".join(examples)


    # ============================================================
    # VERY STRICT PROMPT BUILDER
    # ============================================================
    def build_prompt(self, schema_context: str, question: str) -> str:
        """
        Construct an extremely strict prompt specifically designed
        to avoid hallucinations and produce ONLY canonical TRC.
        """
        prompt_parts = []

        # System-level safety rules
        prompt_parts.append(
            "You are a deterministic database logic engine that outputs ONLY Tuple Relational Calculus (TRC).\n"
            "IMPORTANT RULES FOR SQL GENERATION:\n"
            "1. The NL query decides which tables are needed — NOT the schema context.\n "
            "2. Use ONLY the tables that are logically required to answer the question.\n"
            "3. If all needed information comes from a single table, produce a single-table SQL query.\n"
            "4. Do NOT add joins unless the question requires linking information across tables.\n"
            "5. If multiple tables appear in the TRC but no linking condition is required, treat it as independent filtering, not a join.\n"
            "6. Use foreign keys ONLY when the question explicitly asks to connect entities.\n"
            "7. NEVER invent columns or tables not present in the schema.\n"
            "8. ABSOLUTELY NO explanations, English sentences, or commentary.\n"
            "\n"
        )

        # Provide few-shot examples
        prompt_parts.append("[FEW-SHOT EXAMPLES]")
        prompt_parts.append(self.few_shot_examples())

        # Provide schema context
        prompt_parts.append("\n[SCHEMA CONTEXT]")
        prompt_parts.append(schema_context)

        # Provide question
        prompt_parts.append(f"\n[QUESTION]\n{question}\n")

        return "\n".join(prompt_parts)


    # ============================================================
    # CORE LLM CALL
    # ============================================================
    def generate_trc(self, schema_context: str, question: str, max_new_tokens: int = 512) -> str:
        prompt = self.build_prompt(schema_context, question)

        print("\nComplete INPUT PROMPT\n\n",prompt)

        completion = self.client.chat.completions.create(
            model=self.model_name,
            messages=[
                {"role": "system", "content": "You must output ONLY raw TRC. No other text allowed."},
                {"role": "user", "content": prompt},
            ],
            temperature=0.0,      # deterministic
            max_tokens=max_new_tokens,
            top_p=1,
            stream=False
        )

        return completion.choices[0].message.content.strip()

In [45]:
# ============================================================
# STEP 6: GLUE IT TOGETHER – END-TO-END CALL
# ============================================================

class Text2TRCPipeline:
    def __init__( self, tables_json_path: str = TABLES_JSON_PATH, neo4j_uri: str = NEO4J_URI, neo4j_user: str = NEO4J_USER, neo4j_password: str = NEO4J_PASSWORD):
        """
        Loads Spider-style schema, builds vector index, and initializes TRC generator.
        No Neo4j functionality remains.
        """

        # 1) Load Spider/SParC schema JSON
        print("Loading Spider/SParC-style tables.json...")
        self.db_schemas = load_spider_like_tables_json(tables_json_path)

        # 2) SKIP Neo4j schema graph
        # (All Neo4j functions safely removed)

        # 3) Build vector index
        print("Building schema vector index (embeddings)...")
        self.vector_index = SchemaVectorIndex(EMBEDDING_MODEL_NAME)
        self.vector_index.build_from_schemas(self.db_schemas)

        # 4) LLM for Text → TRC conversion
        print("Initializing Groq client for TRC prompting...")
        self.trc_generator = TextToTRCGenerator("llama-3.1-8b-instant")

    def close(self):
        """No Neo4j, nothing to close."""
        pass

    # --------------------------------------------------------
    # Helper: Convert Spider tables.json schema → TRC-SQL schema dict
    # --------------------------------------------------------
    def convert_spider_schema(self, db_schema):
      """
      Convert DatabaseSchema object into the format needed by
      the TRC→SQL compiler:

          {
              "table_name": ["col1", "col2", ...],
              ...
          }
      """

      schema_dict = {}

      for table_id, table_obj in db_schema.tables.items():
          table_name = table_obj.orig_name.lower()   # ✔ snake_case canonical name

          column_names = [col.orig_name.lower() for col in table_obj.columns]

          schema_dict[table_name] = column_names

      return schema_dict

    def convert_spider_fk_map(self, db_schema):
      fk_map = []
      for table in db_schema.tables.values():
          for (src_col_id, dst_col_id) in table.foreign_keys:
              src_col = db_schema.columns[src_col_id]
              dst_col = db_schema.columns[dst_col_id]

              src_table = db_schema.tables[src_col.table_id]
              dst_table = db_schema.tables[dst_col.table_id]

              fk_map.append((
                  src_table.orig_name.lower(), src_col.orig_name.lower(),
                  dst_table.orig_name.lower(), dst_col.orig_name.lower()
              ))
      return fk_map

    # --------------------------------------------------------
    # MAIN INFERENCE
    # --------------------------------------------------------
    def infer_trc_for_query(
        self,
        question: str,
        db_id: str,
        top_k_schema_nodes: int = 20,
        max_tables_in_context: int = 4,
    ):
        """
        Main function: Text → TRC

        Returns:
            {
                "schema_matches": [...],
                "schema_context": "...",
                "trc": "...",
                "schema_for_sql_compiler": {...}   # ADDED
            }
        """

        # Validate db_id
        if db_id not in self.db_schemas:
            raise ValueError(f"Unknown db_id: {db_id}")

        db_schema = self.db_schemas[db_id]

        # ----------------------------------------------------
        # 1) Semantic retrieval over schema (vector search)
        # ----------------------------------------------------
        matches = self.vector_index.query(
            query_text=question,
            top_k=top_k_schema_nodes,
            db_id=db_id,
        )

        # ----------------------------------------------------
        # 2) Build schema context for the LLM prompt
        # ----------------------------------------------------
        schema_context = build_schema_context(
            db_schema=db_schema,
            matches=matches,
            max_tables=max_tables_in_context,
        )

        # Create a compact match summary for prompt
        match_summary_for_prompt = []
        for meta, score in matches:
            if meta.node_type == "TABLE":
                match_summary_for_prompt.append(
                    f"- table: {meta.table_name} (db_id: {meta.db_id})"
                )
            else:
                match_summary_for_prompt.append(
                    f"- column: {meta.column_name} (table: {meta.table_name}, db_id: {meta.db_id})"
                )

        vector_match_text = "VECTOR MATCHES BASED ON SEMANTICS:\n" + "\n".join(match_summary_for_prompt)

        # Full prompt context
        full_context = vector_match_text + "\n\n" +schema_context

        # ----------------------------------------------------
        # 3) LLM generates the TRC
        # ----------------------------------------------------
        trc = self.trc_generator.generate_trc(
            schema_context=full_context,
            question=question
        )

        # ----------------------------------------------------
        # 4) Build debugging metadata
        # ----------------------------------------------------
        match_summary = [
            {
                "node_type": meta.node_type,
                "db_id": meta.db_id,
                "table_id": meta.table_id,
                "col_id": meta.col_id,
                "table_name": meta.table_name,
                "column_name": meta.column_name,
                "score": score,
            }
            for meta, score in matches
        ]

        # ----------------------------------------------------
        # 5) Convert schema for TRC→SQL compiler
        # ----------------------------------------------------
        converted_schema = self.convert_spider_schema(db_schema)

        # ----------------------------------------------------
        # 6) FK Map
        # ----------------------------------------------------
        fk_map = self.convert_spider_fk_map(db_schema)
        # ----------------------------------------------------
        # 7) Final return
        # ----------------------------------------------------
        return {
            "schema_matches": match_summary,
            "schema_context": schema_context,
            "trc": trc,
            "schema_for_sql_compiler": converted_schema,  # <--- ADDED
            "vector_match_text": vector_match_text,
            "fk_map": fk_map,
            "full_context": full_context
        }

In [46]:
# ============================================================
# EXAMPLE USAGE
# ============================================================
from groq import Groq

trc2sql_client = Groq(api_key=userdata.get('GROQ_API_KEY3'))

"""
    Example: run a single Text->TRC inference on a SParC/Spider db.
    """

pipeline = Text2TRCPipeline(
        tables_json_path=TABLES_JSON_PATH,
    )


Loading Spider/SParC-style tables.json...
Building schema vector index (embeddings)...
Building embeddings for 5379 schema nodes...


Batches:   0%|          | 0/169 [00:00<?, ?it/s]

Initializing Groq client for TRC prompting...


In [48]:
try:
        # Example NL question and DB
        example_db_id = "flight_1"
        example_question = "Show the flight number and distance of all the flights."

        # --- 1. Run TEXT → TRC ---
        result = pipeline.infer_trc_for_query(
            question=example_question,
            db_id=example_db_id,
            top_k_schema_nodes=20,
            max_tables_in_context=4,
        )
        """
        print("\n=== Vector Matches in Schema ===")
        print(result["vector_match_text"])


        print("\n=== SCHEMA CONTEXT FED TO LLM ===")
        print(result["schema_context"])

        print("output")
        print(result["full_context"])
        """
        print("\n=== GENERATED TRC ===")
        print(result["trc"])


        # ============================================================
        # --- 2. Run TRC → SQL using Groq LLaMA-3.1-8B (NEW)
        # ============================================================
        system_prompt = """
You are a strict TRC-to-SQL compiler.
RULES:
1. Follow the TRC properly and answer with the right SQL. The TRC determines the SQL not the schema.
2. DO NOT invent tables or columns.
3.Use ONLY the tables necessary to express the meaning of the question.
4.If the question can be answered from a single table, use that table alone.
5.Do not include unrelated tables, even if they appear in the schema context.
6. Output ONLY valid SQL, no explanations.
"""

        user_prompt = f"""
Vector Matches in SCHEMA CONTEXT:
{result['vector_match_text']}

Below are examples showing how TRC must be converted into valid SQL.

        Example 1
        Schema:
          table stadium(stadium_id, name, capacity, highest, lowest, average)
          table singer(singer_id, name, country)
          table concert(concert_id, stadium_id, concert_name)

        TRC:
        {{ stadium.name | stadium AND stadium.highest > 35000 AND stadium.lowest < 8000 }}

        SQL OUTPUT:
        SELECT name
        FROM stadium
        WHERE highest > 35000
          AND lowest < 8000;

        EXAMPLE 2:
        Schema:
          table stadium(stadium_id, location, name, capacity, highest, lowest, average)
          table singer(singer_id, name, country, song_name, song_release_year, age, is_male)
          table concert(concert_id, concert_name, theme, stadium_id, year)
          table singer_in_concert(concert_id, singer_id)

        TRC:
          {{ name | singer
                  AND singer_in_concert
                  AND singer_in_concert.singer_id = singer.singer_id
                  AND singer_in_concert.concert_id = concert.concert_id
                  AND concert.stadium_id = stadium.stadium_id
                  AND stadium.capacity > 50000 }}

        SQL OUTPUT:
          SELECT singer.name
          FROM singer
          JOIN singer_in_concert
              ON singer_in_concert.singer_id = singer.singer_id
          JOIN concert
              ON singer_in_concert.concert_id = concert.concert_id
          JOIN stadium
              ON concert.stadium_id = stadium.stadium_id
          WHERE stadium.capacity > 50000;


        Now convert the following TRC (for the schema above) into SQL.

TRC:
{result['trc']}

Translate the TRC expression above into executable SQL.
Output ONLY the SQL query.
"""

        try:
            groq_response = trc2sql_client.chat.completions.create(
                model="llama-3.1-8b-instant",
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                temperature=0.0,
                max_tokens=300,
            )
            predicted_sql = groq_response.choices[0].message.content.strip()

        except Exception as e:
            predicted_sql = f"-- GROQ SQL GENERATION ERROR: {str(e)}"

        print("\n=== GENERATED SQL (Groq LLaMA 3.1 8B) ===")
        print(predicted_sql)
        # ============================================================

finally:
        pipeline.close()



Complete INPUT PROMPT

 You are a deterministic database logic engine that outputs ONLY Tuple Relational Calculus (TRC).
IMPORTANT RULES FOR SQL GENERATION:
1. The NL query decides which tables are needed — NOT the schema context.
 2. Use ONLY the tables that are logically required to answer the question.
3. If all needed information comes from a single table, produce a single-table SQL query.
4. Do NOT add joins unless the question requires linking information across tables.
5. If multiple tables appear in the TRC but no linking condition is required, treat it as independent filtering, not a join.
6. Use foreign keys ONLY when the question explicitly asks to connect entities.
7. NEVER invent columns or tables not present in the schema.
8. ABSOLUTELY NO explanations, English sentences, or commentary.


[FEW-SHOT EXAMPLES]
Example 1

    == Vector Matches in Schema ==
    VECTOR MATCHES BASED ON SEMANTICS:
    - table: stadium (db_id: concert_singer)
    - column: capacity (table: stad

# **Evaluating SPARC dataset with dialogue history and visualizing the results**

In [49]:
import json
from tqdm import tqdm

# Path to SParC train/dev/test json
SPARC_TRAIN_PATH = "/content/train.json"   # change if needed

# Output file for TRC predictions
OUTPUT_PATH = "/content/sparc_trc_predictions.json"

# ------------------------------------------------------------
# LOAD DATASET
# ------------------------------------------------------------

with open(SPARC_TRAIN_PATH, "r") as f:
    sparc_data = json.load(f)

# Take only 5 interactions
import random

print(f"Using {len(sparc_data)} interactions for evaluation.")

# ------------------------------------------------------------
# INITIALIZE PIPELINE
# ------------------------------------------------------------

pipeline = Text2TRCPipeline(
    tables_json_path=TABLES_JSON_PATH,
    neo4j_uri=NEO4J_URI,
    neo4j_user=NEO4J_USER,
    neo4j_password=NEO4J_PASSWORD,
)

print("Pipeline initialized.")


Using 3034 interactions for evaluation.
Loading Spider/SParC-style tables.json...
Building schema vector index (embeddings)...
Building embeddings for 5379 schema nodes...


Batches:   0%|          | 0/169 [00:00<?, ?it/s]

Initializing Groq client for TRC prompting...
Pipeline initialized.


In [50]:
# ------------------------------------------------------------
# FUNCTION: BUILD HISTORY CONTEXT
# ------------------------------------------------------------

def build_history_context(interaction, current_turn_idx):
    """
    Builds history of all previous user utterances for multi-turn SParC dialogue.
    Returns a formatted string:

    HISTORY:
    User: ...
    User: ...
    """
    if current_turn_idx == 0:
        return ""  # No history for first turn

    history_lines = ["HISTORY CONTEXT:"]
    for i in range(current_turn_idx):
        prev_utt = interaction["interaction"][i]["utterance"]
        history_lines.append(f"User: {prev_utt}")

    return "\n".join(history_lines) + "\n\n"


In [None]:
# ------------------------------------------------------------
# RUN EVALUATION OVER ENTIRE SParC TRAIN SET
# ------------------------------------------------------------

predictions = []  # store everything

for interaction in tqdm(sparc_data, desc="Processing SParC interactions"):

    db_id = interaction["database_id"]
    turns = interaction["interaction"]

    for turn_idx, turn in enumerate(turns):

        utterance = turn["utterance"]

        # ---- Build history ----
        history_context = build_history_context(interaction, turn_idx)

        # ---- Merge history + current question ----
        llm_input_question = history_context + utterance

        # ---- Run Text->TRC pipeline ----
        result = pipeline.infer_trc_for_query(
            question=llm_input_question,
            db_id=db_id,
            top_k_schema_nodes=20,
            max_tables_in_context=4,
        )

        # ------------------------------------------------------------
        # NEW: SQL generation using Groq LLaMA-3.1-8B
        # ------------------------------------------------------------

        system_prompt = """
        You are a strict TRC-to-SQL compiler.
        RULES:
        1. Follow the TRC properly and answer with the right SQL. The TRC determines the SQL not the schema.
        2. DO NOT invent tables or columns.
        3. Use ONLY the tables necessary to express the meaning of the question.
        4. If the question can be answered from a single table, use that table alone.
        5. Do not include unrelated tables, even if they appear in the schema context.
        6. Output ONLY valid SQL, no explanations.
        """

        user_prompt = f"""
        Vector Matches in SCHEMA CONTEXT:
        {result['vector_match_text']}

        Below are examples showing how TRC must be converted into valid SQL.

        Example 1
        Schema:
          table stadium(stadium_id, name, capacity, highest, lowest, average)
          table singer(singer_id, name, country)
          table concert(concert_id, stadium_id, concert_name)

        TRC:
        {{ stadium.name | stadium AND stadium.highest > 35000 AND stadium.lowest < 8000 }}

        SQL:
        SELECT name
        FROM stadium
        WHERE highest > 35000
          AND lowest < 8000;

        EXAMPLE 2:
        Schema:
          table stadium(stadium_id, location, name, capacity, highest, lowest, average)
          table singer(singer_id, name, country, song_name, song_release_year, age, is_male)
          table concert(concert_id, concert_name, theme, stadium_id, year)
          table singer_in_concert(concert_id, singer_id)

        TRC:
          {{ name | singer
                  AND singer_in_concert
                  AND singer_in_concert.singer_id = singer.singer_id
                  AND singer_in_concert.concert_id = concert.concert_id
                  AND concert.stadium_id = stadium.stadium_id
                  AND stadium.capacity > 50000 }}

        SQL:
          SELECT singer.name
          FROM singer
          JOIN singer_in_concert
              ON singer_in_concert.singer_id = singer.singer_id
          JOIN concert
              ON singer_in_concert.concert_id = concert.concert_id
          JOIN stadium
              ON concert.stadium_id = stadium.stadium_id
          WHERE stadium.capacity > 50000;

        Tuple Relational Calculus (TRC):
        {result['trc']}

        Now convert the following TRC (for the schema above) into SQL. Output Only SQL.
        """

        try:
            groq_response = trc2sql_client.chat.completions.create(
                model="llama-3.1-8b-instant",
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                temperature=0.0,
                max_tokens=300,
            )
            predicted_sql = groq_response.choices[0].message.content.strip()

        except Exception as e:
            predicted_sql = f"-- GROQ SQL GENERATION ERROR: {str(e)}"


        # ---- Store result ----
        predictions.append({
            "db_id": db_id,
            "turn_index": turn_idx,
            "utterance": utterance,
            "history_used": history_context,
            "schema_matches": result["schema_matches"],
            "schema_context": result["schema_context"],
            "predicted_trc": result["trc"],
            "predicted_sql": predicted_sql,     # <-- NEW KEY ADDED
            "gold_sql": turn["query"],          # Ground truth SQL
        })

# ------------------------------------------------------------
# SAVE RESULTS
# ------------------------------------------------------------

with open(OUTPUT_PATH, "w") as f:
    json.dump(predictions, f, indent=2)

print(f"Saved {len(predictions)} predictions to {OUTPUT_PATH}")


In [52]:
with open(OUTPUT_PATH, "w") as f:
    json.dump(predictions, f, indent=2)

print(f"Saved {len(predictions)} predictions to {OUTPUT_PATH}")

Saved 46 predictions to /content/sparc_trc_predictions.json
