-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #162 from zuoxingdong/add_ddpg2
minor update
- Loading branch information
Showing
154 changed files
with
2,460 additions
and
604 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,41 @@ | ||
from collections import deque | ||
|
||
import random | ||
import numpy as np | ||
import torch | ||
|
||
from lagom.envs import flatdim | ||
from lagom.utils import tensorify | ||
|
||
|
||
class ReplayBuffer(object): | ||
r"""A deque-based buffer of bounded size that implements experience replay. | ||
.. note: | ||
Difference with DQN replay buffer: we handle raw observation (no pixel) for continuous control | ||
Thus we do not have transformation to and from 255. and np.uint8 | ||
Args: | ||
capacity (int): max capacity of transition storage in the buffer. When the buffer overflows the | ||
old transitions are dropped. | ||
device (Device): PyTorch device | ||
""" | ||
def __init__(self, capacity, device): | ||
def __init__(self, env, capacity, device): | ||
self.env = env | ||
self.capacity = capacity | ||
self.device = device | ||
self.buffer = deque(maxlen=capacity) | ||
|
||
self.observations = np.zeros([capacity, flatdim(env.observation_space)], dtype=np.float32) | ||
self.actions = np.zeros([capacity, flatdim(env.action_space)], dtype=np.float32) | ||
self.rewards = np.zeros(capacity, dtype=np.float32) | ||
self.next_observations = np.zeros([capacity, flatdim(env.observation_space)], dtype=np.float32) | ||
self.masks = np.zeros(capacity, dtype=np.float32) | ||
|
||
self.size = 0 | ||
self.pointer = 0 | ||
|
||
def __len__(self): | ||
return len(self.buffer) | ||
return self.size | ||
|
||
def add(self, observation, action, reward, next_observation, done): # input must be non-batched | ||
to_float = lambda x: np.asarray(x, dtype=np.float32) # save half memory than float64 | ||
transition = (to_float(observation), to_float(action), reward, to_float(next_observation), done) | ||
self.buffer.append(transition) | ||
self.observations[self.pointer] = observation | ||
self.actions[self.pointer] = action | ||
self.rewards[self.pointer] = reward | ||
self.next_observations[self.pointer] = next_observation | ||
self.masks[self.pointer] = 1. - done | ||
|
||
self.pointer = (self.pointer+1) % self.capacity | ||
self.size = min(self.size + 1, self.capacity) | ||
|
||
def sample(self, batch_size): | ||
D = random.choices(self.buffer, k=batch_size) | ||
D = zip(*D) | ||
observations, actions, rewards, next_observations, dones = list(map(lambda x: np.asarray(x), D)) | ||
masks = 1. - dones | ||
D = (observations, actions, rewards, next_observations, masks) | ||
D = list(map(lambda x: torch.from_numpy(x).float().to(self.device), D)) | ||
return D | ||
idx = np.random.randint(0, self.size, size=batch_size) | ||
return list(map(lambda x: tensorify(x, self.device), [self.observations[idx], | ||
self.actions[idx], | ||
self.rewards[idx], | ||
self.next_observations[idx], | ||
self.masks[idx]])) |
Oops, something went wrong.