In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import cv2
import sys
sys.path.append("game/")
import wrapped_flappy_bird as game

import numpy as np
import collections
import random
from collections import deque

In [2]:
is_training = True
lr=0.001
actions=2
load_weight=False
gamma=0.99
batch_size=32
mem_size=5000
epsilon=0.9
initial_epsilon=1.
final_epsilon=0.1
observation=100
exploration=50000
max_episode=100000
save_checkpoint_freq = 100000

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4, padding=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.map_size = (64, 10, 10)
        self.fc1 = nn.Linear(self.map_size[0]*self.map_size[1]*self.map_size[2], 256)
        self.fc2 = nn.Linear(256, actions)

    def foward(self, x):
        #foward procedure to get MSE loss
        x = F.relu(self.conv1(x), inplace=True)
        x = F.relu(self.conv2(x), inplace=True)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x), inplace=True)
        x = self.fc2(x)
        return x

In [8]:
#model.cuda()

In [9]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)

In [10]:
empty_frame = np.zeros((80, 80), dtype=np.float32)
empty_state = np.stack((empty_frame, empty_frame, empty_frame, empty_frame), axis=0)

In [11]:
currt_state = empty_state
replay_memory = deque()
flappyBird = game.GameState()

In [12]:
def get_action_randomly():
    #random action
    action = np.zeros(actions, dtype=np.float32)
    action_index = 0 if random.random() < 0.8 else 1
    action[action_index] = 1
    return action

In [13]:
def get_action_optim(model, currt_state):
    #model.cuda()
    state_var = Variable(torch.from_numpy(currt_state), volatile=True).unsqueeze(0)
    #state_var.cuda()
    q_value = model.foward(state_var)
    _, action_index = torch.max(q_value, dim=1)
    #print(action_index.data[0])
    action_index = action_index.data[0]#[0][0]
    action = np.zeros(actions, dtype=np.float32)
    action[action_index] = 1
    return action

In [14]:
def get_action(model, currt_state):
    if is_training and random.random() <= epsilon:
        return get_action_randomly()
    return get_action_optim(model, currt_state)

In [15]:
def save_transition(replay_memory, currt_state, param_dict):
    o_next = param_dict["observation"]
    action = param_dict["action"]
    reward = param_dict["reward"]
    terminal = param_dict["terminal"]
    
    #next_state = np.append(currt_state[1:,:,:], np.reshape(o_next, (1, o_next.shape[0], o_next.shape[1] )), axis=0)
    next_state = np.append(currt_state[1:,:,:], o_next, axis=0)
    #print(np.shape(next_state))
    replay_memory.append((currt_state, action, reward, next_state, terminal))
    if len(replay_memory) > mem_size:
        replay_memory.popleft()
    if not terminal:
        currt_state = next_state
    else:
        currt_state = empty_state
    return currt_state, replay_memory

In [16]:
#convert images to 80*80 gray images
def preprocess(observation):
    img = cv2.resize(observation, (80, 80))
    observation = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY)
    return np.reshape(observation,(1,80,80))

In [17]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
ceriterion = nn.MSELoss()

In [None]:
for i in range(observation):
    action = get_action_randomly()
    o, r, terminal = flappyBird.frame_step(action)
    o = preprocess(o)
    param_dict = {
            "observation":o,
            "action":action,
            "reward":r,
            "terminal":terminal
    }
    currt_state, replay_memory = save_transition(replay_memory, currt_state, param_dict)

In [None]:
for episode in range(max_episode):
    time_step = 0
    total_reward = 0.
    while True:
        optimizer.zero_grad()
        action = get_action(model, currt_state)
        o_next, r, terminal = flappyBird.frame_step(action)
        total_reward = total_reward*(gamma**time_step) + r
        o_next = preprocess(o_next)
        param_dict = {
            "observation":o_next,
            "action":action,
            "reward":r,
            "terminal":terminal
        }
        currt_state, replay_memory = save_transition(replay_memory, currt_state, param_dict)
        time_step += 1
        
        minibatch = random.sample(replay_memory, batch_size)
        
        state_batch = np.array([data[0] for data in minibatch])
        action_batch = np.array([data[1] for data in minibatch])
        reward_batch = np.array([data[2] for data in minibatch])
        next_state_batch = np.array([data[3] for data in minibatch])
        #print(np.shape(next_state_batch))
        
        state_batch_var = Variable(torch.from_numpy(state_batch))
        next_state_batch_var = Variable(torch.from_numpy(next_state_batch), volatile=True)
        
        #state_batch_var = state_batch_var.cuda()
        #next_state_batch_var = next_state_batch_var.cuda()
        #print(next_state_batch_var)
        
        q_value_next = model.foward(next_state_batch_var)
        q_value = model.foward(state_batch_var)
        
        y = reward_batch.astype(np.float32)
        max_q, _ = torch.max(q_value_next, dim=1)
        #print(max_q)
        
        for i in range(batch_size):
            if not minibatch[i][4]:
                y[i] += gamma*max_q.data[i]
                
        y = Variable(torch.from_numpy(y))
        action_batch_var = Variable(torch.from_numpy(action_batch))
        
        #y = y.cuda()
        #action_batch_var = action_batch_var.cuda()
        q_value = torch.sum(torch.mul(action_batch_var, q_value), dim=1)
        
        loss = ceriterion(q_value, y)
        loss.backward()
        optimizer.step()
        
        if terminal:
            break
            
    if epsilon > final_epsilon:
        delta = (initial_epsilon - final_epsilon)/exploration
        epsilon -= delta