Skip to content

Commit

Permalink
GAE dev and test versions added. Updated README
Browse files Browse the repository at this point in the history
  • Loading branch information
qfettes committed Jul 16, 2018
1 parent 7e02d32 commit 924ec4a
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 6 deletions.
320 changes: 320 additions & 0 deletions 13.GAE.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions README.md
Expand Up @@ -15,6 +15,7 @@ Relevant Papers:
10. Rainbow with Quantile Regression [[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/10.Quantile-Rainbow.ipynb)
11. Deep Recurrent Q-Learning for Partially Observable MDPs [[Publication]](https://arxiv.org/abs/1507.06527)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/11.DRQN.ipynb)
12. Advantage Actor Critic (A2C) [[Publication1]](https://arxiv.org/abs/1602.01783)[[Publication2]](https://blog.openai.com/baselines-acktr-a2c/)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/12.A2C.ipynb)
13. High-Dimensional Continuous Control Using Generalized Advantage Estimation [[Publication]](https://arxiv.org/abs/1506.02438)[[code]](https://github.com/qfettes/DeepRL-Tutorials/blob/master/13.GAE.ipynb)


Requirements:
Expand Down
2 changes: 2 additions & 0 deletions a2c_devel.py
Expand Up @@ -41,6 +41,8 @@
#a2c control
config.num_agents=16
config.rollout=5
config.USE_GAE = True
config.gae_tau = 0.95

#misc agent variables
config.GAMMA=0.99
Expand Down
2 changes: 1 addition & 1 deletion agents/A2C.py
Expand Up @@ -49,7 +49,7 @@ def __init__(self, static_policy=False, env=None, config=None):
self.model.train()

self.rollouts = RolloutStorage(self.rollout, self.num_agents,
self.num_feats, self.env.action_space, self.device)
self.num_feats, self.env.action_space, self.device, config.USE_GAE, config.gae_tau)

self.value_losses = []
self.entropy_losses = []
Expand Down
Binary file modified saved_agents/model.dump
Binary file not shown.
Binary file modified saved_agents/optim.dump
Binary file not shown.
20 changes: 15 additions & 5 deletions utils/RolloutStorage.py
@@ -1,7 +1,7 @@
import torch

class RolloutStorage(object):
def __init__(self, num_steps, num_processes, obs_shape, action_space, device):
def __init__(self, num_steps, num_processes, obs_shape, action_space, device, USE_GAE=True, gae_tau=0.95):
self.observations = torch.zeros(num_steps + 1, num_processes, *obs_shape).to(device)
self.rewards = torch.zeros(num_steps, num_processes, 1).to(device)
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1).to(device)
Expand All @@ -12,6 +12,8 @@ def __init__(self, num_steps, num_processes, obs_shape, action_space, device):

self.num_steps = num_steps
self.step = 0
self.gae = USE_GAE
self.gae_tau = gae_tau

def insert(self, current_obs, action, action_log_prob, value_pred, reward, mask):
self.observations[self.step + 1].copy_(current_obs)
Expand All @@ -28,7 +30,15 @@ def after_update(self):
self.masks[0].copy_(self.masks[-1])

def compute_returns(self, next_value, gamma):
self.returns[-1] = next_value
for step in reversed(range(self.rewards.size(0))):
self.returns[step] = self.returns[step + 1] * \
gamma * self.masks[step + 1] + self.rewards[step]
if self.gae:
self.value_preds[-1] = next_value
gae = 0
for step in reversed(range(self.rewards.size(0))):
delta = self.rewards[step] + gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
gae = delta + gamma * self.gae_tau * self.masks[step + 1] * gae
self.returns[step] = gae + self.value_preds[step]
else:
self.returns[-1] = next_value
for step in reversed(range(self.rewards.size(0))):
self.returns[step] = self.returns[step + 1] * \
gamma * self.masks[step + 1] + self.rewards[step]
2 changes: 2 additions & 0 deletions utils/hyperparameters.py
Expand Up @@ -12,6 +12,8 @@ def __init__(self):
self.value_loss_weight = 0.5
self.entropy_loss_weight = 0.001
self.grad_norm_max = 0.5
self.USE_GAE=True
self.gae_tau = 0.95

#algorithm control
self.USE_NOISY_NETS=False
Expand Down

0 comments on commit 924ec4a

Please sign in to comment.