New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DQN support of custom models #6254
Conversation
Can one of the admins verify this patch? |
from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn | ||
from ray.rllib.utils import try_import_tf | ||
|
||
tf = try_import_tf() | ||
|
||
|
||
class FullyConnectedNetwork(TFModelV2): | ||
class FullyConnectedNetwork(DistributionalQModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is a generic model, and should not need to know about DistributionalQModel. The right place to do this is in the catalog.
We already wrap the model for the default models here:
Line 318 in 2a0225d
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface) |
and just need to do it for custom models too:
Line 260 in 2a0225d
if model_interface and not issubclass(model_cls, |
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork | ||
from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork | ||
|
||
class DQNCustomModelTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding a new test, can you add an case to https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_model.py exercising this instead?
Test FAILed. |
Why are these changes needed?
Prevents circular dependency of custom models when running DQN - #6091.
I also added a test that checks internal custom models for DQN and makes sure it does not accidentally break for other policies.
Related issue number
Closes #6091
Checks
scripts/format.sh
to lint the changes in this PR.