Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hindsight Experience Replay as a replay buffer #753

Merged
merged 23 commits into from
Oct 30, 2022

Conversation

Juno-T
Copy link
Contributor

@Juno-T Juno-T commented Oct 2, 2022

  • I have marked all applicable categories:
    • exception-raising fix
    • algorithm implementation fix
    • documentation modification
    • new feature
  • I have reformatted the code using make format (required)
  • I have checked the code using make commit-checks make check-codestyle(required)
  • If applicable, I have mentioned the relevant/related issue(s)
  • If applicable, I have listed every items in this Pull Request below

Related thread: pr#510

Hi, I've implemented Hindsight Experience Replay (HER) as a replay buffer and would like to contribute to tianshou.
I saw the discussion in the mentioned pr#510 and it seems to be inactive for quite a while so I decided to make a new, independent pull request.

implementation

I implemented HER solely as a replay buffer. It is done by temporarily directly re-writing transitions storage (self._meta) during the sample_indices() call. The original transitions are cached and will be restored at the beginning of the next sampling or when other methods is called. This will make sure that. for example, n-step return calculation can be done without altering the policy.

There is also a problem with the original indices sampling. The sampled indices are not guaranteed to be from different episodes. So I decided to perform re-writing based on the episode. This guarantees that the sampled transitions from the same episode will have the same re-written goal. This also make the re-writing ratio calculation slightly differ from the paper, but it won't be too different if there are many episodes in the buffer.

In the current commit, HER replay buffer only support 'future' strategy and online sampling. This is the best of HER in term of performance and memory efficiency.

I also add a few more convenient replay buffers (HERVectorReplayBuffer, HERReplayBufferManager), test env (MyGoalEnv), gym wrapper (TruncatedAsTerminated), unit tests, and a simple example (examples/offline/fetch_her_ddpg.py).

verification

I have added unit tests for almost everything I have implemented.
HER replay buffer was also tested using DDPG on FetchReach-v3 env. I used default DDPG parameters from mujoco example and didn't tune anything further to get this good result! (train script: examples/offline/fetch_her_ddpg.py).

Screen Shot 2022-10-02 at 19 22 53

Todo

  • Compile doc with make doc

@codecov-commenter
Copy link

codecov-commenter commented Oct 2, 2022

Codecov Report

Merging #753 (aaf0f73) into master (41ae346) will decrease coverage by 0.52%.
The diff coverage is 76.97%.

@@            Coverage Diff             @@
##           master     #753      +/-   ##
==========================================
- Coverage   91.77%   91.24%   -0.53%     
==========================================
  Files          70       71       +1     
  Lines        4934     5082     +148     
