-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Slate-Q +GPU torch bug fix. #23464
Conversation
…eq_torch_gpu_fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me. Did you get a chance to run it on prod?
Yes, this was tested on a GPU machine in the product. |
# action.shape: [B, S] | ||
actions = train_batch[SampleBatch.ACTIONS] | ||
|
||
observation = convert_to_torch_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a thought, should we fix this in a more generic way for dict obs space?
not sure if this is also happening to other agents, and they just don't use strange looking obs by default.
@@ -487,7 +487,7 @@ slateq-interest-evolution-recsim-env: | |||
convert_to_discrete_action_space: false | |||
seed: 0 | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
framework was changed from torch to tf.
do you want to change it back? or run both?
Slate-Q +GPU torch bug fix.
get_interceptor
. This is probably due to the nested dict structure (preprocessor off for SlateQ) arriving in the loss functionWhy are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.