## Imports

In [30]:
import os
import tensorflow as tf
from tensorflow import keras
import numpy as np
from collections import namedtuple, deque
import time
import gym
from gym.wrappers import RescaleAction
import random

## Actor and critic

In [22]:
def make_MLP(num_in, num_out, final_activation, hidden_dimensions=(256, 256)):

    tensor_dimensions = [num_in]
    if hidden_dimensions is not None:
        tensor_dimensions.extend(hidden_dimensions)
    if num_out is not None:
        tensor_dimensions.append(num_out)

    num_layers = len(tensor_dimensions)  # now including the input layer
    list_of_layers = []

    # tf uses lazy instantiation, so input dimension is inferred during forward pass

    for i, output_dimension in enumerate(tensor_dimensions):
        if i == 0:
            list_of_layers.append(tf.keras.Input(output_dimension))
        elif i == num_layers - 1:
            if final_activation is None:
                list_of_layers.append(tf.keras.layers.Dense(output_dimension))
            else:
                list_of_layers.append(tf.keras.layers.Dense(output_dimension, activation=final_activation))
        else:
            list_of_layers.append(tf.keras.layers.Dense(output_dimension, activation='relu'))
    net = keras.Sequential(list_of_layers)

    return net  # actual_num_out is not required


class MLPTanhActor(keras.Model):
    """Output actions from [-1, 1]."""
    def __init__(self, input_dim, action_dim):
        super().__init__()
        self.net = make_MLP(num_in=input_dim, num_out=action_dim, final_activation='tanh')
        self.build(input_shape=(None, input_dim))  # create the parameters within init based on call; crucial

    def call(self, states: tf.Tensor):
        return self.net(states)


class MLPCritic(keras.Model):

    def __init__(self, input_dim, action_dim):
        super().__init__()
        self.net = make_MLP(num_in=input_dim + action_dim, num_out=1, final_activation=None)
        self.build(input_shape=(None, input_dim + action_dim))

    def call(self, states_and_actions: tuple):
        return self.net(tf.concat(states_and_actions, axis=-1))

## Algorithm

In [23]:
def polyak_update(targ_net: keras.Model, pred_net: keras.Model, polyak: float) -> None:
    for i in range(len(pred_net.weights)):
        targ_net.weights[i].assign(tf.scalar_mul(1 - polyak, pred_net.weights[i]) + tf.scalar_mul(polyak, targ_net.weights[i]))

def save_net(net: keras.Model, save_dir: str, save_name: str) -> None:
    net.save_weights(os.path.join(save_dir, save_name))

def load_net(net: keras.Model, save_dir: str, save_name: str) -> None:
    net.load_weights(os.path.join(save_dir, save_name))

In [36]:
ERROR_DATA = ["shape mismatch"]

class TD3():

    def __init__(
        self,
        input_dim,
        action_dim,
        gamma=0.99,
        lr=3e-4,
        lr_schedule=None,
        polyak=0.995,
        action_noise=0.1,  # standard deviation of action noise
        target_noise=0.2,  # standard deviation of target smoothing noise
        noise_clip=0.5,  # max abs value of target smoothing noise
        policy_delay=2
    ):

        # hyper-parameters

        self.input_dim = input_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.lr = lr
        self.lr_schedule = lr_schedule
        self.polyak = polyak

        self.action_noise = action_noise
        self.target_noise = target_noise
        self.noise_clip = noise_clip

        self.policy_delay = policy_delay

        # trackers

        self.num_Q_updates = tf.Variable(0)  # for delaying updates
        self.mean_Q1_value = tf.Variable(0, dtype=tf.float32)  # for logging; the actor does not get updated every iteration,
        # so this statistic is not available every iteration

        # networks
        # (keras.models.clone_model cannot be used for subclassed models)
        # (weirdly, the weights must be converted to numpy for set_weights to work)

        self.actor = MLPTanhActor(input_dim, action_dim)
        self.actor_targ = MLPTanhActor(input_dim, action_dim)
        self.actor_targ.set_weights([w.numpy() for w in self.actor.weights])

        self.Q1 = MLPCritic(input_dim, action_dim)
        self.Q1_targ = MLPCritic(input_dim, action_dim)
        self.Q1_targ.set_weights([w.numpy() for w in self.Q1.weights])

        self.Q2 = MLPCritic(input_dim, action_dim)
        self.Q2_targ = MLPCritic(input_dim, action_dim)
        self.Q2_targ.set_weights([w.numpy() for w in self.Q2.weights])

        # optimizers

        self.actor_optimizer = keras.optimizers.Adam(learning_rate=lr)
        self.Q1_optimizer = keras.optimizers.Adam(learning_rate=lr)
        self.Q2_optimizer = keras.optimizers.Adam(learning_rate=lr)

    @tf.function
    def act(self, state: np.array, deterministic: bool) -> np.array:
        state_with_batch_dim = tf.reshape(state, (1, -1))
        greedy_action = tf.reshape(self.actor(state_with_batch_dim), (-1, ))
        if deterministic:
            return greedy_action
        else:
            return tf.clip_by_value(greedy_action + self.action_noise * np.random.randn(self.action_dim), -1.0, 1.0)

    @tf.function
    def update_networks(self, b, debug=False):

        if debug:
            bs = b.ns.shape[0]  # for shape checking

        # compute targets

        na = self.actor_targ(b.ns)
        noise = tf.clip_by_value(
            tf.random.normal(na.shape) * self.target_noise, -self.noise_clip, self.noise_clip
        )
        smoothed_na = tf.clip_by_value(na + noise, -1, 1)

        n_min_Q_targ = tf.math.minimum(self.Q1_targ((b.ns, smoothed_na)), self.Q2_targ((b.ns, smoothed_na)))

        targets = b.r + self.gamma * (1 - b.d) * n_min_Q_targ

        if debug:
            tf.Assert(na.shape == (bs, self.action_dim), ERROR_DATA)
            tf.Assert(n_min_Q_targ.shape == (bs, 1), ERROR_DATA)
            tf.Assert(targets.shape == (bs, 1), ERROR_DATA)

        with tf.GradientTape(persistent=True) as tape:

            # compute predictions

            Q1_predictions = self.Q1((b.s, b.a))
            Q2_predictions = self.Q2((b.s, b.a))

            # compute td error

            Q1_loss = tf.reduce_mean((Q1_predictions - targets) ** 2)
            Q2_loss = tf.reduce_mean((Q2_predictions - targets) ** 2)

        if debug:
            tf.Assert(Q1_loss.shape == (), ERROR_DATA)
            tf.Assert(Q2_loss.shape == (), ERROR_DATA)

        # reduce td error

        Q1_gradients = tape.gradient(Q1_loss, self.Q1.trainable_weights)
        self.Q1_optimizer.apply_gradients(zip(Q1_gradients, self.Q1.trainable_weights))

        Q2_gradients = tape.gradient(Q2_loss, self.Q2.trainable_weights)
        self.Q2_optimizer.apply_gradients(zip(Q2_gradients, self.Q2.trainable_weights))

        self.num_Q_updates.assign_add(1)

        if self.num_Q_updates % self.policy_delay == 0:  # delayed policy update

            # compute policy loss

            with tf.GradientTape() as tape:

                a = self.actor(b.s)
                Q1_values = self.Q1((b.s, a))
                policy_loss = - tf.reduce_mean(Q1_values)

            self.mean_Q1_value.assign(-policy_loss)  # logging purpose only
            if debug:
                tf.Assert(a.shape == (bs, self.action_dim), ERROR_DATA)
                tf.Assert(Q1_values.shape == (bs, 1), ERROR_DATA)
                tf.Assert(policy_loss.shape == (), ERROR_DATA)

            # reduce policy loss

            policy_gradients = tape.gradient(policy_loss, self.actor.trainable_weights)
            self.actor_optimizer.apply_gradients(zip(policy_gradients, self.actor.trainable_weights))

            # update target networks

            polyak_update(targ_net=self.actor_targ, pred_net=self.actor, polyak=self.polyak)
            polyak_update(targ_net=self.Q1_targ, pred_net=self.Q1, polyak=self.polyak)
            polyak_update(targ_net=self.Q2_targ, pred_net=self.Q2, polyak=self.polyak)

        return {
            # for learning the q functions
            '(qfunc) Q1 pred': tf.reduce_mean(Q1_predictions),
            '(qfunc) Q2 pred': tf.reduce_mean(Q2_predictions),
            '(qfunc) Q1 loss': Q1_loss,
            '(qfunc) Q2 loss': Q2_loss,
            # for learning the actor
            '(actor) Q1 value': self.mean_Q1_value
        }

    def save_actor(self, save_dir: str) -> None:
        os.makedirs(save_dir, exist_ok=True)
        save_net(net=self.actor, save_dir=save_dir, save_name="actor.h5")

    def load_actor(self, save_dir: str) -> None:
        load_net(net=self.actor, save_dir=save_dir, save_name="actor.h5")

## Replay buffer

In [37]:
Batch = namedtuple('Batch', 's a r ns d')
Transition = namedtuple('Transition', 's a r ns d')

In [38]:
class ReplayBuffer:
    """Just a standard FIFO replay buffer."""

    def __init__(self, capacity=int(1e6), batch_size=100):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)
        self.batch_size = batch_size

    def push(self, s, a, r, ns, d) -> None:
        self.memory.appendleft(Transition(s, a, r, ns, d))

    def is_ready(self):
        return len(self.memory) >= self.batch_size

    def sample(self) -> Batch:
        transitions = random.choices(self.memory, k=self.batch_size)  # sampling WITH replacement
        batch_raw = Batch(*zip(*transitions))
        # actually, converting to tf tensor is not necessary here; could have just used numpy reshape and astype
        s = tf.reshape(tf.convert_to_tensor(batch_raw.s, dtype=tf.float32), (self.batch_size, -1))
        a = tf.reshape(tf.convert_to_tensor(batch_raw.a, dtype=tf.float32), (self.batch_size, -1))
        r = tf.reshape(tf.convert_to_tensor(batch_raw.r, dtype=tf.float32), (self.batch_size, 1))
        ns = tf.reshape(tf.convert_to_tensor(batch_raw.ns, dtype=tf.float32), (self.batch_size, -1))
        d = tf.reshape(tf.convert_to_tensor(batch_raw.d, dtype=tf.float32), (self.batch_size, 1))
        return Batch(s, a, r, ns, d)

## Speed test

Learning from a random batch for 1000 iterations; shouldn't take more than 10 seconds.

In [39]:
algorithm = TD3(input_dim=2, action_dim=1)

batch_size = 100

batch = Batch(
    s=np.random.randn(batch_size, 2).astype('float32'), 
    a=np.random.randn(batch_size, 1).astype('float32'),
    r=np.random.randn(batch_size, 1).astype('float32'),
    ns=np.random.randn(batch_size, 2).astype('float32'),
    d=np.zeros((batch_size, 1)).astype('float32')
)

start = time.perf_counter()
for i in range(1000):
    stats = algorithm.update_networks(batch)
time.perf_counter() - start

4.694819589999952

## Test on Pendulum-v0

In [41]:
env = RescaleAction(gym.make("Pendulum-v0"), -1, 1)
algorithm = TD3(input_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0])
buffer = ReplayBuffer()

num_epochs = 50
for epoch in range(num_epochs):
    start_time = time.perf_counter()
    obs = env.reset()
    ret = 0
    while True:
        action = algorithm.act(obs, deterministic=False).numpy()
        next_obs, reward, done, info = env.step(action)
        ret += reward
        buffer.push(obs, action, reward, next_obs, False)
        if buffer.is_ready():
            algorithm.update_networks(buffer.sample())
        if done:
            break
        obs = next_obs
    print(epoch + 1, ret, time.perf_counter() - start_time)

1 -1377.7201519550424 1.794488460000025
2 -1529.3996065565661 1.0556090200000199
3 -1456.6756359985886 1.0967808179999565
4 -1344.584958734682 1.1548143249999612
5 -1539.4519860086668 1.1703114599999935
6 -1208.306171330558 1.0985516450000432
7 -1222.7494143507333 1.1925283159999935
8 -1258.2740468115094 1.2040690029999723
9 -1298.5261874564592 1.2209370260000014
10 -1289.4815854038097 1.2726620660000663
11 -1602.7285927115747 1.1414879430000155
12 -1329.5986483250476 1.0790147920000663
13 -1267.603610474868 1.0285639100000026
14 -1557.8745009175545 1.1278508709999642
15 -1060.1218081699435 1.0668025140000736
16 -1068.9396667373805 1.0879064770000468
17 -1302.2124557656587 1.1822623059999842
18 -1098.2811893056407 1.2260191759999088
19 -1082.3421652002253 1.2925586159999511
20 -934.9824548898551 1.0865474239999457
21 -895.2772368792777 1.255021227000043
22 -747.08303465816 1.1789143299999978
23 -780.4148414479563 1.173073058
24 -912.0859684696242 1.0808101179999312
25 -656.359061961861

In [42]:
algorithm.save_actor(save_dir="./saved_models/pendulum")

In [43]:
algorithm.load_actor(save_dir="./saved_models/pendulum")

In [49]:
obs = env.reset()
while True:
    action = algorithm.act(obs, deterministic=False).numpy()
    next_obs, reward, done, info = env.step(action)
    env.render()
    if done:
        break
    obs = next_obs