Skip to content

Commit

Permalink
Typed search attributes (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
cretz committed Nov 2, 2023
1 parent fd938c4 commit 7c0a464
Show file tree
Hide file tree
Showing 12 changed files with 1,417 additions and 261 deletions.
265 changes: 197 additions & 68 deletions temporalio/client.py

Large diffs are not rendered by default.

362 changes: 361 additions & 1 deletion temporalio/common.py
Expand Up @@ -4,29 +4,37 @@

import inspect
import types
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import IntEnum
from typing import (
Any,
Callable,
ClassVar,
Collection,
Generic,
Iterator,
List,
Mapping,
Optional,
Sequence,
Text,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
overload,
)

import google.protobuf.internal.containers
from typing_extensions import ClassVar, TypeAlias
from typing_extensions import ClassVar, NamedTuple, TypeAlias, get_origin

import temporalio.api.common.v1
import temporalio.api.enums.v1
import temporalio.types


@dataclass
Expand Down Expand Up @@ -176,6 +184,358 @@ def __setstate__(self, state: object) -> None:

SearchAttributes: TypeAlias = Mapping[str, SearchAttributeValues]

SearchAttributeValue: TypeAlias = Union[str, int, float, bool, datetime, Sequence[str]]

SearchAttributeValueType = TypeVar(
"SearchAttributeValueType", str, int, float, bool, datetime, Sequence[str]
)


class SearchAttributeIndexedValueType(IntEnum):
"""Server index type of a search attribute."""

TEXT = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_TEXT)
KEYWORD = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_KEYWORD)
INT = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_INT)
DOUBLE = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_DOUBLE)
BOOL = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_BOOL)
DATETIME = int(temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_DATETIME)
KEYWORD_LIST = int(
temporalio.api.enums.v1.IndexedValueType.INDEXED_VALUE_TYPE_KEYWORD_LIST
)


class SearchAttributeKey(ABC, Generic[SearchAttributeValueType]):
"""Typed search attribute key representation.
Use one of the ``for`` static methods here to create a key.
"""

@property
@abstractmethod
def name(self) -> str:
"""Get the name of the key."""
...

@property
@abstractmethod
def indexed_value_type(self) -> SearchAttributeIndexedValueType:
"""Get the server index typed of the key"""
...

@property
@abstractmethod
def value_type(self) -> Type[SearchAttributeValueType]:
"""Get the Python type of value for the key.
This may contain generics which cannot be used in ``isinstance``.
:py:attr:`origin_value_type` can be used instead.
"""
...

@property
def origin_value_type(self) -> Type:
"""Get the Python type of value for the key without generics."""
return get_origin(self.value_type) or self.value_type

@property
def _metadata_type(self) -> str:
index_type = self.indexed_value_type
if index_type == SearchAttributeIndexedValueType.TEXT:
return "Text"
elif index_type == SearchAttributeIndexedValueType.KEYWORD:
return "Keyword"
elif index_type == SearchAttributeIndexedValueType.INT:
return "Int"
elif index_type == SearchAttributeIndexedValueType.DOUBLE:
return "Double"
elif index_type == SearchAttributeIndexedValueType.BOOL:
return "Bool"
elif index_type == SearchAttributeIndexedValueType.DATETIME:
return "Datetime"
elif index_type == SearchAttributeIndexedValueType.KEYWORD_LIST:
return "KeywordList"
raise ValueError(f"Unrecognized type: {self}")

def value_set(
self, value: SearchAttributeValueType
) -> SearchAttributeUpdate[SearchAttributeValueType]:
"""Create a search attribute update to set the given value on this
key.
"""
return _SearchAttributeUpdate[SearchAttributeValueType](self, value)

def value_unset(self) -> SearchAttributeUpdate[SearchAttributeValueType]:
"""Create a search attribute update to unset the value on this key."""
return _SearchAttributeUpdate[SearchAttributeValueType](self, None)

@staticmethod
def for_text(name: str) -> SearchAttributeKey[str]:
"""Create a 'Text' search attribute type."""
return _SearchAttributeKey[str](name, SearchAttributeIndexedValueType.TEXT, str)

@staticmethod
def for_keyword(name: str) -> SearchAttributeKey[str]:
"""Create a 'Keyword' search attribute type."""
return _SearchAttributeKey[str](
name, SearchAttributeIndexedValueType.KEYWORD, str
)

@staticmethod
def for_int(name: str) -> SearchAttributeKey[int]:
"""Create an 'Int' search attribute type."""
return _SearchAttributeKey[int](name, SearchAttributeIndexedValueType.INT, int)

@staticmethod
def for_float(name: str) -> SearchAttributeKey[float]:
"""Create a 'Double' search attribute type."""
return _SearchAttributeKey[float](
name, SearchAttributeIndexedValueType.DOUBLE, float
)

@staticmethod
def for_bool(name: str) -> SearchAttributeKey[bool]:
"""Create a 'Bool' search attribute type."""
return _SearchAttributeKey[bool](
name, SearchAttributeIndexedValueType.BOOL, bool
)

@staticmethod
def for_datetime(name: str) -> SearchAttributeKey[datetime]:
"""Create a 'Datetime' search attribute type."""
return _SearchAttributeKey[datetime](
name, SearchAttributeIndexedValueType.DATETIME, datetime
)

@staticmethod
def for_keyword_list(name: str) -> SearchAttributeKey[Sequence[str]]:
"""Create a 'KeywordList' search attribute type."""
return _SearchAttributeKey[Sequence[str]](
name,
SearchAttributeIndexedValueType.KEYWORD_LIST,
# Generic types not supported yet like this: https://github.com/python/mypy/issues/4717
Sequence[str], # type: ignore
)

@staticmethod
def _from_metadata_type(
name: str, metadata_type: str
) -> Optional[SearchAttributeKey]:
if metadata_type == "Text":
return SearchAttributeKey.for_text(name)
elif metadata_type == "Keyword":
return SearchAttributeKey.for_keyword(name)
elif metadata_type == "Int":
return SearchAttributeKey.for_int(name)
elif metadata_type == "Double":
return SearchAttributeKey.for_float(name)
elif metadata_type == "Bool":
return SearchAttributeKey.for_bool(name)
elif metadata_type == "Datetime":
return SearchAttributeKey.for_datetime(name)
elif metadata_type == "KeywordList":
return SearchAttributeKey.for_keyword_list(name)
return None

@staticmethod
def _guess_from_untyped_values(
name: str, vals: SearchAttributeValues
) -> Optional[SearchAttributeKey]:
if not vals:
return None
elif len(vals) > 1:
if isinstance(vals[0], str):
return SearchAttributeKey.for_keyword_list(name)
elif isinstance(vals[0], str):
return SearchAttributeKey.for_keyword(name)
elif isinstance(vals[0], int):
return SearchAttributeKey.for_int(name)
elif isinstance(vals[0], float):
return SearchAttributeKey.for_float(name)
elif isinstance(vals[0], bool):
return SearchAttributeKey.for_bool(name)
elif isinstance(vals[0], datetime):
return SearchAttributeKey.for_datetime(name)
return None


@dataclass(frozen=True)
class _SearchAttributeKey(SearchAttributeKey[SearchAttributeValueType]):
_name: str
_indexed_value_type: SearchAttributeIndexedValueType
# No supported way in Python to derive this, so we're setting manually
_value_type: Type[SearchAttributeValueType]

@property
def name(self) -> str:
return self._name

@property
def indexed_value_type(self) -> SearchAttributeIndexedValueType:
return self._indexed_value_type

@property
def value_type(self) -> Type[SearchAttributeValueType]:
return self._value_type


class SearchAttributePair(NamedTuple, Generic[SearchAttributeValueType]):
"""A named tuple representing a key/value search attribute pair."""

key: SearchAttributeKey[SearchAttributeValueType]
value: SearchAttributeValueType


class SearchAttributeUpdate(ABC, Generic[SearchAttributeValueType]):
"""Representation of a search attribute update."""

@property
@abstractmethod
def key(self) -> SearchAttributeKey[SearchAttributeValueType]:
"""Key that is being set."""
...

