Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use replay_buffer.as_dataset() for minibatches #14

Open
seungjaeryanlee opened this issue Jul 29, 2019 · 1 comment
Open

How to use replay_buffer.as_dataset() for minibatches #14

seungjaeryanlee opened this issue Jul 29, 2019 · 1 comment

Comments

@seungjaeryanlee
Copy link
Owner

I tried to use the replay_buffer.as_dataset() the same way as the TD3 example:

dataset = replay_buffer.as_dataset(

dataset = replay_buffer.as_dataset(

dataset = replay_buffer.as_dataset(
             sample_batch_size=30,
             num_steps=64+1,
             num_parallel_calls=1
).prefetch(3)
iterator = iter(dataset)

def train_step():
  experience, _ = next(iterator)
  loss_info = tf_agent.train(experience)
  # TODO(seungjaeryanlee): Can't use for loop
  # AttributeError: Tensor.op is meaningless when eager execution is enabled.
  # for experience, _ in dataset:
  #   loss_info = tf_agent.train(experience)
  return loss_info

For LunarLander-v2, I thought that it would work if sample_batch_size = 30 and num_steps = 128+1, but it gives the following error:

python tf_agents/agents/ppo/examples/v2/train_eval_gym.py   --root_dir=$HOME/tmp/rndppo/gym/LunarLander-v2/   --logtostderr --use_rnd
/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
2019-07-30 00:54:19.740202: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1696250000 Hz
2019-07-30 00:54:19.740708: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5637b089c420 executing computations on platform Host. Devices:
2019-07-30 00:54:19.740771: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
I0730 00:54:19.771466 140352013358912 parallel_py_environment.py:81] Spawning all processes.
I0730 00:54:20.329198 140352013358912 parallel_py_environment.py:88] All processes started.
W0730 00:54:20.984402 140352013358912 module_wrapper.py:136] From /home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/util/module_wrapper.py:163: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

W0730 00:54:22.374413 140352013358912 deprecation.py:323] From /home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/autograph/impl/api.py:317: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0730 00:54:31.984498 140352013358912 deprecation.py:323] From /home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/training/optimizer.py:172: BaseResourceVariable.constraint (from tensorflow.python.ops.resource_variable_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Apply a constraint manually following the optimizer update step.
2019-07-30 00:55:14.501165: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Invalid argument: {{function_node __inference_Dataset_map_get_next_347}} assertion failed: [TFUniformReplayBuffer is empty. Make sure to add items before sampling the buffer.] [Condition x > y did not hold element-wise:x (TFUniformReplayBuffer/get_next/Select_1:0) = ] [0] [y (TFUniformReplayBuffer/get_next/Select:0) = ] [0]
         [[{{node TFUniformReplayBuffer/get_next/assert_greater/Assert/AssertGuard/else/_1/Assert}}]]
         [[IteratorGetNext]]
Traceback (most recent call last):
  File "tf_agents/agents/ppo/examples/v2/train_eval_gym.py", line 346, in <module>
    app.run(main)
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "tf_agents/agents/ppo/examples/v2/train_eval_gym.py", line 341, in main
    num_eval_episodes=FLAGS.num_eval_episodes)
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/gin/config.py", line 1032, in wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/gin/utils.py", line 49, in augment_exception_message_and_reraise
    six.raise_from(proxy.with_traceback(exception.__traceback__), None)
  File "<string>", line 3, in raise_from
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/gin/config.py", line 1009, in wrapper
    return fn(*new_args, **new_kwargs)
  File "tf_agents/agents/ppo/examples/v2/train_eval_gym.py", line 292, in train_eval
    total_loss, _ = train_step()
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 451, in __call__
    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 665, in _filtered_call
    self.captured_inputs)
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 778, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager)
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 471, in call
    ctx=ctx)
  File "/home/rlee/anaconda3/envs/gsoc/lib/python3.7/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError:   assertion failed: [TFUniformReplayBuffer is empty. Make sure to add items before sampling the buffer.] [Condition x > y did not hold element-wise:x (TFUniformReplayBuffer/get_next/Select_1:0) = ] [0] [y (TFUniformReplayBuffer/get_next/Select:0) = ] [0]
         [[{{node TFUniformReplayBuffer/get_next/assert_greater/Assert/AssertGuard/else/_1/Assert}}]]
         [[IteratorGetNext]] [Op:__inference_train_step_69111]

Function call stack:
train_step -> train_step

  In call to configurable 'train_eval' (<function train_eval at 0x7fa63fdd11e0>)

The error does not seem to appear when num_steps=64+1 or smaller.

On a similar note, in the TD3 example, am I understanding it correctly in that it only calls next(iterator) once, so it is only using one minibatch?

Thank you!

@seungjaeryanlee
Copy link
Owner Author

For now, I'm copying the code here and reverting to original code to try in smaller environments.

    ## AFTER
    # Dataset generates trajectories with shape [Bx2x...]
    # dataset = replay_buffer.as_dataset(
    #     num_parallel_calls=3,
    #     sample_batch_size=batch_size,
    #     num_steps=2).prefetch(3)
    # tf.print(dataset)
    dataset = replay_buffer.as_dataset(
                 sample_batch_size=30,
                 num_steps=64+1,
                 num_parallel_calls=1
    ).prefetch(3)
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      loss_info = tf_agent.train(experience)
      # TODO(seungjaeryanlee): Can't use for loop?
      # AttributeError: Tensor.op is meaningless when eager execution is enabled.
      # for experience, _ in dataset:
      #   loss_info = tf_agent.train(experience)
      return loss_info

    ## BEFORE
    # def train_step():
    #   experience = replay_buffer.gather_all()
    #   loss_info = tf_agent.train(experience)
    #   # TODO(seungjaeryanlee): Can't use for loop?
    #   # AttributeError: Tensor.op is meaningless when eager execution is enabled.
    #   # for experience, _ in dataset:
    #   #   loss_info = tf_agent.train(experience)
    #   return loss_info

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant