# Neural Network Models
Using the labelled Lichess data, this notebook will attempt to find a neural network which can fit that labelled data.

## Pre-requisites

### If running on Google Collab
If not running on Google collab do not run these next two cells!

In [None]:
# Install the only dependency not available from collab directly
!pip install chess

# Get imported files from repo
!git clone -b lichess-neural-networks https://github.com/owenjaques/chessbot.git
!mv chessbot chessbot-repo
!mv chessbot-repo/neural_networks/chessbot .
!rm chessbot-repo -r

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')
working_directory = '/content/gdrive/MyDrive/chessbot_weights/'
data_directory = working_directory

### If not running on Google Collab
Set the weights directory variable to wherever you would like data saved.

In [None]:
!mkdir -p bin
working_directory = './bin'
data_directory = '../pre_processing/data'

## Load the Data

In [None]:
import chess
import chess.pgn
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
import math
from chessbot import modelinput

### Data generator For Keras

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, data_file, input_generator, batch_size=32, pre_process=True, save_file=None):
        self.batch_size = batch_size
        data = np.load(data_file, allow_pickle=True)
        self.y = data['y']
        self.n = len(self.y)

        if pre_process:
            self.X = np.empty((len(data['X']), input_generator.input_length()))
            for i in range(self.n):
                print(f'\rPre-processing input {i}/{self.n}...', end='')
                self.X[i] = input_generator.get_input_from_fen(data['X'][i])
            
            if save_file != None:
                print('Saving X, y...')
                np.savez_compressed(save_file, X=X, y=y)
        else:
            self.X = data['X']

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        # Returns one batch of data
        low = idx * self.batch_size
        high = np.min(low + self.batch_size, self.n)
        return self.X[low:high], self.y[low:high]

## Models

### Regression Model

In [None]:
# Data generators for training the model
training_data = DataGenerator(data_directory + '/training_set.npz', modelinput.ModelInput('positions'), save_file=data_directory + '/positions_training_set.npz')
validation_data = DataGenerator(data_directory + '/validation_set.npz', modelinput.ModelInput('positions'), save_file=data_directory + '/positions_validation_set.npz')

In [None]:
# The actual model
model = keras.Sequential([
	keras.layers.Dense(512, activation='relu'),
	keras.layers.Dense(512, activation='relu'),
	keras.layers.Dense(512, activation='relu'),
	keras.layers.Dense(1)
])

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='mse',
	metrics=[keras.metrics.MeanAbsoluteError()]
)

## Training a Model

In [None]:
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    restore_best_weights=True,
    patience=5,
    verbose=1)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=1,
    min_lr=0.00000000000001,
    verbose=1)

checkpoint = keras.callbacks.ModelCheckpoint(
    f'{working_directory}/model',
    monitor='val_loss',
    save_best_only=True)

tensorboard = keras.callbacks.TensorBoard(
    log_dir=f'{working_directory}/logs',
    write_graph=True,
    write_images=True,
    histogram_freq=1)

model.fit(
    training_data
    epochs=128,
    validation_data=validation_data,
    shuffle=True,
    callbacks=[early_stopping, reduce_lr, checkpoint, tensorboard])

## Model Evaluation

### Optionally load a previous model

In [None]:
model = keras.models.load_model(f'{working_directory}/model')

### Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

y_pred = model.predict(validation_data.X)
y_pred = np.argmax(y_pred, axis=1)
y_true = np.argmax(validation_data.y, axis=1)
cm = confusion_matrix(y_true, y_pred)

ax = sns.heatmap(cm, annot=True, fmt='d', xticklabels=['Losing', 'Drawing', 'Winning'], yticklabels=['Losing', 'Drawing', 'Winning'])
ax.set(xlabel='Predicted label', ylabel='True label')

### Histograms and predictions

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Since at the end of the epoch the validation data is refreshed,
# validation_data actually holds new data at the end of training
evaluation = model.evaluate(validation_data)
predictions = model.predict(validation_data[0])

_, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
ax[0].hist(predictions, bins=50)
ax[0].set_title(f'Predicted labels')
ax[0].set_xlabel('label (y)')
ax[0].set_ylabel('no. of occurences in dataset')
ax[1].hist(validation_data[1], bins=50)
ax[1].set_title(f'Actual labels')
ax[1].set_xlabel('label (y)')
ax[1].set_ylabel('no. of occurences in dataset')
plt.show()

# Bin the data into continuous intervals, then plot a confusion matrix
predictions_binned = np.digitize(predictions, bins=np.linspace(0, 1, 10))
y_binned = np.digitize(validation_data[1], bins=np.linspace(0, 1, 10))
cm = confusion_matrix(y_binned, predictions_binned)
ax = sns.heatmap(cm, annot=True, fmt='d')
ax.set(xlabel='Predicted label', ylabel='True label')

## Why not play a game after all that training?

In [None]:
import time
from IPython.display import clear_output
import chessbot.chessbot
import importlib
importlib.reload(chessbot.chessbot)
from chessbot.chessbot import ChessBot

def play_game(model, exploration_rate=0.0, should_visualise=False):
	white = ChessBot(model, chess.WHITE, exploration_rate)\
	board = chess.Board()

	if should_visualise:
		display(board)

	while not board.is_game_over(claim_draw=True):
		board.push(chess.Move.from_uci(input()) if board.turn == chess.BLACK else white.move(board))

		if should_visualise:
			time.sleep(1)
			clear_output(wait=True)
			display(board)

	return board.outcome(claim_draw=True).result()
 
play_game(model, should_visualise=True)