In [1]:
import os, sys
import warnings
import random
import time

from vizdoom import *

import numpy as np
from skimage import transform, util, color
from collections import deque
import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [3]:
%matplotlib inline
warnings.filterwarnings('ignore')

In [4]:
### MODEL HYPERPARAMETERS### 
im_size = [200,200,4]        # 200x200 image x 4 frames per input
action_size = 3              # 3 possible actions: left, right, shoot
learning_rate =  0.0002      # Alpha (aka learning rate)

### TRAINING HYPERPARAMETERS
total_episodes = 500        # Total episodes for training
max_steps = 100             # Max possible steps in an episode
batch_size = 64             
save_frequency = 100        # Number of episodes before saving checkpoint

# Exploration parameters for epsilon greedy strategy
explore_start = 1.0            # exploration probability at start
explore_stop = 0.01            # minimum exploration probability 
decay_rate = 0.0001            # exponential decay rate for exploration prob

# Q learning hyperparameters
gamma = 0.95               # Discounting rate

### MEMORY HYPERPARAMETERS
pretrain_length = batch_size   # Number of experiences stored in the Memory when initialized for the first time
que_length = 1000000          # Number of experiences the Memory can keep

training = False
trained_model = 'checkpoint400.pth.tar'

In [5]:
def createEnvironment():
    game = DoomGame()
    game.load_config("scenarios/basic.cfg")
    game.init()
    
    left = [1, 0, 0]
    right = [0, 1, 0]
    shoot = [0, 0, 1]
    actions = [left, right, shoot]
    
    return game, actions

In [6]:
def processImg(img):
#     img = np.swapaxes(img, 0, 2)
#     img = color.rgb2gray(img)
    img = img[70:-10, :]
    img = img / 255
    img = transform.resize(img, (200, 200))
    return img

game, actions = createEnvironment()
img = processImg(game.get_state().screen_buffer)
img.shape
plt.imshow(img)
plt.show()

In [7]:
stack_size = 4
stacked_frames = deque([np.zeros((200,200), dtype=np.int) for i in range(stack_size)], maxlen=4)
def stackFrames(stacked_frames, stateim, is_new_episode):
    frame = processImg(stateim)
    if is_new_episode:
        stacked_frames = deque([np.zeros((200,200), dtype=np.int) for i in range(stack_size)], maxlen=4)
        for _ in range(stack_size):
            stacked_frames.append(frame)
        stacked_state = np.stack(stacked_frames, axis=2)
    else:
        stacked_frames.append(frame)
        stacked_state = np.stack(stacked_frames, axis=2)
        
    return stacked_state, stacked_frames

In [8]:
class DQN(nn.Module):
    def __init__(self, img_size):
        self.img_size = img_size
        super().__init__()
        
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.fc1 = nn.Linear(24*24*128, 512)
        self.fc2 = nn.Linear(512, 3)
        
    def forward(self, x):
        x = F.elu(self.bn1(self.conv1(x)))
        x = F.elu(self.bn2(self.conv2(x)))
        x = F.elu(self.bn3(self.conv3(x)))
        
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = F.elu(self.fc2(x))
        return x

In [9]:
class NNet():
    def __init__(self):
        self.nnet = DQN(200).cuda()
    def train(self, memory, batch_size): 
        optimizer = optim.Adam(self.nnet.parameters(), lr=learning_rate)
        batch = memory.sample(batch_size)
        states_mb = torch.FloatTensor(np.array([each[0] for each in batch], ndmin=3)).contiguous().cuda()
        actions_mb = torch.FloatTensor(np.array([each[1] for each in batch])).contiguous().cuda()
        rewards_mb = torch.FloatTensor(np.array([each[2] for each in batch])).contiguous().cuda()
        next_states_mb = torch.FloatTensor(np.array([each[3] for each in batch], ndmin=3).astype(np.float64)).contiguous().cuda()
        dones_mb = np.array([each[4] for each in batch])
        target_Qs_batch = []
        
        Qs_next_state = self.nnet(next_states_mb.transpose(1,3))
        for i in range(0, len(batch)):
            terminal = dones_mb[i]
            if terminal:
                target_Qs_batch.append(rewards_mb[i])
            else:
#                 print(Qs_next_state.shape)
#                 print(torch.max(Qs_next_state[i], 0))
#                 print(torch.max(Qs_next_state[i], 1))
                Q_target = rewards_mb[i] + gamma*torch.max(Qs_next_state[i])
                target_Qs_batch.append(Q_target)
        
        targets_mb = torch.FloatTensor(np.array([each for each in target_Qs_batch]).astype(np.float64)).contiguous().cuda()
        
        states_mb, targets_mb = Variable(states_mb), Variable(targets_mb)
        
        x = self.nnet(states_mb.transpose(1,3))
        loss = self.loss(targets_mb, x, actions_mb)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()   
        return loss
    
    def predict(self, frames):
        frames = torch.FloatTensor(frames)
        frames = frames.contiguous().cuda().unsqueeze(0)
        self.nnet.eval()
        x = self.nnet(frames.transpose(1,3))
        return x.data.cpu().numpy()[0]
    
    def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(folder):
            print("Checkpoint Directory does not exist! Making directory {}".format(folder))
            os.mkdir(folder)
        else:
            print("Saving Checkpoint")
        torch.save({
            'state_dict' : self.nnet.state_dict(),
        }, filepath)
        
    def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98
        filepath = os.path.join(folder, filename)
        if not os.path.exists(filepath):
            raise("No model in path {}".format(filepath))
        checkpoint = torch.load(filepath)
        self.nnet.load_state_dict(checkpoint['state_dict'])
        
    def loss(self, target_Qs, output, actions_):
        Qs = torch.sum(output*actions_, dim=1)
#         loss = (target_Qs-Qs)**2
        return torch.sum((target_Qs-Qs)**2)/target_Qs.size()[0]

In [10]:
class Memory():
    def __init__(self, quelength):
        self.buffer = deque(maxlen=quelength)
        
    def add(self, experience):
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        buffer_size = len(self.buffer)
        index = np.random.choice(np.arange(buffer_size), batch_size, replace=False)
        return [self.buffer[i] for i in index]

In [11]:
game, possible_actions = createEnvironment()

In [12]:
memory = Memory(que_length)
game.new_episode()

def pretrain():
    for i in range(pretrain_length):
        if i == 0:
            state = game.get_state().screen_buffer
            state, stacked_frames = stackFrames(stacked_frames, state, is_new_episode=True)

        action = random.choice(possible_actions)
        reward = game.make_action(action)
        done = game.is_episode_finished()

        if done:
            next_state = np.zeros(state.shape)
            memory.add((state, action, reward, next_state, done))

            game.new_episode()
            state = game.get_state().screen_buffer
            state, stacked_frames = stackFrames(stacked_frames, state, is_new_episode=True)
        else:
            next_state = game.get_state().screen_buffer
            next_state, stacked_frames = stackFrames(stacked_frames, next_state, is_new_episode=False)
            memory.add((state, action, reward, next_state, done))
            state = next_state

In [13]:
nnet = NNet()
def predict(explore_start, explore_stop, decay_rate, decay_step, state, possible_actions):
    exp_exp_tradeoff = np.random.rand()
    
    explore_probability = explore_stop + (explore_start - explore_stop) * np.exp(-decay_rate * decay_step)
    if explore_probability > exp_exp_tradeoff:
        action = random.choice(possible_actions)
    else:
        Qs = nnet.predict(state)
        choice = np.argmax(Qs)
        action = possible_actions[int(choice)]
        
    return action, explore_probability

def run():
    game, actions = createEnvironment()
    mem = deque([], maxlen=que_length)
    for episode in range(total_episodes):
        state = game.get_state()
        is_new_episode = True
        for step in range(max_steps):
            state_im = stackFrames(state.screen_buffer)
            
            exp_exp_tradeoff = random.uniform(0,1)
            if exp_exp_tradeoff > epsilon:
                action = nnet.predict(state_im)
            else:
                action = random.choice(actions)
                
            reward = game.make_action(action)
            new_s = game.get_state()
            new_im = rescaleAndGray(new_s.screen_buffer)
            
            mem.append((state_im, action, reward, new_im))
            state = new_s
            
            Q_target = mem[mem.index(state_im)][2] + gamma*
        
        episode += 1
        epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*episode)

In [14]:
if training:
    pretrain()
    decay_step = 0
    game.init()
    for episode in range(total_episodes):
        if episode > 0:
            print(f"Episode: {episode}, Total Reward: {np.sum(episode_rewards)}, Loss: {loss}, Explore: {exploration_probability}")
        step = 0
        episode_rewards = []
        game.new_episode()
        state = game.get_state().screen_buffer
        state, stacked_frames = stackFrames(stacked_frames, state, True)
        while step < max_steps:
            step += 1
            decay_step += 1

            action, exploration_probability = predict(explore_start, explore_stop, decay_rate, decay_step, state, possible_actions)
            reward = game.make_action(action)
            done = game.is_episode_finished()
            episode_rewards.append(reward)
            if done:
                next_state = np.zeros((im_size[0], im_size[1]), dtype=np.int)
                next_state, stacked_frames = stackFrames(stacked_frames, next_state, False)
                step = max_steps
                total_reward = np.sum(episode_rewards)
    #             print(f"Episode: {episode}, Total Reward: {total_reward}, Loss: {loss}, Explore: {exploration_probability}")
                memory.add((state, action, reward, next_state, done))
            else:
                next_state = game.get_state().screen_buffer
                next_state, stacked_frames = stackFrames(stacked_frames, next_state, False)
                memory.add((state, action, reward, next_state, done))
                state = next_state

            loss = nnet.train(memory, batch_size=batch_size)
        if episode % save_frequency == 0:
            nnet.save_checkpoint(filename=f'checkpoint{episode}.pth.tar')
    nnet.save_checkpoint(filename='final.pth.tar')

In [15]:
nnet.load_checkpoint(filename=trained_model)
game.init()

done = False
game.new_episode()

state = game.get_state().screen_buffer
state, stacked_frames = stackFrames(stacked_frames, state, True)
while not game.is_episode_finished():
    Qs = nnet.predict(state)
    choice = np.argmax(Qs)
    action = possible_actions[int(choice)]
    
    game.make_action(action)
    done = game.is_episode_finished()
    score = game.get_total_reward()
    
    if done:
        inp = input('Enter anything to replay, or nothing to end')
        if not inp:
            break
        else:
            game.new_episode()
            state = game.get_state().screen_buffer
            state, stacked_frames = stackFrames(stacked_frames, state, True)
    else:
        next_state = game.get_state().screen_buffer
        next_state, stacked_frames = stackFrames(stacked_frames, next_state, False)
        state = next_state
    
score = game.get_total_reward()
print(f'Score: {score}')
game.close()

Enter anything to replay, or nothing to enda
Enter anything to replay, or nothing to end
Score: 85.0
