In [1]:
import tensorflow as tf
import keras
import keras.backend as K
from keras.layers import Dense, Dropout, Input
from keras.models import Sequential

In [2]:
import os
VIRTUAL_HEIGHT = int(os.environ["VIRTUAL_HEIGHT"])
VIRTUAL_WIDTH = int(os.environ["VIRTUAL_WIDTH"])
DIM = VIRTUAL_WIDTH * VIRTUAL_HEIGHT

In [3]:
score_metric = keras.metrics.Mean(name="score")
cce_loss = keras.losses.CategoricalCrossentropy()

class SnakeModel(keras.Sequential):
    def train_step(self, data):
        x, y_true, idx = data
        
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True) # Forward pass
            
            # Compute loss
            loss = cce_loss(y_true, y_pred)
        
        # Move 
        global feeder
        predicted_direction = MOVE_CODE[y_pred.numpy().argmax()]
        score = feeder.apply_move(predicted_direction, int(idx))
        
        # Compute gradients and update weights
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Compute the metric score here
        score_metric.update_state(loss)
        
        return {"mean_score": score_metric.result()}
    
    @property
    def metrics(self):
        return [score_metric]

model = SnakeModel()
model.add(Input(shape=(DIM, )))
model.add(Dense(DIM, activation=None))
model.add(Dense(16, activation=None))
model.add(Dense(4, activation="softmax"))

In [4]:
model.load_weights("models/epoch_1.h5")

In [5]:
from glob import glob
from objects import Table, Food, Snake
import pickle as pkl
def load_dataset(folder_path="dataset", maximum=None):
    paths = glob(folder_path + "/*")

    games = []
    for path in paths[:maximum if maximum else len(paths)]:
        with open(path, "rb") as fin:
            table = pkl.load(fin)
            games.append(table)
            
    return games

init_games = load_dataset(maximum=None)
print(f"Loaded {len(init_games)} games")

Loaded 500 games


In [6]:
model.predict(init_games[0].table.ravel().reshape(1, -1))

array([[0.42262495, 0.2956784 , 0.09109779, 0.19059892]], dtype=float32)