In [1]:
import gym
from gym import logger as gymlogger
gymlogger.set_level(40)

from parallelEnv import *
from replaybuffer import *

In [2]:
env = gym.make('LunarLander-v2')
env.seed(1234)
print('State shape: ', env.observation_space)
print('Number of actions: ', env.action_space)

buffer = ReplayBuffer(100)
state = env.reset()
for _ in range(200):
    action = env.action_space.sample()
    next_state, reward, done, _ = env.step(action)
    buffer.push(state, action, reward, next_state, done)
    state = next_state
    if done:
        break
        
states, actions, rewards, next_states, dones = buffer.sample(16)

print(len(buffer.memory))
print(states.shape)
print(actions.shape)
print(rewards.shape)
print(next_states.shape)
print(dones.shape)

State shape:  Box(8,)
Number of actions:  Discrete(4)
84
(16, 8)
(16,)
(16,)
(16, 8)
(16,)


In [3]:
envs = parallelEnv('LunarLander-v2', n = 4)
n=len(envs.ps) # number of parallel instances

buffer = ReplayBuffer(100, parallel_envs = True)
state = envs.reset()
for _ in range(200):
    action = [envs.action_space.sample() for _ in range(n)]
    next_state, reward, done, _ = envs.step(action)
    buffer.push(state, action, reward, next_state, done)
    state = next_state
    if done.any():
        break
        
states, actions, rewards, next_states, dones = buffer.sample(16)

print(len(buffer.memory))
print(states.shape)
print(actions.shape)
print(rewards.shape)
print(next_states.shape)
print(dones.shape)

100
(16, 8)
(16,)
(16,)
(16, 8)
(16,)


In [4]:
env = gym.make('BipedalWalker-v3')
env.seed(1234)
print('State shape: ', env.observation_space)
print('Number of actions: ', env.action_space)

buffer = ReplayBuffer(100)
state = env.reset()
for _ in range(200):
    action = env.action_space.sample()
    next_state, reward, done, _ = env.step(action)
    buffer.push(state, action, reward, next_state, done)
    state = next_state
    if done:
        break
        
states, actions, rewards, next_states, dones = buffer.sample(16)

print(len(buffer.memory))
print(states.shape)
print(actions.shape)
print(rewards.shape)
print(next_states.shape)
print(dones.shape)

State shape:  Box(24,)
Number of actions:  Box(4,)
84
(16, 24)
(16, 4)
(16,)
(16, 24)
(16,)


In [5]:
envs = parallelEnv('BipedalWalker-v3', n = 4)
n=len(envs.ps) # number of parallel instances

buffer = ReplayBuffer(100, parallel_envs = True)
state = envs.reset()
for _ in range(200):
    action = [envs.action_space.sample() for _ in range(n)]
    next_state, reward, done, _ = envs.step(action)
    buffer.push(state, action, reward, next_state, done)
    state = next_state
    if done.any():
        break
        
states, actions, rewards, next_states, dones = buffer.sample(16)

print(len(buffer.memory))
print(states.shape)
print(actions.shape)
print(rewards.shape)
print(next_states.shape)
print(dones.shape)

100
(16, 24)
(16, 4)
(16,)
(16, 24)
(16,)


In [3]:
env = gym.make('BipedalWalker-v3')
env.seed(1234)
print('State shape: ', env.observation_space)
print('Number of actions: ', env.action_space)

buffer = PrioritizedBuffer(100)
state = env.reset()
for _ in range(200):
    action = env.action_space.sample()
    next_state, reward, done, _ = env.step(action)
    buffer.push(1, state, action, reward, next_state, done)
    state = next_state
    if done:
        break

State shape:  Box(24,)
Number of actions:  Box(4,)


In [4]:
states, actions, rewards, next_states, dones = buffer.sample(16)

In [5]:
print(buffer.tree.total())
print(states.shape)
print(actions.shape)
print(rewards.shape)
print(next_states.shape)
print(dones.shape)

100.0
(16, 24)
(16, 4)
(16,)
(16, 24)
(16,)


In [7]:
buffer.tree.tree

array([100.,  64.,  36.,  32.,  32.,  20.,  16.,  16.,  16.,  16.,  16.,
        12.,   8.,   8.,   8.,   8.,   8.,   8.,   8.,   8.,   8.,   8.,
         8.,   8.,   4.,   4.,   4.,   4.,   4.,   4.,   4.,   4.,   4.,
         4.,   4.,   4.,   4.,   4.,   4.,   4.,   4.,   4.,   4.,   4.,
         4.,   4.,   4.,   4.,   4.,   2.,   2.,   2.,   2.,   2.,   2.,
         2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,
         2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,
         2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,
         2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,   2.,
         1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
         1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
         1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
         1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
         1.,   1.,   1.,   1.,   1.,   1.,   1.,   