<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/generic_tools.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
!pip --quiet install neo4j langchain-core langchain-community

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m41.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.5/49.5 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [14]:
import os
from langchain_community.graphs import Neo4jGraph

os.environ["NEO4J_URI"] = "neo4j+s://demo.neo4jlabs.com"
os.environ["NEO4J_USERNAME"] = "recommendations"
os.environ["NEO4J_PASSWORD"] = "recommendations"
os.environ["NEO4J_DATABASE"] = "recommendations"


graph = Neo4jGraph(refresh_schema=False)

In [40]:
from langchain_core.tools import StructuredTool

In [59]:
from typing import Any, Callable, Dict, List, Optional, Union
from functools import wraps
import inspect

def create_filter_function(
    node_label: str,
    properties: Dict[str, type],
    count_field: str = "count",
    grouping_allowed: bool = True
) -> Callable:
    """
    Dynamically creates a filter function based on node properties.

    Args:
        node_label: The Neo4j node label
        properties: Dictionary of property names and their types
        count_field: Name of the count field in the result
        grouping_allowed: Whether grouping by properties is allowed
    """

    def generate_type_hints() -> Dict[str, Any]:
        """Generate type hints for the function parameters"""
        hints = {}
        for prop_name, prop_type in properties.items():
            hints[f"min_{prop_name}"] = Optional[prop_type]
            hints[f"max_{prop_name}"] = Optional[prop_type]
        if grouping_allowed:
            hints["grouping_key"] = Optional[str]
        return hints

    def generate_parameters() -> List[inspect.Parameter]:
        """Generate function parameters"""
        params = []
        for prop_name, prop_type in properties.items():
            if prop_type in (int, float):  # Only numeric types get min/max filters
                params.extend([
                    inspect.Parameter(
                        f"min_{prop_name}",
                        inspect.Parameter.POSITIONAL_OR_KEYWORD,
                        annotation=Optional[prop_type],
                        default=None
                    ),
                    inspect.Parameter(
                        f"max_{prop_name}",
                        inspect.Parameter.POSITIONAL_OR_KEYWORD,
                        annotation=Optional[prop_type],
                        default=None
                    )
                ])
        if grouping_allowed:
            params.append(
                inspect.Parameter(
                    "grouping_key",
                    inspect.Parameter.POSITIONAL_OR_KEYWORD,
                    annotation=Optional[str],
                    default=None
                )
            )
        return params

    def create_filter_conditions(kwargs: Dict[str, Any]) -> List[tuple]:
        """Create filter conditions based on the provided arguments"""
        filters = []
        for prop_name, prop_type in properties.items():
            if prop_type in (int, float):
                min_val = kwargs.get(f"min_{prop_name}")
                max_val = kwargs.get(f"max_{prop_name}")
                if min_val is not None:
                    filters.append((f"n.{prop_name} >= $min_{prop_name}", min_val))
                if max_val is not None:
                    filters.append((f"n.{prop_name} <= $max_{prop_name}", max_val))
        return filters

    def filter_function(*args, **kwargs) -> List[Dict]:
        """The dynamically generated filter function"""
        filters = create_filter_conditions(kwargs)

        # Create parameters dictionary for Neo4j query
        params = {}
        for param_name, value in kwargs.items():
            if value is not None and param_name != "grouping_key":
                # Keep the full parameter name (min_year, max_year, etc.)
                params[param_name] = value

        # Build Cypher query
        where_clause = " AND ".join(condition for condition, _ in filters)

        cypher_statement = f"MATCH (n:{node_label}) "
        if where_clause:
            cypher_statement += f"WHERE {where_clause} "

        grouping_key = kwargs.get("grouping_key")
        return_clause = (
            f"n.`{grouping_key}`, count(n) AS {count_field}"
            if grouping_key
            else f"count(n) AS {count_field}"
        )

        cypher_statement += f"RETURN {return_clause}"

        if grouping_key:
            cypher_statement += f" ORDER BY n.`{grouping_key}`"

        # For debugging
        print(f"Cypher: {cypher_statement}")
        print(f"Params: {params}")

        # Execute query
        return graph.query(cypher_statement, params=params)

    # Create the function with the correct signature
    sig = inspect.Signature(parameters=generate_parameters())
    filter_function.__signature__ = sig
    filter_function.__annotations__ = generate_type_hints()

    return filter_function

# Example usage:
# Define node properties
movie_properties = {
    "year": int,
    "imdbRating": float,
    "title": str,
}

# Create the filter function
movie_count = create_filter_function(
    node_label="Movie",
    properties=movie_properties,
    count_field="movie_count"
)

# Use the generated function
results = movie_count(
    min_year=2000,
    max_year=2020,
    min_rating=7.5,
)

Cypher: MATCH (n:Movie) WHERE n.year >= $min_year AND n.year <= $max_year RETURN count(n) AS movie_count
Params: {'min_year': 2000, 'max_year': 2020, 'min_rating': 7.5}


In [60]:
results

[{'movie_count': 3898}]