# Schema Linking Agent Test

This notebook implements and tests the Schema Linking Agent that identifies and ranks the most relevant database schema elements for processed query parts.

In [ ]:
from dotenv import load_dotenv
import json
import re
from typing import Dict, Any, List, Optional, Tuple
from difflib import SequenceMatcher
import logging

load_dotenv()

# Import unified schemas from our centralized location
from schemas import (
    # Schema Linking types
    SchemaLinkingInput,
    SchemaLinkingOutput,
    SchemaElement,
    JoinPath,
    ExtractedEntitiesAndIntent,
    
    # Error types
    SchemaLinkingError,
    
    # Database schemas
    SCHEMAS,
    QUERY_PATTERNS
)

logger = logging.getLogger(__name__)

In [None]:
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from schema_manager import SchemaManager

## Define Data Structures

## Define Schema Linking System Prompt

In [None]:
SCHEMA_LINKING_SYSTEM_PROMPT = """You are an expert Schema Linking Agent that maps natural language queries to database schema elements.

Your responsibilities:
1. Identify relevant tables and columns for query parts
2. Rank schema elements by relevance
3. Suggest join paths when multiple tables are needed
4. Provide clear rationale for each mapping
5. Note any unresolved elements

Consider:
- Exact matches: Query mentions match schema names
- Semantic matches: Query concepts match schema descriptions
- Domain knowledge: Use provided business rules and synonyms
- Data types: Match query intent with column data types
- Foreign keys: Identify relationships between tables
- Value examples: Match query values with sample data

Scoring guidelines:
- 0.9-1.0: Exact match or very high confidence
- 0.7-0.9: Strong semantic match with domain knowledge
- 0.5-0.7: Moderate match, likely relevant
- Below 0.5: Weak match, possibly relevant

Always prefer specific column matches over table-only matches."""

## Implement the Schema Linking Agent

In [ ]:
class SchemaLinkingAgent:
    """Agent that links natural language queries to database schema elements."""
    
    def __init__(self, config: Optional[Dict] = None):
        """Initialize the Schema Linking Agent."""
        self.config = config or {}
        self.model = self.config.get('model', 'gpt-4o')
        self.model_client = OpenAIChatCompletionClient(model=self.model)
        
        # Initialize schema manager if provided
        if 'schema_manager' in self.config:
            self.schema_manager = self.config['schema_manager']
        else:
            # Create default schema manager
            self.schema_manager = self._create_default_schema_manager()
    
    def _create_default_schema_manager(self):
        """Create a default schema manager instance."""
        return SchemaManager(
            data_path=self.config.get('data_path', '../data/bird/dev_databases'),
            tables_json_path=self.config.get('tables_json_path', '../data/bird/dev_tables.json'),
            dataset_name=self.config.get('dataset_name', 'bird'),
            lazy=True
        )
    
    async def link_schema(self, input_data: SchemaLinkingInput) -> SchemaLinkingOutput:
        """
        Main entry point for schema linking.
        
        Args:
            input_data: SchemaLinkingInput containing query part and context
            
        Returns:
            SchemaLinkingOutput with linked schema elements
            
        Raises:
            SchemaLinkingError: If linking fails
        """
        try:
            # Get schema information
            database_schema = input_data.database_schema or await self._get_schema_description(input_data.database_id)
            
            # Perform schema linking
            relevant_elements = await self._identify_relevant_elements(
                input_data.processed_natural_language,
                input_data.extracted_entities_and_intent,
                database_schema,
                input_data.domain_knowledge_context
            )
            
            # Identify join paths if multiple tables
            join_paths = await self._identify_join_paths(
                relevant_elements,
                database_schema
            )
            
            # Identify unresolved elements
            unresolved = await self._identify_unresolved_elements(
                input_data.processed_natural_language,
                input_data.extracted_entities_and_intent,
                relevant_elements
            )
            
            # Calculate overall confidence
            overall_confidence = self._calculate_overall_confidence(relevant_elements)
            
            # Create output
            output = SchemaLinkingOutput(
                relevant_schema_elements=relevant_elements,
                proposed_join_paths=join_paths,
                overall_linking_confidence=overall_confidence,
                unresolved_elements_notes=unresolved
            )
            
            return output
            
        except Exception as e:
            logger.error(f"Schema linking failed: {str(e)}")
            raise SchemaLinkingError(f"Failed to link schema: {str(e)}")
    
    async def _get_schema_description(self, database_id: str) -> Dict:
        """Get detailed schema description for a database."""
        try:
            # Use the unified SCHEMAS if available
            if database_id in SCHEMAS:
                return self._format_unified_schema(SCHEMAS[database_id])
            
            # Otherwise, use schema manager
            db_info = self.schema_manager.db2dbjsons.get(database_id, {})
            
            # Load detailed info if not already loaded
            if database_id not in self.schema_manager.db2infos:
                self.schema_manager.db2infos[database_id] = self.schema_manager._load_single_db_info(database_id)
            
            detailed_info = self.schema_manager.db2infos.get(database_id, {})
            
            # Format schema description
            schema_desc = {
                "tables": {},
                "foreign_keys": [],
                "primary_keys": {}
            }
            
            # Add table and column information
            for table_name in db_info.get('table_names_original', []):
                columns_info = []
                
                # Get column descriptions
                if table_name in detailed_info.get('desc_dict', {}):
                    for col_tuple in detailed_info['desc_dict'][table_name]:
                        col_name, description, data_type = col_tuple
                        
                        # Get value examples
                        values = ""
                        if table_name in detailed_info.get('value_dict', {}):
                            for val_tuple in detailed_info['value_dict'][table_name]:
                                if val_tuple[0] == col_name:
                                    values = val_tuple[1]
                                    break
                        
                        columns_info.append({
                            "name": col_name,
                            "description": description,
                            "data_type": data_type,
                            "value_examples": values
                        })
                
                schema_desc["tables"][table_name] = columns_info
                
                # Add primary keys
                if table_name in detailed_info.get('pk_dict', {}):
                    schema_desc["primary_keys"][table_name] = detailed_info['pk_dict'][table_name]
            
            # Add foreign keys
            for table_name, fks in detailed_info.get('fk_dict', {}).items():
                for fk in fks:
                    from_col, to_table, to_col = fk
                    schema_desc["foreign_keys"].append({
                        "from_table": table_name,
                        "from_column": from_col,
                        "to_table": to_table,
                        "to_column": to_col
                    })
            
            return schema_desc
            
        except Exception as e:
            logger.error(f"Failed to get schema description: {e}")
            return {"error": str(e), "tables": {}, "foreign_keys": [], "primary_keys": {}}
    
    def _format_unified_schema(self, schema_data: Dict) -> Dict:
        """Format unified schema data to expected format."""
        formatted = {
            "tables": {},
            "foreign_keys": schema_data.get("foreign_keys", []),
            "primary_keys": {}
        }
        
        for table_name, table_info in schema_data.get("tables", {}).items():
            columns = []
            for col_name, col_info in table_info.get("columns", {}).items():
                columns.append({
                    "name": col_name,
                    "data_type": col_info.get("type", ""),
                    "description": col_info.get("description", ""),
                    "value_examples": col_info.get("examples", "")
                })
                
                # Track primary keys
                if col_info.get("primary_key"):
                    if table_name not in formatted["primary_keys"]:
                        formatted["primary_keys"][table_name] = []
                    formatted["primary_keys"][table_name].append(col_name)
            
            formatted["tables"][table_name] = columns
        
        return formatted
    
    async def _identify_relevant_elements(
        self,
        query_part: str,
        entities_and_intent: ExtractedEntitiesAndIntent,
        schema_desc: Dict,
        domain_knowledge: Optional[Dict]
    ) -> List[SchemaElement]:
        """Identify relevant schema elements for a query part."""
        relevant_elements = []
        query_lower = query_part.lower()
        
        # Extract entities
        metrics = entities_and_intent.metrics
        dimensions = entities_and_intent.dimensions
        filters = entities_and_intent.filters
        
        # Check each table and column
        for table_name, columns in schema_desc.get("tables", {}).items():
            table_score = 0.0
            table_rationale = []
            
            # Check table name match
            if self._fuzzy_match(table_name, query_lower):
                table_score = 0.8
                table_rationale.append(f"Table name '{table_name}' matches query")
            
            # Check columns
            column_matches = []
            for col_info in columns:
                col_name = col_info["name"]
                col_desc = col_info.get("description", "")
                col_values = col_info.get("value_examples", "")
                col_type = col_info.get("data_type", "")
                
                col_score = 0.0
                col_rationale = []
                
                # Exact column name match
                if self._fuzzy_match(col_name, query_lower):
                    col_score = 0.9
                    col_rationale.append(f"Column name '{col_name}' matches query")
                
                # Check against extracted entities
                for metric in metrics:
                    if self._fuzzy_match(metric, col_name):
                        col_score = max(col_score, 0.85)
                        col_rationale.append(f"Column matches metric '{metric}'")
                
                for dimension in dimensions:
                    if self._fuzzy_match(dimension, col_name):
                        col_score = max(col_score, 0.85)
                        col_rationale.append(f"Column matches dimension '{dimension}'")
                
                # Check filters
                for filter_item in filters:
                    field = filter_item.get("field", "")
                    if self._fuzzy_match(field, col_name):
                        col_score = max(col_score, 0.9)
                        col_rationale.append(f"Column matches filter field '{field}'")
                
                # Add column if relevant
                if col_score > 0.5:
                    relevant_elements.append(SchemaElement(
                        element_name=f"{table_name}.{col_name}",
                        element_type="Column",
                        table_name=table_name,
                        column_name=col_name,
                        relevance_score=col_score,
                        mapping_rationale="; ".join(col_rationale),
                        data_type=col_type
                    ))
                    column_matches.append(col_name)
                    table_score = max(table_score, col_score * 0.9)
            
            # Add table if relevant
            if table_score > 0.5 or column_matches:
                if column_matches:
                    table_rationale.append(f"Contains relevant columns: {', '.join(column_matches)}")
                
                relevant_elements.append(SchemaElement(
                    element_name=table_name,
                    element_type="Table",
                    table_name=table_name,
                    relevance_score=table_score,
                    mapping_rationale="; ".join(table_rationale)
                ))
        
        # Sort by relevance score
        relevant_elements.sort(key=lambda x: x.relevance_score, reverse=True)
        
        # Use LLM for sophisticated matching if needed
        if len(relevant_elements) < 3:
            llm_elements = await self._llm_schema_matching(
                query_part, entities_and_intent, schema_desc, domain_knowledge
            )
            relevant_elements.extend(llm_elements)
            relevant_elements.sort(key=lambda x: x.relevance_score, reverse=True)
        
        return relevant_elements
    
    def _fuzzy_match(self, term1: str, term2: str) -> bool:
        """Perform fuzzy matching between terms."""
        term1_lower = term1.lower()
        term2_lower = term2.lower()
        
        # Exact match
        if term1_lower in term2_lower or term2_lower in term1_lower:
            return True
        
        # Similarity ratio
        ratio = SequenceMatcher(None, term1_lower, term2_lower).ratio()
        return ratio > 0.7
    
    async def _llm_schema_matching(
        self,
        query_part: str,
        entities_and_intent: ExtractedEntitiesAndIntent,
        schema_desc: Dict,
        domain_knowledge: Optional[Dict]
    ) -> List[SchemaElement]:
        """Use LLM for sophisticated schema matching."""
        # Format entities and intent for prompt
        entities_dict = {
            "metrics": entities_and_intent.metrics,
            "dimensions": entities_and_intent.dimensions,
            "filters": entities_and_intent.filters,
            "primary_goal": entities_and_intent.primary_goal
        }
        
        prompt = f"""
        Match this query part to relevant schema elements:
        
        Query: {query_part}
        Intent: {json.dumps(entities_dict, indent=2)}
        
        Schema tables and columns:
        {json.dumps(schema_desc.get('tables', {}), indent=2)}
        
        Domain knowledge:
        {json.dumps(domain_knowledge, indent=2) if domain_knowledge else 'None'}
        
        Identify the most relevant tables and columns with scores (0-1).
        Return as JSON list of {{"element": "table.column", "type": "Column", "table": "table_name", "column": "col_name", "score": 0.8, "reason": "why"}}.
        """
        
        messages = [
            {"role": "system", "content": SCHEMA_LINKING_SYSTEM_PROMPT},
            {"role": "user", "content": prompt}
        ]
        
        try:
            response = await self.model_client.create(messages=messages)
            content = response.choices[0].message.content
            
            # Parse LLM response
            elements = []
            json_match = re.search(r'\[.*\]', content, re.DOTALL)
            if json_match:
                matches = json.loads(json_match.group())
                for match in matches:
                    elements.append(SchemaElement(
                        element_name=match["element"],
                        element_type=match["type"],
                        table_name=match.get("table", match["element"].split(".")[0]),
                        column_name=match.get("column", match["element"].split(".")[-1] if "." in match["element"] else None),
                        relevance_score=float(match["score"]),
                        mapping_rationale=match["reason"]
                    ))
        except Exception as e:
            logger.warning(f"LLM matching failed: {e}")
        
        return elements
    
    async def _identify_join_paths(
        self,
        relevant_elements: List[SchemaElement],
        schema_desc: Dict
    ) -> List[JoinPath]:
        """Identify join paths between relevant tables."""
        join_paths = []
        
        # Get unique table names
        tables = set()
        for elem in relevant_elements:
            tables.add(elem.table_name)
        
        # Find join paths using foreign keys
        for fk in schema_desc.get("foreign_keys", []):
            from_table = fk["from_table"]
            to_table = fk["to_table"]
            
            if from_table in tables and to_table in tables:
                join_paths.append(JoinPath(
                    from_table=from_table,
                    to_table=to_table,
                    from_column=fk["from_column"],
                    to_column=fk["to_column"],
                    join_condition=f"{from_table}.{fk['from_column']} = {to_table}.{fk['to_column']}",
                    join_type="INNER",
                    confidence=0.95
                ))
        
        # Infer additional joins if needed
        if len(tables) > 1 and not join_paths:
            await self._infer_join_paths(tables, schema_desc, join_paths)
        
        return join_paths
    
    async def _infer_join_paths(
        self,
        tables: set,
        schema_desc: Dict,
        join_paths: List[JoinPath]
    ):
        """Infer join paths when foreign keys are not available."""
        table_list = list(tables)
        for i in range(len(table_list)):
            for j in range(i + 1, len(table_list)):
                table1, table2 = table_list[i], table_list[j]
                
                # Check for common column names
                cols1 = {col["name"] for col in schema_desc.get("tables", {}).get(table1, [])}
                cols2 = {col["name"] for col in schema_desc.get("tables", {}).get(table2, [])}
                
                common_cols = cols1.intersection(cols2)
                for col in common_cols:
                    if any(keyword in col.lower() for keyword in ["id", "key", "code"]):
                        join_paths.append(JoinPath(
                            from_table=table1,
                            to_table=table2,
                            from_column=col,
                            to_column=col,
                            join_condition=f"{table1}.{col} = {table2}.{col}",
                            join_type="INNER",
                            confidence=0.7
                        ))
    
    async def _identify_unresolved_elements(
        self,
        query_part: str,
        entities_and_intent: ExtractedEntitiesAndIntent,
        relevant_elements: List[SchemaElement]
    ) -> List[str]:
        """Identify query elements that couldn't be mapped to schema."""
        unresolved = []
        
        # Get all mapped terms
        mapped_terms = set()
        for elem in relevant_elements:
            if elem.column_name:
                mapped_terms.add(elem.column_name.lower())
            mapped_terms.add(elem.table_name.lower())
        
        # Check metrics
        for metric in entities_and_intent.metrics:
            if metric.lower() not in mapped_terms and metric.lower() not in ["count", "sum", "average", "max", "min"]:
                unresolved.append(f"Metric '{metric}' not found in schema")
        
        # Check dimensions  
        for dimension in entities_and_intent.dimensions:
            if dimension.lower() not in mapped_terms:
                unresolved.append(f"Dimension '{dimension}' not found in schema")
        
        # Check filters
        for filter_item in entities_and_intent.filters:
            field = filter_item.get("field", "")
            if field.lower() not in mapped_terms and field != "condition":
                unresolved.append(f"Filter field '{field}' not found in schema")
        
        return unresolved
    
    def _calculate_overall_confidence(self, relevant_elements: List[SchemaElement]) -> float:
        """Calculate overall linking confidence."""
        if not relevant_elements:
            return 0.0
        
        # Weighted average of top elements
        top_elements = relevant_elements[:5]
        total_weight = 0
        weighted_sum = 0
        
        for i, elem in enumerate(top_elements):
            weight = 1 / (i + 1)  # Higher weight for top elements
            weighted_sum += elem.relevance_score * weight
            total_weight += weight
        
        return weighted_sum / total_weight if total_weight > 0 else 0.0

## Initialize Components and Test

In [ ]:
# Initialize schema linking agent with configuration
config = {
    'model': 'gpt-4o',
    'data_path': '../data/bird/dev_databases',
    'tables_json_path': '../data/bird/dev_tables.json',
    'dataset_name': 'bird'
}

schema_linker = SchemaLinkingAgent(config=config)

### Test with Simple Query

In [ ]:
# Test with a simple query part using the new interface
async def test_simple_schema_linking():
    # Create entities and intent
    entities_and_intent = ExtractedEntitiesAndIntent(
        metrics=["list"],
        dimensions=["schools", "county"],
        filters=[{"field": "county", "operator": "=", "value": "Alameda"}],
        primary_goal="retrieve",
        confidence=0.8
    )
    
    # Create input
    input_data = SchemaLinkingInput(
        processed_natural_language="Show all schools in Alameda county",
        extracted_entities_and_intent=entities_and_intent,
        database_id="california_schools",
        database_schema=None  # Will use default
    )
    
    try:
        result = await schema_linker.link_schema(input_data)
        
        print(f"Overall Confidence: {result.overall_linking_confidence:.2f}")
        print(f"\nRelevant Schema Elements:")
        for elem in result.relevant_schema_elements[:5]:  # Show top 5
            print(f"  {elem.element_type}: {elem.element_name}")
            print(f"    Score: {elem.relevance_score:.2f}")
            print(f"    Rationale: {elem.mapping_rationale}")
        
        print(f"\nProposed Join Paths:")
        for join in result.proposed_join_paths:
            print(f"  {join.from_table} → {join.to_table}")
            print(f"    Condition: {join.join_condition}")
            print(f"    Confidence: {join.confidence:.2f}")
        
        if result.unresolved_elements_notes:
            print(f"\nUnresolved Elements:")
            for note in result.unresolved_elements_notes:
                print(f"  - {note}")
        
        return result
    except SchemaLinkingError as e:
        print(f"Schema linking failed: {e}")
        return None

# Run the test
result = await test_simple_schema_linking()

### Test with Complex Query Requiring Joins

In [ ]:
# Test with a query requiring joins
async def test_join_schema_linking():
    # Create entities for a join query
    entities_and_intent = ExtractedEntitiesAndIntent(
        metrics=["average", "SAT math scores"],
        dimensions=["charter schools"],
        filters=[{"field": "charter status", "operator": "=", "value": 1}],
        primary_goal="aggregation",
        confidence=0.85
    )
    
    input_data = SchemaLinkingInput(
        processed_natural_language="Calculate average SAT math scores for charter schools",
        extracted_entities_and_intent=entities_and_intent,
        database_id="california_schools",
        database_schema=None
    )
    
    try:
        result = await schema_linker.link_schema(input_data)
        
        print(f"Overall Confidence: {result.overall_linking_confidence:.2f}")
        
        # Group elements by table
        tables = {}
        for elem in result.relevant_schema_elements:
            if elem.element_type == "Table":
                if elem.table_name not in tables:
                    tables[elem.table_name] = {"table": elem, "columns": []}
            elif elem.element_type == "Column":
                if elem.table_name not in tables:
                    tables[elem.table_name] = {"table": None, "columns": []}
                tables[elem.table_name]["columns"].append(elem)
        
        print("\nRelevant Tables and Columns:")
        for table_name, info in tables.items():
            table_elem = info["table"]
            if table_elem:
                print(f"\n{table_name} (Score: {table_elem.relevance_score:.2f})")
            else:
                print(f"\n{table_name}")
            
            for col in info["columns"]:
                print(f"  - {col.column_name} ({col.data_type}) - Score: {col.relevance_score:.2f}")
        
        print("\nJoin Strategy:")
        if result.proposed_join_paths:
            for join in result.proposed_join_paths:
                print(f"  JOIN {join.from_table} AND {join.to_table}")
                print(f"    ON {join.join_condition}")
                print(f"    Type: {join.join_type}")
                print(f"    Confidence: {join.confidence:.2f}")
        else:
            print("  No joins needed or couldn't identify join paths")
        
        return result
    except SchemaLinkingError as e:
        print(f"Schema linking failed: {e}")
        return None

# Run the test
join_result = await test_join_schema_linking()

### Test with Value-Based Matching

In [None]:
# Test with value-based matching
await schema_linker.query("""
Link schema for this query part:
Query: "Find schools with excellence rate over 0.8"
Database: california_schools
Entities and Intent: {
    "metrics": ["excellence rate"],
    "dimensions": ["schools"],
    "filters": [{"field": "excellence rate", "operator": ">", "value": 0.8}],
    "primary_goal": "filtering"
}
Domain Knowledge: {
    "business_rules": ["Excellence rate = NumGE1500 / NumTstTakr"]
}
""")

### Test Direct Function Call

In [ ]:
# Direct function call example with new interface
async def test_direct_linking():
    # Create structured input
    entities_and_intent = ExtractedEntitiesAndIntent(
        metrics=["count"],
        dimensions=["department", "salary range"],
        filters=[{"field": "year", "operator": "=", "value": 2023}],
        primary_goal="aggregation",
        confidence=0.8
    )
    
    input_data = SchemaLinkingInput(
        processed_natural_language="Count employees by department and salary range for year 2023",
        extracted_entities_and_intent=entities_and_intent,
        database_id="financial",
        database_schema=None,
        domain_knowledge_context={
            "synonyms": {
                "employee": ["staff", "worker"],
                "department": ["dept", "division"]
            }
        }
    )
    
    try:
        result = await schema_linker.link_schema(input_data)
        
        # Convert to JSON-serializable format
        result_dict = {
            "overall_confidence": result.overall_linking_confidence,
            "relevant_elements": [
                {
                    "name": elem.element_name,
                    "type": elem.element_type,
                    "score": elem.relevance_score,
                    "rationale": elem.mapping_rationale,
                    "table": elem.table_name,
                    "column": elem.column_name
                }
                for elem in result.relevant_schema_elements[:5]
            ],
            "join_paths": [
                {
                    "from": join.from_table,
                    "to": join.to_table,
                    "condition": join.join_condition,
                    "type": join.join_type,
                    "confidence": join.confidence
                }
                for join in result.proposed_join_paths
            ],
            "unresolved": result.unresolved_elements_notes
        }
        
        print(json.dumps(result_dict, indent=2))
        return result
    except SchemaLinkingError as e:
        print(f"Schema linking failed: {e}")
        return None

# Run the test
direct_result = await test_direct_linking()

### Test with Unresolved Elements

In [None]:
# Test with query containing unresolvable elements
await schema_linker.query("""
Link schema for this query part:
Query: "Show student performance metrics by teacher quality index"
Database: california_schools
Entities and Intent: {
    "metrics": ["student performance", "teacher quality index"],
    "dimensions": ["schools"],
    "filters": [],
    "primary_goal": "analysis"
}
""")

## Advanced Example: Complex Multi-Table Query

In [None]:
# Complex multi-table query
complex_entities_and_intent = {
    "metrics": ["average salary", "transaction count"],
    "dimensions": ["district", "account type"],
    "filters": [
        {"field": "transaction amount", "operator": ">", "value": 1000},
        {"field": "account age", "operator": ">", "value": "1 year"}
    ],
    "primary_goal": "aggregation"
}

await schema_linker.query(f"""
Link schema for this complex query part:
Query: "Average district salary and transaction count for accounts older than 1 year with transactions over $1000"
Database: financial
Entities and Intent: {json.dumps(complex_entities_and_intent, indent=2)}
Domain Knowledge: {{
    "business_rules": [
        "Average salary is A11 in district table",
        "Accounts link to districts via district_id",
        "Transactions link to accounts via account_id"
    ]
}}
""")

In [None]:
# Close connections
await schema_linker.close()