# Setting things up to display videos

In [1]:
!pip install einops
from einops import rearrange, repeat
import torch as t
import torch.nn as nn
import numpy
import copy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
# from https://colab.research.google.com/drive/1flu31ulJlgiRL1dnN2ir8wGh9p7Zij2t#scrollTo=odNaDE1zyrL2

#remove " > /dev/null 2>&1" to see what is going on under the hood
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!apt-get update > /dev/null 2>&1
!apt-get install cmake > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[atari] > /dev/null 2>&1
!pip install gym[box2d] > /dev/null 2>&1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting setuptools
  Downloading setuptools-62.3.2-py3-none-any.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 4.3 MB/s 
[?25hInstalling collected packages: setuptools
  Attempting uninstall: setuptools
    Found existing installation: setuptools 57.4.0
    Uninstalling setuptools-57.4.0:
      Successfully uninstalled setuptools-57.4.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m
Successfully installed setuptools-62.3.2


In [2]:
import gym
from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40) #error only
import tensorflow as tf
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import math
import glob
import io
import base64
from IPython.display import HTML

from IPython import display as ipythondisplay

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

"""
Utility functions to enable video recording of gym environment and displaying it
To enable video, just do "env = wrap_env(env)""
"""

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")
    

def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

In [3]:
def show_agent(agent, env):
  env = wrap_env(env)
  obs = env.reset()
  done = False
  while not done:
    env.render(mode='rgb_array')
    obs, _, done, _ = env.step(agent(obs, env))
  env.close()
  show_video()

def random_agent(obs, env):
  return env.action_space.sample()

show_agent(random_agent, gym.make('CartPole-v1'))


# Making an agent

In [5]:
class MLP(nn.Module):
  def __init__(self, in_size, hidden_size, out_size):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(in_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, out_size)
    )
  
  def forward(self, x):
    if isinstance(x, numpy.ndarray): x = t.from_numpy(x).float()
    return self.net(x)

In [6]:
def make_choice(env, eps, net, obs, device=t.device("cpu")):
  if t.rand(1) < eps: return env.action_space.sample()
  if isinstance(obs, numpy.ndarray): obs = t.from_numpy(obs).float()
  obs = obs.to(device)
  with t.no_grad():
    return t.argmax(net(obs)).item()

def evaluate(model, env, eps=0.3, device=t.device("cpu")):
  obs = env.reset()
  model.eval()
  total_reward = 0.
  done = False
  while not done:
    action = make_choice(env, eps, model, obs, device=device)
    obs, reward, done, _ = env.step(action)
    total_reward += reward
  return total_reward

In [7]:
def to_agent(model):
  model.to(t.device("cpu"))
  def act(obs, env):
    with t.no_grad():
      return t.argmax(model(obs)).item()
  return act

# Training

In [8]:
def get_linear_decay_eps(eps_start, eps_end, n_steps):
  return lambda x: eps_start + (eps_end - eps_start) / n_steps * x

In [None]:
def train(model, env, eps_fn=lambda x: 0.05, 
          lr=1e-3, gamma=.98, 
          n_steps=20000, train_freq=16, mem_size=10000, batch_size=128,
          train_on="cpu"):
  if train_on == "cuda":
    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
  else: device = "cpu"
  model.to(device)
  model.train()

  optimizer = t.optim.Adam(model.parameters(), lr=lr)
  loss_fn = nn.MSELoss()

  # instead of experience buffer, tensors storing states, rewards, dones
  obss = t.zeros(mem_size, *env.observation_space.shape)
  acts = t.zeros(mem_size, dtype=t.long)
  rwds = t.zeros(mem_size)
  dones= t.zeros(mem_size, dtype=t.bool)

  obs = env.reset()
  obss[0] = t.from_numpy(obs).float()
  done = False

  for step in range(n_steps):

    # take an action and record info
    act = make_choice(env, eps_fn(step), model, obs, device=device)
    obs, rwd, done, _ = env.step(act)
    acts[step % mem_size] = act
    rwds[step % mem_size] = rwd
    dones[step % mem_size] = done

    # record next observation
    if done: 
      obs = env.reset()
    obss[(step + 1) % mem_size] = t.from_numpy(obs).float()

    # every train_freq steps, sample experiences and update the model
    if (step + 1) % train_freq == 0:
      optimizer.zero_grad()
      batch_indices = t.randperm(min(mem_size, step + 1))[:batch_size]
      obss_batch = obss[batch_indices].to(device)
      acts_batch = acts[batch_indices].to(device)
      rwds_batch = rwds[batch_indices].to(device)
      dones_batch=dones[batch_indices].to(device)
      next_states = obss[(batch_indices + 1) % mem_size].to(device)

      with t.no_grad():
        next_state_qs = t.max(model(next_states), dim=-1).values
      target = rwds_batch + ~dones_batch * gamma * next_state_qs
      model_out = model(obss_batch)[t.arange(batch_indices.size(0), device=device), acts_batch]
      loss = loss_fn(model_out, target)
      loss.backward()
      optimizer.step()
    
    if step % 2000 == 0:
      print(evaluate(model, env, eps=0., device=device))
      model.train()


In [None]:
env = gym.make('CartPole-v1')
hidden_size = 64
model = MLP(*env.observation_space.shape,hidden_size, env.action_space.n)

In [None]:
train(model, env, get_linear_decay_eps(.5, .05, 100000), n_steps=100000, train_on="cuda")

9.0
42.0
21.0
56.0
154.0
44.0
86.0
301.0
254.0
466.0
500.0
315.0
500.0
345.0
500.0
500.0
353.0
410.0
310.0
500.0
331.0
500.0
500.0
500.0
340.0
375.0
500.0
442.0
336.0
500.0
500.0
342.0
491.0
500.0
500.0
367.0
500.0
370.0
347.0
421.0
381.0
291.0
266.0
500.0
280.0
337.0
270.0
305.0
234.0
252.0


In [None]:
agent = to_agent(model)

show_agent(agent, env)

In [None]:
env = gym.make("Acrobot-v1")

show_agent(random_agent, env)

In [None]:
env = gym.make('Acrobot-v1')
hidden_size = 64
model = MLP(*env.observation_space.shape,hidden_size, env.action_space.n)

In [None]:
def get_acrobot_eps(n_steps):
  def out(step):
    if step < .1 * n_steps:
      return 1. + (.1 - 1.) / n_steps * step
    else:
      return .1
  return out

In [None]:
train(model, env, get_acrobot_eps(100000), lr=1e-4, gamma=.99, 
          n_steps=100000, train_freq=4, mem_size=10000, batch_size=128,
          train_on="cuda")

-500.0
-500.0
-500.0
-350.0
-79.0
-76.0
-99.0
-84.0
-83.0
-141.0
-134.0
-115.0
-500.0
-113.0
-500.0
-137.0
-122.0
-94.0
-113.0
-88.0
-93.0
-121.0
-86.0
-112.0
-95.0
-90.0
-95.0
-500.0
-125.0
-90.0
-116.0
-150.0
-500.0
-153.0
-87.0
-500.0
-119.0
-500.0
-85.0
-500.0
-106.0
-500.0
-500.0
-148.0
-165.0
-97.0
-102.0
-500.0
-500.0
-500.0


In [None]:
agent = to_agent(model)

show_agent(agent, env)

# Learning Breakout

In [9]:
# it makes no fucking sense, but the notebook needs to try and fail to make a Breakout env
# without the ROM or else nothing will work
try: gym.make("BreakoutNoFrameskip-v0")
except: print("lmao")

lmao


In [10]:
! wget http://www.atarimania.com/roms/Roms.rar
! mkdir /content/ROM/
! unrar e /content/Roms.rar /content/ROM/
! python -m atari_py.import_roms /content/ROM/

--2022-06-03 17:01:58--  http://www.atarimania.com/roms/Roms.rar
Resolving www.atarimania.com (www.atarimania.com)... 195.154.81.199
Connecting to www.atarimania.com (www.atarimania.com)|195.154.81.199|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 19583716 (19M) [application/x-rar-compressed]
Saving to: ‘Roms.rar’


2022-06-03 17:02:02 (4.69 MB/s) - ‘Roms.rar’ saved [19583716/19583716]


UNRAR 5.50 freeware      Copyright (c) 1993-2017 Alexander Roshal


Extracting from /content/Roms.rar

