-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] action masking does not work with VecEnv and MultiDiscrete action space #74
Comments
It might be a bug, but it's hard to say from the description. Could you share the code to reproduce? And could you show an example of how the mask is being split weirdly? My initial impression is that the (128, 360) shape is intended because each row corresponds to an env in the vecenv. |
this actually looks good to me, we need to retrieve one mask per env. (fyi I think that we expect 1D mask from the env even for multi discrete (see #80 (comment)), it will be reshaped by the algorithm afterward) |
Describe the bug
I am aware of #49 (comment) - but it still does not work. I have investigated the code and this is what I found:
When having more than one environment, each using their own ActionMasker, the masks get collected in batch form, thus splitting the masks across the distributions does not work. This feels to me like a VecEnv bug, however, I followed the advice in the documentation and comments on how to set up the action masker on an env-individual basis.
stable-baselines3-contrib/sb3_contrib/common/maskable/distributions.py
Line 234 in 75b2de1
My ActionSpace is for example
Multidiscrete([5]*72)
. I am spinning up 128 environments. (Fyi: 5*72 = 360)When investigating the
MaskableMultiCategoricalDistribution
it actually creates 72MaskableCategorical
distributions, as it should.BUT: the shape of the mask is not (360,) or (1,360) but instead it is (128, 360). This way the masks get split weirdly. and the above-mentioned line as well as the distributions are not built for it AFAIK. When tracking invalid actions taken in my environment, there are a ton instead of the expected 0.
System Info
Describe the characteristic of your environment:
Am I doing something wrong or are there further ways I can debug this?
The text was updated successfully, but these errors were encountered: