In [4]:
import re
import yaml
from typing import Dict, List, Set, Tuple
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline

class SchemaParser:
    def __init__(self, schema_yaml: str):
        """
        Initialize schema parser with YAML content
        Args:
            schema_yaml: String containing the YAML schema
        """
        self.schema = yaml.safe_load(schema_yaml)

    def _is_node(self, entity_config: Dict) -> bool:
        """Check if an entity is represented as a node"""
        return entity_config.get('represented_as') == 'node'

    def _is_relationship(self, entity_config: Dict) -> bool:
        """Check if an entity is represented as an edge"""
        return entity_config.get('represented_as') == 'edge'

    def extract_schema_elements(self) -> Tuple[Set[str], Set[str], Set[str]]:
        """
        Extract nodes, relationships, and properties from schema
        Returns:
            Tuple of sets containing node types, relationships, and properties
        """
        node_types = set()
        relationships = set()
        properties = set()

        for entity_name, config in self.schema.items():
            if not isinstance(config, dict):
                continue

            # Extract properties
            if 'properties' in config:
                properties.update(config['properties'].keys())

            # Categorize entity
            if self._is_node(config):
                node_types.add(entity_name)
            elif self._is_relationship(config):
                relationships.add(entity_name)
                # Add relationship source and target if available
                if 'source' in config:
                    node_types.add(config['source'])
                if 'target' in config:
                    node_types.add(config['target'])

        return node_types, relationships, properties

class QueryClassifier:
    def __init__(self, schema_yaml: str):
        """
        Initialize the classifier with a schema
        Args:
            schema_yaml: String containing the YAML schema
        """
        # Parse schema
        parser = SchemaParser(schema_yaml)
        self.node_types, self.relationships, self.properties = parser.extract_schema_elements()

        # Add common variations and synonyms
        self.node_types.update(self._generate_variations(self.node_types))
        self.relationships.update(self._generate_variations(self.relationships))

        # Keywords that strongly indicate database queries
        self.db_indicators = {
            'relationship', 'connected', 'linked', 'between', 'path', 'associated',
            'network', 'graph', 'node', 'relationship', 'property', 'query',
            'find', 'search', 'match', 'return', 'where', 'connect'
        }

        # Create the classifier pipeline
        self.pipeline = Pipeline([
            ('vectorizer', TfidfVectorizer(
                ngram_range=(1, 2),
                stop_words='english'
            )),
            ('classifier', RandomForestClassifier(n_estimators=100))
        ])

    def _generate_variations(self, terms: Set[str]) -> Set[str]:
        """Generate common variations of terms (singular/plural, underscores/spaces)"""
        variations = set()
        for term in terms:
            # Replace underscores with spaces
            term_space = term.replace('_', ' ')
            variations.add(term_space)

            # Add plural forms
            if term.endswith('y'):
                variations.add(term[:-1] + 'ies')
            else:
                variations.add(term + 's')

            # Add common abbreviations
            if len(term.split('_')) > 1:
                abbrev = ''.join(word[0] for word in term.split('_'))
                variations.add(abbrev.upper())

        return variations

    def generate_training_data(self) -> Tuple[List[str], List[str]]:
        """Generate synthetic training data based on schema patterns"""
        # Database queries
        db_queries = []

        # Generate queries based on node types and relationships
        for node in self.node_types:
            for rel in self.relationships:
                db_queries.extend([
                    f"Find {rel} between {node} and other nodes",
                    f"Show me all {node}s with {rel}",
                    f"What {node}s are {rel} to X?",
                    f"Query {node}s based on {rel}"
                ])

        # Generate property-based queries
        for node in self.node_types:
            for prop in self.properties:
                db_queries.extend([
                    f"What is the {prop} of {node}?",
                    f"Find {node}s where {prop} equals X",
                    f"Search {node}s by {prop}"
                ])

        # Limit to prevent explosion
        db_queries = list(set(db_queries))[:200]

        # Embedding space queries (more semantic questions)
        embedding_queries = [
            f"What do the documents say about {node}?"
            for node in self.node_types
        ]

        embedding_queries.extend([
            f"Explain the role of {node} in the context of {rel}"
            for node, rel in zip(
                list(self.node_types)[:10],
                list(self.relationships)[:10]
            )
        ])

        embedding_queries.extend([
            "Summarize the information about this topic",
            "What are the key findings in these documents?",
            "Explain the significance of these results",
            "What conclusions can be drawn from this data?",
            "Describe the main concepts discussed",
            "What is the relationship between these elements?",
            "How do these components interact?",
            "What is known about this subject?",
            "Analyze the patterns in this data",
            "What insights can be gained from these documents?"
        ])

        return db_queries, embedding_queries

    def train(self):
        """Train the classifier on synthetic data"""
        db_queries, embedding_queries = self.generate_training_data()
        X = db_queries + embedding_queries
        y = ['database'] * len(db_queries) + ['embedding'] * len(embedding_queries)
        self.pipeline.fit(X, y)

    def classify_query(self, query: str) -> Dict[str, float]:
        """
        Classify a query as database-related or embedding-space-related
        Returns probability scores for each class
        """
        # Check for strong database indicators
        db_indicator_count = sum(1 for word in self.db_indicators if word.lower() in query.lower())

        # Check for schema-specific terms
        schema_term_count = (
            sum(1 for node in self.node_types if node.lower() in query.lower()) +
            sum(1 for rel in self.relationships if rel.lower() in query.lower()) +
            sum(1 for prop in self.properties if prop.lower() in query.lower())
        )

        # Get model probabilities
        probs = self.pipeline.predict_proba([query])[0]

        # Adjust probabilities based on indicators
        if db_indicator_count > 2 or schema_term_count > 2:
            probs[0] = min(0.9, probs[0] + 0.2)
            probs[1] = 1 - probs[0]

        return {
            'database': float(probs[0]),
            'embedding': float(probs[1])
        }

    def get_features(self, query: str) -> Dict[str, bool]:
        """Analyze query features for explanation purposes"""
        return {
            'contains_node_type': any(node.lower() in query.lower() for node in self.node_types),
            'contains_relationship': any(rel.lower() in query.lower() for rel in self.relationships),
            'contains_property': any(prop.lower() in query.lower() for prop in self.properties),
            'contains_db_indicator': any(ind.lower() in query.lower() for ind in self.db_indicators),
            'schema_term_count': sum(
                1 for term in (self.node_types | self.relationships | self.properties)
                if term.lower() in query.lower()
            )
        }


In [None]:

# Example usage:

# Read schema from file
with open('schema_config.yaml', 'r') as f:
    schema_yaml = f.read()

# Initialize and train classifier
classifier = QueryClassifier(schema_yaml)
classifier.train()

In [11]:
# Example queries
queries = [
    "What genes are connected to BRCA1?",
    "Explain what the documents say about cancer pathways",
    "Find all proteins that interact with TP53"
]
# Classify queries
for query in queries:
    result = classifier.classify_query(query)
    print(f"\nQuery: {query}")
    print(f"Classification: {result}")
    print(f"Features: {classifier.get_features(query)}")


Query: What genes are connected to BRCA1?
Classification: {'database': 0.9, 'embedding': 0.09999999999999998}
Features: {'contains_node_type': True, 'contains_relationship': False, 'contains_property': True, 'contains_db_indicator': True, 'schema_term_count': 2}

Query: Explain what the documents say about cancer pathways
Classification: {'database': 0.0, 'embedding': 1.0}
Features: {'contains_node_type': True, 'contains_relationship': False, 'contains_property': False, 'contains_db_indicator': True, 'schema_term_count': 1}

Query: Find all proteins that interact with TP53
Classification: {'database': 0.82, 'embedding': 0.18}
Features: {'contains_node_type': True, 'contains_relationship': False, 'contains_property': False, 'contains_db_indicator': True, 'schema_term_count': 2}


In [12]:
annotation_queries = [
"What genes are associated with the GO term \"apoptosis\"?",
"Which proteins are part of the \"cell cycle\" pathway?",
"What are the child terms of \"neuron\" in the Cell Ontology?",
"Which genes are expressed in the liver?",
"What SNPs are associated with the BRCA1 gene?",
"Which transcripts are produced by the TP53 gene?",
"What are the protein products of the insulin gene?",
"Which enhancers are active in stem cells?",
"What are the target genes of the miRNA let-7?",
"Which proteins interact with the epidermal growth factor receptor?",
]

# Classify queries
for query in annotation_queries:
    result = classifier.classify_query(query)
    print(f"\nQuery: {query}")
    print(f"Classification: {result}")
    print(f"Features: {classifier.get_features(query)}")


Query: What genes are associated with the GO term "apoptosis"?
Classification: {'database': 0.9, 'embedding': 0.09999999999999998}
Features: {'contains_node_type': True, 'contains_relationship': False, 'contains_property': True, 'contains_db_indicator': True, 'schema_term_count': 3}

Query: Which proteins are part of the "cell cycle" pathway?
Classification: {'database': 0.9, 'embedding': 0.09999999999999998}
Features: {'contains_node_type': True, 'contains_relationship': True, 'contains_property': True, 'contains_db_indicator': True, 'schema_term_count': 6}

Query: What are the child terms of "neuron" in the Cell Ontology?
Classification: {'database': 0.96, 'embedding': 0.04}
Features: {'contains_node_type': False, 'contains_relationship': False, 'contains_property': True, 'contains_db_indicator': False, 'schema_term_count': 1}

Query: Which genes are expressed in the liver?
Classification: {'database': 0.9, 'embedding': 0.09999999999999998}
Features: {'contains_node_type': True, 'co