Skip to content

Commit

Permalink
[train] Fix local storage path for windows (ray-project#39951)
Browse files Browse the repository at this point in the history
This PR fixes setting `storage_path` on windows by using `Path(...).as_posix()` to join paths, so that it's always of the format `C:/a/b/c/exp_dir/trial_dir` rather than a combination of `/` and `\`.

This PR also adds a minimal windows test to make sure it works and doesn't give this error anymore. Ray Tune/Train support for windows is not very comprehensive, since no train/tune CI is ported to windows CI.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Victor <vctr.y.m@example.com>
  • Loading branch information
justinvyu authored and Victor committed Oct 11, 2023
1 parent a44da37 commit 1d9a5f9
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 6 deletions.
1 change: 1 addition & 0 deletions ci/ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ test_python() {
args+=(
python/ray/serve/...
python/ray/tests/...
python/ray/train:test_windows
-python/ray/serve:test_cross_language # Ray java not built on Windows yet.
-python/ray/serve:test_gcs_failure # Fork not supported in windows
-python/ray/serve:test_standalone_2 # Multinode not supported on Windows
Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_windows",
size = "small",
srcs = ["tests/test_windows.py"],
tags = ["team:ml", "exclusive", "minimal"],
deps = [":train_lib"]
)

py_test(
name = "test_xgboost_predictor",
size = "small",
Expand Down
11 changes: 6 additions & 5 deletions python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def __init__(
self.storage_filesystem, self.storage_fs_path = get_fs_and_path(
self.storage_path, storage_filesystem
)
self.storage_fs_path = Path(self.storage_fs_path).as_posix()

# Syncing is always needed if a custom `storage_filesystem` is provided.
# Otherwise, syncing is only needed if storage_local_path
Expand Down Expand Up @@ -608,7 +609,7 @@ def experiment_fs_path(self) -> str:
by pyarrow.fs.FileSystem.from_uri already. The URI scheme information is
kept in `storage_filesystem` instead.
"""
return os.path.join(self.storage_fs_path, self.experiment_dir_name)
return Path(self.storage_fs_path, self.experiment_dir_name).as_posix()

@property
def experiment_local_path(self) -> str:
Expand All @@ -617,7 +618,7 @@ def experiment_local_path(self) -> str:
This local "cache" path refers to location where files are dumped before
syncing them to the `storage_path` on the `storage_filesystem`.
"""
return os.path.join(self.storage_local_path, self.experiment_dir_name)
return Path(self.storage_local_path, self.experiment_dir_name).as_posix()

@property
def trial_local_path(self) -> str:
Expand All @@ -629,7 +630,7 @@ def trial_local_path(self) -> str:
raise RuntimeError(
"Should not access `trial_local_path` without setting `trial_dir_name`"
)
return os.path.join(self.experiment_local_path, self.trial_dir_name)
return Path(self.experiment_local_path, self.trial_dir_name).as_posix()

@property
def trial_fs_path(self) -> str:
Expand All @@ -641,7 +642,7 @@ def trial_fs_path(self) -> str:
raise RuntimeError(
"Should not access `trial_fs_path` without setting `trial_dir_name`"
)
return os.path.join(self.experiment_fs_path, self.trial_dir_name)
return Path(self.experiment_fs_path, self.trial_dir_name).as_posix()

@property
def checkpoint_fs_path(self) -> str:
Expand All @@ -651,7 +652,7 @@ def checkpoint_fs_path(self) -> str:
The user of this class is responsible for setting the `current_checkpoint_index`
(e.g., incrementing when needed).
"""
return os.path.join(self.trial_fs_path, self.checkpoint_dir_name)
return Path(self.trial_fs_path, self.checkpoint_dir_name).as_posix()

@property
def checkpoint_dir_name(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion python/ray/train/tests/test_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def test_run(ray_start_4_cpus):
def train_func():
checkpoint = train.get_checkpoint()
checkpoint_dict = load_dict_checkpoint(checkpoint)
train.report(metrics=checkpoint_dict, checkpoint=checkpoint)
if train.get_context().get_world_rank() == 0:
train.report(metrics=checkpoint_dict, checkpoint=checkpoint)
else:
train.report(metrics=checkpoint_dict)
return checkpoint_dict[key]

with create_dict_checkpoint({key: value}) as checkpoint:
Expand Down
56 changes: 56 additions & 0 deletions python/ray/train/tests/test_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""This is a very minimal set of windows tests for Train/Tune."""

import os

import pytest

import ray
from ray import train, tune
from ray.train.data_parallel_trainer import DataParallelTrainer

from ray.train.tests.util import create_dict_checkpoint


@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


@pytest.fixture
def chdir_tmpdir(tmp_path):
original_path = os.getcwd()
os.chdir(tmp_path)
yield
os.chdir(original_path)


def test_storage_path(ray_start_4_cpus, chdir_tmpdir):
"""Tests that Train/Tune with a local storage path works on Windows."""

def train_fn(config):
for i in range(5):
if train.get_context().get_world_rank() == 0:
with create_dict_checkpoint({"dummy": "data"}) as checkpoint:
train.report({"loss": i}, checkpoint=checkpoint)
else:
train.report({"loss": i})

tuner = tune.Tuner(train_fn, run_config=train.RunConfig(storage_path=os.getcwd()))
results = tuner.fit()
assert not results.errors

trainer = DataParallelTrainer(
train_fn,
scaling_config=train.ScalingConfig(num_workers=2),
run_config=train.RunConfig(storage_path=os.getcwd()),
)
trainer.fit()


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", "-x", __file__]))

0 comments on commit 1d9a5f9

Please sign in to comment.