Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Remove checkpoint_dir and reporter deprecation notices #42698

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def train_func(config):
# stdout messages and the results directory.
train_func.__name__ = trainer_cls.__name__

trainable_cls = wrap_function(train_func, warn=False)
trainable_cls = wrap_function(train_func)
has_base_dataset = bool(self.datasets)
if has_base_dataset:
from ray.data.context import DataContext
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool =
logger.debug("Detected class for trainable.")
elif isinstance(trainable, FunctionType) or isinstance(trainable, partial):
logger.debug("Detected function for trainable.")
trainable = wrap_function(trainable, warn=warn)
trainable = wrap_function(trainable)
elif callable(trainable):
logger.info("Detected unknown callable for trainable. Converting to class.")
trainable = wrap_function(trainable, warn=warn)
trainable = wrap_function(trainable)

if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable", trainable)
Expand Down Expand Up @@ -246,7 +246,7 @@ def unregister(self, category, key):

def unregister_all(self, category: Optional[str] = None):
remaining = set()
for (cat, key) in self._registered:
for cat, key in self._registered:
if category and category == cat:
self.unregister(cat, key)
else:
Expand Down
77 changes: 3 additions & 74 deletions python/ray/tune/trainable/function_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
SHOULD_CHECKPOINT,
)
from ray.tune.trainable import Trainable
from ray.tune.utils import (
_detect_checkpoint_function,
_detect_config_single,
_detect_reporter,
)
from ray.tune.utils import _detect_config_single
from ray.util.annotations import DeveloperAPI


Expand All @@ -44,64 +40,6 @@
TEMP_MARKER = ".temp_marker"


_CHECKPOINT_DIR_ARG_DEPRECATION_MSG = """Accepting a `checkpoint_dir` argument in your training function is deprecated.
Please use `ray.train.get_checkpoint()` to access your checkpoint as a
`ray.train.Checkpoint` object instead. See below for an example:

Before
------

from ray import tune

def train_fn(config, checkpoint_dir=None):
if checkpoint_dir:
torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
...

tuner = tune.Tuner(train_fn)
tuner.fit()

After
-----

from ray import train, tune

def train_fn(config):
checkpoint: train.Checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
...

tuner = tune.Tuner(train_fn)
tuner.fit()""" # noqa: E501

_REPORTER_ARG_DEPRECATION_MSG = """Accepting a `reporter` in your training function is deprecated.
Please use `ray.train.report()` to report results instead. See below for an example:

Before
------

from ray import tune

def train_fn(config, reporter):
reporter(metric=1)

tuner = tune.Tuner(train_fn)
tuner.fit()

After
-----

from ray import train, tune

def train_fn(config):
train.report({"metric": 1})

tuner = tune.Tuner(train_fn)
tuner.fit()""" # noqa: E501


@DeveloperAPI
class FunctionTrainable(Trainable):
"""Trainable that runs a user function reporting results.
Expand Down Expand Up @@ -271,29 +209,20 @@ def _report_thread_runner_error(self, block=False):

@DeveloperAPI
def wrap_function(
train_func: Callable[[Any], Any], warn: bool = True, name: Optional[str] = None
train_func: Callable[[Any], Any], name: Optional[str] = None
) -> Type["FunctionTrainable"]:
inherit_from = (FunctionTrainable,)

if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + inherit_from

func_args = inspect.getfullargspec(train_func).args
use_checkpoint = _detect_checkpoint_function(train_func)
use_config_single = _detect_config_single(train_func)
use_reporter = _detect_reporter(train_func)

if use_checkpoint:
raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)

if use_reporter:
raise DeprecationWarning(_REPORTER_ARG_DEPRECATION_MSG)

if not use_config_single:
# use_reporter is hidden
raise ValueError(
"Unknown argument found in the Trainable function. "
"The function args must include a 'config' positional parameter."
"The function args must include a single 'config' positional parameter.\n"
"Found: {}".format(func_args)
)

Expand Down
14 changes: 0 additions & 14 deletions python/ray/tune/trainable/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from ray.air.config import ScalingConfig
from ray.tune.registry import _ParameterRegistry
from ray.tune.utils import _detect_checkpoint_function
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -124,12 +123,6 @@ def setup(self, config):
trainable_with_params = _Inner
else:
# Function trainable
if _detect_checkpoint_function(trainable, partial=True):
from ray.tune.trainable.function_trainable import (
_CHECKPOINT_DIR_ARG_DEPRECATION_MSG,
)

raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)

def inner(config):
fn_kwargs = {}
Expand Down Expand Up @@ -223,13 +216,6 @@ def train_fn(config):
if not inspect.isclass(trainable):
if isinstance(trainable, types.MethodType):
# Methods cannot set arbitrary attributes, so we have to wrap them
if _detect_checkpoint_function(trainable, partial=True):
from ray.tune.trainable.function_trainable import (
_CHECKPOINT_DIR_ARG_DEPRECATION_MSG,
)

raise DeprecationWarning(_CHECKPOINT_DIR_ARG_DEPRECATION_MSG)

def _trainable(config):
return trainable(config)

Expand Down
4 changes: 0 additions & 4 deletions python/ray/tune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
validate_save_restore,
warn_if_slow,
diagnose_serialization,
_detect_checkpoint_function,
_detect_reporter,
_detect_config_single,
wait_for_gpu,
)
Expand All @@ -24,8 +22,6 @@
"validate_save_restore",
"warn_if_slow",
"diagnose_serialization",
"_detect_checkpoint_function",
"_detect_reporter",
"_detect_config_single",
"wait_for_gpu",
]
34 changes: 0 additions & 34 deletions python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,40 +584,6 @@ def validate_save_restore(
return True


def _detect_checkpoint_function(train_func, abort=False, partial=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
func_sig = inspect.signature(train_func)
validated = True
try:
# check if signature is func(config, checkpoint_dir=None)
if partial:
func_sig.bind_partial({}, checkpoint_dir="tmp/path")
else:
func_sig.bind({}, checkpoint_dir="tmp/path")
except Exception as e:
logger.debug(str(e))
validated = False
if abort and not validated:
func_args = inspect.getfullargspec(train_func).args
raise ValueError(
"Provided training function must have 1 `config` argument "
"`func(config)`. Got {}".format(func_args)
)
return validated


def _detect_reporter(func):
"""Use reporter if any arg has "reporter" and args = 2"""
func_sig = inspect.signature(func)
use_reporter = True
try:
func_sig.bind({}, reporter=None)
except Exception as e:
logger.debug(str(e))
use_reporter = False
return use_reporter


def _detect_config_single(func):
"""Check if func({}) works."""
func_sig = inspect.signature(func)
Expand Down