==========================================
+ Hits         4528     4637     +109     
- Misses        406      445      +39     
Flag Coverage Δ
unittests 91.24% <76.97%> (-0.53%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tianshou/utils/net/continuous.py 97.91% <ø> (ø)
tianshou/utils/net/common.py 82.95% <9.09%> (-10.56%) ⬇️
tianshou/data/buffer/manager.py 82.05% <70.00%> (-2.01%) ⬇️
tianshou/utils/logger/base.py 91.83% <83.33%> (-3.52%) ⬇️
tianshou/env/gym_wrappers.py 97.72% <90.90%> (-2.28%) ⬇️
tianshou/data/buffer/her.py 91.76% <91.76%> (ø)
tianshou/data/__init__.py 100.00% <100.00%> (ø)
tianshou/data/buffer/vecbuf.py 100.00% <100.00%> (ø)
tianshou/env/__init__.py 75.00% <100.00%> (ø)
tianshou/env/venv_wrappers.py 80.32% <0.00%> (-3.28%) ⬇️
... and 2 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@Trinkle23897
Copy link
Collaborator

The ubuntu GPU test is run on gym==0.25.2. I'll give some feedback later today.

@Juno-T
Copy link
Contributor Author

Juno-T commented Oct 5, 2022

I see. I'll fix it after the feedback then.

@@ -39,6 +39,7 @@
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add in docs/index.rst

Copy link
Contributor Author

@Juno-T Juno-T Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated: 9d52936

Comment on lines 15 to 17
`acheived_goal`, `desired_goal`, `info` and returns the reward(s).
Note that the goal arguments can have extra batch_size dimension and in that \
case, the rewards of size batch_size should be returned
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can mention the size/shape here, is it for single or batch input?

Copy link
Contributor Author

@Juno-T Juno-T Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated: 63e65c0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I also remove the third argument, info. Previously this argument was manually set to empty dict {} during the sampling because its shape was complicated. Originally I have it there because Fetch env's compute_reward function also takes info argument, although it was never used.

"""
if indices.size == 0:
return
# Construct episode trajectories
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think you should sort indices here. An example:

indices = [9, 7, 8, 6, 4, 5, 3, 1, 2, 0]
done = [0, 0, 1, 0, 0, 1, 0, 0, 1, 1]

after some processing with horizon=8, unique_ep_open_indices = [7, 4, 1, 0] and unique_ep_close_indices=[ 3, 0, -1, 9], which makes no sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously I sort indices in sample_indices method (since it is the only one that calls rewrite_transitions) but yeah, I agree that sort indices should be inside rewrite_transitions.

fixed: 63e65c0

Copy link
Collaborator

@Trinkle23897 Trinkle23897 Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. I didn't see L92 previously, but it also has an issue: if the replay buffer start index is not 0 (i.e., overwrite previous data point in a circular buffer), the sort approach will cause some bugs at some point. It should start with x where x is the smallest index that is greater or equal to start index.

ref:

elif batch_size == 0: # construct current available indices
return np.concatenate(
[np.arange(self._index, self._size),
np.arange(self._index)]
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. This case somehow slipped my mind.
In rewrite_transitions, sorting indices serves as a way to group indices from the same episode together and order them chronologically. So to apply with the circular property, I'm thinking of sorting the indices, then shifting them so that the oldest index is at the front.

Will fix this later maybe today or tomorrow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just few lines I think:

mask = indice >= self._index
indice[~mask] += self._length
sort(indice)
indice[indice >= self._length] -= self._length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated & add some test : 6ee7d08

# episode indices that will be altered
her_ep_indices = np.random.choice(
len(unique_ep_open_indices),
size=int(len(unique_ep_open_indices) * self.future_p),
Copy link
Collaborator

@Trinkle23897 Trinkle23897 Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please briefly explain why we need self.future_p here, instead of all indices?

If I understand correctly, these few lines do some random HER update on the part of trajectories instead of all trajectories.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct. HER only performs goal re-writing on some random episodes (the amount is depending on the ratio parameter future_k, or in here, probablity future_p).

return
self._meta[self._altered_indices] = self._original_meta
# Clean
del self._original_meta, self._altered_indices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to manually delete them because it has been overwritten in the next line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed: 63e65c0

@Juno-T
Copy link
Contributor Author

Juno-T commented Oct 5, 2022

Btw, I tried verifying documentation page by using make doc but I kept getting version mismatch error. The doc test seems to be ok so I think it shouldn't be a problem.

The error, just in case:

WARNING: while setting up extension sphinx.addnodes: node class 'meta' is already registered, its visitors will be overridden
/home/tianshou/venv/lib/python3.8/site-packages/sphinx/jinja2glue.py:106: DeprecationWarning: 'contextfunction' is renamed to 'pass_context', the old name will be removed in Jinja 3.1.
  def warning(context: Dict, message: str, *args: Any, **kwargs: Any) -> str:

Exception occurred:
  File "/home/tianshou/venv/lib/python3.8/site-packages/pkg_resources/__init__.py", line 791, in resolve
    raise VersionConflict(dist, req).with_context(dependent_req)
pkg_resources.ContextualVersionConflict: (docutils 0.19 (/home/tianshou/venv/lib/python3.8/site-packages), Requirement.parse('docutils<0.17,>=0.12'), {'Sphinx'})
The full traceback has been saved in /tmp/sphinx-err-_u9sq96y.log, if you want to report the issue to the developers.
Please also report this if it was a user error, so that a better error message can be provided next time.
A bug report can be filed in the tracker at <https://github.com/sphinx-doc/sphinx/issues>. Thanks!
make[1]: *** [Makefile:20: html] Error 2
make[1]: Leaving directory '/home/tianshou/docs'
make: *** [Makefile:47: doc] Error 2

Comment on lines 11 to 13
`HERReplayBuffer` is to be used with goal-based environment where the \
observation is a dictionary with keys `observation`, `achieved_goal` and \
`desired_goal`. Currently support only HER's future strategy, online sampling.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`HERReplayBuffer` is to be used with goal-based environment where the \
observation is a dictionary with keys `observation`, `achieved_goal` and \
`desired_goal`. Currently support only HER's future strategy, online sampling.
HERReplayBuffer is to be used with goal-based environment where the
observation is a dictionary with keys ``observation``, ``achieved_goal`` and
``desired_goal``. Currently support only HER's future strategy, online sampling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated: 7802451

Comment on lines 16 to 24
:param compute_reward_fn: a function that takes 2 `np.array` arguments, \
`acheived_goal` and `desired_goal`, and returns rewards as `np.array`.
The two arguments are of shape (batch_size, *original_shape) and the returned \
rewards must be of shape (batch_size,).
:param int horizon: the maximum number of steps in an episode.
:param int future_k: the 'k' parameter introduced in the paper. In short, there \
will be at most k episodes that are re-written for every 1 unaltered episode \
during the sampling.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:param compute_reward_fn: a function that takes 2 `np.array` arguments, \
`acheived_goal` and `desired_goal`, and returns rewards as `np.array`.
The two arguments are of shape (batch_size, *original_shape) and the returned \
rewards must be of shape (batch_size,).
:param int horizon: the maximum number of steps in an episode.
:param int future_k: the 'k' parameter introduced in the paper. In short, there \
will be at most k episodes that are re-written for every 1 unaltered episode \
during the sampling.
:param compute_reward_fn: a function that takes 2 ``np.array`` arguments,
``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``.
The two arguments are of shape (batch_size, *original_shape) and the returned
rewards must be of shape (batch_size,).
:param int horizon: the maximum number of steps in an episode.
:param int future_k: the 'k' parameter introduced in the paper. In short, there
will be at most k episodes that are re-written for every 1 unaltered episode
during the sampling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated: 7802451

@Trinkle23897
Copy link
Collaborator

Btw, I tried verifying documentation page by using make doc but I kept getting version mismatch error. The doc test seems to be ok so I think it shouldn't be a problem.

Have you tried to reinstall deps by pip install -e ".[dev]"?

@Juno-T
Copy link
Contributor Author

Juno-T commented Oct 6, 2022

I have tried pip install -e ".[dev]" and manually downgraded to docutils==0.16. Both didn't work.
I wonder if the latest versions (installed during make doc) of sphinx and docutils aren't compatible.
Could you share your sphinx and docutils version please?

@Trinkle23897
Copy link
Collaborator

pip3 install "sphinx<4" sphinxcontrib-bibtex "jinja2<3.1"

Comment on lines 110 to 111
mask = indices >= self._index
indices[~mask] += self.maxsize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indices[indices < self._index] += self.maxsize

examples/offline/fetch_her_ddpg.py Outdated Show resolved Hide resolved
examples/offline/fetch_her_ddpg.py Outdated Show resolved Hide resolved
examples/offline/fetch_her_ddpg.py Outdated Show resolved Hide resolved
examples/offline/fetch_her_ddpg.py Outdated Show resolved Hide resolved
@Juno-T
Copy link
Contributor Author

Juno-T commented Oct 27, 2022

Sorry for missing for a few weeks, I had to clear up some work at hand.

In the latest update, 72b4ab4, I have:

@nuance1979
Copy link
Collaborator

LGTM. Great work, @Juno-T !

@nuance1979
Copy link
Collaborator

@Trinkle23897 I have fixed the test failures. Feel free to approve or suggest more changes.

Copy link
Collaborator

@Trinkle23897 Trinkle23897 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the great works!

@Trinkle23897 Trinkle23897 merged commit d42a5fb into thu-ml:master Oct 30, 2022
BFAnas pushed a commit to BFAnas/tianshou that referenced this pull request May 5, 2024
## implementation
I implemented HER solely as a replay buffer. It is done by temporarily
directly re-writing transitions storage (`self._meta`) during the
`sample_indices()` call. The original transitions are cached and will be
restored at the beginning of the next sampling or when other methods is
called. This will make sure that. for example, n-step return calculation
can be done without altering the policy.

There is also a problem with the original indices sampling. The sampled
indices are not guaranteed to be from different episodes. So I decided
to perform re-writing based on the episode. This guarantees that the
sampled transitions from the same episode will have the same re-written
goal. This also make the re-writing ratio calculation slightly differ
from the paper, but it won't be too different if there are many episodes
in the buffer.

In the current commit, HER replay buffer only support 'future' strategy
and online sampling. This is the best of HER in term of performance and
memory efficiency.

I also add a few more convenient replay buffers
(`HERVectorReplayBuffer`, `HERReplayBufferManager`), test env
(`MyGoalEnv`), gym wrapper (`TruncatedAsTerminated`), unit tests, and a
simple example (examples/offline/fetch_her_ddpg.py).

## verification
I have added unit tests for almost everything I have implemented.
HER replay buffer was also tested using DDPG on [`FetchReach-v3`
env](https://github.com/Farama-Foundation/Gymnasium-Robotics). I used
default DDPG parameters from mujoco example and didn't tune anything
further to get this good result! (train script:
examples/offline/fetch_her_ddpg.py).


![Screen Shot 2022-10-02 at 19 22
53](https://user-images.githubusercontent.com/42699114/193454066-0dd0c65c-fd5f-4587-8912-b441d39de88a.png)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants