In [1]:
import retro
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import base64, io

import numpy as np
from collections import deque, namedtuple

from gym.wrappers.monitoring import video_recorder
from IPython.display import HTML
from IPython import display 
import glob

In [2]:
BUFFER_SIZE = 100_000  # refers to the number of states, actions, rewards, and resultant state stored each
BATCH_SIZE = 32        # refers to the sample size randomly picked out from the memory buffer 
GAMMA = 0.99           # discount rate
TAU = 1e-3             # soft update rate 
LR = 5e-4               
UPDATE_EVERY = 4        

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
env = retro.make('ContraForce-Nes', use_restricted_actions=retro.Actions.DISCRETE)

In [4]:
class QNetwork(nn.Module):
    def __init__(self, seed):
        super(QNetwork, self).__init__()
        self.seed = torch.manual_seed(seed)

        self.conv1 = nn.Conv2d(3, 32, 4, 2, 0)
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 4, 2, 1)
        self.conv3 = nn.Conv2d(64, 128, 4, 2, 0)
        self.conv4 = nn.Conv2d(128, 1, 4, 2, 1)
        
    def forward(self, state):
        x = F.relu(self.conv1(state))
        x = F.relu(self.conv2(x))
        x = self.maxpool1(x)
        x = F.relu(self.conv3(x))
        
        return self.conv4(x).view(-1, 36)

In [5]:
class Agent():
    def __init__(self, seed):
        self.action_size = 36
        self.seed = random.seed(seed)

        self.qnetwork_local = QNetwork(seed).to(device)
        self.qnetwork_target = QNetwork(seed).to(device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        self.memory = MemBuffer(BUFFER_SIZE, BATCH_SIZE, seed)
        self.t_step = 0
    
    def step(self, state, action, reward, next_state, done):
        self.memory.add(state, action, reward, next_state, done)
        
        self.t_step = (self.t_step + 1) % UPDATE_EVERY
        if self.t_step == 0:
            if len(self.memory) > BATCH_SIZE:
                experiences = self.memory.sample()
                self.learn(experiences, GAMMA)

    def act(self, state, eps=0.):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, gamma):
        states, actions, rewards, next_states, dones = experiences

        q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
        q_targets = rewards + gamma * q_targets_next * (1 - dones)
        q_expected = self.qnetwork_local(states).gather(1, actions)
        
        loss = F.mse_loss(q_expected, q_targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)                     

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)



class MemBuffer:
    def __init__(self, buffer_size, batch_size, seed):
        self.action_size = 36
        self.memory = deque(maxlen=buffer_size)  
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
        self.seed = random.seed(seed)
    
    def add(self, state, action, reward, next_state, done):
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)
    
    def sample(self):
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.array([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.array([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
  
        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

In [6]:
def dqn(n_episodes=2000, max_t=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
    scores = []                      
    rolling_mean = []
    scores_window = deque(maxlen=100) 
    eps = eps_start                    
    for i_episode in range(1, n_episodes+1):
        state = env.reset()
        state = np.moveaxis(state, -1, 0)
        score = 0
        for t in range(max_t):
            action = agent.act(state, eps)
            next_state, reward, done, info = env.step(action)
            next_state = np.moveaxis(next_state, -1, 0)
            agent.step(state, action, reward, next_state, done)
            env.render()
            state = next_state
            score += reward
            if done:
                break 
        scores_window.append(score)      
        scores.append(score)             
        rolling_mean.append(np.mean(scores_window))
        eps = max(eps_end, eps_decay*eps) 

        plt.figure(figsize=(15,8))
        plt.plot(scores)
        plt.plot(rolling_mean)
        plt.title('scores')
        plt.grid()
        plt.savefig('scores.jpg')
        plt.close()

        print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
        if i_episode % 100 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
        if np.mean(scores_window)>=200.0:
            print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
            torch.save(agent.qnetwork_local.state_dict(), 'model.pth')
            break
    return scores

agent = Agent(seed=0)
scores = dqn()

Episode 37	Average Score: 0.00