In [None]:
! wget http://www.atarimania.com/roms/Roms.rar
! mkdir /content/ROM/
! unrar x /content/Roms.rar /content/ROM/
! python -m atari_py.import_roms /content/ROM/ROMS

In [None]:
!pip install -q ale-py
!apt-get install x11-utils > /dev/null 2>&1
!pip install pyglet > /dev/null 2>&1
!apt-get install -y xvfb python-opengl > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install -U colabgymrender

In [None]:
from collections import deque
import numpy as np
import gym
from gym import spaces, wrappers
import cv2
cv2.ocl.setUseOpenCL(False)

class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
        assert noops > 0
        obs = None
        for _ in range(noops):
            obs, _, done, _ = self.env.step(self.noop_action)
            if done:
                obs = self.env.reset(**kwargs)
        return obs

    def step(self, action):
        return self.env.step(action)


class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset(**kwargs)
        return obs

    def step(self, action):
        return self.env.step(action)


class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.lives = 0
        self.was_real_done = True

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.was_real_done = done
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if 0 < lives < self.lives:
            # for Qbert sometimes we stay in lives == 0 condtion for a few
            # frames so its important to keep lives > 0, so that we only reset
            # once the environment advertises done.
            done = True
        self.lives = lives
        return obs, reward, done, info

    def reset(self, **kwargs):
        """
        Calls the Gym environment reset, only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.
        Args:
            Extra keywords passed to env.reset() call
        Return:
            ([int] or [float]) the first observation of the environment
        """
        if self.was_real_done:
            obs = self.env.reset(**kwargs)
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _ = self.env.step(0)
        self.lives = self.env.unwrapped.ale.lives()
        return obs

class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        """
        Return only every `skip`-th frame (frameskipping)
        Args:
            env: (Gym Environment) the environment
            skip: (int) number of `skip`-th frame
        Return:
            (Gym Environment) the environment with only `skip`-th frame
        """
        gym.Wrapper.__init__(self, env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = np.zeros(
            (2,)+env.observation_space.shape,
            dtype=env.observation_space.dtype)
        self._skip = skip

    def step(self, action):
        """
        Step the environment with the given action
        Repeat action, sum reward, and max over last observations.
        Args:
            action: ([int] or [float]) the action
        Return:
            ([int] or [float], [float], [bool], dict) observation, reward,
                 done, information
        """
        total_reward = 0.0
        done = None
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if done:
                break
        max_frame = self._obs_buffer.max(axis=0)

        return max_frame, total_reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class ClipRewardEnv(gym.RewardWrapper):
    def __init__(self, env):
        """
        clips the reward to {+1, 0, -1} by its sign.
        :param env: (Gym Environment) the environment
        """
        gym.RewardWrapper.__init__(self, env)

    def reward(self, reward):
        """
        Bin reward to {+1, 0, -1} by its sign.
        """
        return np.sign(reward)


class WarpFramePyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        """
        Warp frames to 84x84 as done in the Nature paper and later work.
        Args:
            env: (Gym Environment) the environment
        """
        gym.ObservationWrapper.__init__(self, env)
        self.width = 84
        self.height = 84
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(1, self.height, self.width),
            dtype=env.observation_space.dtype)

    def observation(self, frame):
        """
        returns the current observation from a frame
        Args:
            frame: ([int] or [float]) environment frame

        Return:
            ([int] or [float]) the observation
        """
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(
            frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return frame[None, :, :]


class FrameStackPyTorch(gym.Wrapper):
    def __init__(self, env, n_frames):
        """Stack n_frames last frames.
        Args:
            env: (Gym Environment) the environment
            n_frames: (int) the number of frames to stack
        """
        assert env.observation_space.dtype == np.uint8

        gym.Wrapper.__init__(self, env)
        self.n_frames = n_frames
        self.frames = deque([], maxlen=n_frames)
        shp = env.observation_space.shape

        self.observation_space = spaces.Box(
            low=np.min(env.observation_space.low),
            high=np.max(env.observation_space.high),
            shape=(shp[0] * n_frames, shp[1], shp[2]),
            dtype=env.observation_space.dtype)

    def reset(self):
        obs = self.env.reset()
        for _ in range(self.n_frames):
            self.frames.append(obs)
        return self._get_ob()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_ob(), reward, done, info

    def _get_ob(self):
        assert len(self.frames) == self.n_frames
        return LazyFrames(list(self.frames))


class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.observation_space = spaces.Box(
            low=0, high=1.0, shape=env.observation_space.shape,
            dtype=np.float32)

    def observation(self, observation):
        # careful! This undoes the memory optimization, use
        # with smaller replay buffers only.
        return np.array(observation).astype(np.float32) / 255.0


class LazyFrames(object):
    def __init__(self, frames):
        self._frames = frames
        self.dtype = frames[0].dtype

    def _force(self):
        return np.concatenate(
            np.array(self._frames, dtype=self.dtype), axis=0)

    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        return len(self._force())

    def __getitem__(self, i):
        return self._force()[i]


def make_atari(env_id):
    """
    Create a wrapped atari envrionment
    Args:
        env_id: (str) the environment ID
    Return:
        (Gym Environment) the wrapped atari environment
    """
    env = gym.make(env_id)
    assert 'NoFrameskip' in env.spec.id
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    return env


def wrap_deepmind_pytorch(env, episode_life=True, clip_rewards=True,
                          frame_stack=True, scale=False):
    """
    Configure environment for DeepMind-style Atari.
    Args:
        env: (Gym Environment) the atari environment
        episode_life: (bool) wrap the episode life wrapper
        clip_rewards: (bool) wrap the reward clipping wrapper
        frame_stack: (bool) wrap the frame stacking wrapper
        scale: (bool) wrap the scaling observation wrapper
    Return:
        (Gym Environment) the wrapped atari environment
    """
    if episode_life:
        env = EpisodicLifeEnv(env)
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = WarpFramePyTorch(env)
    if clip_rewards:
        env = ClipRewardEnv(env)
    if scale:
        env = ScaledFloatFrame(env)
    if frame_stack:
        env = FrameStackPyTorch(env, 4)
    return env


def make_pytorch_env(env_id, episode_life=True, clip_rewards=True,
                     frame_stack=True, scale=False):
    env = make_atari(env_id)
    env = wrap_deepmind_pytorch(
        env, episode_life, clip_rewards, frame_stack, scale)
    return env


def wrap_monitor(env, log_dir):
    env = wrappers.Monitor(
        env, log_dir, video_callable=lambda x: True)
    return env

In [None]:
from collections import deque
import numpy as np
import torch
import operator

class MultiStepBuff:

    def __init__(self, maxlen=3):
        super(MultiStepBuff, self).__init__()
        self.maxlen = int(maxlen)
        self.reset()

    def append(self, state, action, reward):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)

    def get(self, gamma=0.99):
        assert len(self.rewards) > 0
        state = self.states.popleft()
        action = self.actions.popleft()
        reward = self._nstep_return(gamma)
        return state, action, reward

    def _nstep_return(self, gamma):
        r = np.sum([r * (gamma ** i) for i, r in enumerate(self.rewards)])
        self.rewards.popleft()
        return r

    def reset(self):
        # Buffer to store n-step transitions.
        self.states = deque(maxlen=self.maxlen)
        self.actions = deque(maxlen=self.maxlen)
        self.rewards = deque(maxlen=self.maxlen)

    def is_empty(self):
        return len(self.rewards) == 0

    def is_full(self):
        return len(self.rewards) == self.maxlen

    def __len__(self):
        return len(self.rewards)


class LazyMemory(dict):
    state_keys = ['state', 'next_state']
    np_keys = ['action', 'reward', 'done']
    keys = state_keys + np_keys

    def __init__(self, capacity, state_shape, device):
        super(LazyMemory, self).__init__()
        self.capacity = int(capacity)
        self.state_shape = state_shape
        self.device = device
        self.reset()

    def reset(self):
        self['state'] = []
        self['next_state'] = []

        self['action'] = np.empty((self.capacity, 1), dtype=np.int64)
        self['reward'] = np.empty((self.capacity, 1), dtype=np.float32)
        self['done'] = np.empty((self.capacity, 1), dtype=np.float32)

        self._n = 0
        self._p = 0

    def append(self, state, action, reward, next_state, done,
               episode_done=None):
        self._append(state, action, reward, next_state, done)

    def _append(self, state, action, reward, next_state, done):
        self['state'].append(state)
        self['next_state'].append(next_state)
        self['action'][self._p] = action
        self['reward'][self._p] = reward
        self['done'][self._p] = done

        self._n = min(self._n + 1, self.capacity)
        self._p = (self._p + 1) % self.capacity

        self.truncate()

    def truncate(self):
        while len(self) > self.capacity:
            del self['state'][0]
            del self['next_state'][0]

    def sample(self, batch_size):
        indices = np.random.randint(low=0, high=len(self), size=batch_size)
        return self._sample(indices, batch_size)

    def _sample(self, indices, batch_size):
        bias = -self._p if self._n == self.capacity else 0

        states = np.empty(
            (batch_size, *self.state_shape), dtype=np.uint8)
        next_states = np.empty(
            (batch_size, *self.state_shape), dtype=np.uint8)

        for i, index in enumerate(indices):
            _index = np.mod(index+bias, self.capacity)
            states[i, ...] = self['state'][_index]
            next_states[i, ...] = self['next_state'][_index]

        states = torch.ByteTensor(states).to(self.device).float() / 255.
        next_states = torch.ByteTensor(
            next_states).to(self.device).float() / 255.
        actions = torch.LongTensor(self['action'][indices]).to(self.device)
        rewards = torch.FloatTensor(self['reward'][indices]).to(self.device)
        dones = torch.FloatTensor(self['done'][indices]).to(self.device)

        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self['state'])

    def get(self):
        return dict(self)

    def load(self, memory):
        for key in self.state_keys:
            self[key].extend(memory[key])

        num_data = len(memory['state'])
        if self._p + num_data <= self.capacity:
            for key in self.np_keys:
                self[key][self._p:self._p+num_data] = memory[key]
        else:
            mid_index = self.capacity - self._p
            end_index = num_data - mid_index
            for key in self.np_keys:
                self[key][self._p:] = memory[key][:mid_index]
                self[key][:end_index] = memory[key][mid_index:]

        self._n = min(self._n + num_data, self.capacity)
        self._p = (self._p + num_data) % self.capacity
        self.truncate()
        assert self._n == len(self)


class LazyMultiStepMemory(LazyMemory):

    def __init__(self, capacity, state_shape, device, gamma=0.99,
                 multi_step=3):
        super(LazyMultiStepMemory, self).__init__(
            capacity, state_shape, device)

        self.gamma = gamma
        self.multi_step = int(multi_step)
        if self.multi_step != 1:
            self.buff = MultiStepBuff(maxlen=self.multi_step)

    def append(self, state, action, reward, next_state, done):
        if self.multi_step != 1:
            self.buff.append(state, action, reward)

            if self.buff.is_full():
                state, action, reward = self.buff.get(self.gamma)
                self._append(state, action, reward, next_state, done)

            if done:
                while not self.buff.is_empty():
                    state, action, reward = self.buff.get(self.gamma)
                    self._append(state, action, reward, next_state, done)
        else:
            self._append(state, action, reward, next_state, done)
import numpy as np
import torch

In [None]:
class LinearAnneaer:
    def __init__(self, start_value, end_value, num_steps):
        assert num_steps > 0 and isinstance(num_steps, int)

        self.steps = 0
        self.start_value = start_value
        self.end_value = end_value
        self.num_steps = num_steps

        self.a = (self.end_value - self.start_value) / self.num_steps
        self.b = self.start_value

    def step(self):
        self.steps = min(self.num_steps, self.steps + 1)

    def get(self):
        assert 0 < self.steps <= self.num_steps
        return self.a * self.steps + self.b

def disable_gradients(network):
    # Disable calculations of gradients.
    for param in network.parameters():
        param.requires_grad = False

In [None]:
from copy import copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import math


def initialize_weights_xavier(m, gain=1.0):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight, gain=gain)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

def initialize_weights_he(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

# The input image size is 7x7x64
# The DQNbase is used for feature extraction.
class DQNBase(nn.Module):
    def __init__(self, num_channels, embedding_dim=7*7*64):
        super(DQNBase, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(num_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            Flatten(),
        ).apply(initialize_weights_he)
        self.embedding_dim = embedding_dim

    def forward(self, states):
        batch_size = states.shape[0]

        # Calculate embeddings of states.
        state_embedding = self.net(states)
        assert state_embedding.shape == (batch_size, self.embedding_dim)

        return state_embedding

class BaseModel(nn.Module):
    def __init__(self):
        super().__init__()

class DQN(BaseModel):
    def __init__(self, num_channels, num_actions, embedding_dim=7*7*64):
        super(DQN, self).__init__()
        # Feature extractor of DQN.
        self.feature = DQNBase(num_channels=num_channels)
        # DQN network.
        linear = nn.Linear
        self.q_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, num_actions),
            )
        self.num_channels = num_channels
        self.num_actions = num_actions
        self.embedding_dim = embedding_dim

    def forward(self, states=None):
        batch_size = states.shape[0]
        state_embeddings = self.feature(states)
        outputs = self.q_net(state_embeddings).view(batch_size,self.num_actions)
        return outputs


class DuelNoisyQ(nn.Module):
    def __init__(self, num_channels, num_actions, embedding_dim=7*7*64, fc1_units=256, fc2_units=256, num_atoms=51):
        super(DuelNoisyQ, self).__init__()
        # self.seed = torch.manual_seed(seed)

        # Feature extractor of DQN.
        self.feature = DQNBase(num_channels=num_channels)

        # set advantage layer
        self.advantage_hidden_layer = NoisyLinear(embedding_dim, fc1_units)
        self.advantage_layer = NoisyLinear(fc1_units, num_actions * num_atoms)

        # set value layer
        self.value_hidden_layer = NoisyLinear(embedding_dim, fc1_units)
        self.value_layer = NoisyLinear(fc1_units, num_atoms)

        # V architecture
        self.v_net = nn.Sequential(
                # nn.Linear(embedding_dim, fc1_units),
                self.value_hidden_layer,
                # nn.ReLU(),
                # nn.Linear(fc1_units, fc2_units),
                nn.ReLU(),
                self.value_layer
            )


        # Advantage architecture
        self.adv_net = nn.Sequential(
                # nn.Linear(embedding_dim, fc1_units),
                self.advantage_hidden_layer,
                # nn.ReLU(),
                # nn.Linear(fc1_units, fc2_units),
                nn.ReLU(),
                self.advantage_layer
            )


        self.num_channels = num_channels
        self.num_actions = num_actions
        self.embedding_dim = embedding_dim
    def forward(self, states=None):
        feature = self.feature(states)
        advantage = self.adv_net(feature).view(-1, self.num_actions, self.atom_size)
        value     = self.v_net(feature).view(-1, 1, self.atom_size)

        q_atoms = value + advantage - advantage.mean(dim=1, keepdim=True)

        dist = torch.softmax(q_atoms, dim=-1)
        q = torch.sum(dist * self.support, dim=2)

        return q

    def reset_noise(self):
        """Reset all noisy layers."""
        self.advantage_hidden_layer.reset_noise()
        self.advantage_layer.reset_noise()
        self.value_hidden_layer.reset_noise()
        self.value_layer.reset_noise()


class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma=0.5):
        super(NoisyLinear, self).__init__()

        # Learnable parameters.
        self.mu_W = nn.Parameter(
            torch.FloatTensor(out_features, in_features))
        self.sigma_W = nn.Parameter(
            torch.FloatTensor(out_features, in_features))
        self.mu_bias = nn.Parameter(torch.FloatTensor(out_features))
        self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features))

        # Factorized noise parameters.
        self.register_buffer('eps_p', torch.FloatTensor(in_features))
        self.register_buffer('eps_q', torch.FloatTensor(out_features))

        self.in_features = in_features
        self.out_features = out_features
        self.sigma = sigma

        self.reset()
        self.sample()

    def reset(self):
        bound = 1 / np.sqrt(self.in_features)
        self.mu_W.data.uniform_(-bound, bound)
        self.mu_bias.data.uniform_(-bound, bound)
        self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features))
        self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.out_features))

    def f(self, x):
        return x.normal_().sign().mul(x.abs().sqrt())

    def sample(self):
        self.eps_p.copy_(self.f(self.eps_p))
        self.eps_q.copy_(self.f(self.eps_q))

    def forward(self, x):
        if self.training:
            weight = self.mu_W + self.sigma_W * self.eps_q.ger(self.eps_p)
            bias = self.mu_bias + self.sigma_bias * self.eps_q.clone()
        else:
            weight = self.mu_W
            bias = self.mu_bias

        return F.linear(x, weight, bias)


In [None]:
class QRDQN(nn.Module):
    def __init__(self, num_channels, num_actions, embedding_dim=7*7*64, num_quants=50):
        super(QRDQN, self).__init__()


        self.features = DQNBase(num_channels=num_channels)
        self.advantage_net = nn.Sequential(
                nn.Linear(embedding_dim, 256),
                nn.ReLU(),
                nn.Linear(256, num_actions * num_quants),
            )
        self.value_net = nn.Sequential(
                nn.Linear(embedding_dim, 256),
                nn.ReLU(),
                nn.Linear(256, num_quants),
            )

        self.embedding_dim  = embedding_dim
        self.num_actions = num_actions
        self.num_quants  = num_quants



    def forward(self, states=None):

        batch_size = states.shape[0]

        feature = self.features(states)

        advantages = self.advantage_net(feature).view(batch_size, self.num_quants, self.num_actions)
        baselines = self.value_net(feature).view(batch_size, self.num_quants, 1)
        quantiles = baselines + advantages - advantages.mean(dim=2, keepdim=True)

        assert quantiles.shape == (batch_size, self.num_quants, self.num_actions)

        return quantiles

    def calculate_q(self, states=None):

        batch_size = states.shape[0]

        # Calculate quantiles.
        quantiles = self(states=states)

        # Calculate expectations of value distributions.
        q = quantiles.mean(dim=1)
        assert q.shape == (batch_size, self.num_actions)

        return q



In [None]:
from abc import ABC, abstractmethod
import os
from torch.optim import Adam
from collections import deque
import random


class BaseAgent(ABC):
    def __init__(self, env, test_env, num_steps=5*(10**7),
                 batch_size=32, memory_size=10**6, gamma=0.99, multi_step=1,
                 update_interval=4, target_update_interval=10000,
                 start_steps=50000, epsilon_train=0.01, epsilon_eval=0.001,
                 epsilon_decay_steps=250000, double_q_learning=False,
                 dueling_net=False, noisy_net=False, use_per=False,
                 log_interval=100, eval_interval=250000, num_eval_steps=125000,
                 max_episode_steps=27000, grad_cliping=5.0, cuda=True, seed=0, num_quants=50):
        self.env = env
        self.test_env = test_env

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.env.seed(seed)
        self.test_env.seed(2**31-1-seed)
        self.device = torch.device(
            "cuda" if cuda and torch.cuda.is_available() else "cpu")
        self.online_net = None
        self.target_net = None
        self.memory = LazyMultiStepMemory(
                memory_size, self.env.observation_space.shape,
                self.device, gamma, multi_step)
        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.best_eval_score = -np.inf # used to record best evaluation reward
        self.num_actions = self.env.action_space.n  # dimension of action space
        self.num_steps = num_steps  # total training number steps
        self.batch_size = batch_size # Batch size when training neural networks
        self.eval_interval = eval_interval # The frequency of evaluating agents
        self.num_eval_steps = num_eval_steps # Evaluation steps
        self.gamma = gamma ** multi_step # Discount factor
        self.start_steps = start_steps
        self.epsilon_train = LinearAnneaer(
            1.0, epsilon_train, epsilon_decay_steps)
        self.epsilon_eval = epsilon_eval # minimum epsion when evaluating agents
        self.update_interval = update_interval
        self.target_update_interval = target_update_interval # how many steps are needed to update target network
        self.max_episode_steps = max_episode_steps # max episode steps
        self.grad_cliping = grad_cliping # clip gradient
        self.num_quants = num_quants


    def run(self):
        while True:
            self.train_episode()
            if self.steps > self.num_steps:
                break

    def is_update(self):
        return self.steps % self.update_interval == 0\
            and self.steps >= self.start_steps

    def is_random(self, eval=False):
        # Use e-greedy for evaluation.
        if self.steps < self.start_steps:
            return True
        if eval:
            return np.random.rand() < self.epsilon_eval
        return np.random.rand() < self.epsilon_train.get()

    def update_target(self):
        self.target_net.load_state_dict(
            self.online_net.state_dict())

    def explore(self):
        # Act with randomness.
        action = self.env.action_space.sample()
        return action

    def select_action(self, state):
        # Act without randomness.
        state = torch.ByteTensor(
            state).unsqueeze(0).to(self.device).float() / 255.
        with torch.no_grad():
            action = self.online_net.calculate_q(states=state).argmax().item()
        return action

    @abstractmethod
    def learn(self):
        pass
    def train_episode(self):
        self.online_net.train()
        self.target_net.train()

        self.episodes += 1
        episode_return = 0.
        episode_steps = 0

        done = False
        state = self.env.reset()
        ep_reward_list = []
        avg_reward_list = []

        while (not done) and episode_steps <= self.max_episode_steps:
            if self.is_random(eval=False):
                action = self.explore()
            else:
                action = self.select_action(state)
            next_state, reward, done, _ = self.env.step(action)
            self.memory.append(state, action, reward, next_state, done)
            self.steps += 1
            episode_steps += 1
            episode_return += reward
            state = next_state

            self.train_step_interval()

        avg_return = np.mean(ep_reward_list[-100:])

        print(f'Episode: {self.episodes:<4} '
              f'episode steps: {episode_steps:<4} '
              f'total steps: {self.steps:<4}  '
              f'return: {episode_return:<5.1f}'
              )
        ep_reward_list.append(episode_return)
        avg_reward_list.append(avg_return)



    def train_step_interval(self):
        self.epsilon_train.step()

        if self.steps % self.target_update_interval == 0:
            self.update_target()

        if self.is_update():
            self.learn()

        if self.steps % self.eval_interval == 0:
            self.evaluate()
            self.online_net.train()

    def evaluate(self):
        self.online_net.eval()
        num_episodes = 0
        num_steps = 0
        total_return = 0.0

        while True:
            state = self.test_env.reset()
            episode_steps = 0
            episode_return = 0.0
            done = False
            while (not done) and episode_steps <= self.max_episode_steps:
                if self.is_random(eval=True):
                    action = self.explore()
                else:
                    action = self.select_action(state)

                next_state, reward, done, _ = self.test_env.step(action)
                num_steps += 1
                episode_steps += 1
                episode_return += reward
                state = next_state

            num_episodes += 1
            total_return += episode_return

            if num_steps > self.num_eval_steps:
                break

        mean_return = total_return / num_episodes

        if mean_return > self.best_eval_score:
            self.best_eval_score = mean_return

        print('-' * 60)
        print(f'Num steps: {self.steps:<5}  '
              f'return: {mean_return:<5.1f}')
        print('-' * 60)

    def __del__(self):
        self.env.close()
        self.test_env.close()


def calculate_huber_loss(td_errors, kappa=1.0):
    return torch.where(
        td_errors.abs() <= kappa,
        0.5 * td_errors.pow(2),
        kappa * (td_errors.abs() - 0.5 * kappa))

def calculate_quantile_huber_loss(td_errors, taus, weights=None, kappa=1.0):
    assert not taus.requires_grad
    batch_size, N, N_dash = td_errors.shape

    # Calculate huber loss element-wisely.
    element_wise_huber_loss = calculate_huber_loss(td_errors, kappa)
    assert element_wise_huber_loss.shape == (
        batch_size, N, N_dash)

    # Calculate quantile huber loss element-wisely.
    element_wise_quantile_huber_loss = torch.abs(
        taus[..., None] - (td_errors.detach() < 0).float()
        ) * element_wise_huber_loss / kappa
    assert element_wise_quantile_huber_loss.shape == (
        batch_size, N, N_dash)

    # Quantile huber loss.
    batch_quantile_huber_loss = element_wise_quantile_huber_loss.sum(
        dim=1).mean(dim=1, keepdim=True)
    assert batch_quantile_huber_loss.shape == (batch_size, 1)

    if weights is not None:
        quantile_huber_loss = (batch_quantile_huber_loss * weights).mean()
    else:
        quantile_huber_loss = batch_quantile_huber_loss.mean()

    return quantile_huber_loss

def evaluate_quantile_at_action(s_quantiles, actions):
    assert s_quantiles.shape[0] == actions.shape[0]

    batch_size = s_quantiles.shape[0]
    N = s_quantiles.shape[1]

    # Expand actions into (batch_size, N, 1).
    action_index = actions[..., None].expand(batch_size, N, 1)

    # Calculate quantile values at specified actions.
    sa_quantiles = s_quantiles.gather(dim=2, index=action_index)

    return sa_quantiles



class DQNAgent(BaseAgent):
      def __init__(self, env, test_env, num_steps=5*(10**7),
                 batch_size=32, N=200, kappa=1.0, lr=5e-5, memory_size=10**6,
                 gamma=0.99, multi_step=1, update_interval=4,
                 target_update_interval=10000, start_steps=50000,
                 epsilon_train=0.01, epsilon_eval=0.001,
                 epsilon_decay_steps=250000, double_q_learning=False,
                 dueling_net=False, noisy_net=False, use_per=False,
                 log_interval=100, eval_interval=250000, num_eval_steps=125000,
                 max_episode_steps=27000, grad_cliping=None, cuda=True,
                 seed=0, num_quants=50):
        super(DQNAgent, self).__init__(
            env, test_env, num_steps, batch_size, memory_size,
            gamma, multi_step, update_interval, target_update_interval,
            start_steps, epsilon_train, epsilon_eval, epsilon_decay_steps,
            double_q_learning, dueling_net, noisy_net, use_per, log_interval,
            eval_interval, num_eval_steps, max_episode_steps, grad_cliping,
            cuda, seed)


        self.online_net = QRDQN(
              num_channels=env.observation_space.shape[0], num_actions=self.num_actions, num_quants=self.num_quants).to(self.device)


        self.target_net = QRDQN(
              num_channels=env.observation_space.shape[0], num_actions=self.num_actions, num_quants=self.num_quants).to(self.device)

        # Copy parameters of the learning network to the target network.
        self.update_target()
        # Disable calculations of gradients of the target network.
        disable_gradients(self.target_net)
        self.optim = Adam(
            self.online_net.parameters(),
            lr=lr, eps=1e-2/batch_size)
        self.loss_fn = torch.nn.MSELoss()
        self.num_quants = num_quants

        taus = torch.arange(
            0, self.num_quants+1, device=self.device, dtype=torch.float32) / self.num_quants
        self.tau_hats = ((taus[1:] + taus[:-1]) / 2.0).view(1, self.num_quants)
        self.kappa = kappa

      def learn(self):
        self.learning_steps += 1
        states, actions, rewards, next_states, dones =\
            self.memory.sample(self.batch_size)
        weights = None

        quantile_loss, mean_q, errors = self.calculate_loss(
            states, actions, rewards, next_states, dones, weights)
        self.optim.zero_grad()
        quantile_loss.backward()
        # Clip norms of gradients to stebilize training.
        self.optim.step()




      def calculate_loss(self, states, actions, rewards, next_states, dones, weights):

        current_sa_quantiles = evaluate_quantile_at_action(
            self.online_net(states=states), actions)
        assert current_sa_quantiles.shape == (self.batch_size, self.num_quants, 1)

        with torch.no_grad():
            # Calculate Q values of next states.
            next_q = self.online_net.calculate_q(states=next_states)

            # Calculate greedy actions.
            next_actions = torch.argmax(next_q, dim=1, keepdim=True)
            assert next_actions.shape == (self.batch_size, 1)

            # Calculate quantile values of next states and actions at tau_hats.
            next_sa_quantiles = evaluate_quantile_at_action(
                self.target_net(states=next_states),
                next_actions).transpose(1, 2)
            assert next_sa_quantiles.shape == (self.batch_size, 1, self.num_quants)

            # Calculate target quantile values.
            target_sa_quantiles = rewards[..., None] + (
                1.0 - dones[..., None]) * self.gamma * next_sa_quantiles
            assert target_sa_quantiles.shape == (self.batch_size, 1, self.num_quants)

        td_errors = target_sa_quantiles - current_sa_quantiles
        assert td_errors.shape == (self.batch_size, self.num_quants, self.num_quants)

        quantile_huber_loss = calculate_quantile_huber_loss(
            td_errors, self.tau_hats, weights, self.kappa)

        return quantile_huber_loss, next_q.detach().mean().item(), \
            td_errors.detach().abs().sum(dim=1).mean(dim=1, keepdim=True)

In [None]:
# Initialize Training and Testing Environments
env = make_pytorch_env("BreakoutNoFrameskip-v4")
test_env = make_pytorch_env("BreakoutNoFrameskip-v4", episode_life=False, clip_rewards=False)


In [None]:
# Start Training
agent = DQNAgent(env=env, test_env=test_env)
agent.run()

In [None]:
# Function used for saving weights
weight_path = "qnet.pth"
torch.save(agent.online_net.state_dict(), weight_path)

In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')
GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = "rl_coding_proj3"
GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
actor_path = "weight_50.pth"
path = os.path.join(GOOGLE_DRIVE_PATH, actor_path)

In [None]:
env = make_pytorch_env("BreakoutNoFrameskip-v4")
test_env = make_pytorch_env("BreakoutNoFrameskip-v4", episode_life=False, clip_rewards=False)
agent = Agent1(env=env, test_env=test_env)

import time
start_time = time.time()
def evaluate():
    num_episodes = 0
    num_steps = 0
    total_return = 0.0
    while True:
        state = test_env.reset()
        episode_steps = 0
        episode_return = 0.0
        done = False
        print("number of episodes: ", num_episodes)
        while (not done) and episode_steps <= 27000:
            state = torch.ByteTensor(state).unsqueeze(0).to(agent.device).float() / 255.
            with torch.no_grad():
                action = agent.online_net.calculate_q(states=state).argmax().item()
            next_state, reward, done, _ = test_env.step(action)
            num_steps += 1
            episode_steps += 1
            episode_return += reward
            state = next_state
            print("number of steps: ", num_steps)
            print("number of episode steps: ", episode_steps)
            print("reward: ", episode_return)
        num_episodes += 1
        total_return += episode_return



        if num_steps > 120000:
            break

    mean_return = total_return / num_episodes
    print('-' * 60)
    print(f'return: {mean_return:<5.1f}')
    print('-' * 60)

evaluate()
print(time.time()-start_time)

In [None]:
import gym
import random
import matplotlib.pyplot as plt
from colabgymrender.recorder import Recorder

env = make_pytorch_env("BreakoutNoFrameskip-v4")
directory = './video'
env = Recorder(env, directory)
state = env.reset()
terminal = False
while not terminal:
    state = torch.ByteTensor(state).unsqueeze(0).to(agent.device).float() / 255.
    with torch.no_grad():
        action = agent.online_net(states=state).argmax().item()
    state,r,terminal,info = env.step(action)
env.play()