In [1]:
#XVFB will be launched if you run on a server
import os
if type(os.environ.get("DISPLAY")) is not str or len(os.environ.get("DISPLAY"))==0:
    !bash ../xvfb start
    %env DISPLAY=:1
        
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
%matplotlib inline
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import sys
try:
    sys.path.remove('/home/syuntoku14/catkin_ws/devel/lib/python2.7/dist-packages')
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
except:
    pass

import gym
from gym.core import ObservationWrapper
from gym.spaces import Box

# from scipy.misc import imresize
import cv2

from framebuffer import FrameBuffer
from replay_buffer import ReplayBuffer

In [2]:
class ReplayBuffer(ReplayBuffer):
    def concat(self, exp_replay):
        self._storage += exp_replay._storage
        excess = max(len(self._storage) - self._maxsize, 0)
        self._storage = self._storage[excess:]

class PreprocessAtari(ObservationWrapper):
    def __init__(self, env):
        """A gym wrapper that crops, scales image into the desired shapes and optionally grayscales it."""
        ObservationWrapper.__init__(self,env)
        
        self.img_size = (84, 84)
        self.observation_space = Box(0.0, 1.0, (1, self.img_size[0], self.img_size[1]))

    def _observation(self, img):
        img = img[34:-16, 8:-8, :]
        img = cv2.resize(img, self.img_size)
        img = img.mean(-1, keepdims=True)  # grayscale
        img = img.astype('float32') / 255.
               
        return img.transpose([2, 0, 1])
    
def make_env():
    env = gym.make("BreakoutDeterministic-v4")
    env = PreprocessAtari(env)
    env = FrameBuffer(env, n_frames=4, dim_order='pytorch')
    return env

In [3]:
env = make_env()
env.reset()
n_actions = env.action_space.n
state_dim = env.observation_space.shape

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: <class '__main__.PreprocessAtari'> doesn't implement 'observation' method. Maybe it implements deprecated '_observation' method.[0m


In [4]:
class DQN(nn.Module):
    def __init__(self, n_actions):
        super(DQN, self).__init__()
        # input obs, output n_actions
        # The network is based on "Mnih, 2015"
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.l1 = nn.Linear(64*7*7, 512)
        self.l2 = nn.Linear(512, n_actions)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size()[0], -1)
        x = self.l1(x)
        x = self.l2(x)
        return x

      
class DQNAgent:
    def __init__(self, state_shape, n_actions, epsilon=0, reuse=False):
        """A simple DQN agent"""
        
        self.dqn = DQN(n_actions)
        self.epsilon = epsilon

    def get_qvalues(self, states):
        """takes agent's observation, returns qvalues. """
        qvalues = self.dqn(states)
        return qvalues
    
    def get_qvalues_for_actions(self, qvalues, actions):
        return qvalues.gather(1, \
                actions.unsqueeze(0).transpose(0, 1)).squeeze()
   
    def sample_actions(self, qvalues):
        """pick actions given qvalues. Uses epsilon-greedy exploration strategy. """
        epsilon = self.epsilon
        batch_size, n_actions = qvalues.shape
        random_actions = torch.tensor(np.random.choice(n_actions, size=batch_size))
        best_actions = qvalues.argmax(1)
        should_explore = torch.tensor(np.random.choice([0, 1], batch_size, p = [1-epsilon, epsilon])).byte()
        return torch.where(should_explore, random_actions, best_actions)

      
def play_and_record(agent, env, exp_replay, n_steps=1):
    """
    Play the game for exactly n steps, record every (s,a,r,s', done) to replay buffer.
    :returns: return sum of rewards over time
    """
    # Make sure that the state is only one batch state, 4x84x84
    # State at the beginning of rollout
    s = env.framebuffer
    R = 0.0
    
    # Play the game for n_steps as per instructions above
    for t in range(n_steps):
        qvalues = agent.get_qvalues(torch.tensor(s).unsqueeze(0))
        action = agent.sample_actions(qvalues).item()
        next_s, r, done, _ = env.step(action)
        exp_replay.add(s, action, r, next_s, done=done)
        if done == True:
            s = env.reset()
        else:
            s = next_s
    return R 


