-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
tune_controller.py
2491 lines (2073 loc) · 95 KB
/
tune_controller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import json
import time
import traceback
import uuid
import warnings
from collections import defaultdict, deque
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, Set
import logging
import os
import ray
from ray.air import Checkpoint, ResourceRequest
from ray.air._internal.uri_utils import URI
from ray.air.config import CheckpointConfig
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
from ray.air.constants import TIME_THIS_ITER_S
from ray.air.execution import ResourceManager, PlacementGroupResourceManager
from ray.air.execution._internal import RayActorManager, TrackedActor
from ray.train._internal.session import _FutureTrainingResult
from ray.train._internal.storage import StorageContext, _use_storage_context
from ray.train.constants import CHECKPOINT_DIR_NAME
from ray.exceptions import RayActorError, RayTaskError
from ray.tune.error import _AbortTrialExecution, _TuneStopTrialError, _TuneRestoreError
from ray.tune.execution.class_cache import _ActorClassCache
from ray.tune.execution.experiment_state import (
_ExperimentCheckpointManager,
_experiment_checkpoint_exists,
_find_newest_experiment_checkpoint,
)
from ray.tune.experiment.trial import (
_change_working_directory,
_noop_logger_creator,
_TrialInfo,
_Location,
_get_trainable_kwargs,
)
from ray.tune.experiment import Experiment
from ray.tune.execution.insufficient_resources_manager import (
_InsufficientResourcesManager,
)
from ray.tune.result import (
DEBUG_METRICS,
DEFAULT_METRIC,
DONE,
RESULT_DUPLICATE,
SHOULD_CHECKPOINT,
_get_defaults_results_dir,
DEFAULT_EXPERIMENT_NAME,
)
from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE
from ray.tune.trainable import TrainableUtil
from ray.tune import TuneError
from ray.tune.callback import Callback, CallbackList
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.stopper import NoopStopper, Stopper
from ray.tune.search import BasicVariantGenerator, SearchAlgorithm
from ray.train._internal.syncer import SyncConfig
from ray.tune.experiment import Trial
from ray.tune.utils.log import _dedup_logs
from ray.tune.utils.object_cache import _ObjectCache
from ray.tune.utils.resource_updater import _ResourceUpdater
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.utils.log import Verbosity, has_verbosity
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
from ray.tune.utils.util import _split_remote_local_path
from ray.tune.web_server import TuneServer
from ray.util.annotations import DeveloperAPI, Deprecated
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@DeveloperAPI
class TuneController:
CKPT_FILE_TMPL = "experiment_state-{}.json"
RAISE = "RAISE"
def __init__(
self,
*,
search_alg: Optional[SearchAlgorithm] = None,
placeholder_resolvers: Optional[Dict[Tuple, Any]] = None,
scheduler: Optional[TrialScheduler] = None,
experiment_path: Optional[str] = None,
experiment_dir_name: Optional[str] = None,
sync_config: Optional[SyncConfig] = None,
stopper: Optional[Stopper] = None,
resume: Union[str, bool] = False,
server_port: Optional[int] = None,
fail_fast: bool = False,
checkpoint_period: Union[str, int] = None,
callbacks: Optional[List[Callback]] = None,
metric: Optional[str] = None,
trial_checkpoint_config: Optional[CheckpointConfig] = None,
storage: Optional[StorageContext] = None,
reuse_actors: bool = False,
resource_manager_factory: Optional[Callable[[], ResourceManager]] = None,
_trainer_api: bool = False,
):
if resource_manager_factory:
resource_manager = resource_manager_factory()
else:
resource_manager = PlacementGroupResourceManager()
self._actor_manager = RayActorManager(resource_manager=resource_manager)
self._class_cache = _ActorClassCache()
# Resource status
self._resource_updater = _ResourceUpdater(None)
# Actor <-> Trial mappings
self._actor_to_trial: Dict[TrackedActor, Trial] = {}
self._trial_to_actor: Dict[Trial, TrackedActor] = {}
# Resources <-> Trial
self._resources_to_pending_trials: Dict[
ResourceRequest, Set[Trial]
] = defaultdict(set)
# Keep track of actor states
self._pending_trials: Set[Trial] = set()
self._pending_trials_list: List[Trial] = []
self._running_trials: Set[Trial] = set()
self._paused_trials: Set[Trial] = set()
self._stopped_trials: Set[Trial] = set()
self._failed_trials: Set[Trial] = set()
self._resetting_trials: Set[Trial] = set()
self._staged_trials: Set[Trial] = set()
# Removed actors
self._started_actors: Set[TrackedActor] = set()
# Map of tracked actors -> timestamp
# The timestamp is when we requested the stop.
# We track these actors here to force a
# cleanup after some time (as they might be hanging).
# Todo: This timeout logic should be moved into the actor manager.
# This map is populated whenever we request an actor stop:
# - Regular STOP decision
# - Removing an actor because its trial REUSEs a different trial's actor
# - Removing a cached actor because it's not needed anymore
# Actors are only tracked in this map if they actually started (not if they
# were only requested but never started).
# Actors are removed from this map:
# - When the STOP resolved and the actor actually stopped
# - When they are forcefully cleaned up after the timeout.
self._stopping_actors: Dict[TrackedActor, float] = {}
self._earliest_stopping_actor: float = float("inf")
self._actor_cleanup_timeout: int = int(
os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "600")
)
self._actor_force_cleanup_timeout: int = 10
# Reuse actors
self._reuse_actors = reuse_actors
self._actor_cache = _ObjectCache(may_keep_one=True)
# Trial metadata for experiment checkpoints
self._trials_to_cache: Set[Trial] = set()
self._trial_metadata: Dict[str, str] = {}
# TRAINING
self._buffer_length = int(os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1))
self._buffer_min_time_s = float(os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.0))
self._buffer_max_time_s = float(
os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0)
)
# Legacy TrialRunner init
self._search_alg = search_alg or BasicVariantGenerator()
self._placeholder_resolvers = placeholder_resolvers
self._scheduler_alg = scheduler or FIFOScheduler()
self._callbacks = CallbackList(callbacks or [])
self._insufficient_resources_manager = _InsufficientResourcesManager(
for_train=_trainer_api
)
self._pending_trial_queue_times = {}
self._max_pending_trials = _get_max_pending_trials(self._search_alg)
self._storage = storage
self._legacy_sync_config = sync_config or SyncConfig()
if _use_storage_context():
assert storage
self._legacy_experiment_dir_name = None
self._legacy_local_experiment_path = None
self._legacy_remote_experiment_path = None
self._legacy_sync_config = None
else:
# Rename for better code readability
local_experiment_path, remote_experiment_path = _split_remote_local_path(
experiment_path, None
)
# Derive experiment dir name from local path
if not experiment_dir_name and local_experiment_path:
# Maybe derive experiment dir name from local storage dir
experiment_dir_name = Path(local_experiment_path).name
elif not experiment_dir_name:
experiment_dir_name = DEFAULT_EXPERIMENT_NAME
# Set default experiment dir name
if not local_experiment_path:
local_experiment_path = str(
Path(_get_defaults_results_dir()) / experiment_dir_name
)
os.makedirs(local_experiment_path, exist_ok=True)
self._legacy_experiment_dir_name = experiment_dir_name
if self._legacy_sync_config.upload_dir and self._legacy_experiment_dir_name:
if remote_experiment_path:
if not remote_experiment_path.startswith(
self.sync_config.upload_dir
):
raise ValueError(
f"Both a `SyncConfig.upload_dir` and an `experiment_path` "
f"pointing to remote storage were passed, but they do not "
f"point to the same location. Got: "
f"`experiment_path={experiment_path}` and "
f"`SyncConfig.upload_dir={self.sync_config.upload_dir}`. "
)
warnings.warn(
"If `experiment_path` points to a remote storage location, "
"do not set `SyncConfig.upload_dir`. ",
DeprecationWarning,
)
else:
remote_experiment_path = str(
URI(self._legacy_sync_config.upload_dir)
/ self._legacy_experiment_dir_name
)
self._legacy_local_experiment_path = local_experiment_path
if self._legacy_local_experiment_path:
os.makedirs(self._legacy_local_experiment_path, exist_ok=True)
self._legacy_remote_experiment_path = remote_experiment_path
if (
self._legacy_local_experiment_path
and self._legacy_remote_experiment_path
and Path(self._legacy_local_experiment_path)
== Path(self._legacy_remote_experiment_path)
):
warnings.warn(
"The local experiment path is the same as the remote "
"experiment path. Set a different `storage_path` or raise an "
"issue on GitHub if this issue persists. Deactivating the"
"remote experiment path."
)
self._legacy_remote_experiment_path = None
self._metric = metric
self._total_time = 0
self._iteration = 0
self._has_errored = False
self._fail_fast = fail_fast
if isinstance(self._fail_fast, str):
self._fail_fast = self._fail_fast.upper()
if self._fail_fast == self.RAISE:
warnings.warn(
"fail_fast='raise' detected. Be careful when using this "
"mode as resources (such as Ray processes, "
"file descriptors, and temporary files) may not be "
"cleaned up properly. To use "
"a safer mode, use fail_fast=True."
)
else:
raise ValueError(
"fail_fast must be one of {bool, RAISE}. " f"Got {self._fail_fast}."
)
self._print_trial_errors = bool(
int(os.environ.get("TUNE_PRINT_ALL_TRIAL_ERRORS", "1"))
)
self._server = None
self._server_port = server_port
if server_port is not None:
self._server = TuneServer(self, self._server_port)
self._trials: List[Trial] = []
self._live_trials: Set[Trial] = set() # Set of non-terminated trials
self._cached_trial_decisions = {}
self._queued_trial_decisions = {}
self._stop_queue = []
self._should_stop_experiment = False # used by TuneServer
self._stopper = stopper or NoopStopper()
self._start_time = time.time()
self._last_checkpoint_time = -float("inf")
self._session_str = datetime.fromtimestamp(self._start_time).strftime(
"%Y-%m-%d_%H-%M-%S"
)
if checkpoint_period is None:
checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")
self._checkpoint_period = checkpoint_period
self._trial_checkpoint_config = trial_checkpoint_config or CheckpointConfig()
self._checkpoint_manager = self._create_checkpoint_manager()
self._resumed = False
resume_config = self._checkpoint_manager.resume(resume_type=resume)
if resume_config:
try:
self.resume(
resume_unfinished=resume_config.resume_unfinished,
resume_errored=resume_config.resume_errored,
restart_errored=resume_config.restart_errored,
)
self._resumed = True
except Exception as e:
if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
logger.error(str(e))
logger.exception("Runner restore failed.")
if self._fail_fast:
raise
logger.info("Restarting experiment.")
else:
logger.debug("Starting a new experiment.")
def _wrapped(self):
"""Return wrapped tune controller to be passed to scheduler/searchers."""
return TrialRunnerWrapper(
self,
trial_executor=_FakeRayTrialExecutor(self),
runner_whitelist_attr={
"search_alg",
"get_trials",
"get_live_trials",
"_set_trial_status",
"pause_trial",
"stop_trial",
"_schedule_trial_save",
},
executor_whitelist_attr={
"has_resources_for_trial",
"pause_trial",
"save",
"_resource_updater",
},
)
@property
def resumed(self):
return self._resumed
@property
def search_alg(self):
return self._search_alg
@property
def scheduler_alg(self):
return self._scheduler_alg
def setup_experiments(
self, experiments: List[Experiment], total_num_samples: int
) -> None:
"""Obtains any necessary information from experiments.
Mainly used to setup callbacks.
Args:
experiments: List of Experiments
to use.
total_num_samples: Total number of samples
factoring in grid search samplers.
"""
experiment = experiments[0]
spec = experiment.public_spec if experiment else {}
spec["total_num_samples"] = total_num_samples
self._callbacks.setup(**spec)
def end_experiment_callbacks(self) -> None:
"""Calls ``on_experiment_end`` method in callbacks."""
self._callbacks.on_experiment_end(trials=self._trials)
@Deprecated("Use `TrialRunner.experiment_state_path` instead.")
@property
def checkpoint_file(self) -> str:
return self.experiment_state_path
@property
def experiment_state_file_name(self) -> str:
return self.CKPT_FILE_TMPL.format(self._session_str)
@property
def experiment_state_path(self) -> str:
"""Returns the local experiment checkpoint path."""
if _use_storage_context():
return os.path.join(
self._storage.experiment_local_path, self.experiment_state_file_name
)
return os.path.join(
self._legacy_local_experiment_path, self.experiment_state_file_name
)
@property
def experiment_path(self) -> str:
if _use_storage_context():
return self._storage.experiment_fs_path
return self._legacy_remote_experiment_path or self._legacy_local_experiment_path
def _create_checkpoint_manager(self):
return _ExperimentCheckpointManager(
checkpoint_period=self._checkpoint_period,
sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep,
storage=self._storage,
# TODO(justinvyu): Remove these.
local_checkpoint_dir=self._legacy_local_experiment_path,
remote_checkpoint_dir=self._legacy_remote_experiment_path,
sync_config=self._legacy_sync_config,
)
@classmethod
def checkpoint_exists(cls, directory: str) -> bool:
if not os.path.exists(directory):
return False
return _experiment_checkpoint_exists(directory)
def save_to_dir(self, experiment_dir: Optional[str] = None):
"""Save TrialRunner state to experiment directory.
Accepts an ``experiment_dir`` argument which defaults to the
local checkpoint directory.
This method will save the trial runner state, the searcher state,
and the callback states into the experiment directory.
"""
if _use_storage_context():
assert not experiment_dir, "Remove the `experiment_dir` argument."
experiment_dir = self._storage.experiment_local_path
else:
experiment_dir = experiment_dir or self._legacy_local_experiment_path
# Get state from trial executor and runner
runner_state = {
# Trials
"trial_data": list(self._get_trial_checkpoints().values()),
# Experiment data
"runner_data": self.__getstate__(),
# Metadata
"stats": {
"start_time": self._start_time,
"timestamp": self._last_checkpoint_time,
},
}
tmp_file_name = os.path.join(
experiment_dir, f".tmp_experiment_state_{uuid.uuid4()}"
)
with open(tmp_file_name, "w") as f:
json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)
os.replace(
tmp_file_name,
os.path.join(experiment_dir, self.experiment_state_file_name),
)
self._search_alg.save_to_dir(experiment_dir, session_str=self._session_str)
self._callbacks.save_to_dir(experiment_dir, session_str=self._session_str)
def restore_from_dir(self, experiment_dir: Optional[str] = None) -> List[Trial]:
"""Restore TrialRunner state from experiment directory.
Accepts an ``experiment_dir`` argument which defaults to the
local checkpoint directory.
This method will restore the trial runner state, the searcher state,
and the callback states. It will then parse the trial states
and return them as a list of Trial objects.
"""
if _use_storage_context():
assert not experiment_dir, "Remove the `experiment_dir` argument."
experiment_dir = self._storage.experiment_local_path
else:
experiment_dir = experiment_dir or self._legacy_local_experiment_path
# Update local checkpoint dir
self._legacy_local_experiment_path = experiment_dir
# Find newest state file
newest_state_path = _find_newest_experiment_checkpoint(experiment_dir)
if not newest_state_path:
raise ValueError(
f"Tried to resume experiment from directory "
f"`{experiment_dir}`, but no "
f"experiment checkpoint data was found."
)
# Set checkpoint file to load
logger.warning(
f"Attempting to resume experiment from {experiment_dir}. "
"This will ignore any new changes to the specification."
)
logger.info(
"Using the newest experiment state file found within the "
f"experiment directory: {Path(newest_state_path).name}"
)
# Actually load data
with open(newest_state_path, "r") as f:
runner_state = json.load(f, cls=TuneFunctionDecoder)
# 1. Restore trial runner state
self.__setstate__(runner_state["runner_data"])
# 2. Restore search algorithm and callback state
if self._search_alg.has_checkpoint(experiment_dir):
self._search_alg.restore_from_dir(experiment_dir)
if self._callbacks.can_restore(experiment_dir):
self._callbacks.restore_from_dir(experiment_dir)
# 3. Load trials
trials = []
for trial_json_state, trial_runtime_metadata in runner_state["trial_data"]:
trial = Trial.from_json_state(trial_json_state)
trial.restore_run_metadata(trial_runtime_metadata)
# The following properties may be updated on restoration
# Ex: moved local/cloud experiment directory
if _use_storage_context():
# Propagate updated storage ctx properties to the trial's restored copy.
# TODO(justinvyu): [handle_moved_storage_path]
trial.storage.storage_path = self._storage.storage_path
trial.storage.experiment_dir_name = self._storage.experiment_dir_name
else:
# ATTN: Set `local_experiment_path` to update trial checkpoints!
trial.local_experiment_path = self._legacy_local_experiment_path
trial.remote_experiment_path = self._legacy_remote_experiment_path
trial.sync_config = self._legacy_sync_config
trial.experiment_dir_name = self._legacy_experiment_dir_name
# Avoid creating logdir in client mode for returned trial results,
# since the dir might not be creatable locally.
# TODO(ekl) this is kind of a hack.
if not ray.util.client.ray.is_connected():
trial.init_local_path() # Create logdir if it does not exist
trials.append(trial)
return trials
def checkpoint(self, force: bool = False, wait: bool = False):
"""Saves execution state to `self._legacy_local_experiment_path`.
Overwrites the current session checkpoint, which starts when self
is instantiated. Throttle depends on self._checkpoint_period.
Also automatically saves the search algorithm to the local
checkpoint dir.
Args:
force: Forces a checkpoint despite checkpoint_period.
wait: Wait until syncing to cloud has finished.
"""
with warn_if_slow(
"experiment_checkpoint",
message="Checkpointing the experiment state took "
"{duration:.3f} s, which may be a performance "
"bottleneck. Please ensure the "
"`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
"something significantly higher than this duration "
"to ensure compute time is mostly spent on the main "
"training loop.",
# No backlog warning if forced checkpoint as we wait
# for previous sync to finish.
disable=self._checkpoint_manager.auto_checkpoint_enabled or force or wait,
):
self._checkpoint_manager.checkpoint(
save_fn=self.save_to_dir, force=force, wait=wait
)
def resume(
self,
resume_unfinished: bool = True,
resume_errored: bool = False,
restart_errored: bool = False,
):
"""Resumes all checkpointed trials from previous run.
Requires user to manually re-register their objects. Also stops
all ongoing trials.
"""
trials = self.restore_from_dir()
# Set trial statuses according to the resume configuration
for trial in sorted(
trials, key=lambda t: t.run_metadata.last_result_time, reverse=True
):
trial_to_add = trial
if trial.status == Trial.ERROR:
if resume_errored:
# Keep trial ID on resume
trial_to_add.run_metadata.error_filename = None
trial_to_add.run_metadata.pickled_error_filename = None
trial_to_add.set_status(Trial.PENDING)
if not _use_storage_context():
# TODO(justinvyu): Remove this.
# Not needed since trial.checkpoint will be used anyways.
trial_to_add.restore_path = trial.checkpoint.dir_or_data
elif restart_errored:
trial_to_add = trial.reset()
trial_to_add.restore_path = None
elif trial.status != Trial.TERMINATED and not resume_unfinished:
trial_to_add.status = Trial.TERMINATED
self.add_trial(trial_to_add)
def update_max_pending_trials(self, max_pending_trials: Optional[int] = None):
self._max_pending_trials = max_pending_trials or _get_max_pending_trials(
self._search_alg
)
def update_pending_trial_resources(
self, resources: Union[dict, PlacementGroupFactory]
):
"""Update trial resources when resuming from checkpoint.
Only updating the pending ones.
"""
assert resources
if isinstance(resources, dict) and "gpu" not in resources:
resources["gpu"] = 0
for trial in self._trials:
if trial.status == Trial.PENDING:
trial.update_resources(resources=resources)
def is_finished(self):
"""Returns whether all trials have finished running."""
# The checks here are partly redundant but optimized for quick
# evaluation. Specifically, if there are live trials, we check
# these live trials first. Only if none of the live trials is
# live anymore do we loop over all trials for a final check.
trials_done = (
len(self._live_trials) == 0
or all(trial.is_finished() for trial in self._live_trials)
) and all(trial.is_finished() for trial in self._trials)
return trials_done and self._search_alg.is_finished()
def get_trial(self, tid):
trial = [t for t in self._trials if t.trial_id == tid]
return trial[0] if trial else None
def get_trials(self):
"""Returns the list of trials managed by this TrialRunner.
Note that the caller usually should not mutate trial state directly.
"""
return self._trials
def get_live_trials(self):
"""Returns the set of trials that are not in Trial.TERMINATED state."""
return self._live_trials
def add_trial(self, trial: Trial):
"""Adds a new trial to this TrialRunner.
Trials may be added at any time.
Args:
trial: Trial to queue.
"""
# If the config map has had all the references replaced with placeholders,
# resolve them before adding the trial.
if self._placeholder_resolvers:
trial.resolve_config_placeholders(self._placeholder_resolvers)
# With trial.config resolved, create placement group factory if needed.
trial.create_placement_group_factory()
self._trials.append(trial)
if trial.status != Trial.TERMINATED:
self._live_trials.add(trial)
with warn_if_slow("scheduler.on_trial_add"):
self._scheduler_alg.on_trial_add(self._wrapped(), trial)
self._mark_trial_to_checkpoint(trial)
logger.debug(f"Adding trial {trial} with status {trial.status}")
status_str_map = {
Trial.PENDING: self._pending_trials,
Trial.RUNNING: self._running_trials,
Trial.PAUSED: self._paused_trials,
Trial.TERMINATED: self._stopped_trials,
Trial.ERROR: self._failed_trials,
}
status_str_map[trial.status].add(trial)
if trial.status == Trial.PENDING:
self._pending_trials_list.append(trial)
self._resources_to_pending_trials[trial.placement_group_factory].add(trial)
def _update_trial_queue(self, blocking: bool = False, timeout: int = 600) -> bool:
"""Adds next trials to queue if possible.
Note that the timeout is currently unexposed to the user.
Args:
blocking: Blocks until either a trial is available
or is_finished (timeout or search algorithm finishes).
timeout: Seconds before blocking times out.
Returns:
Boolean indicating if a new trial was created or not.
"""
trial = self._search_alg.next_trial()
if blocking and not trial:
start = time.time()
# Checking `is_finished` instead of _search_alg.is_finished
# is fine because blocking only occurs if all trials are
# finished and search_algorithm is not yet finished
while (
not trial and not self.is_finished() and time.time() - start < timeout
):
logger.debug("Blocking for next trial...")
trial = self._search_alg.next_trial()
time.sleep(1)
if trial:
self.add_trial(trial)
return True
return False
def _used_resources_string(self) -> str:
allocated_resources = self._actor_manager.get_live_actors_resources()
return self._resource_updater.debug_string(allocated_resources)
def on_step_begin(self):
self._resource_updater.update_avail_resources()
def on_step_end(self):
self._cleanup_cached_actors(force_all=False)
self._cleanup_stopping_actors(force_all=False)
def _cleanup_cached_actors(self, force_all: bool = False):
if (
self._search_alg.is_finished()
and not self._staged_trials
and self._actor_cache.total_max_objects == 0
):
# If there are no more trials coming in, no trials are pending execution,
# and we don't explicitly want to cache objects, we can evict the full
# cache.
force_all = True
for tracked_actor in self._actor_cache.flush_cached_objects(
force_all=force_all
):
logger.debug(f"Cleaning up cached actor: {tracked_actor}")
# Unset termination callbacks as no trial is associated
tracked_actor.set_on_stop(None)
tracked_actor.set_on_error(None)
self._remove_actor(tracked_actor=tracked_actor)
def _cleanup_stopping_actors(self, force_all: bool = False):
now = time.monotonic()
if (
not force_all
and now - self._earliest_stopping_actor <= self._actor_cleanup_timeout
):
# If the earliest actor to timeout has not reached the timeout, return
return
# This is a bit costly, so we want to avoid running it too often
times = deque(
sorted(
[
(timestamp, tracked_actor)
for tracked_actor, timestamp in self._stopping_actors.items()
],
key=lambda item: item[0],
)
)
while times and (
force_all or time.monotonic() - times[0][0] > self._actor_cleanup_timeout
):
if (
time.monotonic() - times[0][0] < self._actor_force_cleanup_timeout
) and self._actor_manager.is_actor_started(tracked_actor=times[0][1]):
# Even if force_all=True, we give the actors time to clean up
self._actor_manager.next(timeout=1)
continue
_, tracked_actor = times.popleft()
if tracked_actor not in self._stopping_actors:
# Actor stopping has been handled by the block above
continue
if self._actor_manager.is_actor_started(tracked_actor=tracked_actor):
logger.debug(f"Forcefully killing actor: {tracked_actor}")
self._actor_manager.remove_actor(tracked_actor=tracked_actor, kill=True)
self._stopping_actors.pop(tracked_actor)
if times:
self._earliest_stopping_actor = times[0][0]
else:
self._earliest_stopping_actor = float("inf")
def step(self):
if self.is_finished():
raise TuneError("Called step when all trials finished?")
with warn_if_slow("on_step_begin"):
self.on_step_begin()
with warn_if_slow("callbacks.on_step_begin"):
self._callbacks.on_step_begin(
iteration=self._iteration, trials=self._trials
)
# Ask searcher for more trials
self._maybe_update_trial_queue()
# Start actors for added trials
self._maybe_add_actors()
# Handle one event
if not self._actor_manager.next(timeout=0.1):
# If there are no actors running, warn about potentially
# insufficient resources
if not self._actor_manager.num_live_actors:
self._insufficient_resources_manager.on_no_available_trials(
self.get_trials()
)
# Maybe stop whole experiment
self._stop_experiment_if_needed()
# Maybe save experiment state
try:
self.checkpoint()
except Exception as e:
logger.warning(f"Trial controller checkpointing failed: {str(e)}")
raise e
self._iteration += 1
if self._server:
with warn_if_slow("server"):
self._process_stop_requests()
if self.is_finished():
self._server.shutdown()
with warn_if_slow("on_step_end"):
self.on_step_end()
with warn_if_slow("callbacks.on_step_end"):
self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
def _set_trial_status(self, trial: Trial, status: str):
"""Set trial to a specific status.
This will keep track of trials with specific statuses in sets.
For PENDING and PAUSED trials we also keep a list of trials to be able
to retain FIFO ordering. See ``_maybe_add_actors`` for details.
Lastly we also keep a mapping from resources to pending/paused trials
to be able to efficiently start trials for cached actors.
"""
current_status = trial.status
if current_status == status:
logger.debug(f"Trial {trial} already has status {status}. Skipping update.")
return
status_str_map = {
Trial.PENDING: self._pending_trials,
Trial.RUNNING: self._running_trials,
Trial.PAUSED: self._paused_trials,
Trial.TERMINATED: self._stopped_trials,
Trial.ERROR: self._failed_trials,
}
logger.debug(
f"Setting status for trial {trial} from {current_status} to {status}"
)
assert trial in status_str_map[current_status], (trial, current_status)
assert trial not in status_str_map[status], (trial, status)
status_str_map[current_status].remove(trial)
status_str_map[status].add(trial)
# We keep a log for pending trials for FIFO scheduling.
# We do not need to remove from this list as we will just discard
# items that are in this list but not in the respective set.
if status == Trial.PENDING:
self._pending_trials_list.append(trial)
self._resources_to_pending_trials[trial.placement_group_factory].add(trial)
else:
self._resources_to_pending_trials[trial.placement_group_factory].discard(
trial
)
trial.set_status(status)
def _get_trial_checkpoints(self) -> Dict[str, str]:
for trial in self._trials_to_cache:
self._trial_metadata[trial.trial_id] = trial.get_json_state()
self._trials_to_cache.clear()
return self._trial_metadata
def _mark_trial_to_checkpoint(self, trial: Trial):
self._trials_to_cache.add(trial)
###
# UPDATE TRIALS
def _maybe_update_trial_queue(self):
"""Ask the searcher for more trials."""
if self._search_alg.is_finished():
return
dont_wait_for_trial = (
self._pending_trials or self._running_trials or self._paused_trials
)
while len(self._pending_trials) < self._max_pending_trials:
if not self._update_trial_queue(blocking=not dont_wait_for_trial):
break
dont_wait_for_trial = True
def _cleanup_trials(self):
logger.debug("CLEANING UP all trials")
for tracked_actor in list(self._actor_to_trial):
trial = self._actor_to_trial[tracked_actor]
logger.debug(
f"Scheduling trial stop at end of experiment (trial {trial}): "
f"{tracked_actor}"
)
self._schedule_trial_stop(trial)
# Clean up cached actors now
self._cleanup_cached_actors(force_all=True)
start = time.monotonic()
while time.monotonic() - start < 5 and self._actor_manager.num_total_actors:
if _dedup_logs("actor_manager_cleanup", str(start)):
logger.debug(
"Waiting for actor manager to clean up final state [dedup]"
)
self._actor_manager.next(timeout=1)
logger.debug("Force cleanup of remaining actors")
self._cleanup_stopping_actors(force_all=True)
self._actor_manager.cleanup()
def _remove_actor(self, tracked_actor: TrackedActor):
stop_future = self._actor_manager.schedule_actor_task(
tracked_actor, "stop", _return_future=True
)
now = time.monotonic()
if self._actor_manager.remove_actor(
tracked_actor, kill=False, stop_future=stop_future
):
# If the actor was previously alive, track
self._stopping_actors[tracked_actor] = now
self._earliest_stopping_actor = min(self._earliest_stopping_actor, now)
###
# ADD ACTORS
def _maybe_add_actors(self) -> None:
"""Add actors for pending and paused trials.
For actors that have not been staged, yet, we request an actor.