In [None]:
import numpy as np
import torch
import torch.nn as nn
import gym

In [None]:
use_cuda_if_available = True
use_cuda = use_cuda_if_available and torch.cuda.is_available()
DEVICE = torch.device('cuda' if use_cuda else 'cpu')
print(f"Using CUDA: {use_cuda}")

Using CUDA: True


In [None]:
from contextlib import contextmanager

def as_tensor(data: np.ndarray, dtype=torch.float32, batch: bool = False) -> torch.Tensor:
  tensor = torch.as_tensor(data, dtype=dtype, device=DEVICE)
  if batch and len(tensor.shape) == 1:
    tensor = tensor.unsqueeze(dim=0)
  return tensor

@contextmanager
def eval_mode(net: nn.Module):
  training_before = net.training
  try:
    net.eval()
    with torch.no_grad():
      yield net
  finally:
    if training_before:
      net.train()

In [None]:
import random
import os

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def set_seeds(seed=42, env: gym.Env = None):
  torch.use_deterministic_algorithms(mode=True)
  torch.manual_seed(seed)
  random.seed(seed)
  np.random.seed(seed)
  if env:
    env.seed(seed)
    env.action_space.seed(seed)

In [None]:
from collections import namedtuple
from typing import Tuple

Transition = namedtuple("Transition", (
    "state", "action", "reward", "next_state", "done"))

class ReplayBuffer:
  def __init__(self, env: gym.Env, capacity: int):
    self.size = 0
    self.capacity = capacity
    self.cur_slot = 0

    num_states = np.prod(env.observation_space.shape)
    num_actions = np.prod(env.action_space.shape)

    def make_buf(dim):
      return as_tensor(torch.empty(self.capacity, dim))

    self.state_buf = make_buf(num_states)
    self.action_buf = make_buf(num_actions)
    self.reward_buf = make_buf(1)
    self.next_state_buf = make_buf(num_states)
    self.done_buf = make_buf(1)
  
  def record(self, tr: Transition):
    self.state_buf[self.cur_slot] = as_tensor(tr.state).view(-1)
    self.action_buf[self.cur_slot] = as_tensor(tr.action).view(-1)
    self.reward_buf[self.cur_slot] = tr.reward
    self.next_state_buf[self.cur_slot] = as_tensor(tr.next_state).view(-1)
    self.done_buf[self.cur_slot] = tr.done

    self.size = max(self.size, self.cur_slot+1)
    self.cur_slot = (self.cur_slot + 1) % self.capacity

  def sample(self, batch_size: int) -> Transition:
    indices = torch.randint(self.size, (batch_size,))
    state_batch = self.state_buf[indices]
    action_batch = self.action_buf[indices]
    reward_batch = self.reward_buf[indices]
    next_state_batch = self.next_state_buf[indices]
    done_batch = self.done_buf[indices]
    
    return Transition(state_batch, action_batch, reward_batch,
                      next_state_batch, done_batch)

In [None]:
class OUNoise:
  def __init__(self, mean=0, std=0.2, theta=0.15, dt=1e-2):
    self.mean = mean
    self.std = std
    self.theta = theta
    self.dt = dt
    self.x = mean
  
  def step(self, n=1):
    w = np.random.normal(size=n)
    res = np.empty(n)
    sqrt_dt = np.sqrt(self.dt)
    for idx in range(n):
      res[idx] = self.x
      dx = self.theta * (self.mean - self.x) * self.dt + \
        self.std * sqrt_dt * w[idx]
      self.x += dx

    return res if n > 1 else res[0]

In [None]:
import torch.nn as nn
import torch.nn.functional as func

def unif_init(fc: nn.Linear, a: float):
  nn.init.uniform_(fc.weight.data, -a, a)
  nn.init.uniform_(fc.bias.data, -a, a)

def fanin_init(fc: nn.Linear):
  a = 1/np.sqrt(fc.weight.data.size()[0])
  unif_init(fc, a)

class Actor(nn.Module):
  def __init__(self, env: gym.Env, hidden1=400, hidden2=300,
               final_layer_a=3e-3):
    super(Actor, self).__init__()

    self.env = env
    self.num_states = np.prod(env.observation_space.shape)
    self.num_actions = np.prod(env.action_space.shape)

    self.action_loc = (env.action_space.high + env.action_space.low) / 2.0
    self.action_loc = as_tensor(self.action_loc)

    self.action_scale = (env.action_space.high - env.action_space.low) / 2.0
    self.action_scale = as_tensor(self.action_scale)

    self.bn0 = nn.BatchNorm1d(self.num_states)
    self.fc1 = nn.Linear(self.num_states, hidden1)
    self.bn1 = nn.BatchNorm1d(hidden1)
    self.fc2 = nn.Linear(hidden1, hidden2)
    self.bn2 = nn.BatchNorm1d(hidden2)
    self.final = nn.Linear(hidden2, self.num_actions)

    fanin_init(self.fc1)
    fanin_init(self.fc2)
    unif_init(self.final, final_layer_a)
  
  def forward(self, s):
    out = self.bn0(s)
    out = self.bn1(func.relu(self.fc1(out)))
    out = self.bn2(func.relu(self.fc2(out)))
    out = torch.tanh(self.final(out))
    out = self.action_scale * out + self.action_loc
    return out
  
  def policy(self, s, noise=None):
    with eval_mode(self):
      s = as_tensor(np.ravel(s), batch=True)
      action = self.forward(s)[0]
      action = action.cpu().numpy()
      if noise:
        action += noise
        action = np.clip(action, self.env.action_space.low, self.env.action_space.high)
      return np.reshape(action, self.env.action_space.shape)

class Critic(nn.Module):
  def __init__(self, env: gym.Env, hidden1=400, hidden2=300,
               final_layer_a=3e-3):
    super(Critic, self).__init__()

    self.env = env
    num_states = np.prod(env.observation_space.shape)
    num_actions = np.prod(env.action_space.shape)

    self.bn0 = nn.BatchNorm1d(num_states)
    self.fc1 = nn.Linear(num_states, hidden1)
    self.bn1 = nn.BatchNorm1d(hidden1)
    self.fc2 = nn.Linear(hidden1 + num_actions, hidden2)
    self.final = nn.Linear(hidden2, 1)

    fanin_init(self.fc1)
    fanin_init(self.fc2)
    unif_init(self.final, final_layer_a)

  def forward(self, s, a):
    out = self.bn0(s)
    out = self.bn1(func.relu(self.fc1(out)))
    out = torch.cat((out, a), dim=1)
    out = func.relu(self.fc2(out))
    out = self.final(out)
    return out

In [None]:
import itertools
from typing import Iterable

def make_episode(env: gym.Env, action_source) -> Iterable[Transition]:
  state = env.reset()
  for step in itertools.count():
    action = action_source(state)
    next_state, reward, done, _ = env.step(action)
    yield Transition(state, action, reward, next_state, done)
    if done: break
    state = next_state

def create_replay_buffer(env: gym.Env, capacity: int, prefill: int):
  replay_buffer = ReplayBuffer(env, capacity)
  episode = None

  def random_action_source(state):
    return env.action_space.sample()
  
  while replay_buffer.size < prefill:
    if episode is None:
      episode = make_episode(env, random_action_source)
    
    try:
      transition = next(episode)
      replay_buffer.record(transition)
    except StopIteration:
      episode = None
  
  return replay_buffer

def update_target(target_net: nn.Module, source_net: nn.Module, tau: float):
  target_params = target_net.state_dict()
  source_params = source_net.state_dict()

  for name, source_param in source_params.items():
    if name in target_params:
      target_param = target_params[name]
      polyak_avg = tau*source_param.data+(1-tau)*target_param.data
      target_params[name].data.copy_(polyak_avg)
    
  target_net.load_state_dict(target_params)

In [None]:
import copy

StepStats = namedtuple("StepStats", (
    "overall_step", "episode_num", "step", "transition",
    "step_actor_loss", "step_critic_loss"))

EpisodeStats = namedtuple("EpisodeStats", (
    "episode_num", "reward"
))

