In [1]:
from collections import namedtuple
import glob
import logging
import os
import gzip

from keras.callbacks import (
    EarlyStopping,
    TerminateOnNaN,
    ModelCheckpoint,
)
from keras.layers import (
    Dense,
    Dropout,
    LSTM,
)
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential
from keras.utils import Sequence
import numpy as np
import varint

import sample_pb2

Using TensorFlow backend.


In [2]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [3]:
Sample = namedtuple("Sample", ["infoset", "num_cards_in_draw_pile", "kitten_position"])

def yield_samples(filename: str):
    logging.debug("Loading samples from %s", filename)
    with gzip.open(filename) as fd:
        while True:
            try:
                n = varint.decode_stream(fd)
            except (TypeError, EOFError):
                break

            buf = fd.read(n)
            infoset = sample_pb2.InfoSet()
            infoset.ParseFromString(buf)

            num_cards_in_draw_pile = varint.decode_stream(fd)
            kitten_position = varint.decode_stream(fd)

            yield Sample(
                infoset=infoset,
                num_cards_in_draw_pile=num_cards_in_draw_pile,
                kitten_position=kitten_position,
            )

In [21]:
class TrainingSequence(Sequence):
    def __init__(self, directory: str):
        self.batches = sorted(glob.glob(os.path.join(directory, "batch_*.pb.gz")))
        logging.debug("Found %d batches in %s", len(self.batches), directory)

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, idx):
        logging.debug("Loading batch %d", idx)
        batch = self.batches[idx]
        it = yield_samples(batch)
        samples = [next(it) for _ in range(2048)]
        logging.debug("Encoding %d samples", len(samples))
        X = encode_history(samples)
        y = encode_targets(samples)
        return X, y

In [9]:
MAX_HISTORY = 48
N_FEATURES = 59
MAX_CARDS_IN_DRAW_PILE = 13

def encode_history(samples):
    '''
    Game history is encoded as: MaxHistory (48) x
    - One hot encoded player (2)
    - One hot encoded action type (4)
    - One hot encoded Card (10)
    - One hot encoded position in draw pile (13)
    - Concatenated one hot cards seen (3x10)
    '''
    result = np.zeros((len(samples), MAX_HISTORY, N_FEATURES), dtype=np.float32)
    for k, sample in enumerate(samples):
        for i, action in enumerate(sample.infoset.history):
            result[k, i, action.player] = 1
            result[k, i, 2 + action.type] = 1
            result[k, i, 6 + action.card] = 1
            result[k, i, 16 + action.position_in_draw_pile] = 1
            for j, card in enumerate(action.cards_seen):
                result[k, i, 29 + 10*j + card] = 1
    return result

def encode_targets(samples):
    result = np.zeros((len(samples), MAX_CARDS_IN_DRAW_PILE), dtype=np.float32)
    for k, sample in enumerate(samples):
        result[k, sample.kitten_position] = 1
    return result

In [22]:
def build_model():
    model = Sequential([
        Bidirectional(LSTM(100, return_sequences=False),
                      input_shape=(MAX_HISTORY, N_FEATURES)),
        Dense(MAX_CARDS_IN_DRAW_PILE, activation='softmax'),
    ])

    model.compile(
        loss='categorical_crossentropy',
        optimizer='adam',
        metrics=['accuracy'])
    return model


def train(model, data, val_data, output_filename):
    history = model.fit_generator(
        data,
        epochs=50,
        validation_data=data,
        use_multiprocessing=True,
        workers=4,
        max_queue_size=12,
        callbacks=[
            EarlyStopping(
                monitor='val_loss', min_delta=0.001, patience=15,
                restore_best_weights=True),
            TerminateOnNaN(),
            ModelCheckpoint(output_filename, save_best_only=True),
        ],
    )

    return model, history

In [23]:
model = build_model()
print(model.summary())

data_dir = '/Users/timpalpant/go/src/github.com/timpalpant/alphacats/model'
data = TrainingSequence(os.path.join(data_dir, 'train-samples'))
val_data = TrainingSequence(os.path.join(data_dir, 'val-samples'))
test_data = TrainingSequence(os.path.join(data_dir, 'test-samples'))
output_filename = 'model.hd5'
model, history = train(model, data, val_data, output_filename)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bidirectional_8 (Bidirection (None, 200)               128000    
_________________________________________________________________
dense_10 (Dense)             (None, 13)                2613      
Total params: 130,613
Trainable params: 130,613
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/50
 11/100 [==>...........................] - ETA: 3:01 - loss: 2.4625 - acc: 0.1371

Process ForkPoolWorker-116:
Process ForkPoolWorker-121:
Process ForkPoolWorker-122:
Process ForkPoolWorker-117:
Process ForkPoolWorker-120:
Process ForkPoolWorker-118:
Process ForkPoolWorker-115:
  File "/usr/local/Cellar/python/3.7.2_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
KeyboardInterrupt
KeyboardInterrupt
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.2_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.2_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  Fil

KeyboardInterrupt


KeyboardInterrupt: 