In [64]:
import reverb
import time
import tensorflow as tf
import gym
import random
import os
import numpy as np
from multiprocessing  import Process, Queue
from queue            import Empty
from collections      import deque
from keras.models     import Sequential
from keras.layers     import Dense
from keras.optimizers import Adam

In [None]:
class Agent():
    def __init__(self, state_size, action_size):
        self.weight_backup      = "cartpole_weight.h5"
        self.state_size         = state_size
        self.action_size        = action_size
        self.memory             = deque(maxlen=2000)
        self.learning_rate      = 0.001
        self.gamma              = 0.95
        self.exploration_rate   = 1.0
        self.exploration_min    = 0.01
        self.exploration_decay  = 0.995
        self.brain              = self._build_model()

    def _build_model(self):
        model = Sequential()
        model.add(Dense(24, input_dim=self.state_size, activation='relu'))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))

        if os.path.isfile(self.weight_backup):
            model.load_weights(self.weight_backup)
            self.exploration_rate = self.exploration_min
        return model

    def save_model(self):
            self.brain.save(self.weight_backup)

    def act(self, state):
        if np.random.rand() <= self.exploration_rate:
            return random.randrange(self.action_size)
        act_values = self.brain.predict(state)
        return np.argmax(act_values[0])

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        globalMemory.put((state, action, reward, next_state, done))

    def replay(self, sample_batch_size):
        if globalMemory.qsize() < sample_batch_size:
            return
        batch_indexes = set(random.sample(range(globalMemory.qsize()), sample_batch_size))
        replayCopy = []
        while True:
            try:
                item = globalMemory.get(block=False)
            except Empty:
                break
            else:
                replayCopy.append(item)
        for item in replayCopy:
            globalMemory.put(item)
        sample_batch = [replayCopy[index] for index in batch_indexes]
        for state, action, reward, next_state, done in sample_batch:
            target = reward
            if not done:
              target = reward + self.gamma * np.amax(self.brain.predict(next_state)[0])
            target_f = self.brain.predict(state)
            target_f[0][action] = target
            self.brain.fit(state, target_f, epochs=1, verbose=0)
        if self.exploration_rate > self.exploration_min:
            self.exploration_rate *= self.exploration_decay


In [58]:
#fake environment specifiers 
OBSERVATION_SPEC = tf.TensorSpec([10, 10], tf.uint8)
ACTION_SPEC = tf.TensorSpec([2], tf.float32)
EPISODE_LENGTH = 5
NUM_EPISODES = 10

def agent_step(unused_timestep, pid) -> tf.Tensor:
  return tf.convert_to_tensor([1,9],dtype=float)

def environment_step(unused_action) -> tf.Tensor:
  return tf.cast(tf.random.uniform(OBSERVATION_SPEC.shape, maxval=256),
                 OBSERVATION_SPEC.dtype)

In [59]:
def start_server(): 
    return reverb.Server(
    tables=[
        reverb.Table(
            name='my_table',
            sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            # Sets Rate Limiter to a low number for the examples.
            # Read the Rate Limiters section for usage info.
            rate_limiter=reverb.rate_limiters.MinSize(2),
            # The signature is optional but it is good practice to set it as it
            # enables data validation and easier dataset construction. Note that
            # the number of observations is larger than the number of actions.
            # The extra observation is the terminal state where no action is
            # taken.
            signature={
                'actions': tf.TensorSpec(
                    [EPISODE_LENGTH, *ACTION_SPEC.shape],
                    ACTION_SPEC.dtype),
                'observations': tf.TensorSpec(
                    [EPISODE_LENGTH + 1, *OBSERVATION_SPEC.shape],
                    OBSERVATION_SPEC.dtype),
            },
        ),
    ],
    # Sets the port to None to make the server pick one automatically.
    port=8000)

In [60]:
def display_samples(num_samples):
    # Each sample is an entire episode.
    # Adjusts the expected shapes to account for the whole episode length.
    dataset = reverb.TrajectoryDataset.from_table_signature(
      server_address=f'localhost:8000',
      table='my_table',
      max_in_flight_samples_per_worker=10,
      rate_limiter_timeout_ms=10)

    # Batches episodes together.
    # Each item is an episode of the format (observations, actions) as above.
    dataset = dataset.batch(5)

    # Sample has type reverb.ReplaySample.
    for sample in dataset.take(num_samples):
      #print(sample.data['observations'])
      print(sample.data['actions'])

In [61]:
MAX_TIME = 10
SLEEP_WAIT = 1

def server_thread():
    pid = os.getpid()
    print('server process:{}\n'.format(pid))
    start_time = time.perf_counter()
    curr_time = time.perf_counter()
    server = start_server()
    while((curr_time - start_time) < MAX_TIME):
        curr_time = time.perf_counter()
        time.sleep(SLEEP_WAIT)
    server.stop()
    print('server stopped')
    display_samples(1)

In [62]:
def actor_thread(num_eps=NUM_EPISODES):
    pid = os.getpid()
    print('actor process:{}\n'.format(pid))
    
    client = reverb.Client(f'localhost:8000')
    with client.trajectory_writer(num_keep_alive_refs=EPISODE_LENGTH+1) as writer:
      for _ in range(NUM_EPISODES):
        timestep = environment_step(None)

        for _ in range(EPISODE_LENGTH):
          action = agent_step(timestep)
          writer.append({'action': action,
                         'observation': timestep})

          timestep = environment_step(action)

        writer.append({'observation': timestep})

        writer.create_item(
            table='my_table',
            priority=1.5,
            trajectory={
                'actions': writer.history['action'][:-1],
                'observations': writer.history['observation'][:],
            })
        writer.end_episode(timeout_ms=1000)
    print('actor {} done\n'.format(pid))

In [63]:
num_actors = 5

if __name__ == '__main__':
    replays = Process(target=server_thread)
    replays.start()
    
    actors = []
    for i in range(0,num_actors):
        actors.append(Process(target=actor_thread,args=(NUM_EPISODES,)))
        actors[i].start()
        
    for actor in actors:
        actor.join()
    replays.join()

actor process:6783

actor process:6786

actor process:6797
server process:6781


actor process:6825

actor process:6849

actor 6797 done
actor 6783 done
actor 6825 done



actor 6786 done

actor 6849 done

server stopped
tf.Tensor(
[[[1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]]

 [[1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]]

 [[1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]]

 [[1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]
  [1. 9.]]

 [[6. 9.]
  [6. 9.]
  [6. 9.]
  [6. 9.]
  [6. 9.]]], shape=(5, 5, 2), dtype=float32)
