Skip to content

Commit

Permalink
Merge pull request #82 from xadrianzetx/follow-pep-585
Browse files Browse the repository at this point in the history
Follow PEP 585 – Type Hinting Generics In Standard Collections
  • Loading branch information
xadrianzetx committed Jul 26, 2023
2 parents f1cce60 + bc9ac14 commit 8a6ea3e
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 80 deletions.
5 changes: 3 additions & 2 deletions optuna_distributed/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import logging
from typing import Optional

from rich.logging import RichHandler


_default_handler: Optional[logging.Handler] = None
_default_handler: logging.Handler | None = None


def _get_library_logger() -> logging.Logger:
Expand Down
9 changes: 4 additions & 5 deletions optuna_distributed/eventloop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from datetime import datetime
from typing import Optional
from typing import Tuple
from typing import Type

from optuna.study import Study
from optuna.trial import TrialState
Expand Down Expand Up @@ -45,8 +44,8 @@ def __init__(
def run(
self,
terminal: Terminal,
timeout: Optional[float] = None,
catch: Tuple[Type[Exception], ...] = (),
timeout: float | None = None,
catch: tuple[type[Exception], ...] = (),
) -> None:
"""Starts the event loop.
Expand Down
13 changes: 7 additions & 6 deletions optuna_distributed/ipc/queue.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import pickle
from typing import Optional

from dask.distributed import Queue as DaskQueue

Expand Down Expand Up @@ -31,9 +32,9 @@ class Queue(IPCPrimitive):
def __init__(
self,
publishing: str,
recieving: Optional[str] = None,
timeout: Optional[int] = None,
max_retries: Optional[int] = None,
recieving: str | None = None,
timeout: int | None = None,
max_retries: int | None = None,
) -> None:
self._publishing = publishing
self._recieving = recieving
Expand All @@ -43,8 +44,8 @@ def __init__(

self._timeout = timeout
self._max_retries = max_retries
self._publisher: Optional[DaskQueue] = None
self._subscriber: Optional[DaskQueue] = None
self._publisher: DaskQueue | None = None
self._subscriber: DaskQueue | None = None
self._initialized = False

def _initialize(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion optuna_distributed/managers/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import abc
from abc import ABC
from collections.abc import Generator
from typing import Callable
from typing import Generator
from typing import Sequence
from typing import TYPE_CHECKING
from typing import Union
Expand Down
18 changes: 9 additions & 9 deletions optuna_distributed/managers/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Generator
import ctypes
from dataclasses import dataclass
from enum import IntEnum
Expand All @@ -7,9 +10,6 @@
from threading import Thread
import time
from typing import Callable
from typing import Dict
from typing import Generator
from typing import List
from typing import TYPE_CHECKING
import uuid

Expand Down Expand Up @@ -59,7 +59,7 @@ class _StateSynchronizer:
def __init__(self) -> None:
self._optimization_enabled = Variable()
self._optimization_enabled.set(True)
self._task_states: List[Variable] = []
self._task_states: list[Variable] = []

@property
def stop_flag(self) -> str:
Expand Down Expand Up @@ -111,8 +111,8 @@ def __init__(self, client: Client, n_trials: int, heartbeat_interval: int = 60)
recieving=self._public_channel,
timeout=heartbeat_interval,
)
self._private_channels: Dict[int, str] = {}
self._futures: List[Future] = []
self._private_channels: dict[int, str] = {}
self._futures: list[Future] = []

def _ensure_safe_exit(self, future: Future) -> None:
if future.status in ["error", "cancelled"]:
Expand All @@ -126,14 +126,14 @@ def _assign_private_channel(self, trial_id: int) -> "Queue":
self._private_channels[trial_id] = private_channel
return Queue(self._public_channel, private_channel, max_retries=5)

def _create_trials(self, study: Study) -> List[DistributedTrial]:
def _create_trials(self, study: Study) -> list[DistributedTrial]:
# HACK: It's kinda naughty to access _trial_id, but this is gonna make
# our lifes much easier in messaging system.
trial_ids = [study.ask()._trial_id for _ in range(self._n_trials)]
return [DistributedTrial(tid, self._assign_private_channel(tid)) for tid in trial_ids]

def _add_task_context(self, trials: List[DistributedTrial]) -> List[_TaskContext]:
trials_with_context: List[_TaskContext] = []
def _add_task_context(self, trials: list[DistributedTrial]) -> list[_TaskContext]:
trials_with_context: list[_TaskContext] = []
for trial in trials:
trials_with_context.append(
_TaskContext(
Expand Down
12 changes: 6 additions & 6 deletions optuna_distributed/managers/local.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from collections.abc import Generator
import multiprocessing
from multiprocessing import Pipe as MultiprocessingPipe
from multiprocessing import Process
from multiprocessing.connection import Connection
from multiprocessing.connection import wait
import sys
from typing import Dict
from typing import Generator
from typing import List
from typing import TYPE_CHECKING

from optuna import Study
Expand Down Expand Up @@ -49,8 +49,8 @@ def __init__(self, n_trials: int, n_jobs: int) -> None:

self._workers_to_spawn = min(self._n_jobs, n_trials)
self._trials_remaining = n_trials - self._workers_to_spawn
self._pool: Dict[int, Connection] = {}
self._processes: List[Process] = []
self._pool: dict[int, Connection] = {}
self._processes: list[Process] = []

def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None:
trial_ids = [study.ask()._trial_id for _ in range(self._workers_to_spawn)]
Expand All @@ -65,7 +65,7 @@ def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None:

def get_message(self) -> Generator[Message, None, None]:
while True:
messages: List[Message] = []
messages: list[Message] = []
for incoming in wait(self._pool.values(), timeout=10):
# FIXME: This assertion is true only for Unix systems.
# Some refactoring is needed to support Windows as well.
Expand Down
7 changes: 4 additions & 3 deletions optuna_distributed/messages/completed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from collections.abc import Sequence
import io
import logging
from typing import Sequence
from typing import TYPE_CHECKING
from typing import Union

from optuna.study import Study
from optuna.trial import FrozenTrial
Expand Down Expand Up @@ -35,7 +36,7 @@ class CompletedMessage(Message):

closing = True

def __init__(self, trial_id: int, value_or_values: Union[Sequence[float], float]) -> None:
def __init__(self, trial_id: int, value_or_values: Sequence[float] | float) -> None:
self._trial_id = trial_id
self._value_or_values = value_or_values

Expand Down
5 changes: 3 additions & 2 deletions optuna_distributed/messages/suggest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Union

from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(self, trial_id: int, name: str, distribution: BaseDistribution) ->

def process(self, study: Study, manager: "OptimizationManager") -> None:
trial = Trial(study, self._trial_id)
value: Union[float, int, CategoricalChoiceType]
value: float | int | CategoricalChoiceType
if isinstance(self._distribution, FloatDistribution):
value = trial.suggest_float(
name=self._name,
Expand Down
58 changes: 27 additions & 31 deletions optuna_distributed/study.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from __future__ import annotations

from collections.abc import Callable
from collections.abc import Container
from collections.abc import Iterable
from collections.abc import Sequence
import sys
from typing import Any
from typing import Callable
from typing import Container
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import TYPE_CHECKING
from typing import Tuple
from typing import Type
from typing import Union

from dask.distributed import Client
from dask.distributed import LocalCluster
Expand Down Expand Up @@ -64,12 +60,12 @@ class DistributedStudy:
process based parallelism.
"""

def __init__(self, study: Study, client: Optional[Client] = None) -> None:
def __init__(self, study: Study, client: Client | None = None) -> None:
self._study = study
self._client = client

@property
def best_params(self) -> Dict[str, Any]:
def best_params(self) -> dict[str, Any]:
"""Return parameters of the best trial in the study."""
return self._study.best_params

Expand All @@ -84,7 +80,7 @@ def best_trial(self) -> FrozenTrial:
return self._study.best_trial

@property
def best_trials(self) -> List[FrozenTrial]:
def best_trials(self) -> list[FrozenTrial]:
"""Return trials located at the Pareto front in the study."""
return self._study.best_trials

Expand All @@ -94,22 +90,22 @@ def direction(self) -> StudyDirection:
return self._study.direction

@property
def directions(self) -> List[StudyDirection]:
def directions(self) -> list[StudyDirection]:
"""Return the directions of the study."""
return self._study.directions

@property
def trials(self) -> List[FrozenTrial]:
def trials(self) -> list[FrozenTrial]:
"""Return all trials in the study."""
return self._study.trials

@property
def user_attrs(self) -> Dict[str, Any]:
def user_attrs(self) -> dict[str, Any]:
"""Return user attributes."""
return self._study.user_attrs

@property
def system_attrs(self) -> Dict[str, Any]:
def system_attrs(self) -> dict[str, Any]:
"""Return system attributes."""
return self._study.system_attrs

Expand All @@ -118,8 +114,8 @@ def into_study(self) -> Study:
return self._study

def get_trials(
self, deepcopy: bool = True, states: Optional[Container[TrialState]] = None
) -> List[FrozenTrial]:
self, deepcopy: bool = True, states: Container[TrialState] | None = None
) -> list[FrozenTrial]:
"""Return all trials in the study.
For complete documentation, please refer to:
Expand All @@ -136,11 +132,11 @@ def get_trials(
def optimize(
self,
func: ObjectiveFuncType,
n_trials: Optional[int] = None,
timeout: Optional[float] = None,
n_trials: int | None = None,
timeout: float | None = None,
n_jobs: int = -1,
catch: Union[Iterable[Type[Exception]], Type[Exception]] = (),
callbacks: Optional[List[Callable[["Study", FrozenTrial], None]]] = None,
catch: Iterable[type[Exception]] | type[Exception] = (),
callbacks: list[Callable[["Study", FrozenTrial], None]] | None = None,
show_progress_bar: bool = False,
*args: Any,
**kwargs: Any,
Expand Down Expand Up @@ -208,7 +204,7 @@ def optimize(
finally:
self._study._storage.remove_session()

def ask(self, fixed_distributions: Optional[Dict[str, BaseDistribution]] = None) -> Trial:
def ask(self, fixed_distributions: dict[str, BaseDistribution] | None = None) -> Trial:
"""Create a new trial from which hyperparameters can be suggested.
For complete documentation, please refer to:
Expand All @@ -222,9 +218,9 @@ def ask(self, fixed_distributions: Optional[Dict[str, BaseDistribution]] = None)

