diff --git a/temporalio/client.py b/temporalio/client.py index 71f43d6d9..770a51392 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -112,8 +112,9 @@ class Client: Clients do not work across forks since runtimes do not work across forks. """ - @staticmethod + @classmethod async def connect( + cls, target_host: str, *, namespace: str = "default", @@ -133,7 +134,7 @@ async def connect( runtime: Optional[temporalio.runtime.Runtime] = None, http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None, header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC, - ) -> Client: + ) -> Self: """Connect to a Temporal server. Args: @@ -209,7 +210,7 @@ def make_lambda(plugin, next): service_client = await next_function(connect_config) - return Client( + return cls( service_client, namespace=namespace, data_converter=data_converter, diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index 0dc14be65..6f016dc75 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Dict, Literal, Mapping, Optional, Union, cast -from typing_extensions import TypeAlias, TypedDict +from typing_extensions import Self, TypeAlias, TypedDict import temporalio.service from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig @@ -148,12 +148,12 @@ def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]: client_private_key=_read_source(self.client_private_key), ) - @staticmethod - def from_dict(d: Optional[ClientConfigTLSDict]) -> Optional[ClientConfigTLS]: + @classmethod + def from_dict(cls, d: Optional[ClientConfigTLSDict]) -> Optional[Self]: """Create a ClientConfigTLS from a dictionary.""" if not d: return None - return ClientConfigTLS( + return cls( disabled=d.get("disabled"), server_name=d.get("server_name"), # Note: Bridge uses snake_case, but TOML uses kebab-case which is @@ -202,10 +202,10 @@ class ClientConfigProfile: grpc_meta: Mapping[str, str] = field(default_factory=dict) """gRPC metadata.""" - @staticmethod - def from_dict(d: ClientConfigProfileDict) -> ClientConfigProfile: + @classmethod + def from_dict(cls, d: ClientConfigProfileDict) -> Self: """Create a ClientConfigProfile from a dictionary.""" - return ClientConfigProfile( + return cls( address=d.get("address"), namespace=d.get("namespace"), api_key=d.get("api_key"), @@ -318,14 +318,15 @@ def to_dict(self) -> Mapping[str, ClientConfigProfileDict]: """Convert to a dictionary that can be used for TOML serialization.""" return {k: v.to_dict() for k, v in self.profiles.items()} - @staticmethod + @classmethod def from_dict( + cls, d: Mapping[str, Mapping[str, Any]], - ) -> ClientConfig: + ) -> Self: """Create a ClientConfig from a dictionary.""" # We must cast the inner dictionary because the source is often a plain # Mapping[str, Any] from the bridge or other sources. - return ClientConfig( + return cls( profiles={ k: ClientConfigProfile.from_dict(cast(ClientConfigProfileDict, v)) for k, v in d.items() diff --git a/temporalio/runtime.py b/temporalio/runtime.py index 84b683941..64fa12192 100644 --- a/temporalio/runtime.py +++ b/temporalio/runtime.py @@ -37,8 +37,8 @@ class Runtime: Runtimes do not work across forks. """ - @staticmethod - def default() -> Runtime: + @classmethod + def default(cls) -> Runtime: """Get the default runtime, creating if not already created. If the default runtime needs to be different, it should be done with @@ -49,7 +49,7 @@ def default() -> Runtime: """ global _default_runtime if not _default_runtime: - _default_runtime = Runtime(telemetry=TelemetryConfig()) + _default_runtime = cls(telemetry=TelemetryConfig()) return _default_runtime @staticmethod diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index 7c74f5cc6..8e054e56b 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -20,6 +20,7 @@ ) import google.protobuf.empty_pb2 +from typing_extensions import Self import temporalio.api.testservice.v1 import temporalio.bridge.testing @@ -54,8 +55,8 @@ class WorkflowEnvironment: to have ``assert`` failures fail the workflow with the assertion error. """ - @staticmethod - def from_client(client: temporalio.client.Client) -> WorkflowEnvironment: + @classmethod + def from_client(cls, client: temporalio.client.Client) -> Self: """Create a workflow environment from the given client. :py:attr:`supports_time_skipping` will always return ``False`` for this @@ -69,12 +70,11 @@ def from_client(client: temporalio.client.Client) -> WorkflowEnvironment: The workflow environment that runs against the given client. """ # Add the assertion interceptor - return WorkflowEnvironment( - _client_with_interceptors(client, _AssertionErrorInterceptor()) - ) + return cls(_client_with_interceptors(client, _AssertionErrorInterceptor())) - @staticmethod + @classmethod async def start_local( + cls, *, namespace: str = "default", data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, @@ -234,8 +234,9 @@ async def start_local( ) raise - @staticmethod + @classmethod async def start_time_skipping( + cls, *, data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default, interceptors: Sequence[temporalio.client.Interceptor] = [], @@ -357,7 +358,8 @@ async def start_time_skipping( def __init__(self, client: temporalio.client.Client) -> None: """Create a workflow environment from a client. - Most users would use a static method instead. + Most users would use a factory methods instead. + """ self._client = client diff --git a/temporalio/worker/_tuning.py b/temporalio/worker/_tuning.py index cfe95aac4..84abd87ac 100644 --- a/temporalio/worker/_tuning.py +++ b/temporalio/worker/_tuning.py @@ -7,7 +7,7 @@ from datetime import timedelta from typing import Any, Callable, Literal, Optional, Protocol, Union, runtime_checkable -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias import temporalio.bridge.worker from temporalio.common import WorkerDeploymentVersion @@ -310,8 +310,9 @@ def _to_bridge_slot_supplier( class WorkerTuner(ABC): """WorkerTuners allow for the dynamic customization of some aspects of worker configuration""" - @staticmethod + @classmethod def create_resource_based( + cls, *, target_memory_usage: float, target_cpu_usage: float, @@ -341,8 +342,9 @@ def create_resource_based( nexus, ) - @staticmethod + @classmethod def create_fixed( + cls, *, workflow_slots: Optional[int] = None, activity_slots: Optional[int] = None, @@ -362,8 +364,9 @@ def create_fixed( FixedSizeSlotSupplier(nexus_slots if nexus_slots else 100), ) - @staticmethod + @classmethod def create_composite( + cls, *, workflow_supplier: SlotSupplier, activity_supplier: SlotSupplier, diff --git a/temporalio/worker/workflow_sandbox/_restrictions.py b/temporalio/worker/workflow_sandbox/_restrictions.py index baad22fcb..69050c476 100644 --- a/temporalio/worker/workflow_sandbox/_restrictions.py +++ b/temporalio/worker/workflow_sandbox/_restrictions.py @@ -34,6 +34,8 @@ cast, ) +from typing_extensions import Self + try: import pydantic import pydantic_core @@ -182,8 +184,8 @@ class SandboxMatcher: instances. """ - @staticmethod - def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher: + @classmethod + def nested_child(cls, path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher: """Create a matcher where the given child is put at the given path. Args: @@ -195,12 +197,12 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher: """ ret = child for key in reversed(path): - ret = SandboxMatcher(children={key: ret}) + ret = cls(children={key: ret}) return ret access: Set[str] = frozenset() # type: ignore """Immutable set of names to match access. - + This is often only used for pass through checks and not member restrictions. If this is used for member restrictions, even importing/accessing the value will fail as opposed to :py:attr:`use` which is for when it is used. @@ -210,7 +212,7 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher: use: Set[str] = frozenset() # type: ignore """Immutable set of names to match use. - + This is best used for member restrictions on functions/classes because the restriction will not apply to referencing/importing the item, just when it is used. @@ -246,7 +248,7 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher: exclude: Set[str] = frozenset() # type: ignore """Immutable set of names to exclude. - + These override anything that may have been matched elsewhere. """ diff --git a/tests/nexus/test_handler.py b/tests/nexus/test_handler.py index c805a967c..1f3420da3 100644 --- a/tests/nexus/test_handler.py +++ b/tests/nexus/test_handler.py @@ -20,7 +20,7 @@ import uuid from collections.abc import Mapping from concurrent.futures.thread import ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from types import MappingProxyType from typing import Any, Callable, Optional, Union @@ -313,7 +313,9 @@ async def non_serializable_output( class SuccessfulResponse: status_code: int body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None - headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS + headers: Mapping[str, str] = field( + default_factory=lambda: SUCCESSFUL_RESPONSE_HEADERS + ) @dataclass @@ -325,7 +327,9 @@ class UnsuccessfulResponse: # Expected value of inverse of non_retryable attribute of exception. retryable_exception: bool = True body_json: Optional[Callable[[dict[str, Any]], bool]] = None - headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS + headers: Mapping[str, str] = field( + default_factory=lambda: UNSUCCESSFUL_RESPONSE_HEADERS + ) class _TestCase: