Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ class Client:
Clients do not work across forks since runtimes do not work across forks.
"""

@staticmethod
@classmethod
async def connect(
cls,
target_host: str,
*,
namespace: str = "default",
Expand All @@ -133,7 +134,7 @@ async def connect(
runtime: Optional[temporalio.runtime.Runtime] = None,
http_connect_proxy_config: Optional[HttpConnectProxyConfig] = None,
header_codec_behavior: HeaderCodecBehavior = HeaderCodecBehavior.NO_CODEC,
) -> Client:
) -> Self:
"""Connect to a Temporal server.

Args:
Expand Down Expand Up @@ -209,7 +210,7 @@ def make_lambda(plugin, next):

service_client = await next_function(connect_config)

return Client(
return cls(
service_client,
namespace=namespace,
data_converter=data_converter,
Expand Down
21 changes: 11 additions & 10 deletions temporalio/envconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from typing import Any, Dict, Literal, Mapping, Optional, Union, cast

from typing_extensions import TypeAlias, TypedDict
from typing_extensions import Self, TypeAlias, TypedDict

import temporalio.service
from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig
Expand Down Expand Up @@ -148,12 +148,12 @@ def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]:
client_private_key=_read_source(self.client_private_key),
)

@staticmethod
def from_dict(d: Optional[ClientConfigTLSDict]) -> Optional[ClientConfigTLS]:
@classmethod
def from_dict(cls, d: Optional[ClientConfigTLSDict]) -> Optional[Self]:
"""Create a ClientConfigTLS from a dictionary."""
if not d:
return None
return ClientConfigTLS(
return cls(
disabled=d.get("disabled"),
server_name=d.get("server_name"),
# Note: Bridge uses snake_case, but TOML uses kebab-case which is
Expand Down Expand Up @@ -202,10 +202,10 @@ class ClientConfigProfile:
grpc_meta: Mapping[str, str] = field(default_factory=dict)
"""gRPC metadata."""

@staticmethod
def from_dict(d: ClientConfigProfileDict) -> ClientConfigProfile:
@classmethod
def from_dict(cls, d: ClientConfigProfileDict) -> Self:
"""Create a ClientConfigProfile from a dictionary."""
return ClientConfigProfile(
return cls(
address=d.get("address"),
namespace=d.get("namespace"),
api_key=d.get("api_key"),
Expand Down Expand Up @@ -318,14 +318,15 @@ def to_dict(self) -> Mapping[str, ClientConfigProfileDict]:
"""Convert to a dictionary that can be used for TOML serialization."""
return {k: v.to_dict() for k, v in self.profiles.items()}

@staticmethod
@classmethod
def from_dict(
cls,
d: Mapping[str, Mapping[str, Any]],
) -> ClientConfig:
) -> Self:
"""Create a ClientConfig from a dictionary."""
# We must cast the inner dictionary because the source is often a plain
# Mapping[str, Any] from the bridge or other sources.
return ClientConfig(
return cls(
profiles={
k: ClientConfigProfile.from_dict(cast(ClientConfigProfileDict, v))
for k, v in d.items()
Expand Down
6 changes: 3 additions & 3 deletions temporalio/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class Runtime:
Runtimes do not work across forks.
"""

@staticmethod
def default() -> Runtime:
@classmethod
def default(cls) -> Runtime:
"""Get the default runtime, creating if not already created.

If the default runtime needs to be different, it should be done with
Expand All @@ -49,7 +49,7 @@ def default() -> Runtime:
"""
global _default_runtime
if not _default_runtime:
_default_runtime = Runtime(telemetry=TelemetryConfig())
_default_runtime = cls(telemetry=TelemetryConfig())
return _default_runtime

@staticmethod
Expand Down
18 changes: 10 additions & 8 deletions temporalio/testing/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

import google.protobuf.empty_pb2
from typing_extensions import Self

import temporalio.api.testservice.v1
import temporalio.bridge.testing
Expand Down Expand Up @@ -54,8 +55,8 @@ class WorkflowEnvironment:
to have ``assert`` failures fail the workflow with the assertion error.
"""

@staticmethod
def from_client(client: temporalio.client.Client) -> WorkflowEnvironment:
@classmethod
def from_client(cls, client: temporalio.client.Client) -> Self:
"""Create a workflow environment from the given client.

:py:attr:`supports_time_skipping` will always return ``False`` for this
Expand All @@ -69,12 +70,11 @@ def from_client(client: temporalio.client.Client) -> WorkflowEnvironment:
The workflow environment that runs against the given client.
"""
# Add the assertion interceptor
return WorkflowEnvironment(
_client_with_interceptors(client, _AssertionErrorInterceptor())
)
return cls(_client_with_interceptors(client, _AssertionErrorInterceptor()))

@staticmethod
@classmethod
async def start_local(
cls,
*,
namespace: str = "default",
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
Expand Down Expand Up @@ -234,8 +234,9 @@ async def start_local(
)
raise

@staticmethod
@classmethod
async def start_time_skipping(
cls,
*,
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
interceptors: Sequence[temporalio.client.Interceptor] = [],
Expand Down Expand Up @@ -357,7 +358,8 @@ async def start_time_skipping(
def __init__(self, client: temporalio.client.Client) -> None:
"""Create a workflow environment from a client.

Most users would use a static method instead.
Most users would use a factory methods instead.

"""
self._client = client

Expand Down
11 changes: 7 additions & 4 deletions temporalio/worker/_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import timedelta
from typing import Any, Callable, Literal, Optional, Protocol, Union, runtime_checkable

from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias

import temporalio.bridge.worker
from temporalio.common import WorkerDeploymentVersion
Expand Down Expand Up @@ -310,8 +310,9 @@ def _to_bridge_slot_supplier(
class WorkerTuner(ABC):
"""WorkerTuners allow for the dynamic customization of some aspects of worker configuration"""

@staticmethod
@classmethod
def create_resource_based(
cls,
*,
target_memory_usage: float,
target_cpu_usage: float,
Expand Down Expand Up @@ -341,8 +342,9 @@ def create_resource_based(
nexus,
)

@staticmethod
@classmethod
def create_fixed(
cls,
*,
workflow_slots: Optional[int] = None,
activity_slots: Optional[int] = None,
Expand All @@ -362,8 +364,9 @@ def create_fixed(
FixedSizeSlotSupplier(nexus_slots if nexus_slots else 100),
)

@staticmethod
@classmethod
def create_composite(
cls,
*,
workflow_supplier: SlotSupplier,
activity_supplier: SlotSupplier,
Expand Down
14 changes: 8 additions & 6 deletions temporalio/worker/workflow_sandbox/_restrictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
cast,
)

from typing_extensions import Self

try:
import pydantic
import pydantic_core
Expand Down Expand Up @@ -182,8 +184,8 @@ class SandboxMatcher:
instances.
"""

@staticmethod
def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
@classmethod
def nested_child(cls, path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
"""Create a matcher where the given child is put at the given path.

Args:
Expand All @@ -195,12 +197,12 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
"""
ret = child
for key in reversed(path):
ret = SandboxMatcher(children={key: ret})
ret = cls(children={key: ret})
return ret

access: Set[str] = frozenset() # type: ignore
"""Immutable set of names to match access.

This is often only used for pass through checks and not member restrictions.
If this is used for member restrictions, even importing/accessing the value
will fail as opposed to :py:attr:`use` which is for when it is used.
Expand All @@ -210,7 +212,7 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:

use: Set[str] = frozenset() # type: ignore
"""Immutable set of names to match use.

This is best used for member restrictions on functions/classes because the
restriction will not apply to referencing/importing the item, just when it
is used.
Expand Down Expand Up @@ -246,7 +248,7 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:

exclude: Set[str] = frozenset() # type: ignore
"""Immutable set of names to exclude.

These override anything that may have been matched elsewhere.
"""

Expand Down
10 changes: 7 additions & 3 deletions tests/nexus/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import uuid
from collections.abc import Mapping
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -313,7 +313,9 @@ async def non_serializable_output(
class SuccessfulResponse:
status_code: int
body_json: Optional[Union[dict[str, Any], Callable[[dict[str, Any]], bool]]] = None
headers: Mapping[str, str] = SUCCESSFUL_RESPONSE_HEADERS
headers: Mapping[str, str] = field(
default_factory=lambda: SUCCESSFUL_RESPONSE_HEADERS
)


@dataclass
Expand All @@ -325,7 +327,9 @@ class UnsuccessfulResponse:
# Expected value of inverse of non_retryable attribute of exception.
retryable_exception: bool = True
body_json: Optional[Callable[[dict[str, Any]], bool]] = None
headers: Mapping[str, str] = UNSUCCESSFUL_RESPONSE_HEADERS
headers: Mapping[str, str] = field(
default_factory=lambda: UNSUCCESSFUL_RESPONSE_HEADERS
)


class _TestCase:
Expand Down