Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ cython_debug/

# MacOS
.DS_Store

# Optuna database files
*.db
17 changes: 16 additions & 1 deletion plugboard/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


try:
import optuna.storages
import ray.tune
import ray.tune.search
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -110,11 +111,25 @@ def _build_algorithm(
"mode": self._mode,
"metric": self._metric,
}

# Convert storage URI string to optuna storage object if needed
# TODO: Make this more general to support other algorithms, e.g. use a builder class
if "storage" in _algo_kwargs and isinstance(_algo_kwargs["storage"], str):
_algo_kwargs["storage"] = optuna.storages.RDBStorage(url=_algo_kwargs["storage"])
self._logger.info(
"Converted storage URI to Optuna RDBStorage object",
storage_uri=algorithm.storage,
)

algo_cls: _t.Optional[_t.Any] = locate(algorithm.type)
if not algo_cls or not issubclass(algo_cls, ray.tune.search.searcher.Searcher):
raise ValueError(f"Could not locate `Searcher` class {algorithm.type}")
self._logger.info(
"Using custom search algorithm", algorithm=algorithm.type, params=_algo_kwargs
"Using custom search algorithm",
algorithm=algorithm.type,
params={
k: v if k != "storage" else f"<{type(v).__name__}>" for k, v in _algo_kwargs.items()
},
)
return algo_cls(**_algo_kwargs)

Expand Down
46 changes: 45 additions & 1 deletion tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
"""Provides unit tests for the Tuner class."""

from tempfile import TemporaryDirectory
import typing as _t
from unittest.mock import MagicMock, patch

import msgspec
import pytest
import ray.tune

from plugboard.schemas import ConfigSpec, ObjectiveSpec
from plugboard.schemas.tune import CategoricalParameterSpec, FloatParameterSpec, IntParameterSpec
from plugboard.schemas.tune import (
CategoricalParameterSpec,
FloatParameterSpec,
IntParameterSpec,
OptunaSpec,
)
from plugboard.tune import Tuner
from tests.integration.test_process_with_components_run import A, B, C # noqa: F401

Expand All @@ -18,6 +26,13 @@ def config() -> dict:
return msgspec.yaml.decode(f.read())


@pytest.fixture
def temp_dir() -> _t.Iterator[str]:
"""Creates a temporary directory."""
with TemporaryDirectory() as tmpdir:
yield tmpdir


@patch("ray.tune.Tuner")
def test_tuner(mock_tuner_cls: MagicMock, config: dict) -> None:
"""Test the Tuner class."""
Expand Down Expand Up @@ -84,3 +99,32 @@ def test_tuner(mock_tuner_cls: MagicMock, config: dict) -> None:
assert tune_config.search_alg.searcher.__class__.__name__ == "OptunaSearch"
# Must call fit method on the Tuner object
mock_tuner.fit.assert_called_once()


def test_optuna_storage_uri_conversion(temp_dir: str) -> None:
"""Test that storage URI gets converted to Optuna storage object."""
# Create a tuner with minimal configuration
tuner = Tuner(
objective=ObjectiveSpec(
object_type="component", object_name="test", field_type="field", field_name="value"
),
parameters=[
FloatParameterSpec(
object_type="component",
object_name="test",
field_type="arg",
field_name="param",
lower=0.0,
upper=1.0,
)
],
num_samples=1,
mode="max",
algorithm=OptunaSpec(
type="ray.tune.search.optuna.OptunaSearch",
study_name="test-study",
storage=f"sqlite:///{temp_dir}/test_conversion.db",
),
)
algo = tuner._config.search_alg
assert isinstance(algo, ray.tune.search.optuna.OptunaSearch)
Loading