|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from copy import deepcopy |
| 6 | +from typing import Any, Callable, Dict, List, Optional |
| 7 | + |
| 8 | +from haystack.dataclasses import ChatMessage |
| 9 | +from haystack.utils import _deserialize_value_with_schema, _serialize_value_with_schema |
| 10 | +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable |
| 11 | +from haystack.utils.type_serialization import deserialize_type, serialize_type |
| 12 | + |
| 13 | +from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values |
| 14 | + |
| 15 | + |
| 16 | +def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]: |
| 17 | + """ |
| 18 | + Convert a schema dictionary to a serializable format. |
| 19 | +
|
| 20 | + Converts each parameter's type and optional handler function into a serializable |
| 21 | + format using type and callable serialization utilities. |
| 22 | +
|
| 23 | + :param schema: Dictionary mapping parameter names to their type and handler configs |
| 24 | + :returns: Dictionary with serialized type and handler information |
| 25 | + """ |
| 26 | + serialized_schema = {} |
| 27 | + for param, config in schema.items(): |
| 28 | + serialized_schema[param] = {"type": serialize_type(config["type"])} |
| 29 | + if config.get("handler"): |
| 30 | + serialized_schema[param]["handler"] = serialize_callable(config["handler"]) |
| 31 | + |
| 32 | + return serialized_schema |
| 33 | + |
| 34 | + |
| 35 | +def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]: |
| 36 | + """ |
| 37 | + Convert a serialized schema dictionary back to its original format. |
| 38 | +
|
| 39 | + Deserializes the type and optional handler function for each parameter from their |
| 40 | + serialized format back into Python types and callables. |
| 41 | +
|
| 42 | + :param schema: Dictionary containing serialized schema information |
| 43 | + :returns: Dictionary with deserialized type and handler configurations |
| 44 | + """ |
| 45 | + deserialized_schema = {} |
| 46 | + for param, config in schema.items(): |
| 47 | + deserialized_schema[param] = {"type": deserialize_type(config["type"])} |
| 48 | + |
| 49 | + if config.get("handler"): |
| 50 | + deserialized_schema[param]["handler"] = deserialize_callable(config["handler"]) |
| 51 | + |
| 52 | + return deserialized_schema |
| 53 | + |
| 54 | + |
| 55 | +def _validate_schema(schema: Dict[str, Any]) -> None: |
| 56 | + """ |
| 57 | + Validate that a schema dictionary meets all required constraints. |
| 58 | +
|
| 59 | + Checks that each parameter definition has a valid type field and that any handler |
| 60 | + specified is a callable function. |
| 61 | +
|
| 62 | + :param schema: Dictionary mapping parameter names to their type and handler configs |
| 63 | + :raises ValueError: If schema validation fails due to missing or invalid fields |
| 64 | + """ |
| 65 | + for param, definition in schema.items(): |
| 66 | + if "type" not in definition: |
| 67 | + raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.") |
| 68 | + if not _is_valid_type(definition["type"]): |
| 69 | + raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}") |
| 70 | + if definition.get("handler") is not None and not callable(definition["handler"]): |
| 71 | + raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None") |
| 72 | + if param == "messages" and definition["type"] is not List[ChatMessage]: |
| 73 | + raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}") |
| 74 | + |
| 75 | + |
| 76 | +class State: |
| 77 | + """ |
| 78 | + A class that wraps a StateSchema and maintains an internal _data dictionary. |
| 79 | +
|
| 80 | + Each schema entry has: |
| 81 | + "parameter_name": { |
| 82 | + "type": SomeType, |
| 83 | + "handler": Optional[Callable[[Any, Any], Any]] |
| 84 | + } |
| 85 | + """ |
| 86 | + |
| 87 | + def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None): |
| 88 | + """ |
| 89 | + Initialize a State object with a schema and optional data. |
| 90 | +
|
| 91 | + :param schema: Dictionary mapping parameter names to their type and handler configs. |
| 92 | + Type must be a valid Python type, and handler must be a callable function or None. |
| 93 | + If handler is None, the default handler for the type will be used. The default handlers are: |
| 94 | + - For list types: `haystack.agents.state.state_utils.merge_lists` |
| 95 | + - For all other types: `haystack.agents.state.state_utils.replace_values` |
| 96 | + :param data: Optional dictionary of initial data to populate the state |
| 97 | + """ |
| 98 | + _validate_schema(schema) |
| 99 | + self.schema = deepcopy(schema) |
| 100 | + if self.schema.get("messages") is None: |
| 101 | + self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists} |
| 102 | + self._data = data or {} |
| 103 | + |
| 104 | + # Set default handlers if not provided in schema |
| 105 | + for definition in self.schema.values(): |
| 106 | + # Skip if handler is already defined and not None |
| 107 | + if definition.get("handler") is not None: |
| 108 | + continue |
| 109 | + # Set default handler based on type |
| 110 | + if _is_list_type(definition["type"]): |
| 111 | + definition["handler"] = merge_lists |
| 112 | + else: |
| 113 | + definition["handler"] = replace_values |
| 114 | + |
| 115 | + def get(self, key: str, default: Any = None) -> Any: |
| 116 | + """ |
| 117 | + Retrieve a value from the state by key. |
| 118 | +
|
| 119 | + :param key: Key to look up in the state |
| 120 | + :param default: Value to return if key is not found |
| 121 | + :returns: Value associated with key or default if not found |
| 122 | + """ |
| 123 | + return deepcopy(self._data.get(key, default)) |
| 124 | + |
| 125 | + def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None: |
| 126 | + """ |
| 127 | + Set or merge a value in the state according to schema rules. |
| 128 | +
|
| 129 | + Value is merged or overwritten according to these rules: |
| 130 | + - if handler_override is given, use that |
| 131 | + - else use the handler defined in the schema for 'key' |
| 132 | +
|
| 133 | + :param key: Key to store the value under |
| 134 | + :param value: Value to store or merge |
| 135 | + :param handler_override: Optional function to override the default merge behavior |
| 136 | + """ |
| 137 | + # If key not in schema, we throw an error |
| 138 | + definition = self.schema.get(key, None) |
| 139 | + if definition is None: |
| 140 | + raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}") |
| 141 | + |
| 142 | + # Get current value from state and apply handler |
| 143 | + current_value = self._data.get(key, None) |
| 144 | + handler = handler_override or definition["handler"] |
| 145 | + self._data[key] = handler(current_value, value) |
| 146 | + |
| 147 | + @property |
| 148 | + def data(self): |
| 149 | + """ |
| 150 | + All current data of the state. |
| 151 | + """ |
| 152 | + return self._data |
| 153 | + |
| 154 | + def has(self, key: str) -> bool: |
| 155 | + """ |
| 156 | + Check if a key exists in the state. |
| 157 | +
|
| 158 | + :param key: Key to check for existence |
| 159 | + :returns: True if key exists in state, False otherwise |
| 160 | + """ |
| 161 | + return key in self._data |
| 162 | + |
| 163 | + def to_dict(self): |
| 164 | + """ |
| 165 | + Convert the State object to a dictionary. |
| 166 | + """ |
| 167 | + serialized = {} |
| 168 | + serialized["schema"] = _schema_to_dict(self.schema) |
| 169 | + serialized["data"] = _serialize_value_with_schema(self._data) |
| 170 | + return serialized |
| 171 | + |
| 172 | + @classmethod |
| 173 | + def from_dict(cls, data: Dict[str, Any]): |
| 174 | + """ |
| 175 | + Convert a dictionary back to a State object. |
| 176 | + """ |
| 177 | + schema = _schema_from_dict(data.get("schema", {})) |
| 178 | + deserialized_data = _deserialize_value_with_schema(data.get("data", {})) |
| 179 | + return State(schema, deserialized_data) |
0 commit comments