-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] SlateQ training iteration function. #24151
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a training iteration function in the diff here -- maybe github is off, but maybe you forgot to upload a commit? LMK
@avnishn , we use the exact same |
…eq_training_itr
…eq_training_itr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all looks pretty good to me except the global vars that we would use for updating rollout workers -- I guess since sampling is synchronous, my comments don't really matter, so Imma go ahead and approve.
# Update weights and global_vars - after learning on the local worker - | ||
# on all remote workers. | ||
global_vars = { | ||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is using LR schedule or Entropy schedule, shouldn't this be num agent steps trained?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. The thing is that NUM_AGENT_STEPS_SAMPLED is always a sum over all (multi) agents. So let's say you have 2 agents in your env and 1 policy, which both these agents map to. Then you would update this policy's timestep counter with the sum of these 2 agents' steps, which would be incorrect (as this count is possibly much larger than env steps; double if the two agents always act at the same time).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I would leave it for now to reflect our current behavior.
Definitely worth looking into this and maybe provide a better per-policy fix for this.
# Update weights and global_vars - after learning on the local worker - on all | ||
# remote workers. | ||
global_vars = { | ||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here as above
"replay_buffer_config": { | ||
"type": "MultiAgentReplayBuffer", | ||
# Enable the new ReplayBuffer API. | ||
"_enable_replay_buffer_api": True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason that you have to set this and its not just the case already?
Does this make Slateq use another replay buffer api?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it'll use the new replay buffer API we are currently rolling out across RLlib.
…eq_training_itr
…eq_training_itr
…eq_training_itr
…eq_training_itr
…eq_training_itr
SlateQ training iteration function.
_disable_execution_plan_api=True
by default for SlateQ.Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.