def optimize(current_action_qvalues, optimizer, target_dqn, \
             reward_batch, next_obs_batch, is_done_batch):
    next_qvalues_target = target_dqn.get_qvalues(next_obs_batch)

    # compute state values by taking max over next_qvalues_target for all actions
    next_state_values_target = next_qvalues_target.max(1)[0]
    next_state_values_target = torch.where(torch.tensor(is_done_batch).byte(), \
                                 torch.tensor(reward_batch), \
                                 torch.tensor(next_state_values_target))
    
    # compute Q_reference(s,a) as per formula above.
    reference_qvalues = reward_batch + gamma * next_state_values_target

    # Define loss function for sgd.
    td_loss = (current_action_qvalues - reference_qvalues) ** 2
    td_loss = torch.mean(td_loss)

    optimizer.zero_grad()
    td_loss.backward()
    optimizer.step()
    
    return td_loss.item()
  
      
def evaluate(env, agent, n_games=1, greedy=False, t_max=10000):
    """ Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. """
    rewards = []
    for _ in range(n_games):
        s = env.reset()
        reward = 0
        for _ in range(t_max):
            s = torch.tensor(s).unsqueeze(0)
            qvalues = agent.get_qvalues(s)
            action = qvalues.argmax(dim=-1)[0] if greedy else agent.sample_actions(qvalues)[0]
            s, r, done, _ = env.step(action)
            reward += r
            if done: break
                
        rewards.append(reward)
    return np.mean(rewards)

    
def convert_to_tensor(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch):
    obs_batch = torch.tensor(obs_batch)
    act_batch = torch.tensor(act_batch)
    reward_batch = torch.tensor(reward_batch).float()
    next_obs_batch = torch.tensor(next_obs_batch)
    is_done_batch = is_done_batch.astype(np.int)
    return obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch
  

def save_data(folder_path, agent, mean_reward_history, td_loss_history):
    torch.save(policy_agent.dqn.state_dict(), folder_path + 'atari_dqn_state_dict.pt')
    with open(folder_path + 'mean_reward_history.l', 'wb') as f:
        pickle.dump(mean_reward_history, f)
    with open(folder_path + 'td_loss_history.l', 'wb') as f:
        pickle.dump(td_loss_history, f)
        

def load_data(folder_path):
    state_dict = None
    mean_reward_history = []
    td_loss_history = []
    
    state_dict = torch.load(folder_path + 'atari_dqn_state_dict.pt')
    with open(folder_path + 'mean_reward_history.l', 'rb') as f:
        mean_reward_history = pickle.load(f)
    with open(folder_path + 'td_loss_history.l', 'rb') as f:
        td_loss_history = pickle.load(f)
        
    return state_dict, mean_reward_history, td_loss_history

In [5]:
from tqdm import trange
from IPython.display import clear_output
import matplotlib.pyplot as plt
from pandas import DataFrame
moving_average = lambda x, span, **kw: DataFrame({'x':np.asarray(x)}).x.ewm(span=span, **kw).mean().values
%matplotlib inline

mean_rw_history = []
td_loss_history = []

gamma = 0.99
policy_agent = DQNAgent(state_dim, n_actions, epsilon=0.5)
target_agent = DQNAgent(state_dim, n_actions)

# Load the data

In [6]:
rl_path = './data/'
state_dict, mean_rw_history, td_loss_history = load_data(rl_path)
policy_agent.dqn.load_state_dict(state_dict)
policy_agent.dqn.eval()
target_agent.dqn.load_state_dict(state_dict)
target_agent.dqn.eval()

DQN(
  (conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (l1): Linear(in_features=3136, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=4, bias=True)
)

# Parallelize play and record

# 全体的な流れ

それぞれのPCでmultiprocessingを使って高速化させたplay_and_recordのデータを、中央サーバのredisに集める。それぞれは何らかの方法で同期通信する？

## TO DO
1. get state_dict from other process
2. pass the list to the other process

It seems that multi-thread method is not enough(accerelated 10%)
Multi-machines method is needed.

In [10]:
from threading import Thread
import time

def sampling_on_thread(policy_agent, env, exp_replay, n_steps):
    play_and_record(policy_agent, env, exp_replay, n_steps=n_steps)

In [49]:
%%time
# Using single thread
exp_replay = ReplayBuffer(10**5)
play_and_record(policy_agent, env, exp_replay, n_steps=20000)

start
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor

tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9

tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9

2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor

2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor

tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9

3
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
1
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
0
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
2
0.0
tensor([[1.9258, 1.9411, 1.9944, 1.9544]], grad_fn=<ThAddmmBackward>)
3
0.0
tensor

KeyboardInterrupt: 

In [9]:
## Using Multi threads
#exp_replay = ReplayBuffer(10**5)
#thread_time_list = []
#for i in range(10):
#    threads_number = i
#    start = time.time()
#    exp_replays = [ReplayBuffer(10000) for _ in range(threads_number)]
#    ps = [Thread(target=sampling_on_thread, 
#                args=(policy_agent, env, exp_replays[i], int(20000/threads_number))) \
#                for i in range(threads_number)]
#
#    for i in range(threads_number):
#        ps[i].start()
#
#    for i in range(threads_number):
#        ps[i].join()
#
#    for i in range(threads_number):
#        exp_replay.concat(exp_replays[i])
#    thread_time_list.append(time.time()-start)

# Multi Processing with Pytorch

## TO DO
1. get state_dict from other process
2. pass the list to the other process

Have to share, state_dict, (exp_replay)

In [104]:
import torch.multiprocessing as mp

def sampling_on_process(dqn, env):
    s = env.framebuffer
    print(s.shape)
    print(dqn(torch.tensor(s).unsqueeze(0)))
    
exp_replay = ReplayBuffer(1000)
dqn = DQN(n_actions)
dqn.share_memory()

DQN(
  (conv1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (conv2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (l1): Linear(in_features=3136, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=4, bias=True)
)

In [105]:
p = mp.Process(target=sampling_on_process, args=(dqn, env))

In [106]:
s = env.framebuffer
dqn(torch.tensor(s).unsqueeze(0))

torch.Size([1, 4, 84, 84])
torch.Size([1, 32, 20, 20])


tensor([[ 0.0180,  0.0263,  0.0250, -0.0267]], grad_fn=<ThAddmmBackward>)

In [107]:
p.start()
p.join()

(4, 84, 84)
torch.Size([1, 4, 84, 84])


KeyboardInterrupt: 

# Use Redis

## To Do

- [x] SubBufferをリストで保存し、mainでconcatする
- [ ] ReplayBufferそのものを保存する(全部pickle.loadしないといけないのでアレかも)
- [ ] python同士で同期通信する

* cpickleで保存する
* jsonで保存する
* ndarrayで保存する

In [7]:
import redis
pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
r = redis.StrictRedis(connection_pool=pool)

### Method1

In [8]:
%%time
main_replay = ReplayBuffer(2000)
sub_replay1 = ReplayBuffer(1000)
sub_replay2 = ReplayBuffer(1000)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 22.2 µs


In [9]:
play_and_record(policy_agent, env, sub_replay1, sub_replay1._maxsize)
play_and_record(policy_agent, env, sub_replay2, sub_replay2._maxsize)

0.0

In [10]:
print(len(main_replay))
print(len(sub_replay1))
print(len(sub_replay2))

0
1000
1000


In [11]:
# With pickle
import pickle

In [12]:
r.delete('sub_replays')

ConnectionError: Error 111 connecting to localhost:6379. Connection refused.

In [90]:
r.rpush('sub_replays', pickle.dumps(sub_replay1))
r.rpush('sub_replays', pickle.dumps(sub_replay2))

2

In [91]:
sub_replays_list = [pickle.loads(pickled_replay) 
                    for pickled_replay in r.lrange('sub_replays', 0, -1)]

In [92]:
sub_replays_list

[<__main__.ReplayBuffer at 0x7f307bc77438>,
 <__main__.ReplayBuffer at 0x7f307bc79fd0>]

In [93]:
for sub_replay in sub_replays_list:
    main_replay.concat(sub_replay)

print(len(main_replay))

2000