def tell(
self,
trial: Union[Trial, int],
values: Optional[Union[float, Sequence[float]]] = None,
state: Optional[TrialState] = None,
trial: Trial | int,
values: float | Sequence[float] | None = None,
state: TrialState | None = None,
skip_if_finished: bool = False,
) -> FrozenTrial:
"""Finish a trial created with :func:`~optuna_distributed.study.DistributedStudy.ask`.
Expand Down Expand Up @@ -272,7 +268,7 @@ def set_system_attr(self, key: str, value: Any) -> None:

def trials_dataframe(
self,
attrs: Tuple[str, ...] = (
attrs: tuple[str, ...] = (
"number",
"value",
"datetime_start",
Expand Down Expand Up @@ -309,8 +305,8 @@ def stop(self) -> None:

def enqueue_trial(
self,
params: Dict[str, Any],
user_attrs: Optional[Dict[str, Any]] = None,
params: dict[str, Any],
user_attrs: dict[str, Any] | None = None,
skip_if_exists: bool = False,
) -> None:
"""Enqueue a trial with given parameter values.
Expand Down Expand Up @@ -351,7 +347,7 @@ def add_trials(self, trials: Iterable[FrozenTrial]) -> None:
self._study.add_trials(trials)


def from_study(study: Study, client: Optional[Client] = None) -> DistributedStudy:
def from_study(study: Study, client: Client | None = None) -> DistributedStudy:
"""Takes regular Optuna study and extends it to :class:`~optuna_distributed.DistributedStudy`.
This creates an object which behaves like regular Optuna study, except trials
Expand Down
4 changes: 2 additions & 2 deletions optuna_distributed/terminal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from __future__ import annotations

from rich.progress import BarColumn
from rich.progress import Progress
Expand All @@ -22,7 +22,7 @@ class Terminal:
"""

def __init__(
self, show_progress_bar: bool, n_trials: int, timeout: Optional[float] = None
self, show_progress_bar: bool, n_trials: int, timeout: float | None = None
) -> None:
self._timeout = timeout
self._progbar = Progress(
Expand Down

0 comments on commit 8a6ea3e

Please sign in to comment.