In [1]:
import json
import networkx as nx
import matplotlib.pyplot as plt
from typing import Dict, Any, List, Tuple
import spacy
from spacy.matcher import Matcher, PhraseMatcher

class AdvancedSchemaToGraphConverter:
    def __init__(self):
        self.graph = nx.Graph()
        self.nlp = spacy.load("en_core_web_sm")
        self.matcher = Matcher(self.nlp.vocab)
        self.phrase_matcher = PhraseMatcher(self.nlp.vocab)
        
        # Define patterns for common SQL operations
        self.matcher.add("AGGREGATION", [[{"LOWER": {"IN": ["average", "avg", "sum", "count", "max", "min"]}}]])
        self.matcher.add("GROUPING", [[{"LOWER": "group"}, {"LOWER": "by"}]])
        self.matcher.add("ORDERING", [[{"LOWER": {"IN": ["order", "sort"]}}, {"LOWER": "by"}]])
        self.matcher.add("LIMIT", [[{"LOWER": {"IN": ["top", "bottom"]}}, {"POS": "NUM"}]])
        self.matcher.add("TIME_RANGE", [[{"LOWER": {"IN": ["last", "past"]}}, {"POS": "NUM"}, {"LOWER": {"IN": ["day", "week", "month", "year"]}}]])
        
        # Domain-specific enrichments
        self.matcher.add("CUSTOMER_RETENTION", [[{"LOWER": "customer"}, {"LOWER": "retention"}]])
        self.matcher.add("POPULAR_PRODUCTS", [[{"LOWER": "popular"}, {"LOWER": "products"}]])
        self.matcher.add("BOUGHT_TOGETHER", [[{"LOWER": "bought"}, {"LOWER": "together"}]])

    def load_schema(self, schema_json: str):
        schema = json.loads(schema_json)
        self._process_schema(schema)
        self._update_phrase_matcher()

    def _process_schema(self, schema: Dict[str, Any]):
        for table_name, table_info in schema.items():
            self._add_table_node(table_name, table_info)
            self._process_columns(table_name, table_info['columns'])
            self._process_relationships(table_name, table_info.get('relationships', []))

    def _add_table_node(self, table_name: str, table_info: Dict[str, Any]):
        # Adding description and enriched metadata for the table
        self.graph.add_node(table_name, type='table', 
                            description=table_info.get('description', ''),
                            domain_keywords=self._generate_domain_keywords(table_name),
                            temporal_columns=self._detect_temporal_columns(table_info['columns']))

    def _process_columns(self, table_name: str, columns: Dict[str, Any]):
        for column_name, column_info in columns.items():
            column_node = f"{table_name}.{column_name}"
            # Enriching the column with data type, constraints, and statistical metadata
            self.graph.add_node(column_node, type='column', 
                                data_type=column_info['data_type'],
                                constraints=column_info.get('constraints', []),
                                common_values=self._get_common_values(column_name),
                                functional_dependency=self._detect_functional_dependency(table_name, column_name))
            self.graph.add_edge(table_name, column_node, type='has_column')

    def _process_relationships(self, table_name: str, relationships: list):
        for relationship in relationships:
            related_table = relationship['related_table']
            # Adding relationship type and cardinality for richer context
            self.graph.add_edge(table_name, related_table, type='related_to',
                                relationship_type=relationship['type'],
                                cardinality=self._infer_cardinality(relationship['type']))

    def _update_phrase_matcher(self):
        patterns = []
        for node in self.graph.nodes():
            if self.graph.nodes[node]['type'] in ['table', 'column']:
                patterns.append(self.nlp(node.lower()))
        self.phrase_matcher.add("SCHEMA_ELEMENTS", patterns)

    def advanced_question_processing(self, question: str) -> Dict[str, Any]:
        doc = self.nlp(question.lower())
        
        # Extract entities and operations
        entities = [ent.text for ent in doc.ents]
        operations = []
        for match_id, start, end in self.matcher(doc):
            operations.append(doc[start:end].text)
        
        # Get relevant tables and columns
        matches = self.phrase_matcher(doc)
        relevant_elements = [doc[start:end].text for _, start, end in matches]
        
        tables, columns = self._classify_relevant_elements(relevant_elements)
        
        # Expand tables and columns based on the question context
        expanded_tables, expanded_columns = self._expand_relevant_elements(tables, columns, doc)
        
        # Identify potential joins and related tables
        potential_joins = self.find_potential_joins(expanded_tables)
        
        return {
            "entities": entities,
            "operations": operations,
            "tables": expanded_tables,
            "columns": expanded_columns,
            "potential_joins": potential_joins
        }

    def _classify_relevant_elements(self, elements: List[str]) -> Tuple[List[str], List[str]]:
        tables = []
        columns = []
        for element in elements:
            if '.' in element:
                columns.append(element)
            else:
                tables.append(element)
        return tables, columns

    def _expand_relevant_elements(self, tables: List[str], columns: List[str], doc) -> Tuple[List[str], List[str]]:
        expanded_tables = set(tables)
        expanded_columns = set(columns)
        
        # Expand based on relationships
        for table in tables:
            expanded_tables.update(self.get_related_tables(table))
        
        # Expand based on common query patterns
        if any(token.text in ['total', 'sum', 'amount'] for token in doc):
            expanded_columns.add('orders.total_amount')
            expanded_tables.add('orders')
        
        if any(token.text in ['customer', 'user'] for token in doc):
            expanded_tables.add('users')
        
        if any(token.text in ['product', 'item'] for token in doc):
            expanded_tables.add('products')
            expanded_tables.add('order_items')
        
        if any(token.text in ['category'] for token in doc):
            expanded_columns.add('products.category')
        
        if any(token.text in ['bought', 'purchased', 'together'] for token in doc):
            expanded_tables.update(['orders', 'order_items', 'products'])
        
        # Expand columns for all tables
        for table in expanded_tables:
            expanded_columns.update(self.get_table_columns(table))
        
        return list(expanded_tables), list(expanded_columns)

    def get_related_tables(self, table: str) -> List[str]:
        if table not in self.graph:
            return []
        return [node for node in nx.neighbors(self.graph, table) 
                if self.graph.nodes[node]['type'] == 'table']

    def get_table_columns(self, table: str) -> List[str]:
        if table not in self.graph:
            return []
        return [node for node in nx.neighbors(self.graph, table) 
                if self.graph.nodes[node]['type'] == 'column']

    def find_potential_joins(self, tables: List[str]) -> List[tuple]:
        joins = []
        for i in range(len(tables)):
            for j in range(i+1, len(tables)):
                if tables[i] in self.graph and tables[j] in self.graph:
                    try:
                        path = nx.shortest_path(self.graph, tables[i], tables[j])
                        joins.append((tables[i], tables[j], path))
                    except nx.NetworkXNoPath:
                        continue
        return joins

    # Helper methods to enrich schema with richer metadata
    def _generate_domain_keywords(self, table_name: str) -> List[str]:
        # Generating domain-specific keywords
        keywords = {
            'users': ['customer', 'client', 'user'],
            'orders': ['purchase', 'transaction', 'sale'],
            'products': ['item', 'goods', 'merchandise'],
        }
        return keywords.get(table_name, [])

    def _detect_temporal_columns(self, columns: Dict[str, Any]) -> List[str]:
        # Detect columns related to time/date
        temporal_columns = [col for col, info in columns.items() if 'date' in col or 'time' in info['data_type']]
        return temporal_columns

    def _get_common_values(self, column_name: str) -> List[Any]:
        # Simulate common values for demonstration
        common_values = {
            'users.age': [25, 30, 35, 40],
            'products.category': ['Electronics', 'Clothing', 'Home'],
            'orders.total_amount': [100.0, 250.5, 500.0],
        }
        return common_values.get(column_name, [])

    def _detect_functional_dependency(self, table_name: str, column_name: str) -> str:
        # Simulate functional dependency for example
        if table_name == 'orders' and column_name == 'total_amount':
            return 'total_amount = quantity * price'
        return ''

    def _infer_cardinality(self, relationship_type: str) -> str:
        if relationship_type == 'one-to-many':
            return '1-to-N'
        if relationship_type == 'many-to-one':
            return 'N-to-1'
        return 'N-to-N'

In [2]:
# Example usage
converter = AdvancedSchemaToGraphConverter()

In [3]:
schema_json = """
{
  "users": {
    "description": "Table containing user information",
    "columns": {
      "user_id": {"data_type": "int", "constraints": ["primary key"]},
      "name": {"data_type": "varchar"},
      "age": {"data_type": "int"},
      "registration_date": {"data_type": "date"}
    }
  },
  "orders": {
    "description": "Table containing order information",
    "columns": {
      "order_id": {"data_type": "int", "constraints": ["primary key"]},
      "user_id": {"data_type": "int", "constraints": ["foreign key"]},
      "total_amount": {"data_type": "float"},
      "order_date": {"data_type": "date"}
    },
    "relationships": [
      {"related_table": "users", "type": "many-to-one"}
    ]
  },
  "products": {
    "description": "Table containing product information",
    "columns": {
      "product_id": {"data_type": "int", "constraints": ["primary key"]},
      "name": {"data_type": "varchar"},
      "category": {"data_type": "varchar"}
    }
  },
  "order_items": {
    "description": "Table containing ordered products",
    "columns": {
      "order_id": {"data_type": "int", "constraints": ["foreign key"]},
      "product_id": {"data_type": "int", "constraints": ["foreign key"]},
      "quantity": {"data_type": "int"},
      "price": {"data_type": "float"}
    },
    "relationships": [
      {"related_table": "orders", "type": "many-to-one"},
      {"related_table": "products", "type": "many-to-one"}
    ]
  }
}
"""

In [9]:
question = "What is the month-over-month growth rate in total sales for each product category?"
query_context = converter.advanced_question_processing(question)
print(query_context)

{'entities': [], 'operations': [], 'tables': ['products', 'order_items', 'orders'], 'columns': ['orders.total_amount', 'products.category'], 'potential_joins': []}
