In [2]:
import gym
#import creversi.gym_reversi
#from creversi import *
import os
import datetime
import math
import random
import numpy as np
from collections import namedtuple 
from itertools import count
#from tqdm import tqdm_notebook as tqdm 
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F

In [6]:
#env = gym.make('Reversi-v0').unwrapped
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [7]:
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'next_actions', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None) 
        self.memory[self.position] = Transition(*args) 
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
        
    def __len__(self):
        return len(self.memory)

In [8]:
k = 192
fcl_units = 256

class DQN(nn.Module):

    def __init__(self, output_dim = 65):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(2, k, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(k)
        self.conv2 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(k)
        self.conv3 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(k)
        self.conv4 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(k)
        self.conv5 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(k)
        self.conv6 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(k)
        self.conv7 = nn.Conv2d(k, k,kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(k)
        self.conv8 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(k)
        self.conv9 = nn.Conv2d(k, k, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(k)
        self.conv10 = nn.Conv2d(k, k, kernel_size=3, padding=1) 
        self.bn10 = nn.BatchNorm2d(k)
        self.fcl1 = nn.Linear(k * 64, fcl_units) 
        self.fcl2 = nn.Linear(fcl_units, output_dim)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = F.relu(self.bn9(self.conv9(x)))
        x = F.relu(self.bn10(self.conv10(x)))
        x = F.relu(self.fcl1(x.view(-1, k * 64))) 
        x = (self.fcl2(x)).tanh()
        return x

In [9]:
def get_state(board):
    pass

BATCH_SIZE = 256
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 2000
OPTIMIZE_PER_EPISODES = 16
TARGET_UPDATE = 4

In [11]:
policy_net = DQN().to(device)
target_net = DQN().to(device) 
target_net.load_state_dict(policy_net.state_dict()) 
target_net.eval()
None

In [12]:
optimizer = optim.RMSprop(policy_net.parameters(), lr=1e-5)
memory = ReplayMemory(131072)

In [13]:
def epsilon_greedy(state, legal_moves, episodes_done):
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * episodes_done / EPS_DECAY)
    if sample > eps_threshold:
        with torch.no_grad():
            q = policy_net(state)
            _, select = q[0, legal_moves].max(0)
    else:
        select = random.randrange(len(legal_moves))
    return select

def select_action(state, board, episodes_done):
    legal_moves = list(board.legal_moves)
    select = epsilon_greedy(state, legal_moves, episodes_done)
    return legal_moves[select], torch.tensor([[legal_moves[select]]], device=device, dtype=torch.long)

In [15]:
losses = []

criterion = nn.SmoothL1Loss()

def optimize_model():

    if len(memory) < BATCH_SIZE:
        return None
    
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, 
                        batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    non_final_next_actions_list = []
    for next_actions in batch.next_actions:
        if next_actions is not None: 
            non_final_next_actions_list.append(next_actions + [next_actions[0]] * (30 - len(next_actions)))
            non_final_next_actions = torch.tensor(non_final_next_actions_list, device=device, dtype=torch.long)
    
    # Compute Q(s_t, a)
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)

    # 合法手のみの最大値
    target_q = target_net(non_final_next_states)
    # 相手番の価値のため反転する
    next_state_values[non_final_mask] = -target_q.gather(1, non_final_next_actions).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = next_state_values * GAMMA + reward_batch

    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    losses.append(loss.item())
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

    
    