Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Trainer.add_policy() not working for tf, if added policy is trained afterwards. #16927

Merged
merged 2 commits into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions rllib/agents/tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
from random import choice
import unittest

import ray
Expand Down Expand Up @@ -38,15 +39,15 @@ def test_add_delete_policy(self):
},
})

# TODO: (sven): Fix TrainTFMultiGPU to be flexible wrt adding policies
# on-the-fly.
for _ in framework_iterator(config, frameworks=("tf2", "torch")):
for _ in framework_iterator(config):
trainer = pg.PGTrainer(config=config)
# Given evaluation_interval=2, r0, r2, r4 should not contain
# evaluation metrics, while r1, r3 should.
r0 = trainer.train()
self.assertTrue("p0" in r0["policy_reward_min"])
r = trainer.train()
self.assertTrue("p0" in r["policy_reward_min"])
for i in range(1, 4):

def new_mapping_fn(agent_id, episode, **kwargs):
return f"p{choice([i, i - 1])}"

# Add a new policy.
new_pol = trainer.add_policy(
f"p{i}",
Expand All @@ -55,9 +56,9 @@ def test_add_delete_policy(self):
action_space=env.action_space,
config={},
# Test changing the mapping fn.
policy_mapping_fn=lambda aid, eps, **kwargs: f"p{i}",
policy_mapping_fn=new_mapping_fn,
# Change the list of policies to train.
policies_to_train=[f"p{i}"],
policies_to_train=[f"p{i}", f"p{i-1}"],
)
pol_map = trainer.workers.local_worker().policy_map
self.assertTrue(new_pol is not trainer.get_policy("p0"))
Expand Down
10 changes: 8 additions & 2 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,14 @@ def add_policy(
policy_dict = {
policy_id: (policy_cls, observation_space, action_space, config)
}
add_map, add_prep = self._build_policy_map(policy_dict,
self.policy_config)
if self.tf_sess is not None:
with self.tf_sess.graph.as_default():
with self.tf_sess.as_default():
add_map, add_prep = self._build_policy_map(
policy_dict, self.policy_config)
else:
add_map, add_prep = self._build_policy_map(policy_dict,
self.policy_config)
new_policy = add_map[policy_id]

self.policy_map.update(add_map)
Expand Down
34 changes: 20 additions & 14 deletions rllib/execution/train_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,7 @@ def __init__(self,
with self.workers.local_worker().tf_sess.as_default():
for policy_id in (self.policies
or self.local_worker.policies_to_train):
policy = self.workers.local_worker().get_policy(policy_id)
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
if policy._state_inputs:
rnn_inputs = policy._state_inputs + [
policy._seq_lens
]
else:
rnn_inputs = []
self.optimizers[policy_id] = (
LocalSyncParallelOptimizer(
policy._optimizer, self.devices,
list(policy._loss_input_dict_no_rnn.values()),
rnn_inputs, self.per_device_batch_size,
policy.copy))
self.add_optimizer(policy_id)

self.sess = self.workers.local_worker().tf_sess
self.sess.run(tf1.global_variables_initializer())
Expand All @@ -195,6 +182,13 @@ def __call__(self,
if policy_id not in (self.policies
or self.local_worker.policies_to_train):
continue
# Policy seems to be new and doesn't have an optimizer yet.
# Add it here and continue.
elif policy_id not in self.optimizers:
with self.workers.local_worker().tf_sess.graph.as_default(
):
with self.workers.local_worker().tf_sess.as_default():
self.add_optimizer(policy_id)

# Decompress SampleBatch, in case some columns are compressed.
batch.decompress_if_needed()
Expand Down Expand Up @@ -258,6 +252,18 @@ def __call__(self,
self.workers.local_worker().set_global_vars(_get_global_vars())
return samples, fetches

def add_optimizer(self, policy_id):
policy = self.workers.local_worker().get_policy(policy_id)
with tf1.variable_scope(policy_id, reuse=tf1.AUTO_REUSE):
if policy._state_inputs:
rnn_inputs = policy._state_inputs + [policy._seq_lens]
else:
rnn_inputs = []
self.optimizers[policy_id] = (LocalSyncParallelOptimizer(
policy._optimizer, self.devices,
list(policy._loss_input_dict_no_rnn.values()), rnn_inputs,
self.per_device_batch_size, policy.copy))


def all_tower_reduce(path, *tower_data):
"""Reduces stats across towers based on their stats-dict paths."""
Expand Down