Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add keras resnet for Connect4 * Update architecture * Fix error * Correctly calculate loss
- Loading branch information
Showing
3 changed files
with
179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import sys | ||
sys.path.append('..') | ||
from utils import * | ||
|
||
import argparse | ||
from tensorflow.keras.models import * | ||
from tensorflow.keras.layers import * | ||
from tensorflow.keras.optimizers import Adam | ||
from tensorflow.keras.activations import * | ||
|
||
def relu_bn(inputs): | ||
relu1 = relu(inputs) | ||
bn = BatchNormalization()(relu1) | ||
return bn | ||
|
||
def residual_block(x, filters, kernel_size=3): | ||
y = Conv2D(kernel_size=kernel_size, | ||
strides= (1), | ||
filters=filters, | ||
padding="same")(x) | ||
|
||
y = relu_bn(y) | ||
y = Conv2D(kernel_size=kernel_size, | ||
strides=1, | ||
filters=filters, | ||
padding="same")(y) | ||
|
||
y = BatchNormalization()(y) | ||
out = Add()([x, y]) | ||
out = relu(out) | ||
|
||
return out | ||
|
||
def value_head(input): | ||
conv1 = Conv2D(kernel_size=1, | ||
strides=1, | ||
filters=1, | ||
padding="same")(input) | ||
|
||
bn1 = BatchNormalization()(conv1) | ||
bn1_relu = relu(bn1) | ||
|
||
flat = Flatten()(bn1_relu) | ||
|
||
dense1 = Dense(256)(flat) | ||
dn_relu = relu(dense1) | ||
|
||
dense2 = Dense(256)(dn_relu) | ||
|
||
return dense2 | ||
|
||
def policy_head(input): | ||
conv1 = Conv2D(kernel_size=2, | ||
strides=1, | ||
filters=1, | ||
padding="same")(input) | ||
bn1 = BatchNormalization()(conv1) | ||
bn1_relu = relu(bn1) | ||
flat = Flatten()(bn1_relu) | ||
return flat | ||
|
||
class Connect4NNet(): | ||
def __init__(self, game, args): | ||
# game params | ||
self.board_x, self.board_y = game.getBoardSize() | ||
self.action_size = game.getActionSize() | ||
self.args = args | ||
|
||
# Neural Net | ||
# Inputs | ||
self.input_boards = Input(shape=(self.board_x, self.board_y)) | ||
inputs = Reshape((self.board_x, self.board_y, 1))(self.input_boards) | ||
|
||
|
||
bn1 = BatchNormalization()(inputs) | ||
conv1 = Conv2D(args.num_channels, kernel_size=3, strides=1, padding="same")(bn1) | ||
t = relu_bn(conv1) | ||
|
||
|
||
for i in range(self.args.num_residual_layers): | ||
t = residual_block(t, filters=self.args.num_channels) | ||
|
||
self.pi = Dense(self.action_size, activation='softmax', name='pi')(policy_head(t)) | ||
self.v = Dense(1, activation='tanh', name='v')(value_head(t)) | ||
|
||
self.calculate_loss() | ||
|
||
self.model = Model(inputs=self.input_boards, outputs=[self.pi, self.v]) | ||
self.model.compile(loss=[self.loss_pi ,self.loss_v], optimizer=Adam(args.lr)) | ||
|
||
def calculate_loss(self): | ||
self.target_pis = tf.placeholder(tf.float32, shape=[None, self.action_size]) | ||
self.target_vs = tf.placeholder(tf.float32, shape=[None]) | ||
self.loss_pi = tf.losses.softmax_cross_entropy(self.target_pis, self.pi) | ||
self.loss_v = tf.losses.mean_squared_error(self.target_vs, tf.reshape(self.v, shape=[-1,])) | ||
self.total_loss = self.loss_pi + self.loss_v | ||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | ||
with tf.control_dependencies(update_ops): | ||
self.train_step = tf.train.AdamOptimizer(self.args.lr).minimize(self.total_loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import argparse | ||
import os | ||
import shutil | ||
import time | ||
import random | ||
import numpy as np | ||
import math | ||
import sys | ||
sys.path.append('../..') | ||
from utils import * | ||
from NeuralNet import NeuralNet | ||
|
||
import logging | ||
import coloredlogs | ||
log = logging.getLogger(__name__) | ||
|
||
import argparse | ||
|
||
from .Connect4NNet import Connect4NNet as onnet | ||
|
||
args = dotdict({ | ||
'lr': 0.001, | ||
'dropout': 0.3, | ||
'epochs': 10, | ||
'batch_size': 64, | ||
'cuda': True, | ||
'num_channels': 128, | ||
'num_residual_layers': 20 | ||
}) | ||
|
||
class NNetWrapper(NeuralNet): | ||
def __init__(self, game): | ||
self.nnet = onnet(game, args) | ||
self.nnet.model.summary() | ||
self.board_x, self.board_y = game.getBoardSize() | ||
self.action_size = game.getActionSize() | ||
|
||
def train(self, examples): | ||
""" | ||
examples: list of examples, each example is of form (board, pi, v) | ||
""" | ||
input_boards, target_pis, target_vs = list(zip(*examples)) | ||
input_boards = np.asarray(input_boards) | ||
target_pis = np.asarray(target_pis) | ||
target_vs = np.asarray(target_vs) | ||
self.nnet.model.fit(x = input_boards, y = [target_pis, target_vs], batch_size = args.batch_size, epochs = args.epochs) | ||
|
||
def predict(self, board): | ||
""" | ||
board: np array with board | ||
""" | ||
# timing | ||
start = time.time() | ||
|
||
# preparing input | ||
board = board[np.newaxis, :, :] | ||
|
||
# run | ||
pi, v = self.nnet.model.predict(board) | ||
|
||
#print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start)) | ||
return pi[0], v[0] | ||
|
||
def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'): | ||
filepath = os.path.join(folder, filename) | ||
if not os.path.exists(folder): | ||
print("Checkpoint Directory does not exist! Making directory {}".format(folder)) | ||
os.mkdir(folder) | ||
else: | ||
print("Checkpoint Directory exists! ") | ||
self.nnet.model.save_weights(filepath) | ||
|
||
def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'): | ||
# https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98 | ||
filepath = os.path.join(folder, filename) | ||
#if not os.path.exists(filepath): | ||
#raise("No model in path {}".format(filepath)) | ||
self.nnet.model.load_weights(filepath) | ||
log.info('Loading Weights...') | ||
|
Empty file.