In [46]:
import math
import copy
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"  # ポップアウトウィンドウを表示しないようにする
from ple import PLE
from ple.games.flappybird import FlappyBird
from collections import defaultdict
from itertools import chain

In [47]:
def make_anim(images, fps=60, true_image=False):
    duration = len(images) / fps
    import moviepy.editor as mpy

    def make_frame(t):
        try:
            x = images[int(len(images) / duration * t)]
        except:
            x = images[-1]

        if true_image:
            return x.astype(np.uint8)
        else:
            return ((x + 1) / 2 * 255).astype(np.uint8)

    clip = mpy.VideoClip(make_frame, duration=duration)
    clip.fps = fps
    return clip, duration

In [48]:
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames):
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')
    
    def animate(i):
        patch.set_data(frames[i])
    
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    
    anim.save('movie_flappy_bird_DQN.mp4')
    display(display_animation(anim, default_mode='loop'))

In [49]:
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

In [50]:
class Net(nn.Module):
    
    def __init__(self, num_states, num_actions):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(num_states, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, num_actions)
    
    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        output = self.fc3(h2)
        return output

In [51]:
class Brain:
    def __init__(self, num_states, num_actions):
        self.model = Net(num_states, num_actions)
        param = torch.load('weight.pth')
        self.model.load_state_dict(param)
    
    def decide_action(self, state):
        self.model.eval() # ネットワークを推論モードに切り替える
        with torch.no_grad():
            action = self.model(state).max(1)[1].view(1, 1)
        return action

In [52]:
class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)
    
    def get_action(self, state):
        action = self.brain.decide_action(state)
        return action

In [53]:
class Environment:
    
    def __init__(self):
        self.game = FlappyBird()
        self.env = PLE(self.game, fps=30, display_screen=False)
        self.num_states = len(self.game.getGameState())  # 8
        self.num_actions = len(self.env.getActionSet()) # 1
        self.agent = Agent(self.num_states, self.num_actions)
        
    def run(self):
        self.env.reset_game() # 環境の初期化
        observation = self.game.getGameState() # 観測をそのまま状態sとして使用
        state = observation
        state = np.array(list(self.get_relative_state(state)))
        state = torch.from_numpy(state).type(torch.FloatTensor) # numpy変数をPyTorchのテンソルに変換
        # FloatTensor size 4 を size 1x4に変換
        state = torch.unsqueeze(state, 0)

        # record frame
        frames = [self.env.getScreenRGB()]

        while not self.env.game_over():
            action = self.agent.get_action(state)
            _ = self.env.act(self.env.getActionSet()[action])
            frames.append(self.env.getScreenRGB())
            state_next = self.game.getGameState() 
            
            state_next = np.array(list(self.get_relative_state(state_next)))
            state_next = torch.from_numpy(state_next).type(torch.FloatTensor) 
            state_next = torch.unsqueeze(state_next, 0)
            state = state_next
            
            done = self.game.game_over()
            if done:
                break
        
        print("len frames:", len(frames))
        clip, duration = make_anim(frames, fps=60, true_image=True)
        clip = clip.rotate(-90)
        display(clip.ipython_display(fps=60, autoplay=1, loop=1, max_duration = duration))

    bucket_range_per_feature = {
        'next_next_pipe_bottom_y': 40,
        'next_next_pipe_dist_to_player': 512,
        'next_next_pipe_top_y': 40,
        'next_pipe_bottom_y': 20,
        'next_pipe_dist_to_player': 20,
        'next_pipe_top_y': 20,
        'player_vel': 4,
        'player_y': 16
    }
    
    def get_relative_state(self, state):
        # パイプの絶対位置の代わりに相対位置を使用する
        state = copy.deepcopy(state)
        state['next_next_pipe_bottom_y'] -= state['player_y']
        state['next_next_pipe_top_y'] -= state['player_y']
        state['next_pipe_bottom_y'] -= state['player_y']
        state['next_pipe_top_y'] -= state['player_y']

        # アルファベット順に並び替える
        state_key = [k for k, v in sorted(state.items())]

        # 相対位置を返す
        state_idx = []
        for key in state_key:
            state_idx.append(int(state[key] / self.bucket_range_per_feature[key]))
        return tuple(state_idx)

In [54]:
flappy_env = Environment()
flappy_env.run()

  8%|▊         | 15/180 [00:00<00:01, 149.61it/s]

len frames: 179


 99%|█████████▉| 179/180 [00:00<00:00, 270.32it/s]
