In [None]:
from tqdm import trange

In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

from collections import deque
import random
import torch
import atari_py
import cv2


class Env():
    def __init__(self, max_episode_length=1.08e5, history_length=4):
        self.device = torch.device("cuda:0")
        self.ale = atari_py.ALEInterface()
        self.ale.setInt('max_num_frames_per_episode', max_episode_length)
        self.ale.setFloat('repeat_action_probability', 0)
        self.ale.setInt('frame_skip', 0)
        self.ale.setBool('color_averaging', False)
        self.ale.loadROM(atari_py.get_game_path("space_invaders"))
        actions = self.ale.getMinimalActionSet()
        self.actions = dict([i, e] for i, e in zip(range(len(actions)), actions))
        self.window = history_length
        self.state_buffer = deque([], maxlen=self.window)
        self.training = True
    
    def _reset_buffer(self):
        for _ in range(self.window):
            self.state_buffer.append(torch.zeros(84, 84, device=self.device))
    
    def _get_state(self):
        state = cv2.resize(self.ale.getScreenGrayscale(), (84, 84), interpolation=cv2.INTER_LINEAR)
        return torch.tensor(state, dtype=torch.float32, device=self.device).div_(255)
    
    def reset(self):
        self._reset_buffer()
        self.ale.reset_game()
        for _ in range(random.randrange(30)):
            self.ale.act(0)
            if self.ale.game_over():
                self.ale.reset_game()
        observation = self._get_state()
        self.state_buffer.append(observation)
        return torch.stack(list(self.state_buffer), 0)
    
    def step(self, action):
#         print(f"action:{action}")
        frame_buffer = torch.zeros(2, 84, 84, device=self.device)
        reward, done = 0, False
        for t in range(4):
            reward += self.ale.act(self.actions.get(action))
#             print(f"reward:{reward}")
            if t == 2:
                frame_buffer[0] = self._get_state()
            elif t == 3:
                frame_buffer[1] = self._get_state()
            done = self.ale.game_over()
            if done:
                break
        observation = frame_buffer.max(0)[0]
#         print(f"observation:{observation.shape}")
        self.state_buffer.append(observation)
        return torch.stack(list(self.state_buffer), 0), reward, done

In [None]:
env = Env()
action_space = len(env.actions)

In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, std_init=0.5):
        super(NoisyLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        self.register_buffer('bias_epsilon', torch.empty(out_features))
        self.reset_parameters()
        self.reset_noise()
    
    def reset_parameters(self):
        mu_range = 1 / math.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features))
    
    def _scale_noise(self, size):
        x = torch.randn(size)
        return x.sign().mul_(x.abs().sqrt_())
    
    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_features)
        epsilon_out = self._scale_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)
    
    def forward(self, input):
        if self.training:
            return F.linear(input, self.weight_mu + self.weight_sigma * self.weight_epsilon, self.bias_mu + self.bias_sigma * self.bias_epsilon)
        else:
            return F.linear(input, self.weight_mu, self.bias_mu)


class DQN(nn.Module):
    def __init__(self, atoms, action_space, history_length, hidden_size=256, noisy_std=0.1):
        super(DQN, self).__init__()
        self.atoms = atoms
        self.action_space = action_space
        
        self.convs = nn.Sequential(nn.Conv2d(history_length, 32, 5, stride=5, padding=0), nn.ReLU(), 
                                   nn.Conv2d(32, 64, 5, stride=5, padding=0), nn.ReLU()
                                  )
        self.conv_output_size = 576
        
        self.fc_h_v = NoisyLinear(self.conv_output_size, hidden_size, std_init=noisy_std)
        self.fc_h_a = NoisyLinear(self.conv_output_size, hidden_size, std_init=noisy_std)
        self.fc_z_v = NoisyLinear(hidden_size, self.atoms, std_init=noisy_std)
        self.fc_z_a = NoisyLinear(hidden_size, self.action_space * self.atoms, std_init=noisy_std)
    
    def forward(self, x, use_log_softmax=False):
#         print(f"x:{x.shape}")
        x = self.convs(x)
#         print(f"x:{x.shape}")
        x = x.view(-1, self.conv_output_size)
#         print(f"x:{x.shape}")
        
        v = self.fc_z_v(F.relu(self.fc_h_v(x)))
#         print(f"v:{v.shape}")
        a = self.fc_z_a(F.relu(self.fc_h_a(x)))
#         print(f"a:{a.shape}")
        v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_space, self.atoms)
#         print(f"v:{v.shape}")
#         print(f"a:{a.shape}")
        q = v + a - a.mean(1, keepdim=True)
#         print(f"q:{q.shape}")
        q = F.log_softmax(q, dim=2) if use_log_softmax else F.softmax(q, dim=2)
#         print(f"q:{q.shape}")
        
        return q
    
    def reset_noise(self):
        self.fc_h_v.reset_noise()
        self.fc_h_a.reset_noise()
        self.fc_z_v.reset_noise()
        self.fc_z_a.reset_noise()

In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

import torch
import torch.optim as optim


class Agent():
    def __init__(self, env, atoms=51, V_min=-10.0, V_max=10.0, batch_size=32, multi_step=20, discount=0.99, norm_clip=10.0, lr=1.0e-4, adam_eps=1.5e-4):
        self.device = torch.device("cuda:0")
        self.action_space = len(env.actions)
        self.atoms = atoms
        self.Vmin = V_min
        self.Vmax = V_max
        self.support = torch.linspace(self.Vmin, self.Vmax, self.atoms).to(self.device)
        self.delta_z = (self.Vmax - self.Vmin) / (self.atoms - 1)
        self.batch_size = batch_size
        self.n = multi_step
        self.discount = discount
        self.norm_clip = norm_clip
        
        self.online_net = DQN(self.atoms, self.action_space, env.window).to(self.device)
        self.online_net.train()
        
        self.target_net = DQN(self.atoms, self.action_space, env.window).to(self.device)
        self.update_target_net()
        self.target_net.train()
        
        for param in self.target_net.parameters():
            param.requires_grad = False
        
        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr, eps=adam_eps)
        
    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())
        
    def reset_noise(self):
        self.online_net.reset_noise()
        
    def act(self, state):
        with torch.no_grad():
            return (self.online_net(state.unsqueeze(0)) * self.support).sum(2).argmax(1).item()

In [None]:
agent = Agent(env)

In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

import torch
import numpy as np
from collections import namedtuple


Transition = namedtuple('Transition', ('timestep', 'state', 'action', 'reward', 'nonterminal'))


class SegmentTree():
    def __init__(self, size):
        self.index = 0
        self.size = size
        self.full = False
        self.sum_tree = np.zeros((2 * size - 1, ), dtype=np.float32)
        self.data = np.array([None] * size)
        self.max = 1
    
    def _propagate(self, index, value):
        parent = (index - 1) // 2
        left, right = 2 * parent + 1, 2 * parent + 2
        self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right]
        if parent != 0:
            self._propagate(parent, value)
    
    def update(self, index, value):
        self.sum_tree[index] = value
        self._propagate(index, value)
        self.max = max(value, self.max)
    
    def append(self, data, value):
        self.data[self.index] = data
        self.update(self.index + self.size - 1, value)
        self.index = (self.index + 1) % self.size
        self.full = self.full or self.index == 0
        self.max = max(value, self.max)


class ReplayMemory():
    def __init__(self, capacity, history_length, discount, multi_step, priority_weight=0.4, priority_exponent=0.5):
        self.device = torch.device("cuda:0")
        self.capacity = capacity
        self.history = history_length
        self.discount = discount
        self.n = multi_step
        self.priority_weight = priority_weight
        self.priority_exponent = priority_exponent
        self.t = 0
        self.transitions = SegmentTree(capacity)
        
    def append(self, state, action, reward, terminal):
        state = state[-1].mul(255).to(dtype=torch.uint8, device=torch.device("cpu"))
        self.transitions.append(Transition(self.t, state, action, reward, not terminal), self.transitions.max)
        self.t = 0 if terminal else self.t + 1

In [None]:
mem = ReplayMemory(100_000, env.window, agent.discount, agent.n)
T_max = 100_000
learn_start = 1_600
priority_weight_increase = (1 - mem.priority_weight) / (T_max - learn_start)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure()
plt.imshow(env._get_state().cpu().numpy(), cmap="gray")
plt.show()

state, done = env.reset(), False
plt.figure()
plt.imshow(state[-1].cpu().numpy(), cmap="gray")
plt.show()

action = agent.act(state)
state, reward, done = env.step(action)

plt.figure()
plt.imshow(state[-1].cpu().numpy(), cmap="gray")
plt.show()

In [None]:
reward_clip = 1

T, done = 0, True
replay_frequency = 1
for T in trange(1, T_max + 1):
    if done:
        state, done = env.reset(), False

    if T % replay_frequency == 0:
        agent.reset_noise()

    action = agent.act(state)
    next_state, reward, done = env.step(action)
    if reward_clip > 0:
        reward = max(min(reward, reward_clip), -reward_clip)
    mem.append(state, action, reward, done)