In [1]:
import collections
import gym
import numpy as np
import tensorflow as tf
import tqdm

from matplotlib import pyplot as plt
from tensorflow.keras import layers
from typing import Any, List, Sequence, Tuple

env = gym.make('CartPole-v0')

seed = 42
env.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)

eps = np.finfo(np.float32).eps.item()

In [14]:
class ActorCritic(tf.keras.Model):
    def __init__(self,
                 num_actions: int,
                 num_hidden_units: int):
        super().__init__()
        
        self.common = layers.Dense(num_hidden_units, activation='relu')
        self.actor = layers.Dense(num_actions)
        self.critic = layers.Dense(1)
        
    def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
        x = self.common(inputs)
        return self.actor(x), self.critic(x)

In [74]:
env.step(1)

(array([-1.31111619,  0.15901837,  0.04051762, -0.37264538]),
 1.0,
 True,
 {'TimeLimit.truncated': True})

In [15]:
num_actions = env.action_space.n  #2
num_hidden_units=128

model = ActorCritic(num_actions, num_hidden_units)

In [44]:
def env_step(action: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    # state, reward, done flag
    
    state, reward, done, _ = env.step(action)
    return (state.astype(np.float32),
            np.array(reward, np.int32),
            np.array(done, np.int32))

def tf_env_step(action: tf.Tensor) -> List[tf.Tensor]:
    return tf.numpy_function(env_step, [action],
                            [tf.float32, tf.int32, tf.int32])

In [82]:
tf.random.categorical([[0.5, 0.5]], 1)

<tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]], dtype=int64)>

In [45]:
def run_episode(initial_state: tf.Tensor,
                model: tf.keras.Model,
                max_steps: int) -> List[tf.Tensor]:
    action_probs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
    values = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
    rewards = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    
    initial_state_shape = initial_state.shape
    state = initial_state
    
    for t in tf.range(max_steps):
        state = tf.expand_dims(state, 0)
        
        action_logits_t, value = model(state)
        
        action = tf.random.categorical(action_logits_t, 1)[0,0]
        action_probs_t = tf.nn.softmax(action_logits_t)
        
        values = values.write(t, tf.squeeze(value))
        
        action_probs = action_probs.write(t, action_probs_t[0, action])
        
        state, reward, done = tf_env_step(action)
        state.set_shape(initial_state_shape)
        
        rewards = rewards.write(t, reward)
        
        if tf.cast(done, tf.bool):
            break
            
    action_probs = action_probs.stack()
    values = values.stack()
    rewards = rewards.stack()
    
    return action_probs, values, rewards

In [46]:
def get_expected_return(rewards: tf.Tensor,
                        gamma: float,
                        standardize: bool = True) -> tf.Tensor:
    n = tf.shape(rewards)[0]
    returns = tf.TensorArray(dtype=tf.float32, size=n)
    
    rewards = tf.cast(rewards[::-1], dtype=tf.float32)
    discounted_sum = tf.constant(0.0)
    discounted_sum_shape = discounted_sum.shape
    for i in tf.range(n):
        reward = rewards[i]
        discounted_sum = reward + gamma * discounted_sum
        discounted_sum.set_shape(discounted_sum_shape)
        returns = returns.write(i, discounted_sum)
    returns = returns.stack()[::-1]
    
    if standardize:
        returns = ((returns - tf.math.reduce_mean(returns)) /
                   (tf.math.reduce_std(returns) + eps))
        
    return returns

In [52]:
huber_loss = tf.keras.losses.Huber(reduction = tf.keras.losses.Reduction.SUM)

def compute_loss(action_probs: tf.Tensor,
                 values: tf.Tensor,
                 returns: tf.Tensor) -> tf.Tensor:
    
    advantage = returns - values
    
    action_log_probs = tf.math.log(action_probs)
    actor_loss = -tf.math.reduce_sum(action_log_probs * advantage)
    
    critic_loss = huber_loss(values, returns)
    
    return actor_loss + critic_loss

In [55]:
optimizer = tf.keras.optimizers.Adam(learning_rate = 0.01)

@tf.function
def train_step(initial_state: tf.Tensor,
               model: tf.keras.Model,
               optimizer: tf.keras.optimizers.Optimizer,
               gamma: float,
               max_steps_per_episode: int) -> tf.Tensor:
    
    with tf.GradientTape() as tape:
        action_probs, values, rewards = run_episode(
            initial_state, model, max_steps_per_episode)
        
        returns = get_expected_return(rewards, gamma)
        
        action_probs, values, returns = [
            tf.expand_dims(x, 1) for x in [action_probs, values, returns]
        ]
        
        loss = compute_loss(action_probs, values, returns)
        
    grads = tape.gradient(loss, model.trainable_variables)
    
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    episode_reward = tf.math.reduce_sum(rewards)
    return episode_reward

In [57]:
%%time

max_episodes = 10000
max_steps_per_episode = 1000

reward_threshold = 195
running_reward = 0

gamma = 0.99

with tqdm.trange(max_episodes) as t:
    for i in t:
        initial_state = tf.constant(env.reset(), dtype=tf.float32)
        episode_reward = int(train_step(
            initial_state, model, optimizer, gamma, max_steps_per_episode))
        
        running_reward = episode_reward*0.01 + running_reward * .99
        
        t.set_description(f'Episode {i}')
        t.set_postfix(
            episode_reward=episode_reward, running_reward = running_reward)
        
        if i % 10 == 0:
            pass#print(f'Episode {i}: average reward: {avg_reward}')
        
        if running_reward > reward_threshold:
            break
            
print(f'\nSolved at episode {i}: average reward: {running_reward:.2f}!')

Episode 625:   6%|█▍                     | 625/10000 [09:30<2:22:36,  1.10it/s, episode_reward=200, running_reward=195]


Solved at episode 625: average reward: 195.00!
Wall time: 9min 30s