def train(env: gym.Env, actor: Actor, critic: Critic, actor_optim, critic_optim,
          noise: OUNoise, replay_buffer: ReplayBuffer, batch_size, gamma, tau,
          on_step=None, on_episode_end=None):
  target_actor = copy.deepcopy(actor)
  target_actor.eval()

  target_critic = copy.deepcopy(critic)
  target_critic.eval()

  def critic_loss(batch: Transition):
    with torch.no_grad():
      next_action = target_actor(batch.next_state)
      next_q_value = target_critic(batch.next_state, next_action)
      target = batch.reward + gamma * (1 - batch.done) * next_q_value
    value_estimate = critic(batch.state, batch.action)
    return func.mse_loss(value_estimate, target)
  
  def actor_loss(batch: Transition):
    q_value = critic(batch.state, actor(batch.state))
    return -torch.mean(q_value)
  
  def update_step():
    batch = replay_buffer.sample(batch_size)

    critic.zero_grad()
    step_critic_loss = critic_loss(batch)
    step_critic_loss.backward()
    critic_optim.step()

    actor.zero_grad()
    step_actor_loss = actor_loss(batch)
    step_actor_loss.backward()
    actor_optim.step()

    update_target(target_actor, actor, tau)
    update_target(target_critic, critic, tau)
    return step_actor_loss.item(), step_critic_loss.item()

  def action_source(state):
    return actor.policy(state, noise.step())

  overall_step = 0
  for episode_num in itertools.count():
    episode = None
    episode_reward = 0

    for step in itertools.count():
      if episode is None:
        episode = make_episode(env, action_source)

      try:
        transition: Transition = next(episode)
        episode_reward += transition.reward
        replay_buffer.record(transition)
        step_actor_loss, step_critic_loss = update_step()

        if on_step:
          stats = StepStats(overall_step, episode_num, step, transition,
                            step_actor_loss, step_critic_loss)
          if on_step(stats):
            return

      except StopIteration:
        episode = None
        if on_episode_end:
          stats = EpisodeStats(episode_num, episode_reward)
          if on_episode_end(stats):
            return
        break

      finally:
        overall_step += 1

In [None]:
def benchmark(env: gym.Env, actor: Actor, total_episodes=64):
  episode_rewards = np.empty(total_episodes)

  def action_source(state):
    return actor.policy(state)

  set_seeds(env=env)
  for episode_num in range(total_episodes):
    episode_reward = 0
    for transition in make_episode(env, action_source):
      episode_reward += transition.reward
    episode_rewards[episode_num] = episode_reward
  
  return np.mean(episode_rewards)

In [None]:
def train_until_sufficient(env_id: str, req_score: float):
  env = gym.make(env_id)
  actor = Actor(env).to(DEVICE)
  critic = Critic(env).to(DEVICE)
  actor_optim = torch.optim.Adam(actor.parameters(), lr=1e-3)
  critic_optim = torch.optim.Adam(critic.parameters(), 
                                  lr=1e-4, weight_decay=1e-2)
  noise = OUNoise()
  replay_buffer = create_replay_buffer(env, int(1e6), int(1e3))
  batch_size = 64
  gamma = 1-1e-2
  tau = 1e-3

  benchmark_period = 16
  benchmark_size = 128

  episode_rewards = []
  def on_episode_end(stats: EpisodeStats):
    episode_rewards.append(stats.reward)
    mean_reward = np.mean(episode_rewards[-32:])
    print(f"Episode #{stats.episode_num}. Mean reward = {mean_reward}")
    if stats.episode_num > 0 and stats.episode_num % benchmark_period == 0:
      score = benchmark(env, actor, total_episodes=benchmark_size)
      print(f"[Benchmark] Score: {score}")
      if score > req_score:
        return True

  train(env, actor, critic, actor_optim, critic_optim, noise, replay_buffer,
        batch_size, gamma, tau, on_episode_end=on_episode_end)
  
  return actor

In [None]:
train_until_sufficient("Pendulum-v1", -160)

Episode #0. Mean reward = -1203.1900336244596
Episode #1. Mean reward = -1249.3253867591625
Episode #2. Mean reward = -1357.321775632476
Episode #3. Mean reward = -1347.8823730187823
Episode #4. Mean reward = -1402.2891948453982
Episode #5. Mean reward = -1439.6136207705115
Episode #6. Mean reward = -1452.9114968540562
Episode #7. Mean reward = -1490.2073530738626
Episode #8. Mean reward = -1514.6910180677291
Episode #9. Mean reward = -1522.4171326960243
Episode #10. Mean reward = -1542.2024431460952
Episode #11. Mean reward = -1565.94386172617
Episode #12. Mean reward = -1571.2318761419067
Episode #13. Mean reward = -1587.86450397867
Episode #14. Mean reward = -1593.874195859919
Episode #15. Mean reward = -1606.3637045254845
Episode #16. Mean reward = -1619.1206480854248
[Benchmark] Score: -1770.766601195583
Episode #17. Mean reward = -1628.4890640863985
Episode #18. Mean reward = -1638.416776483169
Episode #19. Mean reward = -1648.4736423368347
Episode #20. Mean reward = -1656.142797

KeyboardInterrupt: ignored