# Tool DSL Parsing

In [11]:
import weaviate
from weaviate.classes.query import Filter
from weaviate.classes.query import Metrics # used for aggregation filter (not in this code cell block)
from weaviate.classes.aggregate import GroupByAggregate # used for aggregation filter (not in this code cell block)

def _build_weaviate_filter(filter_string: str) -> Filter:
    def _parse_condition(condition: str) -> Filter:
        parts = condition.split(':')
        if len(parts) < 3:
            raise ValueError(f"Invalid condition: {condition}")

        property, operator, value = parts[0], parts[1], ':'.join(parts[2:])

        if operator == '=':
            return Filter.by_property(property).equal(value)
        elif operator == '!=':
            return Filter.by_property(property).not_equal(value)
        elif operator == '>':
            return Filter.by_property(property).greater_than(float(value))
        elif operator == '<':
            return Filter.by_property(property).less_than(float(value))
        elif operator == '>=':
            return Filter.by_property(property).greater_than_equal(float(value))
        elif operator == '<=':
            return Filter.by_property(property).less_than_equal(float(value))
        elif operator == 'LIKE':
            return Filter.by_property(property).like(value)
        elif operator == 'CONTAINS_ANY':
            return Filter.by_property(property).contains_any(value.split(','))
        elif operator == 'CONTAINS_ALL':
            return Filter.by_property(property).contains_all(value.split(','))
        elif operator == 'WITHIN':
            lat, lon, dist = map(float, value.split(','))
            return Filter.by_property(property).within_geo_range(lat, lon, dist)
        else:
            raise ValueError(f"Unsupported operator: {operator}")

    def _parse_group(group: str) -> Filter:
        if 'AND' in group:
            conditions = [_parse_group(g.strip()) for g in group.split('AND')]
            return Filter.all_of(conditions)
        elif 'OR' in group:
            conditions = [_parse_group(g.strip()) for g in group.split('OR')]
            return Filter.any_of(conditions)
        else:
            return _parse_condition(group)

    # Remove outer parentheses if present
    filter_string = filter_string.strip()
    if filter_string.startswith('(') and filter_string.endswith(')'):
        filter_string = filter_string[1:-1]

    return _parse_group(filter_string)

In [12]:
_build_weaviate_filter("name:=:charles")

_FilterValue(value='charles', operator=<_Operator.EQUAL: 'Equal'>, target='name')

In [15]:
from typing import Tuple, Optional, List
from weaviate.classes.query import Metrics
from weaviate.classes.aggregate import GroupByAggregate
import re

def _build_weaviate_aggregation(agg_string: str) -> Tuple[Optional[GroupByAggregate], List[Metrics]]:
    """
    Parses an aggregation string into Weaviate GroupByAggregate and Metrics objects.
    
    Format:
    GROUP_BY(property) METRICS(property:type[metrics], property2:type[metrics])
    
    Examples:
    - "GROUP_BY(publication) METRICS(wordCount:int[count,mean,max])"
    - "METRICS(rating:num[mean,sum], title:text[count,topOccurrences])"
    - "GROUP_BY(author) METRICS(isPublished:bool[totalTrue,percentageFalse])"
    
    Returns:
    Tuple of (GroupByAggregate, List[Metrics])
    """
    def _parse_metrics(metrics_str: str) -> List[Metrics]:
        # Extract content within METRICS(...)
        metrics_list = []
        
        # Split multiple metrics definitions by comma, but not within brackets
        metrics_parts = []
        current_part = ""
        bracket_count = 0
        
        for char in metrics_str:
            if char == '[':
                bracket_count += 1
            elif char == ']':
                bracket_count -= 1
            elif char == ',' and bracket_count == 0:
                metrics_parts.append(current_part.strip())
                current_part = ""
                continue
            current_part += char
        if current_part:
            metrics_parts.append(current_part.strip())
        
        # Parse each metric definition
        for metric in metrics_parts:
            # Parse property:type[operations]
            match = re.match(r'(\w+):(text|int|num|bool)\[([\w,]+)\]', metric.strip())
            if not match:
                raise ValueError(f"Invalid metrics format: {metric}")
            
            prop_name, data_type, operations = match.groups()
            operations = [op.strip() for op in operations.split(',')]
            
            # Create appropriate Metrics object based on type
            if data_type == 'text':
                metric_obj = Metrics(prop_name).text(
                    count='count' in operations,
                    top_occurrences='topOccurrences' in operations
                )
            elif data_type in ('int', 'num'):
                metric_obj = Metrics(prop_name).number(
                    count='count' in operations,
                    minimum='min' in operations,
                    maximum='max' in operations,
                    mean='mean' in operations,
                    median='median' in operations,
                    mode='mode' in operations,
                    sum_='sum' in operations
                )
            elif data_type == 'bool':
                metric_obj = Metrics(prop_name).boolean(
                    count='count' in operations,
                    total_true='totalTrue' in operations,
                    total_false='totalFalse' in operations,
                    percentage_true='percentageTrue' in operations,
                    percentage_false='percentageFalse' in operations
                )
            else:
                raise ValueError(f"Unsupported data type: {data_type}")
                
            metrics_list.append(metric_obj)
        
        return metrics_list

    def _parse_group_by(group_str: str) -> GroupByAggregate:
        # Extract property name from GROUP_BY(property)
        match = re.match(r'GROUP_BY\((\w+)\)', group_str)
        if not match:
            raise ValueError(f"Invalid GROUP_BY format: {group_str}")
        
        return GroupByAggregate(prop=match.group(1))

    # Initialize return values
    group_by = None
    metrics = []
    
    # Split into parts by space, but not within parentheses
    parts = []
    current_part = ""
    paren_count = 0
    
    for char in agg_string:
        if char == '(':
            paren_count += 1
        elif char == ')':
            paren_count -= 1
        elif char.isspace() and paren_count == 0:
            if current_part:
                parts.append(current_part)
            current_part = ""
            continue
        current_part += char
    if current_part:
        parts.append(current_part)
    
    # Parse each part
    for part in parts:
        if part.startswith('GROUP_BY'):
            if group_by is not None:
                raise ValueError("Multiple GROUP_BY clauses not allowed")
            group_by = _parse_group_by(part)
        elif part.startswith('METRICS'):
            # Extract content within METRICS(...)
            match = re.match(r'METRICS\((.*)\)', part)
            if not match:
                raise ValueError(f"Invalid METRICS format: {part}")
            metrics.extend(_parse_metrics(match.group(1)))
    
    return group_by, metrics

In [16]:
agg_str1 = "METRICS(wordCount:int[count,mean,max])"
group_by1, metrics1 = _build_weaviate_aggregation(agg_str1)

In [17]:
group_by1

In [18]:
metrics1

[_MetricsNumber(property_name='wordCount', count=True, maximum=True, mean=True, median=False, minimum=False, mode=False, sum_=False)]

In [19]:
def test_aggregation_builder():
    # Test case 1: Just metrics
    agg_str1 = "METRICS(wordCount:int[count,mean,max])"
    group_by1, metrics1 = _build_weaviate_aggregation(agg_str1)
    assert group_by1 is None
    assert len(metrics1) == 1
    
    # Test case 2: Group by with metrics
    agg_str2 = "GROUP_BY(publication) METRICS(rating:num[mean,sum])"
    group_by2, metrics2 = _build_weaviate_aggregation(agg_str2)
    assert group_by2.prop == "publication"
    assert len(metrics2) == 1
    
    # Test case 3: Multiple metrics
    agg_str3 = "METRICS(rating:num[mean,sum], isPublished:bool[totalTrue,percentageFalse])"
    group_by3, metrics3 = _build_weaviate_aggregation(agg_str3)
    assert group_by3 is None
    assert len(metrics3) == 2
    
    print("All tests passed!")

In [20]:
test_aggregation_builder()

All tests passed!


In [21]:
from typing import Optional, Dict, Any, List, Union
import weaviate
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.collections import Collection

def query_database(
    weaviate_client: weaviate.WeaviateClient,
    collection_name: str,
    search_query: Optional[str] = None,
    filter_string: Optional[str] = None,
    aggregation_string: Optional[str] = None
) -> Dict[str, Any]:
    """
    Query a Weaviate database with optional search, filter, and aggregation parameters.
    
    Args:
        weaviate_client: The Weaviate client instance
        collection_name: Name of the collection to query
        search_query: Optional search query string for hybrid search
        filter_string: Optional filter string in format "property:operator:value" with AND/OR
        aggregation_string: Optional aggregation string in format "GROUP_BY(prop) METRICS(prop:type[metrics])"
    
    Returns:
        Dict containing query results with keys based on the query type:
        - 'objects': List of objects if search/filter was used
        - 'aggregations': Aggregation results if aggregation was used
    """
    collection = weaviate_client.collections.get(collection_name)
    
    # Parse filter if provided
    filter_obj = None
    if filter_string:
        filter_obj = _build_weaviate_filter(filter_string)
    
    # Parse aggregation if provided
    group_by, metrics = None, None
    if aggregation_string:
        group_by, metrics = _build_weaviate_aggregation(aggregation_string)
    
    result: Dict[str, Any] = {}
    
    # Case 1: Only aggregation
    if aggregation_string and not (search_query or filter_string):
        agg_response = collection.aggregate.over_all(
            group_by=group_by,
            return_metrics=metrics
        )
        result['aggregations'] = {
            'total_count': agg_response.total_count,
            'groups': [{
                'group': g.grouped_by,
                'properties': g.properties
            } for g in agg_response.groups] if agg_response.groups else None
        }
    
    # Case 2: Search with optional filter and/or aggregation
    elif search_query:
        hybrid_response = collection.query.hybrid(
            query=search_query,
            filters=filter_obj,
            return_metadata=MetadataQuery(score=True),
            limit=10  # Configurable
        )
        
        if aggregation_string:
            agg_response = collection.aggregate.over_all(
                group_by=group_by,
                return_metrics=metrics,
                filters=filter_obj  # Apply same filters to aggregation
            )
            result['aggregations'] = {
                'total_count': agg_response.total_count,
                'groups': [{
                    'group': g.grouped_by,
                    'properties': g.properties
                } for g in agg_response.groups] if agg_response.groups else None
            }
        
        result['objects'] = [{
            'properties': obj.properties,
            'score': obj.metadata.score
        } for obj in hybrid_response.objects]
    
    # Case 3: Only filter
    elif filter_string:
        filter_response = collection.query.fetch_objects(
            filters=filter_obj,
            limit=10  # Configurable
        )
        
        if aggregation_string:
            agg_response = collection.aggregate.over_all(
                group_by=group_by,
                return_metrics=metrics,
                filters=filter_obj
            )
            result['aggregations'] = {
                'total_count': agg_response.total_count,
                'groups': [{
                    'group': g.grouped_by,
                    'properties': g.properties
                } for g in agg_response.groups] if agg_response.groups else None
            }
        
        result['objects'] = [{
            'properties': obj.properties
        } for obj in filter_response.objects]
    
    return result

def _handle_error(error_msg: str) -> Dict[str, Any]:
    """Helper function to handle errors uniformly"""
    return {
        'error': error_msg,
        'objects': None,
        'aggregations': None
    }

In [None]:
# Create schemas to test queries
import weaviate
weaviate_client = weaviate.connect_to_local()


In [25]:
import weaviate
weaviate_client = weaviate.connect_to_local()

# Search with filter and aggregation
results = query_database(
    weaviate_client=weaviate_client,
    collection_name="JeopardyQuestion",
    search_query="food",
    filter_string="round:=:Double Jeopardy! AND points:>:300",
    aggregation_string="GROUP_BY(round) METRICS(points:int[count,mean,max])"
)

# Just aggregation
agg_results = query_database(
    weaviate_client,
    "Article",
    aggregation_string="METRICS(rating:num[mean,sum])"
)

# Search with filter
search_results = query_database(
    weaviate_client,
    "Article",
    search_query="machine learning",
    filter_string="rating:>:4.0"
)

WeaviateQueryError: Query call with protocol GRPC search failed with message <AioRpcError of RPC that terminated with:
	status = StatusCode.UNKNOWN
	details = "could not find class JeopardyQuestion in schema"
	debug_error_string = "UNKNOWN:Error received from peer  {created_time:"2024-11-06T11:47:13.032285-05:00", grpc_status:2, grpc_message:"could not find class JeopardyQuestion in schema"}"
>.