Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DQN does not allow custom models #6091

Closed
sytelus opened this issue Nov 5, 2019 · 6 comments · Fixed by #6258
Closed

DQN does not allow custom models #6091

sytelus opened this issue Nov 5, 2019 · 6 comments · Fixed by #6258

Comments

@sytelus
Copy link
Contributor

sytelus commented Nov 5, 2019

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04
  • Ray installed from (source or binary): source
  • Ray version: 0.8.0.dev6
  • Python version: 3.7.5
  • Exact command to reproduce:

The following code tries to set built-in VisionNetwork from TF as custom model and it errors out as described below. However, the code succeeds if custom model was not set in which case exact same VisionNetwork gets selected automatically by _get_v2_model. The cause of this issue is explained below however I'm not sure about the fix.

import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.visionnet_v2 import VisionNetwork

ModelCatalog.register_custom_model("my_model", VisionNetwork)

config = {'model': {
            "custom_model": "my_model",
            "custom_options": {},  # extra options to pass to your model
        }}
ray.init()

agent = DQNTrainer(config=config, env="BreakoutNoFrameskip-v4")

Describe the problem

Current code in master is not allowing the use of custom models in DQN. When trying to use custom model (either for TF or PyTorch), error is thrown indicating that model has not been subclassed from DistributionalQModel. This happens even when custom model is set to simply ray.rllib.models.tf.visionnet_v2.VisionNetwork.

Error message:

'The given model must subclass', <class 'ray.rllib.agents.dqn.distributional_q_model.DistributionalQModel'>)

Source code / logs

Cause of this issue is this check. Notice that this check is only done if custom_model is set. Apparently built-in models don't subclass DistributionalQModel either however as this check is not applied to built-in models they work fine.

@ericl
Copy link
Contributor

ericl commented Nov 5, 2019

There's a bit of code below that check that auto wraps the default model in the interface. I'm open to auto wrapping custom models as well if you want to make a patch.

Why not instead subclass the right model class though? It makes the behaviour a bit more clear I think: https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_model.py#L59

@arunavo4
Copy link

arunavo4 commented Nov 9, 2019

@sytelus Hey I ran into this exact issue a few days back and all I subclassed the right Model and everything works as expected.
Copy-paste this Model code below.

`

# ============== VisionNetwork Model ==================

class VisionNetwork(DistributionalQModel):
    """Generic vision network implemented in DistributionalQModel API."""

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name, **kw):
        super(VisionNetwork, self).__init__(
            obs_space, action_space, num_outputs, model_config, name, **kw)

        activation = get_activation_fn(model_config.get("conv_activation"))
        filters = model_config.get("conv_filters")
        if not filters:
            filters = _get_filter_config(obs_space.shape)
        no_final_linear = model_config.get("no_final_linear")
        vf_share_layers = model_config.get("vf_share_layers")

        inputs = tf.keras.layers.Input(
            shape=obs_space.shape, name="observations")
        last_layer = inputs

        # Build the action layers
        for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
            last_layer = tf.keras.layers.Conv2D(
                out_size,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding="same",
                name="conv{}".format(i))(last_layer)
        out_size, kernel, stride = filters[-1]
        if no_final_linear:
            # the last layer is adjusted to be of size num_outputs
            last_layer = tf.keras.layers.Conv2D(
                num_outputs,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding="valid",
                name="conv_out")(last_layer)
            conv_out = last_layer
        else:
            last_layer = tf.keras.layers.Conv2D(
                out_size,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding="valid",
                name="conv{}".format(i + 1))(last_layer)
            conv_out = tf.keras.layers.Conv2D(
                num_outputs, [1, 1],
                activation=None,
                padding="same",
                name="conv_out")(last_layer)

        # Build the value layers
        if vf_share_layers:
            last_layer = tf.keras.layers.Lambda(
                lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)
            value_out = tf.keras.layers.Dense(
                1,
                name="value_out",
                activation=None,
                kernel_initializer=normc_initializer(0.01))(last_layer)
        else:
            # build a parallel set of hidden layers for the value net
            last_layer = inputs
            for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1):
                last_layer = tf.keras.layers.Conv2D(
                    out_size,
                    kernel,
                    strides=(stride, stride),
                    activation=activation,
                    padding="same",
                    name="conv_value_{}".format(i))(last_layer)
            out_size, kernel, stride = filters[-1]
            last_layer = tf.keras.layers.Conv2D(
                out_size,
                kernel,
                strides=(stride, stride),
                activation=activation,
                padding="valid",
                name="conv_value_{}".format(i + 1))(last_layer)
            last_layer = tf.keras.layers.Conv2D(
                1, [1, 1],
                activation=None,
                padding="same",
                name="conv_value_out")(last_layer)
            value_out = tf.keras.layers.Lambda(
                lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer)

        self.base_model = tf.keras.Model(inputs, [conv_out, value_out])
        self.register_variables(self.base_model.variables)

    def forward(self, input_dict, state, seq_lens):
        # explicit cast to float32 needed in eager
        model_out, self._value_out = self.base_model(
            tf.cast(input_dict["obs"], tf.float32))
        return tf.squeeze(model_out, axis=[1, 2]), state

    def value_function(self):
        return tf.reshape(self._value_out, [-1])


# ================== Register Custom Model ======================
ModelCatalog.register_custom_model("NatureCNN", VisionNetwork)`

@arunavo4
Copy link

arunavo4 commented Nov 9, 2019

But I think this should be implemented in the library by default.

@ericl
Copy link
Contributor

ericl commented Nov 9, 2019

I see, that makes sense as I guess it's the expected behaviour.

Cc @AmeerHajAli I think this would be a good issue to get started on if you're interested.

@sytelus
Copy link
Contributor Author

sytelus commented Nov 9, 2019

Yeah, I think rllib shouldn't make distinction between built-in models and custom model. If it wraps up internal models then it should probably do so for custom models as well.

@AmeerHajAli
Copy link
Contributor

That sounds good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants