Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request make it easier to supply custom model #457

Closed
ben-arnao opened this issue Aug 24, 2020 · 15 comments
Closed

Feature request make it easier to supply custom model #457

ben-arnao opened this issue Aug 24, 2020 · 15 comments
Assignees

Comments

@ben-arnao
Copy link

I tried assigning my own layers to the post_processing variable within my categorical qnetwork but i get a message that weights are shared. when i try to use then create my categorical dqn agent. It would be nice if the main categorical q network constructor allowed a parameter for you to provide a set of keras layers where the q_layer is just appended to the end like it is the the encoding network scheme. The weights will be copied for you.

@summer-yue summer-yue assigned kuanghuei and ebrevdo and unassigned kuanghuei Aug 26, 2020
@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 27, 2020

Hi @ben-arnao . We have a new way to declare Networks, including preprocessing layers, that is more flexible and may work better in your case. We have a new Sequential network that accepts Keras layers and TF-Agents networks; as well as a new NestMap network which acts a bit like preprocessing_layers. You can combine them; and feed the resulting Network into regular DQN. We can extend this ability to Categorical DQN pretty easily.

Here's an example of Sequential with regular DQN.

Here's an example of Sequential with NestMap.

NestMap can take the new Keras preprocessing layers as well.

Let me know if this is interesting to you; we can get a PR out that makes it possible to use these with Categorical DQN, including a unit test example.

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 27, 2020

(in the case of your own Sequential layer; you provide the final output layer which projects to [num_atoms, num_actions]. We can put that in an example.)

@ben-arnao
Copy link
Author

@ebrevdo Yes that would be great! I could probably find a way to make it work i just wanted to avoid writing too much custom code if possible.

@awestlake87
Copy link

Are there any plans to unify tf.keras.Model and Network? I'm pretty new to TF-Agents, so I don't really fully understand the difference between the two. It seems to me like you should be able to construct a DQN from any keras model, but maybe I'm missing something.

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 17, 2021

One fundamental issue with tf.keras.Model is that a Model object gives us no indication which output corresponds to an RNN state; so if you're doing, e.g., DQN-RNN, we can't tell which part of the input is the RNN and which part of the output is RNN. In contrast, Network explicitly separates state vs. non-state inputs and outputs.

If you want to use a Model with DQN without any RNN parts, you can easily just wrap it: tf_agents.networks.Sequential([my_model]) will do it.

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 17, 2021

@sguada we should see who can modify the existing CategoricalDQN unit tests / example to use the new Sequential networks, at which point we can link to that here and consider the issue resolved. Who would you suggest?

@awestlake87
Copy link

@ebrevdo Thanks for clarifying (and for the quick response)! I guess I didn't know Model inherits from Layer, but that makes total sense now!

@awestlake87
Copy link

awestlake87 commented Mar 18, 2021

I'm having some trouble passing a model to tf_agents.networks.Sequential([my_model]). I have a dead simple stateless DNN model for the cartpole agent that takes Tensor[None, 4] and outputs Tensor[None, 2] for left/right, but I keep getting the following AssertionError:

AssertionError: Could not compute output KerasTensor(type_spec=TensorSpec(shape=(None, 2), dtype=tf.float32, name=None), name='q_vals/BiasAdd:0', description="created by layer 'q_vals'")
  In call to configurable 'DqnAgent' (<class 'tf_agents.agents.dqn.dqn_agent.DqnAgent'>)

Here's the keras Model summary:

Model: "f32_logits_agent"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input (InputLayer)           [(None, 4)]               0         
_________________________________________________________________
q_vals (Dense)               (None, 2)                 10        
=================================================================

I have a feeling this has something to do with the network_state, but I'm not certain. I built this model with the functional keras API, not the sequential API, so maybe that could have something to do with it?

I'll keep digging around in the code to see if I can figure out where this error is coming from, but in the meantime if you guys have any examples or tips that would be much appreciated! (Also I'm not really sure where to go for help on this. If this is the wrong place feel free to let me know)

Update:
Here's how to reproduce it:

from tf_agents.environments import tf_py_environment
from tf_agents.environments import suite_gym
from tf_agents.specs import tensor_spec
import tensorflow as tf
import tf_agents

def create_keras_agent(env):
    input = tf.keras.layers.Input(name="input", shape=[4])

    output = tf.keras.layers.Dense(name="q_vals", units=2)(input)

    model = tf.keras.Model(
        name="f32_logits_qnet", inputs=input, outputs=output
    )

    # calling with constant works fine
    print(model(tf.constant([[0.0, 0.0, 0.0, 0.0]])))

    model.summary()

    # It kinda looks like Network.create_variables is doing something like this
    random_input = tensor_spec.sample_spec_nest(
        env.time_step_spec().observation, outer_dims=(1,)
    )
    print(random_input)
    # But it ends up working no problem
    print(model(random_input))

    qnet = tf_agents.networks.sequential.Sequential([model])
    
    # AssertionError here:
    tf_agents.agents.dqn.dqn_agent.DqnAgent(
        env.time_step_spec(),
        env.action_spec(),
        name="f32_logits_agent",
        q_network=qnet,
        optimizer=tf.keras.optimizers.Adam(),
        td_errors_loss_fn=tf_agents.utils.common.element_wise_squared_loss
    )


env = suite_gym.load("CartPole-v0")
env = tf_py_environment.TFPyEnvironment(env)
create_keras_agent(env)

Update 2
Ok, now I'm getting somewhere. The DqnAgent constructor makes a copy of q_network in order to construct the target_q_network. The model works fine until it's copied, then the copy raises this AssertionError when it's called. My naive guess would be that this is because the copy doesn't preserve the relationships between the layers when the model is constructed via the functional API (I haven't actually tested sequential construction to be sure).

Assuming my guess is correct, I would prefer to keep using the functional API, but I understand if it won't be supported. Any ideas on how I could get this to work or am I stuck with the sequential API unless I write my own custom Network?

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 18, 2021

That sucks. Looks like Model doesn't survive a copy.deepcopy operation. @fchollet this is maybe a bug in keras Functional models?

@awestlake87 for now you can work around this by creating the target_q_network yourself. basically do:

def create_m():
    input = tf.keras.layers.Input(name="input", shape=[4])

    output = tf.keras.layers.Dense(name="q_vals", units=2)(input)

    model = tf.keras.Model(
        name="f32_logits_qnet", inputs=input, outputs=output
    )

...
q_network = create_m()
target_q_network = create_m()

The key thing is that the q_network and target cannot share the same tf.Variables. Doing this yourself should mean DQN will avoid perfoming the deepcopy for you. LMK if that works.

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 18, 2021

Another issue that will come up. For agent.train(), the network needs to be able to handle an additional time dimension, so total input shape e.g. [B, T, D]; and when the policy is used, it needs to be able to handle just the batch dimension, e.g., shape [B, D]; sadly this means it's hard to declare an Input since that seems to require a predetermined known rank.

In order to handle this, you'll need to wrap your model in a layer that knows to automatically flatten the [B, T] dimensions for you. For that I think you'll want to use tf_agents.keras_layers.SquashedOuterWrapper(model, inner_rank=1) as the thing you pass to DQNAgent.

@ebrevdo
Copy link
Contributor

ebrevdo commented Mar 18, 2021

@sguada looks like we'll need to provide a README somewhere on how to use TF-Agents with tf.keras.Model; specifically, describing the need to use SquashedOuterWrapper and the current limitation of not being able to perform deepcopy (meaning the user has to manually create target networks).

@awestlake87
Copy link

Thanks for the heads up on the input shape! That'll save me some time. I'll try creating the target_q_network manually as well and see where that gets me.

@awestlake87
Copy link

That appears to have worked! It looks like it's working exactly the same as it was when I was just adding the dense layers to the tf_agents.networks.sequential.Sequential network.

@awestlake87
Copy link

awestlake87 commented Mar 18, 2021

Just as an update, when I started trying to create an agent for Breakout, I had to remove the tf_agents.keras_layers.SquashedOuterWrapper(model, inner_rank=1) call because it was flattening the image from (210, 160, 3) to (33600, 3). I'm currently training it, but I haven't run into that [B, T, D] issue yet. That could just be because I might not have used the code that would trigger it yet.

Edit:
My bad, I just wasn't using the inner_rank right. I had to add it back in as soon as I started dealing with RNN state.

@ebrevdo
Copy link
Contributor

ebrevdo commented Apr 18, 2021

For images looks like you may need inner_rank=3? Anyway glad it's working for you! I'll mark as closed for now but reopen if anything comes up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants