In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

from collections import deque
import random
import torch
import cv2
import gym
from gym.wrappers.pixel_observation import PixelObservationWrapper


class Env():
    def __init__(self, action_size=4, history_length=4):
        self.device = torch.device("cuda:0")
        self.wrapped_env = PixelObservationWrapper(gym.make("LunarLander-v2"), pixels_only=True)
        self.action_space = [i for i in range(action_size)]
        self.window = history_length
        self.state_buffer = deque([], maxlen=self.window)
    
    def _reset_buffer(self):
        for _ in range(self.window):
            self.state_buffer.append(torch.zeros(84, 84, device=self.device))
            
    def _process_observation(self, observation):
        observation = cv2.cvtColor(cv2.resize(observation["pixels"], (84, 84), interpolation=cv2.INTER_AREA), cv2.COLOR_RGB2GRAY)
        observation = torch.tensor(observation, dtype=torch.float32, device=self.device).div_(255)
        self.state_buffer.append(observation)
        return torch.stack(list(self.state_buffer), 0)
    
    def reset(self):
        self._reset_buffer()
        observation = self.wrapped_env.reset()
        observation = self._process_observation(observation)
        return observation
    
    def close(self):
        self.wrapped_env.close()
    
    def step(self, action):
        observation, reward, done, info = self.wrapped_env.step(action)
        observation = self._process_observation(observation)
        return observation, reward, done, info

In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, std_init=0.5):
        super(NoisyLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        self.register_buffer('bias_epsilon', torch.empty(out_features))
        self.reset_parameters()
        self.reset_noise()
    
    def reset_parameters(self):
        mu_range = 1 / math.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_features))
    
    def _scale_noise(self, size):
        x = torch.randn(size)
        return x.sign().mul_(x.abs().sqrt_())
    
    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_features)
        epsilon_out = self._scale_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)
    
    def forward(self, input):
        if self.training:
            return F.linear(input, self.weight_mu + self.weight_sigma * self.weight_epsilon, self.bias_mu + self.bias_sigma * self.bias_epsilon)
        else:
            return F.linear(input, self.weight_mu, self.bias_mu)


class DQN(nn.Module):
    def __init__(self, atoms, action_size, history_length, hidden_size=256, noisy_std=0.1):
        super(DQN, self).__init__()
        self.atoms = atoms
        self.action_size = action_size
        
        self.convs = nn.Sequential(nn.Conv2d(history_length, 32, 5, stride=5, padding=0), nn.ReLU(), 
                                   nn.Conv2d(32, 64, 5, stride=5, padding=0), nn.ReLU()
                                  )
        self.conv_output_size = 576
        
        self.fc_h_v = NoisyLinear(self.conv_output_size, hidden_size, std_init=noisy_std)
        self.fc_h_a = NoisyLinear(self.conv_output_size, hidden_size, std_init=noisy_std)
        self.fc_z_v = NoisyLinear(hidden_size, self.atoms, std_init=noisy_std)
        self.fc_z_a = NoisyLinear(hidden_size, self.action_size * self.atoms, std_init=noisy_std)
    
    def forward(self, x, use_log_softmax=False):
        x = self.convs(x)
        x = x.view(-1, self.conv_output_size)
        
        v = self.fc_z_v(F.relu(self.fc_h_v(x)))
        a = self.fc_z_a(F.relu(self.fc_h_a(x)))
        v, a = v.view(-1, 1, self.atoms), a.view(-1, self.action_size, self.atoms)
        q = v + a - a.mean(1, keepdim=True)
        q = F.log_softmax(q, dim=2) if use_log_softmax else F.softmax(q, dim=2)
        
        return q
    
    def reset_noise(self):
        self.fc_h_v.reset_noise()
        self.fc_h_a.reset_noise()
        self.fc_z_v.reset_noise()
        self.fc_z_a.reset_noise()

In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

import torch
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_


class Agent():
    def __init__(self, env, atoms=51, V_min=-10.0, V_max=10.0, batch_size=32, multi_step=20, discount=0.99, norm_clip=10.0, lr=1.0e-4, adam_eps=1.5e-4):
        self.device = torch.device("cuda:0")
        self.action_size = len(env.action_space)
        self.atoms = atoms
        self.Vmin = V_min
        self.Vmax = V_max
        self.support = torch.linspace(self.Vmin, self.Vmax, self.atoms).to(self.device)
        self.delta_z = (self.Vmax - self.Vmin) / (self.atoms - 1)
        self.batch_size = batch_size
        self.n = multi_step
        self.discount = discount
        self.norm_clip = norm_clip
        
        self.online_net = DQN(self.atoms, self.action_size, env.window).to(self.device)
        self.online_net.train()
        
        self.target_net = DQN(self.atoms, self.action_size, env.window).to(self.device)
        self.update_target_net()
        self.target_net.train()
        
        for param in self.target_net.parameters():
            param.requires_grad = False
        
        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr, eps=adam_eps)
    
    def train(self):
        self.online_net.train()
    
    def eval(self):
        self.online_net.eval()
    
    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())
    
    def reset_noise(self):
        self.online_net.reset_noise()
    
    def act(self, observation):
        with torch.no_grad():
            return (self.online_net(observation.unsqueeze(0)) * self.support).sum(2).argmax(1).item()
    
    def act_e_greedy(self, state, epsilon=0.001):
        return np.random.randint(0, self.action_size) if np.random.random() < epsilon else self.act(state)
    
    def learn(self, mem):
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(self.batch_size)
        
        log_ps = self.online_net(states, use_log_softmax=True)
        log_ps_a = log_ps[range(self.batch_size), actions]
        
        with torch.no_grad():
            pns = self.online_net(next_states)
            dns = self.support.expand_as(pns) * pns
            argmax_indices_ns = dns.sum(2).argmax(1)
            self.target_net.reset_noise()
            pns = self.target_net(next_states)
            pns_a = pns[range(self.batch_size), argmax_indices_ns]
            
            Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)
            Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)
            b = (Tz - self.Vmin) / self.delta_z
            l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.atoms - 1)) * (l == u)] += 1
            
            m = states.new_zeros(self.batch_size, self.atoms)
            offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
            m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))
            m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))
        
        loss = -torch.sum(m * log_ps_a, 1)
        self.online_net.zero_grad()
        (weights * loss).mean().backward()
        clip_grad_norm_(self.online_net.parameters(), self.norm_clip)
        self.optimizer.step()
        
        mem.update_priorities(idxs, loss.detach().cpu().numpy())

In [None]:
# adapted from https://github.com/Kaixhin/Rainbow

import torch
import numpy as np
from collections import namedtuple


Transition = namedtuple('Transition', ('timestep', 'state', 'action', 'reward', 'nonterminal'))
blank_trans = Transition(0, torch.zeros(84, 84, dtype=torch.uint8), None, 0, False)


class SegmentTree():
    def __init__(self, size):
        self.index = 0
        self.size = size
        self.full = False
        self.sum_tree = np.zeros((2 * size - 1, ), dtype=np.float32)
        self.data = np.array([None] * size)
        self.max = 1
    
    def _propagate(self, index, value):
        parent = (index - 1) // 2
        left, right = 2 * parent + 1, 2 * parent + 2
        self.sum_tree[parent] = self.sum_tree[left] + self.sum_tree[right]
        if parent != 0:
            self._propagate(parent, value)
    
    def update(self, index, value):
        self.sum_tree[index] = value
        self._propagate(index, value)
        self.max = max(value, self.max)
    
    def append(self, data, value):
        self.data[self.index] = data
        self.update(self.index + self.size - 1, value)
        self.index = (self.index + 1) % self.size
        self.full = self.full or self.index == 0
        self.max = max(value, self.max)
        
    def total(self):
        return self.sum_tree[0]
    
    def _retrieve(self, index, value):
        left, right = 2 * index + 1, 2 * index + 2
        if left >= len(self.sum_tree):
            return index
        elif value <= self.sum_tree[left]:
            return self._retrieve(left, value)
        else:
            return self._retrieve(right, value - self.sum_tree[left])
    
    def find(self, value):
        index = self._retrieve(0, value)
        data_index = index - self.size + 1
        return (self.sum_tree[index], data_index, index)
    
    def get(self, data_index):
        return self.data[data_index % self.size]


class ReplayMemory():
    def __init__(self, capacity, history_length, discount, multi_step, priority_weight=0.4, priority_exponent=0.5):
        self.device = torch.device("cuda:0")
        self.capacity = capacity
        self.history = history_length
        self.discount = discount
        self.n = multi_step
        self.priority_weight = priority_weight
        self.priority_exponent = priority_exponent
        self.t = 0
        self.transitions = SegmentTree(capacity)
        
    def append(self, state, action, reward, terminal):
        state = state[-1].mul(255).to(dtype=torch.uint8, device=torch.device("cpu"))
        self.transitions.append(Transition(self.t, state, action, reward, not terminal), self.transitions.max)
        self.t = 0 if terminal else self.t + 1
        
    def _get_transition(self, idx):
        transition = np.array([None] * (self.history + self.n))
        transition[self.history - 1] = self.transitions.get(idx)
        for t in range(self.history - 2, -1, -1):
            if transition[t + 1].timestep == 0:
                transition[t] = blank_trans
            else:
                transition[t] = self.transitions.get(idx - self.history + 1 + t)
        for t in range(self.history, self.history + self.n):
            if transition[t - 1].nonterminal:
                transition[t] = self.transitions.get(idx - self.history + 1 + t)
            else:
                transition[t] = blank_trans
        return transition
        
    def _get_sample_from_segment(self, segment, i):
        valid = False
        while not valid:
            sample = np.random.uniform(i * segment, (i + 1) * segment)
            prob, idx, tree_idx = self.transitions.find(sample)
            if (self.transitions.index - idx) % self.capacity > self.n and (idx - self.transitions.index) % self.capacity >= self.history and prob != 0:
                valid = True
        
        transition = self._get_transition(idx)
        state = torch.stack([trans.state for trans in transition[:self.history]]).to(device=self.device).to(dtype=torch.float32).div_(255)
        next_state = torch.stack([trans.state for trans in transition[self.n:self.n + self.history]]).to(device=self.device).to(dtype=torch.float32).div_(255)
        action = torch.tensor([transition[self.history - 1].action], dtype=torch.int64, device=self.device)
        R = torch.tensor([sum(self.discount ** n * transition[self.history + n - 1].reward for n in range(self.n))], dtype=torch.float32, device=self.device)
        nonterminal = torch.tensor([transition[self.history + self.n - 1].nonterminal], dtype=torch.float32, device=self.device)
        
        return prob, idx, tree_idx, state, action, R, next_state, nonterminal
    
    def sample(self, batch_size):
        p_total = self.transitions.total()
        segment = p_total / batch_size
        batch = [self._get_sample_from_segment(segment, i) for i in range(batch_size)]
        probs, idxs, tree_idxs, states, actions, returns, next_states, nonterminals = zip(*batch)
        states, next_states, = torch.stack(states), torch.stack(next_states)
        actions, returns, nonterminals = torch.cat(actions), torch.cat(returns), torch.stack(nonterminals)
        probs = np.array(probs, dtype=np.float32) / p_total
        capacity = self.capacity if self.transitions.full else self.transitions.index
        weights = (capacity * probs) ** -self.priority_weight
        weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=self.device)
        return tree_idxs, states, actions, returns, next_states, nonterminals, weights
    
    def update_priorities(self, idxs, priorities):
        priorities = np.power(priorities, self.priority_exponent)
        [self.transitions.update(idx, priority) for idx, priority in zip(idxs, priorities)]

In [None]:
import numpy as np
from datetime import datetime


env = Env(history_length=10)
agent = Agent(env)
mem = ReplayMemory(100_000, env.window, agent.discount, agent.n)

episodes = 3
replay_frequency = 1
reward_clip = 1
max_steps = 200_000
learning_start_step = 1600
priority_weight_increase = (1 - mem.priority_weight) / (max_steps - learning_start_step)

rewards = []
ep_rewards = []

print(f"{datetime.now()}, start training")
steps = 0
for episode_ix in range(1, episodes+1):
    observation, ep_reward, ep_steps, done = env.reset(), 0, 0, False
#     while not done:
    for _ in range(3):
        if steps % replay_frequency == 0:
            agent.reset_noise()
        action = agent.act(observation)
        next_observation, reward, done, info = env.step(action)
        rewards.append(reward)
        ep_reward += reward
        ep_steps += 1
        steps += 1
        print(f"{datetime.now()}, episode:{episode_ix:2d}, step:{steps:5d}, action:{action}, reward:{reward:9.4f}, done:{done}")
        if reward_clip > 0:
            reward = max(min(reward, reward_clip), -reward_clip)
        mem.append(observation, action, reward, done)
        if steps >= learning_start_step:
            mem.priority_weight = min(mem.priority_weight + priority_weight_increase, 1)
            if steps % replay_frequency == 0:
                pass
#                 agent.learn(mem)
            
        observation = next_observation
    ep_rewards.append(ep_reward)
    print(f"{datetime.now()}, episode:{episode_ix:2d}, step:{steps:5d}, reward:{ep_reward:10.4f}")
print(f"{datetime.now()}, end training")
env.close()

In [None]:
reward_clip = 1
target_update = 2000

rewards = []
eval_rewards = []
episode_rewards = []
eval_episode_rewards = []
episode_reward = 0.0
episode_steps = 0
episode_count = 0

done = True
for T in range(1, T_max + 1):
    if done:
        state, done = env.reset(), False

    if T % replay_frequency == 0:
        agent.reset_noise()

    action = agent.act(state)
    next_state, reward, done = env.step(action)
    
    rewards.append(reward)
    episode_reward += reward
    episode_steps += 1
    if done:
        episode_count += 1
        episode_rewards.append(episode_reward)
        print(f"{datetime.now()}, T:{T}, Episode:{episode_count}, Steps:{episode_steps}, Avg. Reward:{episode_reward/episode_steps:.4f}, Total Reward:{episode_reward}")
        episode_reward = 0.0
        episode_steps = 0
    
    if reward_clip > 0:
        reward = max(min(reward, reward_clip), -reward_clip)
    mem.append(state, action, reward, done)
    
    if T >= learn_start:
        mem.priority_weight = min(mem.priority_weight + priority_weight_increase, 1)
        
        if T % replay_frequency == 0:
            agent.learn(mem)
            
        if episode_count % 10 == 0 and done:
            agent.eval()
            
            eval_episode_reward = 0.0
            eval_episode_steps = 0
            eval_done = True
            while True:
                if eval_done:
                    eval_state, eval_done = env.reset(), False
                eval_action = agent.act_e_greedy(eval_state)
                eval_state, eval_reward, eval_done = env.step(eval_action)
                eval_rewards.append(eval_reward)
                eval_episode_reward += eval_reward
                eval_episode_steps += 1
                if eval_done:
                    eval_episode_rewards.append(eval_episode_reward)
                    print(f"{datetime.now()}, T:{T}, Eval_Episode:{episode_count}, Steps:{eval_episode_steps}, "
                          f"Avg. Reward:{eval_episode_reward/eval_episode_steps:.4f}, Total Reward:{eval_episode_reward}")
                    break
            agent.train()
        
        if T % target_update == 0:
            agent.update_target_net()
    
    state = next_state

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt


plt.style.use("dark_background")

In [None]:
plt.figure()
axes = plt.gca()
axes.set_ylim([0,500])
plt.hist(rewards)
plt.show()

plt.figure()
plt.plot(episode_rewards)
plt.show()

plt.figure()
plt.plot(eval_episode_rewards)
plt.show()