In [1]:
import tensorflow as tf
import keras
import keras.backend as K
from keras.layers import Dense, Dropout, Input
from keras.models import Sequential

In [2]:
import os
from dotenv import load_dotenv

load_dotenv()
VIRTUAL_HEIGHT = int(os.environ["VIRTUAL_HEIGHT"])
VIRTUAL_WIDTH = int(os.environ["VIRTUAL_WIDTH"])
DIM = VIRTUAL_WIDTH * VIRTUAL_HEIGHT
BATCH_SIZE = 64
MOVE_CODE = {0:"right", 1:"left", 2:"up", 3:"down"}

## Data stuff

In [4]:
from glob import glob
from objects import Table, Food, Snake
import pickle as pkl
def load_dataset(folder_path=f"dataset_{VIRTUAL_HEIGHT}x{VIRTUAL_WIDTH}", maximum=None):
    paths = glob(folder_path + "/*")

    games = []
    for path in paths[:maximum if maximum else len(paths)]:
        with open(path, "rb") as fin:
            table = pkl.load(fin)
            games.append(table)
            
    return games

init_games = load_dataset(maximum=None)
print(f"Loaded {len(init_games)} games")

Loaded 500 games


In [11]:
print(init_games[99])

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [15]:
game = feeder.games[0]

In [16]:
np.sum(game.table[game.table==1])+1

2.0

In [18]:
len(feeder.games)

64

## Train feeder

In [8]:
import random
import numpy as np

class TrainFeeder:
    def __init__(self, init_games, n=BATCH_SIZE):
#         assert len(init_games) >= n
        self.init_games = init_games
        self.games = random.choices(self.init_games, k=n)
    
    @property
    def feed(self):
        while True:
            vectors = [game.table.ravel().reshape(1, -1) for game in self.games]
            for idx, vector in enumerate(vectors):
                yield vector, self.games[idx].move_quality().reshape(1, -1), np.array(idx)
    
    def apply_moves(self, directions: list):
        results = []
        for idx, direction in enumerate(directions):
            result = self.games[idx].snake.move(direction, return_state=True)
            results.append(result)
            
            if result == -1:
                # Reset the game
                self.games[idx] = random.choice(self.init_games)
            
        return np.array(results)
    
    def apply_move(self, direction, idx):
        result = self.games[idx].snake.move(direction, return_state=True)
        
        if result == -1:
            # Reset the game
            self.games[idx] = random.choice(self.init_games)
            
        return result
    
    def length_mean(self):
        # Sum the body + head of the snake. Body is annotated with 1 and head with 3
        length = [np.sum(game.table[game.table==1])+1 for game in self.games]
        return np.mean(length)
    
feeder = TrainFeeder(init_games=init_games)

## Model

In [5]:
score_metric = keras.metrics.Mean(name="score")
cce_loss = keras.losses.CategoricalCrossentropy()

class SnakeModel(keras.Sequential):
    def train_step(self, data):
        x, y_true, idx = data
        
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True) # Forward pass
            
            # Compute loss
            loss = cce_loss(y_true, y_pred)
        
        # Move 
        global feeder
        predicted_direction = MOVE_CODE[y_pred.numpy().argmax()]
        score = feeder.apply_move(predicted_direction, int(idx))
        
        # Compute gradients and update weights
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Compute the metric score here
        score_metric.update_state(loss)
        
        return {"mean_score": score_metric.result(), "mean_length": feeder.length_mean()}
    
    @property
    def metrics(self):
        return [score_metric]

model = SnakeModel()
model.add(Input(shape=(DIM, )))
model.add(Dense(DIM, activation=None))
model.add(Dense(16, activation=None))
model.add(Dense(4, activation="softmax"))

In [6]:
model.compile(optimizer="adam", run_eagerly=True)
model.summary()

Model: "snake_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 1024)              1049600   
_________________________________________________________________
dense_1 (Dense)              (None, 16)                16400     
_________________________________________________________________
dense_2 (Dense)              (None, 4)                 68        
Total params: 1,066,068
Trainable params: 1,066,068
Non-trainable params: 0
_________________________________________________________________


In [7]:
import logging
logging.basicConfig(format="%(asctime)s || %(message)",
                    level=logging.DEBUG)

In [8]:
feeder.feed.__next__()[0].shape

(1, 1024)

In [9]:
from keras.callbacks import ModelCheckpoint

history = model.fit(x=feeder.feed,
    epochs=3,
    verbose=1,
    callbacks=[ModelCheckpoint("models/epoch_{epoch}.h5", verbose=False)],
    steps_per_epoch=100,
)

Epoch 1/3
Epoch 2/3
Epoch 3/3


In [10]:
with open("models/history.pkl", "wb") as fout:
    pkl.dump(history.history, fout)

FileNotFoundError: [Errno 2] No such file or directory: 'models/history.pkl'