In [1]:
import cv2
import sys
sys.path.append("game/")
import numpy as np
import random
from collections import deque

import torch
import torch.nn as nn
from torch.autograd import Variable

import wrapped_flappy_bird as game

from DQN import DQN

In [2]:
cuda = True

In [3]:
#convert images to 80*80 gray images
def preprocess(observation):
    img = cv2.resize(observation, (80, 80))
    observation = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY)
    return np.reshape(observation,(80,80,1))

In [4]:
class CFG(object):
    lr=0.001
    actions=2
    is_training=True
    load_weight=False
    gamma=0.99
    batch_size=32
    mem_size=5000
    epsilon=0.9
    initial_epsilon=1.
    final_epsilon=0.1
    observation=100
    exploration=50000
    max_episode=100000
    save_checkpoint_freq = 100000

In [5]:
best_time_step = 0.
flappyBird = game.GameState()
cfg = CFG()
dqn = DQN(cfg)

In [6]:
if cuda:
    dqn.model = dqn.model.cuda()

In [7]:
action=[1,0]
o, r, terminal = flappyBird.frame_step(action)
best_time_step = 0.
o = preprocess(o)

In [8]:
for i in range(cfg.observation):
    action = dqn.get_action_randomly()
    o, r, terminal = flappyBird.frame_step(action)
    o = preprocess(o)
    dqn.storeTransition(o, action, r, terminal)

In [9]:
def test_dqn(model, episode):
    """Test the behavor of dqn when training
       model -- dqn model
       episode -- current training episode
    """
    model.is_training = False
    #model.set_eval()
    ave_time = 0.
    for test_case in range(5):
        model.time_step = 0
        flappyBird = game.GameState()
        o, r, terminal = flappyBird.frame_step([1, 0])
        o = preprocess(o)
        model.set_initial_state()
        while True:
            action = model.get_optim_action()
            o, r, terminal = flappyBird.frame_step(action)
            if terminal:
                break
            o = preprocess(o)
            model.current_state = np.append(model.current_state[1:,:,:], o.reshape((1,)+o.shape), axis=0)
            model.increase_time_step()
        ave_time += model.time_step
    ave_time /= 5
    print ('testing: episode: {}, average time: {}'.format(episode, ave_time))
    return ave_time

In [10]:
for episode in range(cfg.max_episode):
    dqn.time_step = 0
    total_reward = 0.
    while True:
        dqn.optimizer.zero_grad()
        action = dqn.get_action()
        o_next, r, terminal = flappyBird.frame_step(action)
        total_reward += cfg.gamma**dqn.time_step * r
        o_next = preprocess(o_next)
        dqn.storeTransition(o_next, action, r, terminal)
        dqn.increaseTimeStep()
        dqn.model.train()
        dqn.trainByBatch()
        
        if terminal:
            break
    
    print ('episode: {}, epsilon: {:.4f}, max time step: {}, total reward: {:.6f}'.format(
            episode, dqn.epsilon, dqn.time_step, total_reward))
    
    if dqn.epsilon > cfg.final_e:
        delta = (cfg.init_e - cfg.final_e)/cfg.exploration
        dqn.epsilon -= delta
    
    if episode % 100 == 0:
        ave_time = test_dqn(dqn, episode)
    
    if ave_time > best_time_step:
        best_time_step = ave_time
        save_checkpoint({
                'episode': episode,
                'epsilon': dqn.epsilon,
                'state_dict': dqn.state_dict(),
                'best_time_step': best_time_step,
                 }, True, 'checkpoint-episode-%d.pth.tar' %episode)
    elif episode % cfg.save_checkpoint_freq == 0:
        save_checkpoint({
                'episode:': episode,
                'epsilon': dqn.epsilon,
                'state_dict': dqn.state_dict(),
                'time_step': ave_time,
                 }, False, 'checkpoint-episode-%d.pth.tar' %episode)
    else:
        continue
    print ('save checkpoint, episode={}, ave time step={:.2f}'.format(
                episode, ave_time))

NotImplementedError: 