Skip to content

Commit

Permalink
[air] New train.Checkpoint API: Update `Train + tune tests and exam…
Browse files Browse the repository at this point in the history
…ples` (ray-project#38770)

This PR converts the following tests to use the new FF:

tune_torch_regression_example
transformers_example_cpu
test_tune

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Kai Fricke <krfricke@users.noreply.github.com>
Co-authored-by: Kai Fricke <krfricke@users.noreply.github.com>
Signed-off-by: Victor <vctr.y.m@example.com>
  • Loading branch information
2 people authored and Victor committed Oct 11, 2023
1 parent 02388e4 commit a02a45b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 44 deletions.
2 changes: 1 addition & 1 deletion python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ py_test(
name = "test_tune",
size = "large",
srcs = ["tests/test_tune.py"],
tags = ["team:ml", "exclusive", "tune", "no_new_storage"],
tags = ["team:ml", "exclusive", "tune", "new_storage"],
deps = [":train_lib", ":conftest"]
)

Expand Down
79 changes: 36 additions & 43 deletions python/ray/train/tests/test_tune.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import logging

import pytest

import ray
from ray import train, tune
from ray.air.constants import TRAINING_ITERATION
from ray.train import Checkpoint, FailureConfig, RunConfig, ScalingConfig
from ray.train import FailureConfig, RunConfig, ScalingConfig
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend, BackendConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
Expand All @@ -20,7 +19,8 @@
from ray.train.torch.torch_trainer import TorchTrainer
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner
from ray.tune.impl.tuner_internal import _TUNER_PKL

from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -133,9 +133,8 @@ def test_tune_checkpoint(ray_start_4_cpus):
def train_func():
for i in range(9):
train.report(dict(test=i))
train.report(
dict(test=i + 1), checkpoint=Checkpoint.from_dict(dict(hello="world"))
)
with create_dict_checkpoint(dict(hello="world")) as checkpoint:
train.report(dict(test=i + 1), checkpoint=checkpoint)

trainer = DataParallelTrainer(
train_func,
Expand All @@ -147,26 +146,24 @@ def train_func():
param_space={"train_loop_config": {"max_iter": 5}},
)

[trial] = tuner.fit()._experiment_analysis.trials
checkpoint_path = trial.checkpoint.dir_or_data
assert os.path.exists(checkpoint_path)
checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict()
assert checkpoint["hello"] == "world"
result_grid = tuner.fit()
assert len(result_grid) == 1
result = result_grid[0]
assert result.checkpoint
assert load_dict_checkpoint(result.checkpoint)["hello"] == "world"


def test_reuse_checkpoint(ray_start_4_cpus):
def train_func(config):
itr = 0
ckpt = train.get_checkpoint()
if ckpt is not None:
ckpt = ckpt.to_dict()
ckpt = load_dict_checkpoint(ckpt)
itr = ckpt["iter"] + 1

for i in range(itr, config["max_iter"]):
train.report(
dict(test=i, training_iteration=i),
checkpoint=Checkpoint.from_dict(dict(iter=i)),
)
with create_dict_checkpoint(dict(iter=i)) as checkpoint:
train.report(dict(test=i, training_iteration=i), checkpoint=checkpoint)

trainer = DataParallelTrainer(
train_func,
Expand All @@ -177,18 +174,16 @@ def train_func(config):
trainer,
param_space={"train_loop_config": {"max_iter": 5}},
)
[trial] = tuner.fit()._experiment_analysis.trials
checkpoint_path = trial.checkpoint.dir_or_data
checkpoint = Checkpoint.from_directory(checkpoint_path).to_dict()
assert checkpoint["iter"] == 4
result_grid = tuner.fit()
assert len(result_grid) == 1
result = result_grid[0]
assert result.checkpoint
assert load_dict_checkpoint(result.checkpoint)["iter"] == 4

tuner = Tuner(
trainer,
param_space={"train_loop_config": {"max_iter": 10}},
).restore(trial.local_dir, trainable=trainer)
analysis = tuner.fit()._experiment_analysis
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 5
tuner = Tuner.restore(result_grid.experiment_path, trainable=trainer)
result_grid = tuner.fit()
assert len(result_grid) == 1
assert len(result_grid[0].metrics_dataframe) == 5


def test_retry_with_max_failures(ray_start_4_cpus):
Expand All @@ -199,16 +194,14 @@ def train_func():
restored = bool(ckpt) # Does a previous checkpoint exist?
itr = 0
if ckpt:
ckpt = ckpt.to_dict()
ckpt = load_dict_checkpoint(ckpt)
itr = ckpt["iter"] + 1

for i in range(itr, 4):
if i == 2 and not restored:
raise Exception("try to fail me")
train.report(
dict(test=i, training_iteration=i),
checkpoint=Checkpoint.from_dict(dict(iter=i)),
)
with create_dict_checkpoint(dict(iter=i)) as checkpoint:
train.report(dict(test=i, training_iteration=i), checkpoint=checkpoint)

trainer = DataParallelTrainer(
train_func,
Expand All @@ -220,7 +213,7 @@ def train_func():
)

result_grid = tuner.fit()
checkpoint = result_grid[0].checkpoint.to_dict()
checkpoint = load_dict_checkpoint(result_grid[0].checkpoint)
assert checkpoint["iter"] == 3
df = result_grid[0].metrics_dataframe
assert len(df[TRAINING_ITERATION]) == 4
Expand Down Expand Up @@ -275,14 +268,11 @@ def train_func(config):
@pytest.mark.parametrize("in_trainer", [True, False])
@pytest.mark.parametrize("in_tuner", [True, False])
def test_run_config_in_trainer_and_tuner(
propagate_logs, tmp_path, caplog, in_trainer, in_tuner
propagate_logs, tmp_path, monkeypatch, caplog, in_trainer, in_tuner
):
trainer_run_config = (
RunConfig(name="trainer", local_dir=str(tmp_path)) if in_trainer else None
)
tuner_run_config = (
RunConfig(name="tuner", local_dir=str(tmp_path)) if in_tuner else None
)
monkeypatch.setenv("RAY_AIR_LOCAL_CACHE_DIR", str(tmp_path))
trainer_run_config = RunConfig(name="trainer") if in_trainer else None
tuner_run_config = RunConfig(name="tuner") if in_tuner else None
trainer = DataParallelTrainer(
lambda config: None,
backend_config=TestConfig(),
Expand All @@ -296,13 +286,16 @@ def test_run_config_in_trainer_and_tuner(
"`RunConfig` was passed to both the `Tuner` and the `DataParallelTrainer`"
)
if in_trainer and in_tuner:
assert list((tmp_path / "tuner").glob(_TUNER_PKL))
assert (tmp_path / "tuner").exists()
assert not (tmp_path / "trainer").exists()
assert both_msg in caplog.text
elif in_trainer and not in_tuner:
assert list((tmp_path / "trainer").glob(_TUNER_PKL))
assert not (tmp_path / "tuner").exists()
assert (tmp_path / "trainer").exists()
assert both_msg not in caplog.text
elif not in_trainer and in_tuner:
assert list((tmp_path / "tuner").glob(_TUNER_PKL))
assert (tmp_path / "tuner").exists()
assert not (tmp_path / "trainer").exists()
assert both_msg not in caplog.text
else:
assert tuner._local_tuner.get_run_config() == RunConfig()
Expand Down

0 comments on commit a02a45b

Please sign in to comment.