In [1]:
import gym
from multiprocessing import Process, Queue

In [2]:
class RL_Process(Process):
    def __init__(self, *args, env: gym.Env = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.recv_messages = Queue()
        self.running = False
        self.message_handlers = {'quit': self.quit}
        self.env = env
        
    def create_networks(self):
        import tensorflow as tf
        self.tf = tf
        self.actor = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filters = 16, kernel_size = 7, strides = 4, input_shape = self.env.observation_space.shape),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(128, activation = 'relu'),
            tf.keras.layers.Dense(env.action_space.n)
        ])
        self.critic = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filters = 16, kernel_size = 7, strides = 4, input_shape = self.env.observation_space.shape),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(128, activation = 'relu'),
            tf.keras.layers.Dense(1)
        ])
        self.optimizer = tf.keras.optimizers.Adam()
        
    def quit(self, message):
        print("Received quit message")
        self.running = False
        
    def run(self):
        self.running = True
        self.create_networks()
        while self.running:
            if not self.recv_messages.empty():
                message = self.recv_messages.get()
                message_id = message['id']
                if message_id in self.message_handlers:
                    self.message_handlers[message_id](message)
                else:
                    raise ValueError(f"Invalid message received: {message}")
        

In [3]:
test_process = RL_Process()
test_process.start()

Process RL_Process-1:
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "<ipython-input-2-de9aec5873dc>", line 31, in run
    self.create_networks()
  File "<ipython-input-2-de9aec5873dc>", line 13, in create_networks
    tf.keras.layers.Conv2D(filters = 16, kernel_size = 7, strides = 4, input_shape = self.env.observation_space.shape),
AttributeError: 'NoneType' object has no attribute 'observation_space'


In [4]:
test_process.recv_messages.put({'id': 'quit'})

In [8]:
def discount_rewards(rewards, gamma, standardized: bool = False):
    discounted_rewards = np.zeros_like(rewards)
    R = 0
    for t in reversed(range(len(rewards))):                   
        R = R * gamma + rewards[t]
        discounted_rewards[t] = R
    if standardized:
        mean = np.mean(discounted_rewards)
        discounted_rewards -= mean
        standard_deviation = np.std(discounted_rewards)
        discounted_rewards/=(standard_deviation + np.finfo(np.float32).eps) 
    return discounted_rewards

class Worker_Process(RL_Process):
    def __init__(self, *args, batch_size = 30, nn_process_queue = None, worker_id = 0, max_timesteps = 100, gamma = 0.95, **kwargs):
        super().__init__(*args, **kwargs)
        self.gamma = gamma
        self.batch_size = batch_size
        self.max_timesteps = max_timesteps
        self.nn_process_queue = nn_process_queue
        self.worker_id = worker_id
        self.message_handlers.update(train = self.train_batch)
        self.message_handlers['train'] = self.train_batch
    
    def choose_action(self, obs):
        obs = np.expand_dims(obs, axis = 0)
        logits = self.actor.predict(obs)
        probablity_weights = tf.nn.softmax(logits = logits).numpy()[0]
        action = np.random.choice(env.action_space.n, 1, p = probablity_weights)[0]
        return action

    def estimate_value(self, obs):
        obs = np.expand_dims(obs, axis = 0)
        value = self.critic.predict(obs)
        return value
   
    def critic_loss(self, observations, rewards):
        huber_loss = self.tf.keras.losses.Huber(reduction=self.tf.keras.losses.Reduction.SUM)
        values = self.critic(observations)
        loss = huber_loss(values, rewards)
        return loss

    def actor_loss(self, actions, observations, values, rewards):
        advantage = rewards - values
        logits = self.actor(observations)
        negative_log_prob = self.tf.nn.sparse_softmax_cross_entropy_with_logits(logits = logits, labels = actions)
        loss = self.tf.reduce_mean(negative_log_prob*advantage)
        return loss
    
    def run_episode(self):
        obs = self.env.reset()
        observations = []
        values = []
        rewards = []
        actions = []
        for t in range(self.max_timesteps):
            action = self.choose_action(obs)
            value = self.estimate_value(obs)
            observations.append(obs)
            values.append(value)
            actions.append(action)
            obs, reward, done, info = self.env.step(action)
            rewards.append(reward)
            if done:
                break
        return observations, actions, values, rewards
    
    def run_batch(self, episodes):
        batch_observations = []
        batch_values = []
        batch_rewards = []
        batch_actions = []
        with tqdm(total = episodes, desc = 'Batch Progress') as progress_bar:
            for episode in range(episodes):
                observations, actions, values, rewards = run_episode()
                progress_bar.set_postfix_str(f'Episode Reward: {sum(rewards)}')
                progress_bar.update()
                batch_observations.extend(observations)
                batch_actions.extend(actions)
                batch_values.extend(values)
                rewards = discount_rewards(rewards, gamma = self.gamma, standardized = False)
                batch_rewards.extend(rewards)
        return batch_observations, batch_actions, batch_values, batch_rewards
    
    def train_step(self, observations, actions, values, rewards):
        values = np.array(values)
        rewards = np.array(rewards)
        observations = np.array(observations)
        # Step 1. Train actor using critic
        with tf.GradientTape() as tape:
            loss = self.actor_loss(actions, observations, values, rewards)
            actor_gradients = tape.gradient(loss, self.actor.trainable_variables)
            self.optimizer.apply_gradients(zip(actor_gradients, self.actor.trainable_variables))
 
        # Step 2. Train critic
        with tf.GradientTape() as tape:
            loss = self.critic_loss(observations, rewards)
            critic_gradients = tape.gradient(loss, self.critic.trainable_variables)
            self.optimizer.apply_gradients(zip(critic_gradients, self.critic.trainable_variables))
        
        return actor_gradients, critic_gradients
    
    def train_batch(self, message):
        batches = message['batches']
        mean_values = []
        mean_rewards = []
        episodes = []
        for batch in range(batches):
            observations, actions, values, rewards = self.run_batch(message['episodes'])
            episodes.append(batch * self.batch_size)
            mean_values.append(np.mean(values))
            mean_rewards.append(np.mean(rewards))
            actor_gradients, critic_gradients = self.train_step(observations, actions, values, rewards)
            new_message = dict(id = 'gradients', actor_gradients = actor_gradients, critic_gradients = critic_gradients)
            self.nn_process_queue.put(new_message)
        self.nn_process_queue.put(dict(id = 'complete', worker_id = self.worker_id))
        

In [6]:
class NN_Process(RL_Process):
    def __init__(self, *args, max_workers = 4, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_workers = max_workers
        self.workers = [Worker_Process(nn_process_queue=self.recv_messages, worker_id=i) for i in range(max_workers)]
        self.message_handlers['gradients'] = self.apply_gradients
        self.message_handlers['complete'] = self.worker_complete
    def apply_gradients(self, message):
        self.optimizer.apply_gradients(zip(message['actor_gradients'], self.actor.trainable_variables))
        self.optimizer.apply_gradients(zip(message['critic_gradients'], self.critic.trainable_variables))
        
    def worker_complete(self, message):
        worker_id = message['worker_id']
        if worker_id < self.max_workers:
            self.workers[worker_id].recv_messages.put(dict(id = 'quit'))
        else:
            raise ValueError(f'Worker ID {worker_id} is invalid')
    def start_training(self, total_episodes = 100, batch_size = 50):
        for worker in self.workers:
            if not worker.is_alive():
                worker.start()
        self.total_episodes = total_episodes
        self.batch_size = batch_size
        self.start()