In [1]:
import psycopg2
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain_community.llms import Ollama
import os

Transform catalog to be used in a RAG context

## RAG documentation

In [32]:
import re

def clean_sql_dump(sql_dump: str) -> str:
    """
    Removes SQL-style comments, flattens newlines, and condenses extra spaces.
    """
    # Remove line-based SQL comments (e.g. -- comment)
    sql_dump = re.sub(r'--.*?$', '', sql_dump, flags=re.MULTILINE)
    # Remove /* */ block comments if present
    sql_dump = re.sub(r'/\*.*?\*/', '', sql_dump, flags=re.DOTALL)
    # Replace newlines with spaces
    sql_dump = re.sub(r'\n', ' ', sql_dump)
    # Replace multiple spaces with single space
    sql_dump = re.sub(r'\s{2,}', ' ', sql_dump)
    # Trim leading/trailing whitespace
    sql_dump = sql_dump.strip()
    return sql_dump


def extract_schema_info(sql_dump: str):
    """
    Extracts table definitions, primary keys, foreign keys, and
    a simple 'joins' mapping from a cleaned SQL dump.
    """
    # Regex patterns
    # This pattern handles optional schema name (e.g. CREATE TABLE public.actor ...)
    # and captures everything inside the parentheses up to the matching semicolon.
    table_pattern = re.compile(
        r'CREATE TABLE\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional schema name (group 1)
        r'["]?(\w+)["]?'
        r'\s*\((.*?)\)\s*;', 
        re.IGNORECASE | re.DOTALL
    )

    pkey_pattern = re.compile(
        r'ALTER TABLE ONLY\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional schema name
        r'["]?(\w+)["]?'
        r'\s+ADD CONSTRAINT\s+(\w+)\s+PRIMARY KEY\s*\((.*?)\);',
        re.IGNORECASE | re.DOTALL
    )

    fkey_pattern = re.compile(
        r'ALTER TABLE ONLY\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional schema name
        r'["]?(\w+)["]?'
        r'\s+ADD CONSTRAINT\s+(\w+)\s+FOREIGN KEY\s*\((.*?)\)\s+REFERENCES\s+'
        r'(?:["]?(\w+)["]?\.)?'   # optional ref schema
        r'["]?(\w+)["]?\s*\((.*?)\)'
        r'(?:\s+ON UPDATE \w+\s+ON DELETE \w+)?;', 
        re.IGNORECASE | re.DOTALL
    )

    # Data structures to fill
    tables = {}
    primary_keys = {}
    foreign_keys = {}
    joins = {}

    # Extract tables
    for match in table_pattern.finditer(sql_dump):
        schema_name = match.group(1)  # Might be None if no schema specified
        table_name = match.group(2)
        columns_str = match.group(3)
        
        # Key for referencing the table will just be table_name
        # but you could store "schema.table" if needed:
        full_table_name = table_name if not schema_name else f"{schema_name}.{table_name}"
        
        tables[full_table_name] = columns_str

    # Extract primary keys
    for match in pkey_pattern.finditer(sql_dump):
        schema_name = match.group(1)
        table_name = match.group(2)
        pkey_name = match.group(3)
        pkey_columns = match.group(4).strip()

        full_table_name = table_name if not schema_name else f"{schema_name}.{table_name}"
        primary_keys[full_table_name] = {
            'pkey_name': pkey_name,
            'pkey_columns': pkey_columns
        }

    # Extract foreign keys
    for match in fkey_pattern.finditer(sql_dump):
        schema_name = match.group(1)
        table_name = match.group(2)
        fkey_name = match.group(3)
        fkey_columns = match.group(4).strip()
        ref_schema = match.group(5)
        ref_table_name = match.group(6)
        ref_columns = match.group(7).strip()

        full_table_name = table_name if not schema_name else f"{schema_name}.{table_name}"
        full_ref_table_name = ref_table_name if not ref_schema else f"{ref_schema}.{ref_table_name}"

        if full_table_name not in foreign_keys:
            foreign_keys[full_table_name] = []

        foreign_keys[full_table_name].append({
            'fkey_name': fkey_name,
            'fkey_columns': fkey_columns,
            'ref_table': full_ref_table_name,
            'ref_columns': ref_columns
        })

    # For simplicity, let's say "joins" just replicate foreign_keys
    joins = foreign_keys

    return tables, primary_keys, foreign_keys, joins


def create_rag_documents(tables, primary_keys, foreign_keys, joins):
    """
    Converts the extracted schema info into a list of JSON-like documents
    suitable for RAG ingestion. Each document contains:
      - table_name
      - columns (list of columns)
      - primary_key
      - foreign_keys
      - joins
    """
    documents = []

    for table_name, columns_str in tables.items():
        # Split columns by commas that are not inside parentheses
        # (to avoid splitting on something like numeric(4,2)).
        # A simpler approach is to split by lines or semicolons, but let's do a naive approach:
        raw_cols = split_columns(columns_str)

        # Clean each column definition
        cleaned_cols = [clean_column_definition(c) for c in raw_cols if c.strip()]

        doc = {
            'table_name': table_name,
            'columns': cleaned_cols,
            'primary_key': primary_keys.get(table_name, {}),
            'foreign_keys': foreign_keys.get(table_name, []),
            'joins': joins.get(table_name, [])
        }
        documents.append(doc)

    return documents


def split_columns(columns_block: str):
    """
    Split a CREATE TABLE column block into individual column/constraint lines,
    ignoring commas found inside parentheses (like numeric(4,2)).
    """
    # A very common approach is to do a manual parse counting parentheses.
    # For brevity, here’s a quick version:
    results = []
    current = []
    paren_depth = 0

    for char in columns_block:
        if char == '(':
            paren_depth += 1
            current.append(char)
        elif char == ')':
            paren_depth -= 1
            current.append(char)
        elif char == ',' and paren_depth == 0:
            # We reached a top-level comma -> new column
            results.append("".join(current).strip())
            current = []
        else:
            current.append(char)

    # Add the last accumulated column
    if current:
        results.append("".join(current).strip())

    return results


def clean_column_definition(col_def: str) -> str:
    """
    Cleans up one column/constraint definition line by removing extra
    semicolons, repeating spaces, etc.
    """
    # Remove trailing semicolons if any
    col_def = col_def.rstrip(';')
    # Convert multiple spaces to single
    col_def = re.sub(r'\s{2,}', ' ', col_def)
    # Trim
    col_def = col_def.strip()
    return col_def


In [39]:
sql_dump_path = "./context/pagila-schema.sql"
with open(sql_dump_path, 'r') as file:
    sql_dump = file.read()
sql_dump = clean_sql_dump(sql_dump)

tables, primary_keys, foreign_keys, joins = extract_schema_info(sql_dump)
rag_docs = create_rag_documents(tables, primary_keys, foreign_keys, joins)


## Embeddings and vectorizing

In [40]:
import json
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.llms import Ollama
from langchain.docstore.document import Document
# Convert each dict to a text string or store as JSON
documents = []
for doc in rag_docs:
    # Convert dict to a JSON string or any text representation you prefer
    text_content = json.dumps(doc, ensure_ascii=False, indent=2)
    documents.append(Document(page_content=text_content))

# ------------------------------------------------------
# 2) Create embeddings with a Hugging Face model + FAISS vector store
# ------------------------------------------------------
# Example: Using the MiniLM model from SentenceTransformers
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(documents, embedding=embedding_model)


### prompt template

In [41]:
prompt_template = PromptTemplate(
    input_variables=["context", "question"],
    template=(
        "Here is the data context :\n"
        "{context}\n\n"
        "Given this context, generate a sql query answering the following question:\n"
        "{question}"
    )
)

## Test

In [42]:
retriever = vectorstore.as_retriever()

qa_chain = RetrievalQA.from_chain_type(
    llm=Ollama(model="llama3"),  # Replace with your LLM
    retriever=retriever,
    # Optionally, pass a custom prompt if you want to strictly control formatting
    # chain_type_kwargs={"prompt": prompt_template}
)

# ------------------------------------------------------
# 5) Test with a sample question
# ------------------------------------------------------
question = "Quelle est la liste des films disponibles avec leur catégorie ?"
response = qa_chain.run(question)

print("Generated SQL Query:", response)

Generated SQL Query: To answer this question, we need to join the `film` table with the `film_category` table.

From the context, we can see that the `film_category` table has a foreign key `film_id` that references the `film_id` in the `film` table. This means that for each row in the `film_category` table, there is a corresponding film in the `film` table.

To get the list of films with their categories, we can perform an inner join between the two tables on the `film_id` column. The resulting query would be:

```
SELECT f.title, fc.category_name
FROM film f
INNER JOIN film_category fc ON f.film_id = fc.film_id;
```

However, this assumes that there is a `category_name` column in the `category` table, which does not exist in the provided context. To get the category name, we would need to join the `film_category` table with the `category` table on the `category_id` column.

Here's an updated query:

```
SELECT f.title, c.name AS category_name
FROM film f
INNER JOIN film_category fc O