In [2]:
import glob
import logging
import os

from keras.callbacks import (
    EarlyStopping,
    TerminateOnNaN,
    ModelCheckpoint,
)
from keras.layers import (
    Dense,
    Dropout,
    CuDNNLSTM,
)
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential
from keras.utils import Sequence
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import numpy as np

%matplotlib inline

Using TensorFlow backend.


In [3]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [4]:
MAX_HISTORY = 48
N_FEATURES = 59
MAX_CARDS_IN_DRAW_PILE = 13

In [5]:
class TrainingSequence(Sequence):
    def __init__(self, directory: str):
        self.batches = sorted(glob.glob(os.path.join(directory, "batch_*.npz")))
        logging.debug("Found %d batches in %s", len(self.batches), directory)

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, idx):
        logging.debug("Loading batch %d", idx)
        batch = np.load(self.batches[idx])
        n_samples = len(batch["X"]) // MAX_HISTORY // N_FEATURES
        X = batch["X"].reshape((n_samples, MAX_HISTORY, N_FEATURES))
        y = batch["y"].reshape((n_samples, MAX_CARDS_IN_DRAW_PILE))
        return X, y

In [6]:
def build_model(input_shape: tuple, output_shape: int):
    model = Sequential([
        Bidirectional(CuDNNLSTM(128, return_sequences=False),
                      input_shape=input_shape),
        Dense(output_shape, activation='softmax'),
    ])

    model.compile(
        loss='categorical_crossentropy',
        optimizer='adam',
        metrics=['accuracy'])
    return model


def train(model, data, val_data):
    history = model.fit_generator(
        data,
        epochs=50,
        validation_data=data,
        use_multiprocessing=True,
        workers=8,
        max_queue_size=16,
        callbacks=[
            EarlyStopping(
                monitor='val_loss', min_delta=0.001, patience=5,
                restore_best_weights=True),
            TerminateOnNaN(),
        ],
    )

    return model, history

In [None]:
data_dir = '/home/timpalpant/go/src/github.com/timpalpant/alphacats/model'
data = TrainingSequence(os.path.join(data_dir, 'train-samples'))
val_data = TrainingSequence(os.path.join(data_dir, 'val-samples'))
test_data = TrainingSequence(os.path.join(data_dir, 'test-samples'))

X, y = data[0]
input_shape = X[0].shape
output_shape = y[0].shape[0]
model = build_model(input_shape, output_shape)
print(model.summary())

model, history = train(model, data, val_data)
model.save('model2.hd5')

In [None]:
h = history.history
epochs = np.arange(len(h['loss'])) + 1
plt.plot(epochs, h['loss'], label='loss')
plt.plot(epochs, h['val_loss'], label='val_loss')
plt.xticks(epochs)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

In [None]:
predicted = model.predict_generator(test_data)

In [None]:
samples = list(yield_samples(test_data.batches[0]))
actual = test_data[0][1]
pred = predicted[:len(actual)]

In [None]:
def has_see_the_kitten(sample: Sample):
    for h in sample.infoset.history:
        if h.type == sample_pb2.Action.PLAY_CARD and h.card == sample_pb2.SEE_THE_FUTURE:
            if any(card == sample_pb2.EXPLODING_KITTEN for card in h.cards_seen):
                return True
    return False

selected = -1
for i, s in enumerate(samples):
    if has_see_the_kitten(s):
        selected = i
print(selected)

samples[4000]

Fig 1: Normal learning equal probability for each and not to predict invalid positions
Fig 2: See the future with exploding kitten + reset after shuffle
Fig 3: See the future with no kitten + reset after shuffle
Fig 4: Insert exploding kitten + know its position exactly + shift after draws
Fig 5: Other player inserts kitten

In [None]:
positions = np.arange(MAX_CARDS_IN_DRAW_PILE)

plt.figure(figsize=(12,4))
plt.subplot(1, 3, 1)
plt.bar(positions, pred[3090], width=0.5, label='predicted')
plt.bar(positions+0.5, actual[3090], width=0.5, label='actual')
plt.xticks(positions+0.5, positions)
plt.xlabel('Position')
plt.ylabel('Probability')
plt.legend()

plt.subplot(1, 3, 2)
plt.bar(positions, pred[3096], width=0.5, label='predicted')
plt.bar(positions+0.5, actual[3096], width=0.5, label='actual')
plt.xticks(positions+0.5, positions)
plt.xlabel('Position')
plt.ylabel('Probability')
plt.legend()

plt.subplot(1, 3, 3)
plt.bar(positions, pred[4000], width=0.5, label='predicted')
plt.bar(positions+0.5, actual[4000], width=0.5, label='actual')
plt.xticks(positions+0.5, positions)
plt.xlabel('Position')
plt.ylabel('Probability')
plt.legend()

In [None]:
pred[4021]