Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DQN Agent Issue With Custom Environment #258

Closed
IbraheemNofal opened this issue Nov 28, 2019 · 16 comments
Closed

DQN Agent Issue With Custom Environment #258

IbraheemNofal opened this issue Nov 28, 2019 · 16 comments

Comments

@IbraheemNofal
Copy link

IbraheemNofal commented Nov 28, 2019

So I've been following the DQN agent example / tutorial and I set it up like in the example, only difference is that I built my own custom python environment which I then wrapped in TensorFlow. However, no matter how I shape my observations and action specs, I can't seem to get it to work whenever I give it an observation and request an action. Here's the error that I get:

tensorflow.python.framework.errors_impl.InvalidArgumentError: In[0] is not a matrix. Instead it has shape [10] [Op:MatMul]

Here's how I'm setting up my agent:

    layer_parameters = (10,) #10 layers deep, shape is unspecified
    
    #placeholders 
    learning_rate = 1e-3  # @param {type:"number"}
    train_step_counter = tf.Variable(0)

    #instantiate agent

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    
    env = SumoEnvironment(self._num_actions,self._num_states)
    env2 = tf_py_environment.TFPyEnvironment(env)
    q_net= q_network.QNetwork(env2.observation_spec(),env2.action_spec(),fc_layer_params = layer_parameters)
    
    print("Time step spec")
    print(env2.time_step_spec())

    agent = dqn_agent.DqnAgent(env2.time_step_spec(),
    env2.action_spec(),
    q_network=q_net,
    optimizer = optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)`

And here's how I'm setting up my environment:

`class SumoEnvironment(py_environment.PyEnvironment):