Extracting  /content/ROM/128 in 1 Game Select ROM (Unknown) ~.bin          0%  OK 
Extracting  /content/ROM/2 in 1 - Chess, Othello (Atari) (Prototype).bin       0%  OK 
Extracting  /content/ROM/2 Pak Special - Cavern Blaster, City War (1992) (HES) (773-867) (PAL).bin       0%  OK 
Extracting  /content/ROM/2 Pak Special - Challenge, Surfing (1990) (HES) (771-333) (PAL).bin       0%  OK 
Extracting  /content/ROM/2 Pak Special - Dolphin, Oin

In [11]:
!pip install stable_baselines3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting stable_baselines3
  Downloading stable_baselines3-1.5.0-py3-none-any.whl (177 kB)
[K     |████████████████████████████████| 177 kB 14.8 MB/s 
Collecting gym==0.21
  Downloading gym-0.21.0.tar.gz (1.5 MB)
[K     |████████████████████████████████| 1.5 MB 73.2 MB/s 
Building wheels for collected packages: gym
  Building wheel for gym (setup.py) ... [?25l[?25hdone
  Created wheel for gym: filename=gym-0.21.0-py3-none-any.whl size=1616798 sha256=9705a77e3d7693d5f83c83e7e815e9dc79f75325fdf4e2504d29eb98cc69bd68
  Stored in directory: /root/.cache/pip/wheels/76/ee/9c/36bfe3e079df99acf5ae57f4e3464ff2771b34447d6d2f2148
Successfully built gym
Installing collected packages: gym, stable-baselines3
  Attempting uninstall: gym
    Found existing installation: gym 0.17.3
    Uninstalling gym-0.17.3:
      Successfully uninstalled gym-0.17.3
Successfully installed gym-0.21.0 stable-baseline

In [12]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [13]:
from gdrive.MyDrive.mlab.days.atari_wrappers import AtariWrapper

In [14]:
# if you get error "no registered env with id" then delete and restart the runtime
env = gym.make("BreakoutNoFrameskip-v0") 
env = AtariWrapper(env)
show_agent(random_agent, env)

In [None]:
class AtariCNN(nn.Module):
  def __init__(self, obs_n_channels, n_action_space):
    super().__init__()
    self.obs_n_channels = obs_n_channels
    self.net = nn.Sequential(
      nn.Conv2d(obs_n_channels, 32, 8, stride=4),
      nn.ReLU(),
      nn.Conv2d(32, 64, 4, stride=2),
      nn.ReLU(),
      nn.Conv2d(64, 64, 3, stride=1),
      nn.ReLU(),
      nn.Flatten(),
      nn.Linear(3136, n_action_space)
    )

  def forward(self, x):
    if isinstance(x, numpy.ndarray): x = t.from_numpy(x).float()
    if len(x.shape) == 3: x = x.unsqueeze(0)
    if x.size(-1) == self.obs_n_channels: # channel inputs are at the end
      x = rearrange(x, 'b h w c -> b c h w')
    return self.net(x)

In [None]:
env = AtariWrapper(gym.make("BreakoutNoFrameskip-v0"))
model = AtariCNN(env.observation_space.shape[-1], env.action_space.n)

In [None]:
n_steps = 1000000
train(model, env, eps_fn=get_linear_decay_eps(1., 0.01, n_steps), 
          lr=3e-5, gamma=.99, 
          n_steps=n_steps, train_freq=64, mem_size=20000, batch_size=512,
          train_on="cuda")

1.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0
2.0
1.0
0.0
1.0
0.0
0.0
0.0
0.0
2.0
0.0
2.0
1.0
0.0
1.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
2.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
1.0
1.0
0.0
0.0
2.0
2.0
2.0
0.0
2.0
0.0
2.0
1.0
2.0
0.0
0.0
1.0
1.0
0.0
0.0
0.0
1.0
1.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0
0.0
2.0
0.0
1.0
0.0
0.0
0.0
0.0
1.0
0.0
3.0
0.0
0.0
0.0
0.0
0.0
0.0
2.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
2.0
1.0
3.0
0.0
0.0
1.0
1.0
0.0
2.0
2.0
1.0
0.0
0.0
0.0
0.0
1.0
1.0
1.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
3.0
1.0
1.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0
1.0
1.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
1.0
0.0
1.0
0.0
1.0
0.0
1.0
1.0
0.0
2.0
1.0
0.0
0.0
0.0
0.0
1.0
0.0
1.0
2.0
0.0
0.0
2.0
1.0
1.0
0.0
0.0
0.0
0.0
0.0
1.0
2.0
2.0
0.0
0.0
1.0
2.0
0.0
0.0
2.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
2.0
0.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
1.0
0.0
0.0
0.0
1.0
0.0
2.0
1.0
2.0
1.0
1.0
1.0
1.0
0.0
0.0
4.0
1.0
0.0
0.0
1.0
1.0
1.0


In [None]:
env = gym.make("BreakoutNoFrameskip-v0")
env = AtariWrapper(env)
agent = to_agent(model)

In [None]:
show_agent(agent, env)

# Extensions

In [15]:
def train(model, env, eps_fn=lambda x: 0.05, 
          lr=1e-3, gamma=.98, 
          double=False, double_update_freq = 1000, multistep=1,
          n_steps=20000, train_freq=16, eval_freq=2000,
          mem_size=10000, batch_size=128,
          train_on="cpu"):
  if train_on == "cuda":
    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
  else: device = "cpu"

  # set up model(s)
  model.to(device)
  if double:
    target_model = copy.deepcopy(model)
    target_model.eval()
  model.train()

  # set up auxiliary tensors for time discounting
  discount = gamma ** t.arange(multistep + 1, device=device)
  multistep_ranges = repeat(t.arange(multistep, device=device), 'ms -> bs ms', bs=batch_size)

  # set up optimization
  optimizer = t.optim.Adam(model.parameters(), lr=lr)
  loss_fn = nn.MSELoss()

  # instead of experience buffer, tensors storing states, rewards, dones
  obss = t.zeros(mem_size, *env.observation_space.shape, device=device)
  acts = t.zeros(mem_size, dtype=t.long, device=device)
  rwds = t.zeros(mem_size, device=device)
  dones= t.zeros(mem_size, dtype=t.bool, device=device)

  obs = env.reset()
  obss[0] = t.from_numpy(obs).float()
  done = False

  for step in range(n_steps):

    # take an action and record info
    act = make_choice(env, eps_fn(step), model, obs, device=device)
    obs, rwd, done, _ = env.step(act)
    acts[step % mem_size] = act
    rwds[step % mem_size] = rwd
    dones[step % mem_size] = done

    # record next observation
    if done: 
      obs = env.reset()
    obss[(step + 1) % mem_size] = t.from_numpy(obs).float()

    # every train_freq steps, sample experiences and update the model
    if (step + 1) % train_freq == 0 and step >= batch_size:

      batch_ids = t.randperm(min(step, mem_size), device=device)[:batch_size]
      obss_batch = obss[batch_ids]
      acts_batch = acts[batch_ids]

      multistep_ids = (batch_ids.unsqueeze(1) + multistep_ranges) % mem_size # : [batch_size, multistep]
      multistep_rwds = rwds[multistep_ids]
      multistep_dones = (dones[multistep_ids])
      ever_done = multistep_dones.any(-1)
      done_step = t.argmax(multistep_dones.float(), dim=-1) + 1
      done_step[~ever_done] = multistep
      still_alive = t.cumprod(~multistep_dones, dim=-1)
      still_alive = t.cat((t.zeros(batch_size, 1, dtype=t.bool, device=device), still_alive), dim=-1)
      batch_discounts = discount * still_alive

      last_states = obss[(batch_ids + done_step) % mem_size]
      with t.no_grad():
        if double:
          last_state_qs = t.max(target_model(last_states), dim=-1).values
        else:
          last_state_qs = t.max(model(last_states), dim=-1).values
      multistep_rwds = t.cat((multistep_rwds, last_state_qs.unsqueeze(1)), dim=-1)
      target = t.sum(batch_discounts * multistep_rwds, dim=-1)
      model_out = model(obss_batch)[t.arange(batch_size, device=device), acts_batch]
      
      loss = loss_fn(model_out, target)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    
    if (step + 1) % double_update_freq * train_freq == 0:
      target_model.load_state_dict(model.state_dict())

    if step % eval_freq == 0:
      score = evaluate(model, env, eps=0., device=device)
      print(score)
      model.train()

In [12]:
env = gym.make('CartPole-v1')
hidden_size = 64
model = MLP(*env.observation_space.shape,hidden_size, env.action_space.n)

In [13]:
train(model, env, eps_fn=lambda x: 0.05, 
          lr=1e-3, gamma=.98, 
          double=True, double_update_freq = 1000, multistep=3,
          n_steps=100000, train_freq=16, eval_freq=2000,
          mem_size=10000, batch_size=16,
          train_on="cpu")

11.0
9.0
11.0
11.0
10.0
72.0
31.0
32.0
99.0
87.0
89.0
63.0
132.0
205.0
376.0
234.0
451.0
188.0
199.0
500.0
162.0
249.0
281.0
500.0
272.0
303.0
243.0
291.0
347.0
203.0
219.0
456.0
276.0
233.0
264.0
345.0
446.0
404.0
342.0
222.0
143.0
298.0
145.0
228.0
307.0
313.0
164.0
181.0
362.0
335.0


In [14]:
show_agent(to_agent(model), env)

In [16]:
class DuelingDQN(nn.Module):
  def __init__(self, obs_n_channels, n_action_space):
    super().__init__()
    self.obs_n_channels = obs_n_channels
    self.net = nn.Sequential(
      nn.Conv2d(obs_n_channels, 32, 8, stride=4),
      nn.ReLU(),
      nn.Conv2d(32, 64, 4, stride=2),
      nn.ReLU(),
      nn.Conv2d(64, 64, 3, stride=1),
      nn.ReLU(),
      nn.Flatten()
    )
    self.value = nn.Linear(3136, 1)
    self.advantage = nn.Linear(3136, n_action_space)
  def forward(self, x):
    if isinstance(x, numpy.ndarray): x = t.from_numpy(x).float()
    if len(x.shape) == 3: x = x.unsqueeze(0)
    if x.size(-1) == self.obs_n_channels: # channel inputs are at the end
      x = rearrange(x, 'b h w c -> b c h w')
    conv_out = self.net(x)
    val_out = self.value(conv_out)
    adv_out = self.advantage(conv_out)
    return val_out + adv_out - t.mean(adv_out, dim=-1, keepdim=True)

In [21]:
env = AtariWrapper(gym.make("BreakoutNoFrameskip-v0"))
model = DuelingDQN(env.observation_space.shape[-1], env.action_space.n)

In [22]:
n_steps = 1000000
train(model, env, eps_fn=get_linear_decay_eps(1., 0.01, n_steps), 
          lr=3e-5, gamma=.99, 
          double=True, double_update_freq = 1000, multistep=3,
          n_steps=n_steps, train_freq=16, eval_freq=10000,
          mem_size=10000, batch_size=128,
          train_on="cuda")

0.0
0.0
0.0
0.0
2.0
2.0
0.0
0.0
0.0
0.0
0.0
4.0
2.0
0.0
0.0
1.0
1.0
0.0
0.0
1.0
1.0
1.0
0.0
0.0
1.0
0.0
1.0
1.0
0.0
4.0
2.0
0.0
3.0
2.0
1.0
1.0
0.0
0.0
2.0
3.0
0.0
1.0
0.0
0.0
1.0
0.0
2.0
0.0
5.0
0.0
4.0
3.0
1.0
1.0
5.0
1.0
1.0
6.0
0.0
0.0
4.0
0.0
0.0
0.0
0.0
3.0
2.0
0.0
1.0
2.0
9.0
1.0
0.0
5.0
0.0
5.0
1.0
0.0
1.0
1.0
2.0
1.0
2.0
1.0
4.0
0.0
3.0
4.0
5.0
7.0
2.0
0.0
2.0
0.0
0.0


KeyboardInterrupt: ignored

In [23]:
t.save(model.state_dict(), '/content/gdrive/MyDrive/mlab/days/w3d3/breakout_trained')

In [17]:
model = DuelingDQN(env.observation_space.shape[-1], env.action_space.n)
model.load_state_dict(t.load('/content/gdrive/MyDrive/mlab/days/w3d3/breakout_trained'))

<All keys matched successfully>

In [20]:
env.reset()
show_agent(to_agent(model), env)

# Policy gradients

In [4]:
class PolicyMLP(nn.Module):
  def __init__(self, in_size, hidden_size, out_size):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(in_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, out_size),
        nn.LogSoftmax(dim=-1)
    )
  
  def forward(self, x):
    if isinstance(x, numpy.ndarray): x = t.from_numpy(x).float()
    return self.net(x)

class PolicyCNN(nn.Module):
  def __init__(self, obs_n_channels, n_action_space):
    super().__init__()
    self.obs_n_channels = obs_n_channels
    self.net = nn.Sequential(
      nn.Conv2d(obs_n_channels, 32, 8, stride=4),
      nn.ReLU(),
      nn.Conv2d(32, 64, 4, stride=2),
      nn.ReLU(),
      nn.Conv2d(64, 64, 3, stride=1),
      nn.ReLU(),
      nn.Flatten(),
      nn.Linear(3136, 512),
      nn.Linear(512, n_action_space),
      nn.LogSoftmax(dim=-1)
    )

  def forward(self, x):
    if isinstance(x, numpy.ndarray): x = t.from_numpy(x).float()
    if len(x.shape) == 3: x = x.unsqueeze(0)
    if x.size(-1) == self.obs_n_channels: # channel inputs are at the end
      x = rearrange(x, 'b h w c -> b c h w')
    return self.net(x)

In [5]:
def policy_sample(env, eps, net, obs, device=t.device("cpu")):
  if t.rand(1) < eps: return env.action_space.sample()
  if isinstance(obs, numpy.ndarray): obs = t.from_numpy(obs).float()
  obs = obs.to(device)
  with t.no_grad():
    return t.multinomial(net(obs).exp(), 1).item()

def policy_eval(model, env, eps=0.3, device=t.device("cpu")):
  obs = env.reset()
  model.eval()
  total_reward = 0.
  done = False
  while not done:
    action = policy_sample(env, eps, model, obs, device=device)
    obs, reward, done, _ = env.step(action)
    total_reward += reward
  return total_reward

def policy_to_agent(model):
  model.to(t.device("cpu"))
  def act(obs, env):
    with t.no_grad():
      return t.multinomial(model(obs).exp(), 1).item()
  return act

In [6]:
def policy_train(model, env, eps_fn=lambda x: 0.05, 
          lr=1e-3,
          n_steps=20000, eval_freq=2000,
          batch_size=5000, train_on="cpu"):
  if train_on == "cuda":
    device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
  else: device = t.device("cpu")

  model.to(device)
  model.train()
  optimizer = t.optim.Adam(model.parameters(), lr=lr)

  p_acts = t.zeros(batch_size, device=device)
  rwds = t.zeros(batch_size, device=device)

  step = 0

  while step < n_steps:
    running_rwd = 0.
    ep_len = 0
    done = False
    train = False # is it time to train?
    eval = False # is it time to evaluate?
    obs = env.reset()

    while not done:
      logprob_act = model(obs)
      with t.no_grad():
        act = t.multinomial(logprob_act.exp(), 1).item()
      p_acts[step % batch_size] = logprob_act[act]
      obs, rwd, done, _ = env.step(act)

      running_rwd += rwd
      step += 1
      ep_len += 1

      if (step - 1) % eval_freq == 0: eval = True
      if step % batch_size == 0: train = True

    # update the rewards for the entire episode
    rwds[(step - ep_len) % batch_size : step % batch_size] = running_rwd

    # every train_freq steps, sample experiences and update the model
    if train and step >= batch_size:

      loss = - t.mean(p_acts * rwds)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      p_acts = t.zeros(batch_size, device=device)
      rwds = t.zeros(batch_size, device=device)

    if eval:
      score = policy_eval(model, env, eps=0., device=device)
      print(score)
      model.train()

In [7]:
env = gym.make('CartPole-v1')
model = PolicyMLP(*env.observation_space.shape, 64, env.action_space.n)
show_agent(policy_to_agent(model), env)

In [8]:
policy_train(model, env, eps_fn=lambda x: 0.05, 
          lr=1e-2, 
          n_steps=200000, eval_freq=5000,
          batch_size=5000,
          train_on="cpu")

63.0
18.0
27.0
10.0
47.0
106.0
68.0
31.0
51.0
73.0
99.0
54.0
131.0
180.0
165.0
171.0
81.0
155.0
165.0
166.0
104.0
107.0
60.0
55.0
35.0
42.0
111.0
94.0
102.0
110.0
92.0
119.0
98.0
112.0
135.0
103.0
107.0
110.0
114.0
48.0
59.0


In [12]:
policy_eval(model, env, eps=0)

76.0

In [16]:
show_agent(policy_to_agent(model), env)