In [1]:
import psycopg2
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain_community.llms import Ollama
from protected import pwd_db
import os

Transform catalog to be used in a RAG context

## RAG documentation

In [2]:
import re

def clean_sql_dump(sql_dump: str) -> str:
    """
    Removes SQL-style comments, flattens newlines, and condenses extra spaces.
    """
    # Remove line-based SQL comments (e.g. -- comment)
    sql_dump = re.sub(r'--.*?$', '', sql_dump, flags=re.MULTILINE)
    # Remove /* */ block comments if present
    sql_dump = re.sub(r'/\*.*?\*/', '', sql_dump, flags=re.DOTALL)
    # Replace newlines with spaces
    sql_dump = re.sub(r'\n', ' ', sql_dump)
    # Replace multiple spaces with single space
    sql_dump = re.sub(r'\s{2,}', ' ', sql_dump)
    # Trim leading/trailing whitespace
    sql_dump = sql_dump.strip()
    return sql_dump


def extract_schema_info(sql_dump: str):
    """
    Extracts table definitions, primary keys, foreign keys, and
    a simple 'joins' mapping from a cleaned SQL dump.
    """
    # Regex patterns
    # This pattern handles optional schema name (e.g. CREATE TABLE public.actor ...)
    # and captures everything inside the parentheses up to the matching semicolon.
    table_pattern = re.compile(
        r'CREATE TABLE\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional schema name (group 1)
        r'["]?(\w+)["]?'
        r'\s*\((.*?)\)\s*;', 
        re.IGNORECASE | re.DOTALL
    )

    pkey_pattern = re.compile(
        r'ALTER TABLE ONLY\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional schema name
        r'["]?(\w+)["]?'
        r'\s+ADD CONSTRAINT\s+(\w+)\s+PRIMARY KEY\s*\((.*?)\);',
        re.IGNORECASE | re.DOTALL
    )

    fkey_pattern = re.compile(
        r'ALTER TABLE ONLY\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional schema name
        r'["]?(\w+)["]?'
        r'\s+ADD CONSTRAINT\s+(\w+)\s+FOREIGN KEY\s*\((.*?)\)\s+REFERENCES\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional ref schema
        r'["]?(\w+)["]?\s*\((.*?)\)'
        r'(?:\s+ON UPDATE \w+\s+ON DELETE \w+)?;', 
        re.IGNORECASE | re.DOTALL
    )

    # Data structures to fill
    tables = {}
    primary_keys = {}
    foreign_keys = {}
    joins = {}

    # Extract tables
    for match in table_pattern.finditer(sql_dump):
        schema_name = match.group(1)  # Might be None if no schema specified
        table_name = match.group(2)
        columns_str = match.group(3)
        
        # Key for referencing the table will just be table_name
        # but you could store "schema.table" if needed:
        full_table_name = table_name if not schema_name else f"{schema_name}.{table_name}"
        
        tables[full_table_name] = columns_str

    # Extract primary keys
    for match in pkey_pattern.finditer(sql_dump):
        schema_name = match.group(1)
        table_name = match.group(2)
        pkey_name = match.group(3)
        pkey_columns = match.group(4).strip()

        full_table_name = table_name if not schema_name else f"{schema_name}.{table_name}"
        primary_keys[full_table_name] = {
            'pkey_name': pkey_name,
            'pkey_columns': pkey_columns
        }

    # Extract foreign keys
    for match in fkey_pattern.finditer(sql_dump):
        schema_name = match.group(1)
        table_name = match.group(2)
        fkey_name = match.group(3)
        fkey_columns = match.group(4).strip()
        ref_schema = match.group(5)
        ref_table_name = match.group(6)
        ref_columns = match.group(7).strip()

        full_table_name = table_name if not schema_name else f"{schema_name}.{table_name}"
        full_ref_table_name = ref_table_name if not ref_schema else f"{ref_schema}.{ref_table_name}"

        if full_table_name not in foreign_keys:
            foreign_keys[full_table_name] = []

        foreign_keys[full_table_name].append({
            'fkey_name': fkey_name,
            'fkey_columns': fkey_columns,
            'ref_table': full_ref_table_name,
            'ref_columns': ref_columns
        })

    # For simplicity, let's say "joins" just replicate foreign_keys
    joins = foreign_keys

    return tables, primary_keys, foreign_keys, joins


def create_rag_documents(tables, primary_keys, foreign_keys, joins):
    """
    Converts the extracted schema info into a list of JSON-like documents
    suitable for RAG ingestion. Each document contains:
      - table_name
      - columns (list of columns)
      - primary_key
      - foreign_keys
      - joins
    """
    documents = []

    for table_name, columns_str in tables.items():
        # Split columns by commas that are not inside parentheses
        # (to avoid splitting on something like numeric(4,2)).
        # A simpler approach is to split by lines or semicolons, but let's do a naive approach:
        raw_cols = split_columns(columns_str)

        # Clean each column definition
        cleaned_cols = [clean_column_definition(c) for c in raw_cols if c.strip()]

        doc = {
            'table_name': table_name,
            'columns': cleaned_cols,
            'primary_key': primary_keys.get(table_name, {}),
            'foreign_keys': foreign_keys.get(table_name, []),
            'joins': joins.get(table_name, [])
        }
        documents.append(doc)

    return documents


def split_columns(columns_block: str):
    """
    Split a CREATE TABLE column block into individual column/constraint lines,
    ignoring commas found inside parentheses (like numeric(4,2)).
    """
    # A very common approach is to do a manual parse counting parentheses.
    # For brevity, here’s a quick version:
    results = []
    current = []
    paren_depth = 0

    for char in columns_block:
        if char == '(':
            paren_depth += 1
            current.append(char)
        elif char == ')':
            paren_depth -= 1
            current.append(char)
        elif char == ',' and paren_depth == 0:
            # We reached a top-level comma -> new column
            results.append("".join(current).strip())
            current = []
        else:
            current.append(char)

    # Add the last accumulated column
    if current:
        results.append("".join(current).strip())

    return results


def clean_column_definition(col_def: str) -> str:
    """
    Cleans up one column/constraint definition line by removing extra
    semicolons, repeating spaces, etc.
    """
    # Remove trailing semicolons if any
    col_def = col_def.rstrip(';')
    # Convert multiple spaces to single
    col_def = re.sub(r'\s{2,}', ' ', col_def)
    # Trim
    col_def = col_def.strip()
    return col_def


In [3]:
sql_dump_path = "./context/pagila-schema.sql"
with open(sql_dump_path, 'r') as file:
    sql_dump = file.read()
sql_dump = clean_sql_dump(sql_dump)

tables, primary_keys, foreign_keys, joins = extract_schema_info(sql_dump)
rag_docs = create_rag_documents(tables, primary_keys, foreign_keys, joins)


## Embeddings and vectorizing

In [4]:
import json
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.llms import Ollama
from langchain.docstore.document import Document
# Convert each dict to a text string or store as JSON
documents = []
for doc in rag_docs:
    # Convert dict to a JSON string or any text representation you prefer
    text_content = json.dumps(doc, ensure_ascii=False, indent=2)
    documents.append(Document(page_content=text_content))

# ------------------------------------------------------
# 2) Create embeddings with a Hugging Face model + FAISS vector store
# ------------------------------------------------------
# Example: Using the MiniLM model from SentenceTransformers
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(documents, embedding=embedding_model)


  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
  from .autonotebook import tqdm as notebook_tqdm


### prompt template

In [5]:
prompt_template = PromptTemplate(
    input_variables=["context", "question"],
    template=(
        "Here is the data context :\n"
        "{context}\n\n"
        "Given this context, generate a sql query answering the following question:\n"
        "{question}"
    )
)

## Test

In [6]:
retriever = vectorstore.as_retriever()

qa_chain = RetrievalQA.from_chain_type(
    llm=Ollama(model="llama3"),  # Replace with your LLM
    retriever=retriever,
    # Optionally, pass a custom prompt if you want to strictly control formatting
    # chain_type_kwargs={"prompt": prompt_template}
)

# ------------------------------------------------------
# 5) Test with a sample question
# ------------------------------------------------------
question = "Quelle est la liste des films disponibles avec leur catégorie ?"
response = qa_chain.run(question)

print("Generated SQL Query:", response)

  llm=Ollama(model="llama3"),  # Replace with your LLM
  response = qa_chain.run(question)


Generated SQL Query: A database question!

To answer this, we need to analyze the relationships between tables. We can see that there are two relationships:

1. `film` has a foreign key `language_id` referencing the `language` table.
2. `film_category` has foreign keys `film_id` and `category_id` referencing the `film` and `category` tables, respectively.

This suggests that we need to join the `film` table with the `film_category` table on the `film_id` column, and then join the result with the `category` table on the `category_id` column.

The final query would be:
```sql
SELECT f.title, fc.category_name
FROM film f
JOIN film_category fc ON f.film_id = fc.film_id
JOIN category c ON fc.category_id = c.category_id;
```
This should give us a list of films with their corresponding categories.


# Adding logging context

In [7]:
import psycopg2

def get_logging_context(
):
    """
    Connect to PostgreSQL and retrieve table structures and some rows
    from queries_log and generated_views for an LLM 'logging context'.
    """
    connection_params = {
    'dbname': 'pagila',
    'user': 'postgres',
    'password': pwd_db,
    'host': 'localhost',
    'port': 5432
    }
    # Connect to the database
    conn = psycopg2.connect(**connection_params)
    cursor = conn.cursor()

    # Helper: fetch columns from information_schema
    def fetch_table_structure(table_name):
        cursor.execute("""
            SELECT column_name, data_type, is_nullable
            FROM information_schema.columns
            WHERE table_name = %s
            ORDER BY ordinal_position
        """, (table_name,))
        columns = cursor.fetchall()
        # Build a small text block describing each column
        lines = []
        for col_name, data_type, is_nullable in columns:
            lines.append(f"  - {col_name} {data_type} {'(nullable)' if is_nullable == 'YES' else '(not null)'}")
        structure_str = "Columns:\n" + "\n".join(lines)
        return structure_str

    # Helper: fetch sample rows from each table
    def fetch_sample_rows(table_name):
        cursor.execute(f"SELECT * FROM {table_name} ORDER BY created_at")
        rows = cursor.fetchall()
        # If you want column names, re-fetch from cursor.description:
        col_names = [desc.name for desc in cursor.description]

        lines = []
        for row in rows:
            # Zip column names with the row values
            row_data = ", ".join(f"{col}: {val}" for col, val in zip(col_names, row))
            lines.append("    " + row_data)
        if not lines:
            lines = ["    No rows found."]
        return "\n".join(lines)

    # Build the logging context
    logging_context_lines = ["-- LOGGING SCHEMA METADATA --"]

    for table_name in ["queries_log", "generated_views"]:
        logging_context_lines.append(f"Table: {table_name}")
        # structure
        table_structure = fetch_table_structure(table_name)
        logging_context_lines.append(table_structure)
        # sample rows
        sample_rows = fetch_sample_rows(table_name)
        logging_context_lines.append("Sample rows (up to 5):\n" + sample_rows)
        logging_context_lines.append("")  # blank line

    # Combine into a single text block
    logging_context = "\n".join(logging_context_lines)

    # Clean up
    cursor.close()
    conn.close()

    return logging_context


In [8]:
log_context = get_logging_context()
print(log_context)

-- LOGGING SCHEMA METADATA --
Table: queries_log
Columns:
  - query_id integer (not null)
  - question text (not null)
  - description text (nullable)
  - sql_text text (not null)
  - created_at timestamp without time zone (not null)
Sample rows (up to 5):
    No rows found.