@property
@abstractmethod
def value(self) -> Optional[SearchAttributeValueType]:
"""Value that is being set or ``None`` if being unset."""
...


@dataclass(frozen=True)
class _SearchAttributeUpdate(SearchAttributeUpdate[SearchAttributeValueType]):
_key: SearchAttributeKey[SearchAttributeValueType]
_value: Optional[SearchAttributeValueType]

@property
def key(self) -> SearchAttributeKey[SearchAttributeValueType]:
return self._key

@property
def value(self) -> Optional[SearchAttributeValueType]:
return self._value


@dataclass(frozen=True)
class TypedSearchAttributes(Collection[SearchAttributePair]):
"""Collection of typed search attributes.
This is represented as an immutable collection of
:py:class:`SearchAttributePair`. This can be created passing a sequence of
pairs to the constructor.
"""

search_attributes: Sequence[SearchAttributePair]
"""Underlying sequence of search attribute pairs. Do not mutate this, only
create new ``TypedSearchAttribute`` instances.
These are sorted by key name during construction. Duplicates cannot exist.
"""

empty: ClassVar[TypedSearchAttributes]
"""Class variable representing an empty set of attributes."""

def __post_init__(self):
"""Post-init initialization."""
# Sort
object.__setattr__(
self,
"search_attributes",
sorted(self.search_attributes, key=lambda pair: pair.key.name),
)
# Ensure no duplicates
for i, pair in enumerate(self.search_attributes):
if i > 0 and self.search_attributes[i - 1].key.name == pair.key.name:
raise ValueError(
f"Duplicate search attribute entries found for key {pair.key.name}"
)

def __len__(self) -> int:
"""Get the number of search attributes."""
return len(self.search_attributes)

def __getitem__(
self, key: SearchAttributeKey[SearchAttributeValueType]
) -> SearchAttributeValueType:
"""Get a single search attribute value by key or fail with
``KeyError``.
"""
ret = next((v for k, v in self if k == key), None)
if ret is None:
raise KeyError()
return ret

def __iter__(self) -> Iterator[SearchAttributePair]:
"""Get an iterator over search attribute key/value pairs."""
return iter(self.search_attributes)

def __contains__(self, key: object) -> bool:
"""Check whether this search attribute contains the given key.
This uses key equality so the key must be the same name and type.
"""
return any(v for k, v in self if k == key)

@overload
def get(
self, key: SearchAttributeKey[SearchAttributeValueType]
) -> Optional[SearchAttributeValueType]:
...

@overload
def get(
self,
key: SearchAttributeKey[SearchAttributeValueType],
default: temporalio.types.AnyType,
) -> Union[SearchAttributeValueType, temporalio.types.AnyType]:
...

def get(
self,
key: SearchAttributeKey[SearchAttributeValueType],
default: Optional[Any] = None,
) -> Any:
"""Get an attribute value for a key (or default). This is similar to
dict.get.
"""
try:
return self.__getitem__(key)
except KeyError:
return default

def updated(self, *search_attributes: SearchAttributePair) -> TypedSearchAttributes:
"""Copy this collection, replacing attributes with matching key names or
adding if key name not present.
"""
attrs = list(self.search_attributes)
# Go over each update, replacing matching keys by index or adding
for attr in search_attributes:
existing_index = next(
(i for i, attr in enumerate(attrs) if attr.key.name == attr.key.name),
None,
)
if existing_index is None:
attrs.append(attr)
else:
attrs[existing_index] = attr
return TypedSearchAttributes(attrs)


TypedSearchAttributes.empty = TypedSearchAttributes(search_attributes=[])


def _warn_on_deprecated_search_attributes(
attributes: Optional[Union[SearchAttributes, Any]],
stack_level: int = 2,
) -> None:
if attributes and isinstance(attributes, Mapping):
warnings.warn(
"Dictionary-based search attributes are deprecated",
DeprecationWarning,
stacklevel=1 + stack_level,
)


MetricAttributes: TypeAlias = Mapping[str, Union[str, int, float, bool]]


Expand Down

0 comments on commit 7c0a464

Please sign in to comment.