-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RNN support #19
Comments
Oh yes, in fact it has already support the rnn policy but I will write a small demonstration script after finishing the documentation. Thanks for pushing me. |
Thanks! Do you mind quickly posting a snippet here and explain how to use the RNN policy? I need to run something ASAP, thanks a lot! |
Actually, my experiments show that a RNN policy is nearly always inferior to a MLP policy in robotic control domains. |
It varies from case to case, but I agree that RNN might not be the best option. I do need it as a baseline though, so will still appreciate it if someone can explain how to run RNN policy with Tianshou. Thanks! |
@miriaford I upload a script in test/discrete/test_drqn.py.
How to use:
That's all. If you find anything wrong, please give me feedback. |
Cool, thanks a lot! Does it also work out of the box for PPO and SAC? Do you have any benchmark scripts/results for them? |
@miriaford I haven't tried RNN on PPO and SAC, but you can be the first. Following the changes between test_dqn.py and test_drqn.py and you can construct your first RNN-style PPO and SAC! |
Hi, I might have missed something obvious. Could you please educate me on:
Thanks for your help! |
|
This is great! Would love to check out the new interface. Another quick question: Thanks again! |
Yes, you are right. The longer the length, the smaller the difference of RNN performance between training and testing. |
Please kindly let me know when you have finished RNN-critic. I'm a bit bottlenecked on this. Thank you so much! |
Sure! I currently fix issue #38. Maybe today or tomorrow I’ll resolve your issue. |
Thanks for reopening the issue. I'd be happy to test it when you finish, thanks! |
Any updates? Sorry for pinging again. Thanks! |
I'm coding it. |
The code has been pushed. |
I'm not sure whether or not to support stacked action for the Q network. This needs to change some lines of code in all policy which need a Q-network (ddpg, td3, sac). Or you can create the Q-network as (stacked-s, single-a) -> Q value (this is in tianshou/test/continuous/net.py Lines 119 to 144 in c2a7caf
But for PPO, the current version is enough. |
Thanks!! Will take a look and let you know! |
Aha, I know how to support it (Q(stacked-s, stacked-a)) without modifying any part of the codebase. Try to add an gym.wrapper which would modify the state as {'original_state': state, 'action': action}. That means in the traditional way we save (s, a, s', r, d) in the replay buffer, but now we save ([s, a], a, [s', a'], r, d) in the buffer. |
Should we merge the essential commit from dev to master at this point? |
What's your desired feature? Could you please provide a simple example? |
I am trying to reach O(n) complexity in policy.learn(). In the current implementation, when sample_avail is not set, O(n) can't be reached. When sample_avail is set, policy.learn() can only get part of the reward, which makes it impossible to calculate returns. |
It is possible to use the original replaybuffer without framestack and sample_avail and leave all the dirty work to policy. I wonder is there are better way? |
Do you mean sample only available index with on-policy method for RNN usage? I haven't figured it out clearly so let's take an example:
If I understand correctly, you want to only extract [0-2], [3-6], [7-a], [b-c] without overlap. What I'm not sure about is the data format that you expect. (and for me "O(n)" is ambiguous) |
Let me state my question more clearly and thoroughly.
By setting sample_avail to True. The first problem is solved to some extend. However sequences like [b, c, c] may still exist so the second problem is still there. Also, in the current implementation, only reward at the end of an episode will be returned by buf.sample() which makes training impossible. The data format I am expecting is something like pytorch's PackedSequence, which include the sequences and the lengths of the sequences. It seems a bit complecated to implement this in replay buffer. I am now implementing it in process_fn of my policy. |
So you want a new sample method: it points out the number of episodes you want and will return the sampled episodes, for example, [[7-a], [3-6], [7-a], [b-c]] with episode_size=4, am I right? |
Yes |
Okay, I see. I think the simplest way is to implement it in the buffer side. Since we can easily maintain the done flag, we can do something like: # begin_index stores all the start index of one episode (if available), e.g., [0, 3, 7, b]
episode_index = np.random.choice(begin_index, episode_size) # assume we get [b, 7, 3, 7]
# do a sorting with key=-episode_length, since we want to mimic torch's PackedSequence
# and after that the episode_index is [7, 7, 3, b]
# should maintain a list of episode_length, in the above case: [3, 4, 4, 2]
data, seq_length = [], []
while not self._meta.done[episode_index][0]: # since the first episode's length is the longest
# find out the episodes that haven't terminated
available_episode_index = episode_index[self._meta.done[episode_index] == 0]
data.append(self[available_episode_index]) # should set stack_num=0
seq_length.append(len(available_episode_index))
episode_index += self._meta.done[episode_index] == 0
data = Batch.cat(data)
# for now, (data, seq_length) is something like a PackedSequence. I think maintaining I prefer to add a new argument in the initialization of Collector, to indicate whether to use this mode. |
Yes I think it is appropriate. |
But should we release a new version like |
It's up to you. It's really a personal choice. For Python-only project, I tend to prefer releasing patches like |
Did someone try transfomrer before ? In NLP area , transformer seems perform better than RNN unit . |
But in RL, the bigger the model, the harder it converges to an optimal policy. :( |
Ok , I get it. |
Is there a reason why in the From Recurrent Experience Replay in Distributed Reinforcement Learning by Kapturowski et al. in 2018 (https://openreview.net/pdf?id=r1lyTjAqYX)
|
No, because as you mentioned |
Is it possible to use the recurrent network in the discrete soft actor critic? |
@Trinkle23897 Could you please point me to the exact workflow you based on this RNN implementation? I'm trying to figure out why doesn't it work (I have tested RNN with SAC and with DQN on 5 environments, it only worked with DQN for Cartpole) |
Hello, @Trinkle23897 , I am encountering a problem at line 89-90 in v_s.append(self.critic(minibatch.obs))
v_s_.append(self.critic(minibatch.obs_next)) I am wondering why the critie does not support 'states' input as actor does? As it is common for the critic sharing the same RNN input network as the actor using, how can I pass the same state that the actor used to the critic? I noticed there is a state batch stored in the 'policy' key of the minibatch, but this should be the output state if I understand correctly, right? Or I am not supposed to pass any state to the critic during training? Do I missed anything here? Thanks a lot. |
Currently, yes |
Thanks. I finally put the input state into the model input dict to solve the problem. Another thing, I found the if self._grad_value: # clip large gradient
nn.utils.clip_grad_value_(
self._actor_critic.parameters(), clip_value=self._grad_value
)
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
self._actor_critic.parameters(), max_norm=self._grad_norm
) experiment: >>> w.grad
tensor([[0.4000, 0.4000, 0.4000, 0.4000, 0.4000],
[ inf, inf, inf, inf, inf],
[2.3000, 2.3000, 2.3000, 2.3000, 2.3000]])
>>> torch.nn.utils.clip_grad_norm_([w], 0.5) # this cannot deal with inf
tensor(inf)
>>> w.grad
tensor([[0., 0., 0., 0., 0.],
[nan, nan, nan, nan, nan],
[0., 0., 0., 0., 0.]])
>>> torch.nn.utils.clip_grad_value_([w], 0.5) # and this cannot deal with nan
>>> w.grad
tensor([[0., 0., 0., 0., 0.],
[nan, nan, nan, nan, nan],
[0., 0., 0., 0., 0.]])
========
>>> w.grad
tensor([[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[ inf, inf, inf, inf, inf],
[1.8000, 1.8000, 1.8000, 1.8000, 1.8000]])
>>> torch.nn.utils.clip_grad_value_([w], 0.5) # but this can deal with inf
>>> w.grad
tensor([[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000]]) |
Sure! Feel free to submit a PR |
I see on README that RNN support is on your TODO list. However, the module API seems to support RNN (
forward(obs, state)
method). Could you please provide some examples on how to train RNN policy? Thanks!The text was updated successfully, but these errors were encountered: