Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Metrics do-over 03: Switch over Learner to new MetricsLogger API. #44995

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
6f1b505
wip
sven1977 Mar 13, 2024
b6e2714
wip
sven1977 Mar 13, 2024
4dfb2ce
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 2, 2024
e6402c6
wip
sven1977 Apr 3, 2024
33487cc
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 4, 2024
a02abbd
doctest fix
sven1977 Apr 4, 2024
d9f3e6e
wip
sven1977 Apr 4, 2024
e909a73
wip
sven1977 Apr 4, 2024
bdaa04c
wip
sven1977 Apr 6, 2024
52d9e12
wip
sven1977 Apr 6, 2024
f77ffdb
wip
sven1977 Apr 8, 2024
81c4c79
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 Apr 8, 2024
1672675
wip
sven1977 Apr 8, 2024
adf9e8c
wip
sven1977 Apr 8, 2024
c9e5c2f
wip
sven1977 Apr 8, 2024
683bc4b
wip
sven1977 Apr 8, 2024
5bd220f
wip
sven1977 Apr 8, 2024
d931945
LINT
sven1977 Apr 8, 2024
e9888de
wip
sven1977 Apr 8, 2024
0e97d8f
fixes
sven1977 Apr 8, 2024
7584cce
fixes
sven1977 Apr 8, 2024
5ba69af
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 8, 2024
36bfa57
wip
sven1977 Apr 8, 2024
5565d4f
fixes
sven1977 Apr 8, 2024
27c793d
fixes
sven1977 Apr 8, 2024
d75b31a
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 Apr 8, 2024
bf9cef0
fixes
sven1977 Apr 9, 2024
872b49b
Apply suggestions from code review
sven1977 Apr 9, 2024
743dabd
fixes
sven1977 Apr 9, 2024
d50a39c
Merge remote-tracking branch 'origin/cleanup_examples_folder_03' into…
sven1977 Apr 9, 2024
d1a18c6
fix
sven1977 Apr 9, 2024
a092ea3
Merge branch 'master' of https://github.com/ray-project/ray into clea…
sven1977 Apr 9, 2024
2780a3b
Merge branch 'cleanup_examples_folder_03' into complete_metrics_and_s…
sven1977 Apr 9, 2024
97cb5c3
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 9, 2024
96daa91
fix
sven1977 Apr 9, 2024
56f6bd6
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 15, 2024
e7105c3
wip
sven1977 Apr 15, 2024
a782d24
wip
sven1977 Apr 15, 2024
dcfbaa0
wip
sven1977 Apr 16, 2024
f014844
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 16, 2024
e8957f5
wip
sven1977 Apr 16, 2024
83ff35b
wip
sven1977 Apr 17, 2024
f89a796
wip
sven1977 Apr 17, 2024
429bfab
wip
sven1977 Apr 17, 2024
f6aad0c
wip
sven1977 Apr 17, 2024
96fefb7
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 17, 2024
72b6cd3
WandB logging of videos working!
sven1977 Apr 17, 2024
8133447
wip
sven1977 Apr 18, 2024
2d5b0dd
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 19, 2024
5f0b776
Merge branch 'master' of https://github.com/ray-project/ray into comp…
sven1977 Apr 20, 2024
5d746db
wip
sven1977 Apr 20, 2024
4d910a5
wip
sven1977 Apr 20, 2024
856cb18
Merge branch 'master' into metrics_do_over_02_algo_and_ppo_training_step
sven1977 Apr 20, 2024
10bc567
wip
sven1977 Apr 20, 2024
ae4d83d
Merge remote-tracking branch 'origin/metrics_do_over_02_algo_and_ppo_…
sven1977 Apr 20, 2024
b446c67
test_ppo_w_envrunner passing
sven1977 Apr 21, 2024
7274850
wip
sven1977 Apr 21, 2024
a7871fe
Merge branches 'master' and 'master' of https://github.com/ray-projec…
sven1977 Apr 21, 2024
b93ec75
wip
sven1977 Apr 21, 2024
f47ff61
wip
sven1977 Apr 22, 2024
da91a28
wip
sven1977 Apr 22, 2024
ef8ff4d
wip
sven1977 Apr 22, 2024
49d9cde
wip
sven1977 Apr 22, 2024
7c8647e
fixes
sven1977 Apr 22, 2024
3fe1f99
LINT
sven1977 Apr 22, 2024
709c281
wip
sven1977 Apr 23, 2024
cc8737c
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 23, 2024
0fd8ab9
wip
sven1977 Apr 23, 2024
cd70956
wip
sven1977 Apr 23, 2024
afb01bc
wip
sven1977 Apr 23, 2024
21cd445
wip
sven1977 Apr 23, 2024
a6939f6
wip
sven1977 Apr 23, 2024
7f23206
wip
sven1977 Apr 23, 2024
12df29a
wip
sven1977 Apr 23, 2024
6a01602
wip
sven1977 Apr 23, 2024
481043f
wip
sven1977 Apr 23, 2024
772a7c0
wip
sven1977 Apr 23, 2024
fad02f4
SAC Learning Pendulum (careful about the prio replay buffer setting, …
sven1977 Apr 23, 2024
1713ac6
SAC Learning Pendulum (after sharing its training_step code with DQN,…
sven1977 Apr 23, 2024
16283d7
fixes
sven1977 Apr 23, 2024
6807c1c
fixes
sven1977 Apr 23, 2024
6a07183
fixes
sven1977 Apr 24, 2024
4ca0e44
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 24, 2024
7d6f098
fixes
sven1977 Apr 24, 2024
76fe68b
fixes
sven1977 Apr 24, 2024
9b96ce7
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 24, 2024
81b600e
Merge branch 'metrics_do_over_02_algo_and_ppo_training_step' into met…
sven1977 Apr 24, 2024
9df4c3b
wip
sven1977 Apr 24, 2024
6a4dcb1
wip
sven1977 Apr 24, 2024
812479f
wip
sven1977 Apr 25, 2024
b91f819
wip
sven1977 Apr 25, 2024
36f47c0
merge
sven1977 Apr 26, 2024
38fd4bc
LINT
sven1977 Apr 26, 2024
a2bdd3a
wip
sven1977 Apr 26, 2024
820fe92
wip
sven1977 Apr 26, 2024
1c9da35
wip
sven1977 Apr 27, 2024
52a93de
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 27, 2024
c14cb25
wip
sven1977 Apr 27, 2024
fd5f2cf
wip
sven1977 Apr 27, 2024
7a925b0
learns SAC pendulum in 5000ts ~60sec upto better than -300
sven1977 Apr 28, 2024
5ae4218
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 28, 2024
a152d62
still learns SAC pendulum in 5000ts ~60sec upto better than -300
sven1977 Apr 28, 2024
637a398
still learns SAC pendulum in 5000ts ~60sec upto better than -300
sven1977 Apr 28, 2024
6d7c013
still learns SAC pendulum in 7000ts ~60sec upto better than -200
sven1977 Apr 28, 2024
0485ee7
LINT
sven1977 Apr 28, 2024
e96efdc
fixes
sven1977 Apr 28, 2024
6c5627d
doctest fix
sven1977 Apr 28, 2024
dea6a34
wip
sven1977 Apr 28, 2024
f7ae9ab
wip
sven1977 Apr 29, 2024
e3ff44f
DQN and SAC learn again
sven1977 Apr 29, 2024
1c8cbea
doctest fix
sven1977 Apr 29, 2024
e21b5f2
fix
sven1977 Apr 29, 2024
ea38d03
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 29, 2024
b4e5a79
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 29, 2024
6089c10
wip
sven1977 Apr 29, 2024
de50b67
fix
sven1977 Apr 30, 2024
6ba560d
fixes most test cases
sven1977 Apr 30, 2024
245f9bd
merge
sven1977 Apr 30, 2024
ac7470c
LINT
sven1977 Apr 30, 2024
902818d
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 30, 2024
0f19aa1
wip
sven1977 Apr 30, 2024
3a677c5
fixes
sven1977 Apr 30, 2024
2fae0ca
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 Apr 30, 2024
720340c
Merge branch 'master' of https://github.com/ray-project/ray into metr…
sven1977 May 1, 2024
9242c24
fixes
sven1977 May 1, 2024
4fa2edb
fixes
sven1977 May 1, 2024
3110a56
fixed multi-gpu test cases.
sven1977 May 1, 2024
9ae2204
fixes
sven1977 May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 0 additions & 12 deletions doc/source/rllib/package_ref/learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,3 @@ Adding and Removing Modules

Learner.add_module
Learner.remove_module

Managing Results
----------------

.. autosummary::
:nosignatures:
:toctree: doc/

Learner.compile_results
Learner.register_metric
Learner.register_metrics
Learner._check_result
40 changes: 5 additions & 35 deletions doc/source/rllib/rllib-learner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,10 @@ Updates
results = learner_group.update_from_batch(
batch=DUMMY_BATCH, async_update=True
)
# `results` is a list of results dict. The items in the list represent the different
# remote results from the different calls to
# `update_from_batch(..., async_update=True)`.
assert len(results) > 0
# Each item is a results dict, already reduced over the n Learner workers.
assert isinstance(results[0], dict), results[0]
# `results` is an already reduced dict, which is the result of
# reducing over the individual async `update_from_batch(..., async_update=True)`
# calls.
assert isinstance(results, dict), results

# This is an additional non-gradient based update.
learner_group.additional_update(**ADDITIONAL_UPDATE_KWARGS)
Expand Down Expand Up @@ -373,9 +371,7 @@ Implementation
- calculate the loss for gradient based update to a module.
* - :py:meth:`~ray.rllib.core.learner.learner.Learner.additional_update_for_module()`
- do any non gradient based updates to a RLModule, e.g. target network updates.
* - :py:meth:`~ray.rllib.core.learner.learner.Learner.compile_results()`
- compute training statistics and format them for downstream use.


Starter Example
---------------

Expand Down Expand Up @@ -417,30 +413,4 @@ A :py:class:`~ray.rllib.core.learner.learner.Learner` that implements behavior c

return loss

@override(Learner)
def compile_results(
self,
*,
batch: MultiAgentBatch,
fwd_out: Dict[str, Any],
loss_per_module: Dict[str, TensorType],
metrics_per_module: DefaultDict[ModuleID, Dict[str, Any]],
) -> Dict[str, Any]:

results = super().compile_results(
batch=batch,
fwd_out=fwd_out,
loss_per_module=loss_per_module,
metrics_per_module=metrics_per_module,
)
# report the mean weight of each
mean_ws = {}
for module_id in self.module.keys():
m = self.module[module_id]
parameters = convert_to_numpy(self.get_parameters(m))
mean_ws[module_id] = np.mean([w.mean() for w in parameters])
results[module_id]["mean_weight"] = mean_ws[module_id]

return results


21 changes: 12 additions & 9 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from ray.rllib.utils.metrics import (
ALL_MODULES,
ENV_RUNNER_RESULTS,
ENV_RUNNER_SAMPLING_TIMER,
EVALUATION_ITERATION_TIMER,
EVALUATION_RESULTS,
FAULT_TOLERANCE_STATS,
Expand All @@ -117,7 +118,6 @@
TIMERS,
TRAINING_ITERATION_TIMER,
TRAINING_STEP_TIMER,
SAMPLE_TIMER,
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
Expand Down Expand Up @@ -1608,9 +1608,9 @@ def training_step(self) -> ResultDict:
)

# Collect SampleBatches from sample workers until we have a full batch.
with self.metrics.log_time((TIMERS, SAMPLE_TIMER)):
with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
if self.config.count_steps_by == "agent_steps":
train_batch, env_runner_metrics = synchronous_parallel_sample(
train_batch, env_runner_results = synchronous_parallel_sample(
worker_set=self.workers,
max_agent_steps=self.config.train_batch_size,
sample_timeout_s=self.config.sample_timeout_s,
Expand All @@ -1620,7 +1620,7 @@ def training_step(self) -> ResultDict:
_return_metrics=True,
)
else:
train_batch, env_runner_metrics = synchronous_parallel_sample(
train_batch, env_runner_results = synchronous_parallel_sample(
worker_set=self.workers,
max_env_steps=self.config.train_batch_size,
sample_timeout_s=self.config.sample_timeout_s,
Expand All @@ -1632,7 +1632,7 @@ def training_step(self) -> ResultDict:
train_batch = train_batch.as_multi_agent()

# Reduce EnvRunner metrics over the n EnvRunners.
self.metrics.log_n_dicts(env_runner_metrics, key=ENV_RUNNER_RESULTS)
self.metrics.log_n_dicts(env_runner_results, key=ENV_RUNNER_RESULTS)

# Only train if train_batch is not empty.
# In an extreme situation, all rollout workers die during the
Expand Down Expand Up @@ -3115,13 +3115,16 @@ def _create_local_replay_buffer_if_necessary(
return from_config(ReplayBuffer, config["replay_buffer_config"])

def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
"""Runs one training iteration (self.iteration will be +1 after this).
"""Runs one training iteration (`self.iteration` will be +1 after this).

Calls `self.training_step()` repeatedly until the minimum time (sec),
sample- or training steps have been reached.
Calls `self.training_step()` repeatedly until the configured minimum time (sec),
minimum sample- or minimum training steps have been reached.

Returns:
The results dict from the training iteration.
The ResultDict from the last call to `training_step()`. Note that even
though we only return the last ResultDict, the user stil has full control
over the history and reduce behavior of individual metrics at the time these
metrics are logged with `self.metrics.log_...()`.
"""
with self._timers[TRAINING_ITERATION_TIMER]:
# In case we are training (in a thread) parallel to evaluation,
Expand Down
32 changes: 16 additions & 16 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
from typing import Any, Dict

from ray.rllib.algorithms.appo.appo import APPOConfig
from ray.rllib.algorithms.impala.impala_learner import ImpalaLearner
Expand Down Expand Up @@ -46,7 +45,7 @@ def additional_update_for_module(
last_update: int,
mean_kl_loss_per_module: dict,
**kwargs,
) -> Dict[str, Any]:
) -> None:
"""Updates the target networks and KL loss coefficients (per module).

Args:
Expand All @@ -63,27 +62,28 @@ def additional_update_for_module(
# updates.
# We should instead have the target / kl threshold update be based off
# of the train_batch_size * some target update frequency * num_sgd_iter.
results = super().additional_update_for_module(
super().additional_update_for_module(
module_id=module_id, config=config, timestep=timestep
)

if (timestep - last_update) >= config.target_update_frequency:
# TODO (Sven): DQN uses `config.target_network_update_freq`. Can we
# choose a standard here?
last_update_ts_key = (module_id, LAST_TARGET_UPDATE_TS)
if (
timestep - self.metrics.peek(last_update_ts_key, default=0)
>= config.target_update_frequency
):
self._update_module_target_networks(module_id, config)
results[NUM_TARGET_UPDATES] = 1
results[LAST_TARGET_UPDATE_TS] = timestep
else:
results[NUM_TARGET_UPDATES] = 0
results[LAST_TARGET_UPDATE_TS] = last_update
# Increase lifetime target network update counter by one.
self.metrics.log_value((module_id, NUM_TARGET_UPDATES), 1, reduce="sum")
# Update the (single-value -> window=1) last updated timestep metric.
self.metrics.log_value(last_update_ts_key, timestep, window=1)

if config.use_kl_loss and module_id in mean_kl_loss_per_module:
results.update(
self._update_module_kl_coeff(
module_id, config, mean_kl_loss_per_module[module_id]
)
self._update_module_kl_coeff(
module_id, config, mean_kl_loss_per_module[module_id]
)

return results

@abc.abstractmethod
def _update_module_target_networks(
self, module_id: ModuleID, config: APPOConfig
Expand All @@ -100,7 +100,7 @@ def _update_module_target_networks(
@abc.abstractmethod
def _update_module_kl_coeff(
self, module_id: ModuleID, config: APPOConfig, sampled_kl: float
) -> Dict[str, Any]:
) -> None:
"""Dynamically update the KL loss coefficients of each module with.

The update is completed using the mean KL divergence between the action
Expand Down
20 changes: 14 additions & 6 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Dict

from ray.rllib.algorithms.appo.appo import (
APPOConfig,
Expand Down Expand Up @@ -159,9 +159,11 @@ def compute_loss_for_module(
+ (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
)

# Register important loss stats.
self.register_metrics(
module_id,
# Register all important loss stats.
# Note that our MetricsLogger (self.metrics) is currently in tensor-mode,
# meaning that it allows us to even log in-graph/compiled tensors through
# its `log_...()` APIs.
self.metrics.log_dict(
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. This is now much cleaner. Learner only learns and metrics logger logs.

POLICY_LOSS_KEY: mean_pi_loss,
VF_LOSS_KEY: mean_vf_loss,
Expand All @@ -171,6 +173,8 @@ def compute_loss_for_module(
self.curr_kl_coeffs_per_module[module_id]
),
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
)
# Return the total loss.
return total_loss
Expand All @@ -194,7 +198,7 @@ def _update_module_target_networks(
@override(AppoLearner)
def _update_module_kl_coeff(
self, module_id: ModuleID, config: APPOConfig, sampled_kl: float
) -> Dict[str, Any]:
) -> None:
# Update the current KL value based on the recently measured value.
# Increase.
kl_coeff_var = self.curr_kl_coeffs_per_module[module_id]
Expand All @@ -206,4 +210,8 @@ def _update_module_kl_coeff(
elif sampled_kl < 0.5 * config.kl_target:
kl_coeff_var.assign(kl_coeff_var * 0.5)

return {LEARNER_RESULTS_CURR_KL_COEFF_KEY: kl_coeff_var.numpy()}
self.metrics.log_value(
(module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY),
kl_coeff_var.numpy(),
window=1,
)
17 changes: 11 additions & 6 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Dict

from ray.rllib.algorithms.appo.appo import (
APPOConfig,
Expand Down Expand Up @@ -163,9 +163,8 @@ def compute_loss_for_module(
+ (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
)

# Register important loss stats.
self.register_metrics(
module_id,
# Log important loss stats.
self.metrics.log_dict(
{
POLICY_LOSS_KEY: mean_pi_loss,
VF_LOSS_KEY: mean_vf_loss,
Expand All @@ -175,6 +174,8 @@ def compute_loss_for_module(
self.curr_kl_coeffs_per_module[module_id]
),
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
)
# Return the total loss.
return total_loss
Expand Down Expand Up @@ -231,7 +232,7 @@ def _update_module_target_networks(
@override(AppoLearner)
def _update_module_kl_coeff(
self, module_id: ModuleID, config: APPOConfig, sampled_kl: float
) -> Dict[str, Any]:
) -> None:
# Update the current KL value based on the recently measured value.
# Increase.
kl_coeff_var = self.curr_kl_coeffs_per_module[module_id]
Expand All @@ -243,4 +244,8 @@ def _update_module_kl_coeff(
elif sampled_kl < 0.5 * config.kl_target:
kl_coeff_var.data *= 0.5

return {LEARNER_RESULTS_CURR_KL_COEFF_KEY: kl_coeff_var.item()}
self.metrics.log_value(
(module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY),
kl_coeff_var.item(),
window=1,
)
6 changes: 4 additions & 2 deletions rllib/algorithms/bc/tf/bc_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def possibly_masked_mean(t):

policy_loss = -possibly_masked_mean(log_probs)

self.register_metrics(
module_id,
# Log important loss stats.
self.metrics.log_dict(
{
POLICY_LOSS_KEY: policy_loss,
},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
)

# Return total loss which is for BC simply the policy loss.
Expand Down
7 changes: 6 additions & 1 deletion rllib/algorithms/bc/torch/bc_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def possibly_masked_mean(t):

policy_loss = -possibly_masked_mean(log_probs)

self.register_metrics(module_id, {POLICY_LOSS_KEY: policy_loss})
# Log important loss stats.
self.metrics.log_dict(
{POLICY_LOSS_KEY: policy_loss},
key=module_id,
window=1, # <- single items (should not be mean/ema-reduced over time).
)

# Return the total loss which is for BC simply the policy loss.
return policy_loss