In [1]:
import cv2
import gym
import copy
import tqdm
import torch
import math
import random
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from collections import deque, namedtuple

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
torch.cuda.get_device_name()

'Tesla T4'

In [4]:
SEED = 42

In [5]:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True

In [6]:
class ReplayMemory():
    def __init__(self, mem_capacity, batch_size):
        self.mem_capacity = mem_capacity
        self.batch_size = batch_size
        self.memory = deque(maxlen=self.mem_capacity)
        self.Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state')) 

    def __len__(self):
        return len(self.memory)

    def push(self, s, a, r, ns):
        s = torch.FloatTensor(s)
        a = torch.LongTensor([a])
        r = torch.FloatTensor([r])
        if ns is not None:
            ns = torch.FloatTensor(ns)

        transition = self.Transition(state=s, action=a, reward=r, next_state=ns)
        self.memory.append(transition)

    def sample(self):
        transitions = random.sample(self.memory, self.batch_size)
        return self.Transition(*(zip(*transitions)))

In [7]:
class SkipMax(gym.Wrapper):
    def __init__(self, env, skip=4):
        gym.Wrapper.__init__(self, env)
        self.skip = skip
        self.frame_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)

    def reset(self):
        return self.env.reset()

    def step(self, action):
        total_reward = 0
        for i in range(self.skip):
            state, reward, done, info = self.env.step(action)

            if i == self.skip - 2:
                self.frame_buffer[0] = state

            if i == self.skip - 1:
                self.frame_buffer[1] = state

            total_reward += reward
            if done:
                break

        max_frame = self.frame_buffer.max(axis=0)
        return max_frame, total_reward, done, info

In [8]:
class WrapFrame(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.width, self.height = 84, 84
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8)

    def observation(self, frame):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return frame[:, :, None]

In [9]:
class ClipReward(gym.RewardWrapper):
    def __init__(self, env):
        gym.RewardWrapper.__init__(self, env)

    def reward(self, r):
        return np.sign(r)

In [10]:
class LazyFrames(object):
    def __init__(self, frames):
        self.frames = frames
        self.output = None

    def get_output(self):
        if self.output is None:
            self.output = np.concatenate(self.frames, axis=2)
            self.frames = None
        
        return self.output

    def __array__(self, dtype=None):
        output = self.get_output()
        if dtype is not None:
            output = output.astype(dtype)

        return output

    def __len__(self):
        return len(self.get_output())

    def __getitem__(self, index):
        return self.get_output()[index]

In [11]:
class StackFrames(gym.Wrapper):
    def __init__(self, env, k=4):
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = deque(maxlen=k)

        obs_shape = env.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(obs_shape[0], obs_shape[1], obs_shape[2] * k), dtype=np.uint8)
    
    def get_frames(self):
        return LazyFrames(list(self.frames))

    def reset(self):
        obs = self.env.reset()
        for _ in range(self.k):
            self.frames.append(obs)        

        return self.get_frames()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.frames.append(obs)

        return self.get_frames(), reward, done, info

In [12]:
class WrapImage(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        obs_shape = env.observation_space.shape
        self.observation_shape = gym.spaces.Box(low=0.0, high=1.0, shape=(obs_shape[-1], obs_shape[0], obs_shape[1]), dtype=np.uint8)

    def observation(self, image):
        return np.swapaxes(image, 2, 0)

In [13]:
def get_env(env_name):
    env = gym.make(env_name)
    env = SkipMax(env)
    env = WrapFrame(env)
    env = ClipReward(env)
    env = StackFrames(env)
    env = WrapImage(env)
    return env

In [14]:
class DQN(nn.Module):
    def __init__(self, output_size):
        super().__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        self.fc1 = nn.Linear(7 * 7 * 64, 512)
        self.fc2 = nn.Linear(512, output_size)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.shape[0], -1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

In [15]:
class Agent():
    def __init__(self, env, net, memory, update_freq, learning_start, e_start, e_end, e_steps, gamma, target_update, print_every, render):
        self.env = env
        self.net = net.to(device)
        self.memory = memory
        self.update_freq = update_freq
        self.learning_start = learning_start
        self.e_start = e_start

        self.e_end = e_end
        self.e_steps = e_steps
        self.gamma = gamma
        self.target_update = target_update
        self.print_every = print_every
        self.render = render

        self.steps = 0
        self.episode = 0
        self.target_net = copy.deepcopy(self.net)
        self.target_net.eval()
        self.opt = torch.optim.Adam(self.net.parameters(), lr=1e-4)
        
    def get_epsilon(self):
        epsilon = self.e_end + (self.e_start - self.e_end) * math.exp(-1. * self.steps / self.e_steps)
        return epsilon

    def get_action(self, state):
        epsilon = self.get_epsilon()
        if random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            with torch.no_grad():
                state = torch.FloatTensor(state).unsqueeze(0).to(device)
                Q = self.net(state)
                action = Q.max(1)[1].item()

        return action

    def train(self, episodes):
        ep_rewards = []

        for episode in tqdm.tqdm(range(episodes), total=episodes):
            done = False
            episode_reward = 0
            state = self.env.reset()

            while not done:
                if self.render == True and episode % self.print_every == 0:
                    env.render()

                action = self.get_action(state)
                next_state, reward, done, _ = self.env.step(action)
                episode_reward += reward
                self.memory.push(state, action, reward, None if done else next_state)
                state = next_state
                self.steps += 1

                if self.steps % self.update_freq == 0 and self.steps > self.learning_start:
                    ep_loss = self.optimize()

                if self.steps % (self.target_update * self.update_freq) == 0 and self.steps > self.learning_start:
                    self.target_net.load_state_dict(self.net.state_dict())

            ep_rewards.append(episode_reward)
            if episode % self.print_every == 0:
                avg_reward = np.mean(ep_rewards[-self.print_every:])
                print(f" episode: {episode} | avg_reward: {avg_reward:.4f}")

        return ep_rewards

    def optimize(self):
        mem_sample = self.memory.sample()
        non_terminal_mask = torch.ByteTensor(list(map(lambda ns: ns is not None, mem_sample.next_state)))

        state_batch = torch.cat(mem_sample.state).to(device)
        action_batch = torch.cat(mem_sample.action).unsqueeze(1).to(device)
        reward_batch = torch.cat(mem_sample.reward).unsqueeze(1).to(device)
        non_terminal_next_state_batch = torch.cat([ns for ns in mem_sample.next_state if ns is not None]).to(device)

        state_batch = state_batch.view(self.memory.batch_size, 4, 84, 84)
        non_terminal_next_state_batch = non_terminal_next_state_batch.view(-1, 4, 84, 84)

        Q_preds = self.net(state_batch)
        Q_vals = Q_preds.gather(1, action_batch)

        target_preds = self.target_net(non_terminal_next_state_batch)
        target_vals = torch.zeros(self.memory.batch_size, 1).to(device)
        target_vals[non_terminal_mask] = target_preds.max(1)[0].unsqueeze(1)
        
        expected_vals = reward_batch + (self.gamma * target_vals)
        loss = F.smooth_l1_loss(Q_vals, expected_vals.detach())

        self.opt.zero_grad()
        loss.backward()
        for p in self.net.parameters():
            p.grad.data.clamp_(-1, 1)
        self.opt.step()

        return loss.item()

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
episodes = 500
mem_capacity = 10000
batch_size = 32
render = False 
env_name = "PongNoFrameskip-v4"
output_size = gym.make(env_name).action_space.n
learning_start = 10000
update_freq = 1
e_start = 1.0
e_end = 0.01
e_steps = 30000
gamma = 0.99
target_update = 1000
print_every = 50

In [17]:
device

device(type='cuda')

In [18]:
env = get_env(env_name)
env.seed(SEED)
memory = ReplayMemory(mem_capacity, batch_size)
net = DQN(output_size)
agent = Agent(env, net, memory, update_freq, learning_start, e_start, e_end, 
              e_steps, gamma, target_update, print_every, render)

In [19]:
reward_history = agent.train(episodes)

  0%|          | 1/500 [00:01<10:43,  1.29s/it]

 episode: 0 | avg_reward: -21.0000


 10%|█         | 51/500 [07:53<1:45:01, 14.03s/it]

 episode: 50 | avg_reward: -20.1800


 20%|██        | 101/500 [26:22<2:45:44, 24.92s/it]

 episode: 100 | avg_reward: -10.7600


 20%|██        | 102/500 [26:46<2:43:07, 24.59s/it]

KeyboardInterrupt: ignored