In [None]:
from obstacle_tower_env import ObstacleTowerEnv

import os
import time
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

print(torch.cuda.is_available())
print(torch.cuda.device_count())

device = torch.device("cuda:0")

training_time = int(time.time())

In [None]:
env = ObstacleTowerEnv(retro=True, realtime_mode=False)
env.seed(550)
env.floor(0)
obs = env.reset()

In [None]:
print(env.observation_space)
print(env.action_space)

movement_dict = {0:"No-Op", 1:"Forward", 2:"Backward"}
cam_rot_dict = {0:"No-Op", 1:"Counter-Clockwise", 2:"Clockwise"}
jump_dict = {0:"No-Op", 1:"Jump"}
turn_dict = {0:"No-Op", 1:"Right", 2:"Left"}

plt.imshow(obs)

In [None]:
def output_to_mdas(pred_output):
    pred_movement = pred_output // 3
    pred_cam_rot = pred_output % 3
    pred_jump = 0
    pred_turn = 0
    
    pred_action = [pred_movement, pred_cam_rot, pred_jump, pred_turn]
    
    return np.array(pred_action)

In [None]:
def probs_to_md(a_probs):
    a_actuals = (a_probs >= 0.5).astype(int)
    a_md = [0, 0, 0, 0]
    if a_actuals[0] == 1:
        if a_actuals[2] == 0:
            a_md[1] = 1
    if a_actuals[1] == 1:
        a_md[0] = 1
    if a_actuals[2] == 1:
        if a_actuals[0] == 0:
            a_md[1] = 2
    if a_actuals[3] == 1:
        a_md[2] = 1
    return a_md

In [None]:
def md_to_d(md_action):
    return md_action[0] * 18 + md_action[1] * 6 + md_action[2] * 3 + md_action[3]

In [None]:
def d_to_md(d_action):
    md_movement = (d_action // 18) 
    md_cam_rot = (d_action // 6) % 3
    md_jump = (d_action // 3) % 2
    md_turn = d_action % 3 
    
    md_action = [md_movement, md_cam_rot, md_jump, md_turn]
    
    return np.array(md_action)

In [None]:
n_actions = 6

In [None]:
counter = 0
for i in range(1):
    random_out = random.randrange(n_actions)
    random_action_md = output_to_mdas(random_out)
    random_action = md_to_d(random_action_md)
    if i == 0:
        print(f"Random action:{random_action_md}")
        print(f"Movement:{movement_dict[random_action_md[0]]}, RotateCam:{cam_rot_dict[random_action_md[1]]}, "
              f"Jump:{jump_dict[random_action_md[2]]}, Turn:{turn_dict[random_action_md[3]]}")
    obs, reward, done, info = env.step(random_action)
    if done:
        print("Done")
    counter += 1
print(f"Reward:{reward}, Done:{done}")
plt.imshow(obs)
plt.show()

counter = 0
while not done:
    random_action = env.action_space.sample()
    random_action_md = d_to_md(random_action)
    action_probs = np.array([0.2, 0.9, 0.3, 0.2])
    random_action_md = probs_to_md(action_probs)
    random_action = md_to_d(random_action_md)
    if counter == 0:
        print(f"Random action:{random_action_md}")
        print(f"Movement:{movement_dict[random_action_md[0]]}, RotateCam:{cam_rot_dict[random_action_md[1]]}, "
              f"Jump:{jump_dict[random_action_md[2]]}, Turn:{turn_dict[random_action_md[3]]}")
    obs, reward, done, info = env.step(random_action)
    if done:
        print("Done")
    counter += 1
print(f"Reward:{reward}, Done:{done}")
plt.imshow(obs)
plt.show()

In [None]:
resize = T.Compose([T.ToPILImage(),
                    T.ToTensor()])

def get_screen(obs):
    screen = obs.transpose((2, 0, 1))
    screen = torch.from_numpy(screen)
    return resize(screen).unsqueeze(0).to(device)

obs = env.reset()
plt.figure()
plt.imshow(get_screen(obs).cpu().squeeze(0).permute(1, 2, 0).numpy(), interpolation='none')
plt.title('Example screen')
plt.show()

In [None]:
class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(6, 16, kernel_size=3, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=2)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=2)
        self.bn4 = nn.BatchNorm2d(32)

        def conv2d_size_out(size, kernel_size = 3, stride = 2):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(conv2d_size_out(w))))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(conv2d_size_out(h))))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, outputs)
    
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        return self.head(x.view(x.size(0), -1))

In [None]:
BATCH_SIZE = 1024
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 10000
TARGET_UPDATE = 4

ep_rewards = []

AGGREGATE_STATS_EVERY = 10
SAVE_MODEL_EVERY = 20

init_screen = get_screen(obs)
_, _, screen_height, screen_width = init_screen.shape

policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
    
    
    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity
    
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    
    def __len__(self):
        return len(self.memory)

In [None]:
optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(2048)

steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

In [None]:
temp_state = get_screen(obs)
temp_state = torch.cat((temp_state, temp_state), 1)
temp_out = select_action(temp_state)
random_action_md = output_to_mdas(temp_out.item())
random_action = md_to_d(random_action_md)
print(f"Random action:{random_action_md}")
print(f"Movement:{movement_dict[random_action_md[0]]}, RotateCam:{cam_rot_dict[random_action_md[1]]}, "
      f"Jump:{jump_dict[random_action_md[2]]}, Turn:{turn_dict[random_action_md[3]]}")

In [None]:
episode_durations = []

def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())

In [None]:
last_screen = get_screen(obs)
current_screen = get_screen(obs)
state = current_screen - last_screen
input_state = torch.cat((current_screen, state), 1)
for i in range(1): 
    with torch.no_grad():
        action = policy_net(input_state).max(1)[1].view(1, 1)
    
    obs, reward, done, info = env.step(md_to_d(output_to_mdas(action.item())))
    
    last_screen = current_screen
    current_screen = get_screen(obs)
    if not done:
        next_state = current_screen - last_screen
    else:
        next_state = None
    
    state = next_state
    input_state = torch.cat((current_screen, state), 1)
    if done:
        break

plt.imshow(obs)

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
num_episodes = 500
print(f"{datetime.now()} Start training")
for i_episode in range(1, num_episodes+1):
    
    episode_reward = 0
    
    obs = env.reset()
    last_screen = get_screen(obs)
    current_screen = get_screen(obs)
    state = current_screen - last_screen
    input_state = torch.cat((current_screen, state), 1)
    for t in count():
        action = select_action(input_state)
        obs, reward, done, info = env.step(md_to_d(output_to_mdas(action.item())))

        episode_reward += reward
        
        reward = torch.tensor([reward], device=device)

        last_screen = current_screen
        current_screen = get_screen(obs)
        if not done:
            next_state = current_screen - last_screen
            next_input_state = torch.cat((current_screen, next_state), 1)
        else:
            next_state = None
            next_input_state = None

        memory.push(input_state, action, next_input_state, reward)

        state = next_state
        input_state = next_input_state
    
        optimize_model()
        if done:
            episode_durations.append(t + 1)
#             plot_durations()
            break

    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
        
    print(f"{datetime.now()} Episode:{i_episode}, EpisodeReward:{episode_reward:.2f}")
    ep_rewards.append(episode_reward)
    if not i_episode % AGGREGATE_STATS_EVERY or i_episode == 1:
        average_reward = sum(ep_rewards[-AGGREGATE_STATS_EVERY:])/len(ep_rewards[-AGGREGATE_STATS_EVERY:])
        min_reward = min(ep_rewards[-AGGREGATE_STATS_EVERY:])
        max_reward = max(ep_rewards[-AGGREGATE_STATS_EVERY:])
        print(f"{datetime.now()} Episode:{i_episode}, AvgReward:{average_reward:.2f}, MinReward:{min_reward:.2f}, MaxReward:{max_reward:.2f}")
    if not i_episode % SAVE_MODEL_EVERY:
        if not os.path.exists(f"models/{training_time}"):
            os.makedirs(f"models/{training_time}")
        torch.save(policy_net.state_dict(), f"models/{training_time}/{int(time.time())}_{i_episode}_{average_reward:.2f}_model.pth")

print(f"{datetime.now()} Complete")

In [None]:
np.save("models/ep_rewards.npy" , np.array(ep_rewards))

In [None]:
ep_rewards_eval = []
print(f"{datetime.now()} Start evaluation")
for _ in range(20):
    done = False
    env.seed(np.random.randint(low=560, high=99990))
    env.floor(0)
    obs = env.reset()
    episode_reward = 0
    last_screen = get_screen(obs)
    current_screen = get_screen(obs)
    state = current_screen - last_screen
    input_state = torch.cat((current_screen, state), 1)
    while not done:
        with torch.no_grad():
            action = policy_net(input_state).max(1)[1].view(1, 1)

        obs, reward, done, info = env.step(md_to_d(output_to_mdas(action.item())))

        episode_reward += reward

        last_screen = current_screen
        current_screen = get_screen(obs)
        if not done:
            next_state = current_screen - last_screen
            next_input_state = torch.cat((current_screen, next_state), 1)
        else:
            next_state = None
            next_input_state = None

        state = next_state
        input_state = next_input_state
        if done:
            break
    ep_rewards_eval.append(episode_reward)
    print(f"{datetime.now()} Episode reward:{episode_reward:.2f}")
ep_rewards_eval = np.array(ep_rewards_eval)
print(f"{datetime.now()} Min episode reward:{ep_rewards_eval.min()}")
print(f"{datetime.now()} Max episode reward:{ep_rewards_eval.max()}")
print(f"{datetime.now()} Mean episode reward:{ep_rewards_eval.mean()}")

In [None]:
np.save("models/ep_rewards_eval.npy" , ep_rewards_eval)

In [None]:
env.close()