# Schema Linking
1. extract_keywords
2. retrieve_context
3. retrieve_entity

In [1]:
qn = {
        "question_id": 93,
        "db_id": "financial",
        "question": "How many male customers who are living in North Bohemia have average salary greater than 8000?",
        "evidence": "Male means that gender = 'M'; A3 refers to region; A11 pertains to average salary.",
        "SQL": "SELECT COUNT(T1.client_id) FROM client AS T1 INNER JOIN district AS T2 ON T1.district_id = T2.district_id WHERE T1.gender = 'M' AND T2.A3 = 'North Bohemia' AND T2.A11 > 8000",
        "difficulty": "moderate"
    }

# Extract Keywords

In [None]:
import ast
from llm.gemini import model
from src.prompt_bank import extract_keywords_template

In [None]:
extract_keywords_prompt = extract_keywords_template.format(
    QUESTION=qn['question'],
    HINT=qn['evidence']
    )
response = model.invoke(extract_keywords_prompt)
keywords = ast.literal_eval(response.content)
keywords

['male',
 'customers',
 'North Bohemia',
 'average salary',
 '8000',
 'gender',
 "'M'",
 'A3',
 'region',
 'A11']

# Retrieve Context

In [None]:
from src.database_utils.db_catalog.preprocess import EMBEDDING_FUNCTION, Chroma
from src.database_utils.db_catalog.search import query_vector_db


def add_description(tables_with_descriptions, retrieved_descriptions):
    """
    Adds descriptions to tables from retrieved descriptions.

    Args:
        tables_with_descriptions (Dict[str, Dict[str, Dict[str, str]]]): The current tables with descriptions.
        retrieved_descriptions (Dict[str, Dict[str, Dict[str, str]]]): The retrieved descriptions.

    Returns:
        Dict[str, Dict[str, Dict[str, str]]]: The updated tables with descriptions.
    """
    for table_name, column_descriptions in retrieved_descriptions.items():
        if table_name not in tables_with_descriptions:
            tables_with_descriptions[table_name] = {}
        for column_name, description in column_descriptions.items():
            if (column_name not in tables_with_descriptions[table_name] or 
                description["score"] > tables_with_descriptions[table_name][column_name]["score"]):
                tables_with_descriptions[table_name][column_name] = description
    return tables_with_descriptions


def format_retrieved_descriptions(retrieved_columns):
    """
    Formats retrieved descriptions by removing the score key.

    Args:
        retrieved_columns: The retrieved columns with descriptions.

    Returns:
        The formatted descriptions.
    """
    for column_descriptions in retrieved_columns.values():
        for column_info in column_descriptions.values():
            column_info.pop("score", None)
    return retrieved_columns


vector_db_path = '/home/chenjie/projects/chase-sql-on-finch/data/bird/financial/context_vector_db'
vector_db = Chroma(persist_directory=vector_db_path, embedding_function=EMBEDDING_FUNCTION)


def retrieve_context(question, evidence, keywords, vector_db, top_k=5):
    tables_with_descriptions = {}
    for keyword in keywords:
        question_based_query = f"{question} {keyword}"
        evidence_based_query = f"{evidence} {keyword}"
        retrieved_question_based_query = query_vector_db(vector_db, question_based_query, top_k)
        retrieved_evidence_based_query = query_vector_db(vector_db, evidence_based_query, top_k)
        tables_with_descriptions = add_description(tables_with_descriptions, retrieved_question_based_query)
        tables_with_descriptions = add_description(tables_with_descriptions, retrieved_evidence_based_query)
    schema_with_descriptions = format_retrieved_descriptions(tables_with_descriptions)
    return schema_with_descriptions

In [None]:
schema_with_descriptions = retrieve_context(qn['question'], qn['evidence'], keywords, vector_db, top_k=5)