# Tool DSL Parsing

In [5]:
import weaviate
from weaviate.classes.query import Filter
from weaviate.classes.query import Metrics # used for aggregation filter (not in this code cell block)
from weaviate.classes.aggregate import GroupByAggregate # used for aggregation filter (not in this code cell block)

def _build_weaviate_filter(filter_string: str) -> Filter:
    def _parse_condition(condition: str) -> Filter:
        parts = condition.split(':')
        if len(parts) < 3:
            raise ValueError(f"Invalid condition: {condition}")

        property, operator, value = parts[0], parts[1], ':'.join(parts[2:])

        if operator == '=':
            return Filter.by_property(property).equal(value)
        elif operator == '!=':
            return Filter.by_property(property).not_equal(value)
        elif operator == '>':
            return Filter.by_property(property).greater_than(float(value))
        elif operator == '<':
            return Filter.by_property(property).less_than(float(value))
        elif operator == '>=':
            return Filter.by_property(property).greater_than_equal(float(value))
        elif operator == '<=':
            return Filter.by_property(property).less_than_equal(float(value))
        elif operator == 'LIKE':
            return Filter.by_property(property).like(value)
        elif operator == 'CONTAINS_ANY':
            return Filter.by_property(property).contains_any(value.split(','))
        elif operator == 'CONTAINS_ALL':
            return Filter.by_property(property).contains_all(value.split(','))
        elif operator == 'WITHIN':
            lat, lon, dist = map(float, value.split(','))
            return Filter.by_property(property).within_geo_range(lat, lon, dist)
        else:
            raise ValueError(f"Unsupported operator: {operator}")

    def _parse_group(group: str) -> Filter:
        if 'AND' in group:
            conditions = [_parse_group(g.strip()) for g in group.split('AND')]
            return Filter.all_of(conditions)
        elif 'OR' in group:
            conditions = [_parse_group(g.strip()) for g in group.split('OR')]
            return Filter.any_of(conditions)
        else:
            return _parse_condition(group)

    # Remove outer parentheses if present
    filter_string = filter_string.strip()
    if filter_string.startswith('(') and filter_string.endswith(')'):
        filter_string = filter_string[1:-1]

    return _parse_group(filter_string)

In [6]:
_build_weaviate_filter("name:=:charles")

_FilterValue(value='charles', operator=<_Operator.EQUAL: 'Equal'>, target='name')

In [7]:
import weaviate
from typing import Any, Dict, List

def _build_weaviate_aggregation(client: weaviate.Client, class_name: str, agg_string: str) -> weaviate.gql.aggregate.AggregateBuilder:
    group_by_property = None
    agg_string = agg_string.strip()
    
    # Check for GROUP_BY
    if agg_string.startswith('GROUP_BY('):
        # Find the closing ')'
        close_paren_index = agg_string.find(')')
        if close_paren_index == -1:
            raise ValueError("Missing closing parenthesis in GROUP_BY")
        group_by_property = agg_string[len('GROUP_BY('):close_paren_index].strip()
        
        # After the ')', expect a ':'
        if len(agg_string) > close_paren_index + 1 and agg_string[close_paren_index + 1] == ':':
            agg_string = agg_string[close_paren_index + 2:]  # Skip the ':' as well
        else:
            # No colon after GROUP_BY, assume no aggregations specified
            agg_string = ''
    else:
        # No GROUP_BY, agg_string remains as is
        pass
    
    # Now, agg_string contains the list of aggregations, separated by ','
    aggregations_list = []
    if agg_string:
        aggregations_list = agg_string.split(',')
    
    # Now, parse each aggregation
    aggregations = []
    for agg in aggregations_list:
        agg = agg.strip()
        # Parse property_name:aggregation_type[optional_parameter]
        # Find the first ':'
        colon_index = agg.find(':')
        if colon_index == -1:
            raise ValueError(f"Invalid aggregation: {agg}")
        property_name = agg[:colon_index].strip()
        rest = agg[colon_index + 1:].strip()
        # Check for optional parameter in aggregation_type
        if '[' in rest:
            # aggregation_type[parameter]
            agg_type, param = rest.split('[', 1)
            agg_type = agg_type.strip()
            param = param.rstrip(']').strip()
        else:
            agg_type = rest
            param = None
        # Add to aggregations list
        aggregation = {
            'property_name': property_name,
            'aggregation_type': agg_type,
            'parameter': param
        }
        aggregations.append(aggregation)
    
    # Now, build the Weaviate aggregation query
    agg_query = client.query.aggregate(class_name)
    
    if group_by_property:
        agg_query = agg_query.with_group_by(group_by_property)
    
    # Build the fields string in GraphQL syntax
    fields = ''
    for aggregation in aggregations:
        property_name = aggregation['property_name']
        agg_type = aggregation['aggregation_type']
        param = aggregation['parameter']
        
        if param:
            # For functions like TOP_OCCURRENCES[limit]
            field = f"{property_name} {{ {agg_type}(limit: {param}) }}"
        else:
            field = f"{property_name} {{ {agg_type} }}"
        fields += field + '\n'
    
    # Add the fields to the query
    agg_query = agg_query.with_fields(fields)
    
    return agg_query