Table: generated_views
Columns:
  - view_id integer (not null)
  - query_id integer (not null)
  - view_name text (not null)
  - created_at timestamp without time zone (not null)
Sample rows (up to 5):
    No rows found.



## schema context

In [9]:
import json
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document

sql_dump_path = "./context/pagila-schema.sql"
with open(sql_dump_path, 'r') as file:
    sql_dump = file.read()
sql_dump = clean_sql_dump(sql_dump)

tables, primary_keys, foreign_keys, joins = extract_schema_info(sql_dump)
rag_docs = create_rag_documents(tables, primary_keys, foreign_keys, joins)

documents = []
for doc in rag_docs:
    text_content = json.dumps(doc, ensure_ascii=False, indent=2)
    documents.append(Document(page_content=text_content))

# Create embeddings + FAISS vector store
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(documents, embedding=embedding_model)

# Create a retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})

In [10]:
from langchain.chains import LLMChain

In [11]:
from langchain.prompts import PromptTemplate

multi_context_prompt = PromptTemplate(
    input_variables=["schema_context", "logging_context", "question"],
    template=(
        "You have two separate pieces of context:\n\n"
        "=== SCHEMA CONTEXT (Main DB) ===\n"
        "{schema_context}\n\n"
        "=== LOGGING CONTEXT (queries_log & generated_views) ===\n"
        "{logging_context}\n\n"
        "The user wants two main outputs:\n\n"
        "1) A valid SQL query (SELECT, JOIN, etc.) that answers the question.\n"
        "2) One or more SQL statements to log this query:\n"
        "   - An INSERT INTO queries_log, with:\n"
        "       * question: the exact user question\n"
        "       * description: a short description of what the query does (NOT NULL)\n"
        "       * sql_text: the exact SQL query from (1)\n"
        "       * created_at: use NOW() or CURRENT_TIMESTAMP\n"
        "     (Do not set them to NULL. Provide real values.)\n\n"
        "   - a CREATE VIEW statement.\n"
        "   - an INSERT INTO generated_views referencing that new view.\n\n"
        "Important:\n"
        "- Do NOT output NULL for 'description' or 'sql_text'.\n"
        "- Use the exact same SQL text for 'sql_text' that you generated in step (1).\n"
        "- If query_id and view_id are SERIAL or auto-increment, use 'DEFAULT' or omit the column, or use RETURNING.\n"
        "- Please produce these statements in **four** separate code blocks, labeled:\n"
        "  1) -- Final Query\n"
        "  2) -- Insert into queries_log\n"
        "  3) -- Create or replace view\n"
        "  4) -- Insert into generated_views\n\n"
        "The question is:\n"
        "\"{question}\"\n\n"
        "Now provide the standardized output exactly as requested.\n"
    )
)


In [12]:
# 1) Retrieve from vectorstore
question = "Quelle est la liste des films disponibles avec leur catégorie ?"
retrieved_docs = retriever.get_relevant_documents(question)

# Combine them into a single string for the 'schema_context'
schema_context_str = "\n".join(doc.page_content for doc in retrieved_docs)

# 2) Retrieve logging context from Postgres
logging_context_str = get_logging_context()

# 3) Format the final prompt
final_prompt = multi_context_prompt.format(
    schema_context=schema_context_str,
    logging_context=logging_context_str,
    question=question
)

# 4) Send it to the LLM
# from langchain.llms import Ollama
from langchain_ollama import OllamaLLM

# llm = Ollama(model="llama3")
llm = OllamaLLM(model="llama3")
response = llm(final_prompt)

print("LLM Response:\n", response)

  retrieved_docs = retriever.get_relevant_documents(question)
  response = llm(final_prompt)


LLM Response:
 Here are the outputs:

**1) Final Query**
SELECT m.title, c.name
FROM film_category fc
JOIN category c ON fc.category_id = c.category_id
JOIN film m ON fc.film_id = m.film_id;

