Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] SlateQ training iteration function. #24151

Merged
merged 16 commits into from
Apr 29, 2022

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Apr 24, 2022

SlateQ training iteration function.

  • Set _disable_execution_plan_api=True by default for SlateQ.

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Copy link
Contributor

@avnishn avnishn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a training iteration function in the diff here -- maybe github is off, but maybe you forgot to upload a commit? LMK

rllib/agents/slateq/slateq_tf_policy.py Show resolved Hide resolved
rllib/agents/slateq/slateq_tf_policy.py Show resolved Hide resolved
rllib/agents/slateq/slateq_tf_policy.py Show resolved Hide resolved
rllib/agents/slateq/slateq_torch_policy.py Show resolved Hide resolved
@sven1977
Copy link
Contributor Author

@avnishn , we use the exact same training_iteration function as DQN now. That's why I had to add the td_error stats. Makes SlateQ a little more powerful and the code base simpler.

Copy link
Contributor

@avnishn avnishn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks pretty good to me except the global vars that we would use for updating rollout workers -- I guess since sampling is synchronous, my comments don't really matter, so Imma go ahead and approve.

# Update weights and global_vars - after learning on the local worker -
# on all remote workers.
global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is using LR schedule or Entropy schedule, shouldn't this be num agent steps trained?

Copy link
Contributor Author

@sven1977 sven1977 Apr 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. The thing is that NUM_AGENT_STEPS_SAMPLED is always a sum over all (multi) agents. So let's say you have 2 agents in your env and 1 policy, which both these agents map to. Then you would update this policy's timestep counter with the sum of these 2 agents' steps, which would be incorrect (as this count is possibly much larger than env steps; double if the two agents always act at the same time).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I would leave it for now to reflect our current behavior.
Definitely worth looking into this and maybe provide a better per-policy fix for this.

# Update weights and global_vars - after learning on the local worker - on all
# remote workers.
global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here as above

"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
# Enable the new ReplayBuffer API.
"_enable_replay_buffer_api": True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason that you have to set this and its not just the case already?

Does this make Slateq use another replay buffer api?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it'll use the new replay buffer API we are currently rolling out across RLlib.

rllib/agents/slateq/slateq_tf_policy.py Show resolved Hide resolved
@sven1977 sven1977 merged commit 539832f into ray-project:master Apr 29, 2022
@sven1977 sven1977 deleted the slateq_training_itr branch June 2, 2023 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants