/
OthelloNNet.py
32 lines (27 loc) · 2.15 KB
/
OthelloNNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys
sys.path.append('..')
from utils import *
import argparse
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
class OthelloNNet():
def __init__(self, game, args):
# game params
self.board_x, self.board_y = game.getBoardSize()
self.action_size = game.getActionSize()
self.args = args
# Neural Net
self.input_boards = Input(shape=(self.board_x, self.board_y)) # s: batch_size x board_x x board_y
x_image = Reshape((self.board_x, self.board_y, 1))(self.input_boards) # batch_size x board_x x board_y x 1
h_conv1 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='same', use_bias=False)(x_image))) # batch_size x board_x x board_y x num_channels
h_conv2 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='same', use_bias=False)(h_conv1))) # batch_size x board_x x board_y x num_channels
h_conv3 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='valid', use_bias=False)(h_conv2))) # batch_size x (board_x-2) x (board_y-2) x num_channels
h_conv4 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='valid', use_bias=False)(h_conv3))) # batch_size x (board_x-4) x (board_y-4) x num_channels
h_conv4_flat = Flatten()(h_conv4)
s_fc1 = Dropout(args.dropout)(Activation('relu')(BatchNormalization(axis=1)(Dense(1024, use_bias=False)(h_conv4_flat)))) # batch_size x 1024
s_fc2 = Dropout(args.dropout)(Activation('relu')(BatchNormalization(axis=1)(Dense(512, use_bias=False)(s_fc1)))) # batch_size x 1024
self.pi = Dense(self.action_size, activation='softmax', name='pi')(s_fc2) # batch_size x self.action_size
self.v = Dense(1, activation='tanh', name='v')(s_fc2) # batch_size x 1
self.model = Model(inputs=self.input_boards, outputs=[self.pi, self.v])
self.model.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer=Adam(args.lr))