**2) Insert into queries_log**
INSERT INTO queries_log (question, description, sql_text, created_at)
VALUES ('Quelle est la liste des films disponibles avec leur catégorie ?', 'List of available movies with their categories', 
         'SELECT m.title, c.name
          FROM film_category fc
          JOIN category c ON fc.category_id = c.category_id
          JOIN film m ON fc.film_id = m.film_id;', NOW());

**3) Create or replace view**
CREATE OR REPLACE VIEW available_movies AS
SELECT m.title, c.name
FROM film_category fc
JOIN category c ON fc.category_id = c.category_id
JOIN film m ON fc.film_id = m.film_id;

**4) Insert into generated_views**
INSERT INTO generated_views (view_name, query_id)
VALUES ('available_movies', DEFAULT);

Please note that I've assumed the `query_id` is a serial column 

In [13]:
import os
import uuid

def parse_llm_output_to_sql_files(
    llm_output: str,
    output_dir: str = "sql_files"
) -> None:
    """
    Parse an LLM response containing up to four markers:
      **1) -- Final Query**
      **2) -- Insert into queries_log**
      **3) -- Create or replace view**
      **4) -- Insert into generated_views**
    
    and write each section that follows into a separate .sql file.

    Each file is saved in either 'logging' or 'views' subfolder within 'output_dir',
    and is appended with an automatically generated unique ID (UUID).

    Parameters
    ----------
    llm_output : str
        The LLM's generated text containing the four labeled sections.
    output_dir : str, optional
        The base directory where 'logging' and 'views' subfolders will be created.

    Example
    -------
    test_response = \"\"\"
    **1) -- Final Query**
    SELECT * FROM film;

    **2) -- Insert into queries_log**
    INSERT INTO queries_log(...);

    **3) -- Create or replace view**
    CREATE OR REPLACE VIEW...;

    **4) -- Insert into generated_views**
    INSERT INTO generated_views(...);
    \"\"\"

    parse_llm_output_to_sql_files(test_response)
    """

    # Automatically generate a unique ID for this parsing run
    unique_id = uuid.uuid4().hex  # or str(uuid.uuid4())

    # Define the markers (in the order we expect them)
    markers = [
        "**1) -- Final Query**",
        "**2) -- Insert into queries_log**",
        "**3) -- Create or replace view**",
        "**4) -- Insert into generated_views**"
    ]

    # Map each marker to (base_filename, subfolder)
    # We'll append the unique_id to base_filename later when writing
    marker_file_map = {
        "**1) -- Final Query**":           ("final_query", "views"),
        "**2) -- Insert into queries_log**": ("insert_into_queries_log", "logging"),
        "**3) -- Create or replace view**":  ("create_or_replace_view", "views"),
        "**4) -- Insert into generated_views**": ("insert_into_generated_views", "logging"),
    }

    # Prepare a dictionary to hold lines for each marker
    content_map = {m: [] for m in markers}

    # Split the LLM output by lines
    lines = llm_output.splitlines()
    current_marker = None

    # Collect lines under each marker
    for line in lines:
        line_stripped = line.strip()
        if line_stripped in markers:
            current_marker = line_stripped
        else:
            if current_marker:
                content_map[current_marker].append(line)

    # Ensure the base output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Write out each marker's content to the corresponding file & subfolder
    nonempty_count = 0
    for marker, (base_name, subfolder) in marker_file_map.items():
        block_content = "\n".join(content_map[marker]).strip()
        if block_content:
            subpath = os.path.join(output_dir, subfolder)
            os.makedirs(subpath, exist_ok=True)

            filename = f"{base_name}_{unique_id}.sql"
            file_path = os.path.join(subpath, filename)

            with open(file_path, "w", encoding="utf-8") as f:
                f.write(block_content + "\n")

            nonempty_count += 1

    print(f"Parsed {nonempty_count} section(s). Files written under '{output_dir}' with unique ID '{unique_id}'.")



parse_llm_output_to_sql_files(response, output_dir="sql_files")


Parsed 0 section(s). Files written under 'sql_files' with unique ID '814534bfe520410e9ff160e53278fdbf'.