def __init__(self, no_of_Actions, no_of_Observations):

    #this means that the observation consists of a number of arrays equal to self._num_states, with datatype float32
    self._observation_spec = specs.TensorSpec(shape=(16,),dtype=np.float32,name='observation')
    #action spec, shape unknown, min is 0, max is the number of actions
    self._action_spec = specs.BoundedArraySpec(shape=(1,),dtype=np.int32,minimum=0,maximum=no_of_Actions-1,name='action')
    
   
    self._state = 0
    self._episode_ended = False`

And here is what my input / observations look like:

tf.Tensor([ 0. 0. 0. 0. 0. 0. 0. 0. -1. -1. -1. -1. 0. 0. 0. -1.], shape=(16,), dtype=float32)

I've tried experimenting with the shape and depth of my Q_Net and it seems to me that the [10] in the error is related to the shape of my q network. Setting its layer parameters to (4,) yields an error of:

tensorflow.python.framework.errors_impl.InvalidArgumentError: In[0] is not a matrix. Instead it has shape [4] [Op:MatMul]

@tagomatech
Copy link

tagomatech commented Dec 7, 2019

FWIW Isn't it a ArraySpec (tf_agents\specs\array_spec.py), TensorSpec (tensorflow_core\python\framework\tensor_spec.py) conversion of some sort issue?

Is there a specific reason why you are using TensorSpec instead of ArraySpec or BoundedArraySpec to define the observation space in your customized environment (I mean inconsistently with examples provided by the tf-agents team)?

@IbraheemNofal
Copy link
Author

IbraheemNofal commented Dec 7, 2019

FWIW Isn't it a ArraySpec (tf_agents\specs\array_spec.py), TensorSpec (tensorflow_core\python\framework\tensor_spec.py) conversion of some sort issue?

Is there a specific reason why you are using TensorSpec instead of ArraySpec or BoundedArraySpec to define the observation space in your customized environment (I mean inconsistently with examples provided by the tf-agents team)?

It is an array spec in the provided examples, but as far as I understand, the spec doesn't matter much in this case as long as what I'm passing as observations fits the aforementioned spec, no? I experimented with both tensor spec and boundedarrayspec and with different shapes for those specs, all yield the same error. And I'm not certain whether it's a conversion issue or not.

@ebrevdo
Copy link
Contributor

ebrevdo commented Dec 10, 2019

I would like to see a full trace of your code, but I believe the problem is that you're performing training (correct me if I'm wrong), which requires getting batches of data from your environment. So for example, if the observatoin spec is TensorSpec(..., shape=(16,), dtype=tf.float32) then at training time you should be passing tensors of shape [batch_size, 2, 16]; where the first dimension corresponds to batch size (whatever it is) and the second corresponds to the fact that DQN requires seeing 2 time steps. The inner dimension matches the spec.

@ebrevdo
Copy link
Contributor

ebrevdo commented Dec 10, 2019

For a better answer, we'd need to look at your copy of the repo; so we can understand the diff. Also ensure that you can run the original example before you made any changes and that it doesn't lead to an error.

@IbraheemNofal
Copy link
Author

I would like to see a full trace of your code, but I believe the problem is that you're performing training (correct me if I'm wrong), which requires getting batches of data from your environment. So for example, if the observatoin spec is TensorSpec(..., shape=(16,), dtype=tf.float32) then at training time you should be passing tensors of shape [batch_size, 2, 16]; where the first dimension corresponds to batch size (whatever it is) and the second corresponds to the fact that DQN requires seeing 2 time steps. The inner dimension matches the spec.

The error occurs whenever I request an action from the dqn agent's collect_policy, passing one observation as you stated. As I understood from 1_dqn_tutorial under collab I have to pass only one observation, the current one, from the environment to get an action. I haven't yet gotten to training the agent. This isn't a copy of this repo, it's me attempting to adapt the solution to an entirely different environment.

Attached below is the link to the repo, code is in "Traffic Program ML.py". Please note: I'm handling the "step" function of the environment manually; in term of getting next Timestep and saving it in memory, since I'm attempting to run multiple agents at the same time in one custom environment, and since in my case it wouldn't be possible to get the next state immediately after requesting an action from the agent.

I'll be attempting to ensure that I can run the original example, will get back to you with results. Thank you for reaching out and helping me resolve the issue.

Repo Link

@ebrevdo
Copy link
Contributor

ebrevdo commented Dec 11, 2019

Have you tried ensuring that your observations are properly batched? If right now you're using a single py environment, and wrapping it in a TFPyEnvironment, try adding a batch dimension: first wrap the py environment in a batched env: py_env = BatchedPyEnvironment([py_env]).

@IbraheemNofal
Copy link
Author

IbraheemNofal commented Dec 11, 2019

Have you tried ensuring that your observations are properly batched? If right now you're using a single py environment, and wrapping it in a TFPyEnvironment, try adding a batch dimension: first wrap the py environment in a batched env: py_env = BatchedPyEnvironment([py_env]).

Batching my observations? So for instance if my batch size is 32, the shape of my observations would become (32,16, ), correct? Can't I just pass a single observation to the agent, when requesting an action and collecting data? I think I wasn't very clear about what I'm doing in the previous comment. In the repo above, what I'm attempting to do is pass a single observation to each DQN agent, each with their own TFPYEnvironment to keep track of their variables, but all of which are interacting with a single simulation run. A good analogy would be multiple players(the agents) playing a single Co-OP game run, each taking in their own set of observations every X Milliseconds of gameplay time. I may have to consider altering my architecture if batched observations are a requirement.

@sguada
Copy link
Member

sguada commented Dec 20, 2019

For collect or evaluation, in general your PyEnvironment should generate numpy.arrays without batch_dimension and define ArraySpecs, then the TFPyEnvironment would do the appropriate conversion, including adding batch_size = 1
You can pass a single observation, but it needs to have batch_size = 1. This is different than the batch_size used for training.

@ebrevdo
Copy link
Contributor

ebrevdo commented Jan 22, 2020

Does that resolve your issue?

@IbraheemNofal
Copy link
Author

Oh, I apologize. Unfortunately, after messing around with it for a bit, I still couldn't get it to work. I ended up migrating to Tensorforce.

@onurcanbektas
Copy link

Have you been able to solve the issue? and if so how?

@IbraheemNofal
Copy link
Author

IbraheemNofal commented Jul 1, 2020

I unfortunately wasn't able to. I ended up using a different reinforcement learning library with a different DQN implementation. Though looking back at it now, I probably wasn't shaping my observations correctly at the time. It'd help if you posted your configuration, showing what your observations look like and the specific error message you're getting, as that might shed more light on the issue you're having.

@onurcanbektas
Copy link

onurcanbektas commented Jul 1, 2020

In my case, this is the shape of the observation_spec

BoundedTensorSpec((10,10), np.int32, minimum=0, maximum=2)

and this is my custom agent class

class singleAgent(DqnAgent):

def __init__(self,
               env: TFPyEnvironment,
               init_pos,
               reward_fn: Callable = lambda time_step: time_step.reward,
               action_fn: Callable = lambda action: action,
               name: str='IMAgent',
               #q_network=None,
               **dqn_kwargs):
    self._env = env
    self._observation_spec = self._env.observation_spec()
    self._action_spec = self._env.action_spec()
    self._qNetwork = self._buildQNetwork()
    self._currentPos = init_pos
    self._init_pos = init_pos
    self._name = name

    self._action_fn = action_fn
    self._reward_fn = reward_fn

    baseEnv_ts_spec = self._env.time_step_spec()
    time_step_spec = TimeStep(
            step_type=baseEnv_ts_spec.step_type,
            reward=baseEnv_ts_spec.reward,
            discount=baseEnv_ts_spec.discount,
            observation=self._qNetwork.input_tensor_spec)
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=self._env._learning_rate)
    super(singleAgent, self).__init__(time_step_spec,
                         self._action_spec,
                         self._qNetwork,
                         optimizer,
                         name=self._name,
                         **dqn_kwargs)
    
    
    self.initialize()
    self._policy_state = self.policy.get_initial_state(batch_size=self._env._training_batch_size)
    self._rewards = []

    self._replay_buffer = TFUniformReplayBuffer(
            data_spec=self.collect_data_spec,
            batch_size=self._env._training_batch_size,
            max_length=self._env._replay_buffer_max_length)

    self.train = common.function(self.train)

  def _buildQNetwork(self):
    # the shape of the NN
    fc_layer_params = (100,)

    #=> QNetwork(
    #input_tensor_spec, action_spec, preprocessing_layers=None,
    #preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=(75, 40),
    #dropout_layer_params=None, activation_fn=tf.keras.activations.relu,
    #kernel_initializer=None, batch_squash=True, dtype=tf.float32, name='QNetwork')
    q_net = QNetwork(
            self._observation_spec,
            self._action_spec,
            fc_layer_params=fc_layer_params)
    # TODO: add dtype = tf.int32?
    q_net.create_variables()
    #q_net.summary()   
    return q_net
...

and in the end, I am calling them like this

multi = baseEnvironement()
multiEnv = TFPyEnvironment(multi)
multi._reset()
p1 = singleAgent(multi, (2,2))
p1._reset()
p1.act(collect=False)

where the act method is implemented like

def act(self, collect=False) -> Trajectory:
    time_step = self._env.current_time_step()
    print("---")
    print(self._policy_state)    
    if collect:
      policy_step = self.collect_policy.action(time_step, policy_state=self._policy_state)
    else:
      policy_step = self.policy.action(time_step, policy_state=self._policy_state)

    self._policy_state = policy_step.state
    next_time_step = self._step_environment(policy_step.action)
    traj = trajectory.from_transition(time_step, policy_step, next_time_step)

    self._rewards.append(next_time_step.reward)
    if collect:
      self._replay_buffer.add_batch(traj)

    return traj

The exact error message is the following

-> p1.act(collect=False)
...
else
-> policy_step = self.policy.action(time_step, policy_state=self._policy_state)
...
"In[0] is not a matrix. Instead it has shape [100] [Op:MatMul]"

So, basically it is failing in the policy.action(time_step, policy_state=self._policy_state) method

@IbraheemNofal
Copy link
Author

IbraheemNofal commented Jul 1, 2020

I haven't dealt a lot with the TF-Agents or TF in general, as I use a different library, but I think the issue lies within your observation spec, where your observations are a Tensor of rank 2, meaning it has 2 dimensions, while it expects a 3 dimensional shape, probably because your network uses a 2D convolutional neural net which takes input tensors of rank 3 (rank 4 actually if you count the batch dimension). If your observations are indeed of shape (10,10), then try instead specifying your observation_spec as (1,10,10) such that it's a 3 dimensional shape, and of course, make sure your input is actually of that exact shape.

@onurcanbektas
Copy link

@IbraheemNofal thanks you very much for the response; there is no convolutional layer, and the input shape is indeed (10,10). I have just tried (1,10,10) but no avail; the error is the same.

@maggawron
Copy link

@onurcanbektas Try feeding (10, 10, 1) as input shape. The images are represented as (height, width, num_channels), for example RGB image of 512x512 will be (512, 512, 3) as RGB has 3 channels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants