Started up working on a stable_baselines3 self-play implementation for snake. Still lots of work to do to get it fully functional but figured I would open source it so others can work off it/give me tips if they see anything blatantly misconfigured since I am relatively new to applying RL. 

Notebook is a blend of https://github.com/hardmaru/slimevolleygym/blob/master/training_scripts/train_ppo_selfplay.py and https://www.kaggle.com/kwabenantim/stable-baselines-starter

In [None]:
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col
from kaggle_environments import evaluate, make

In [None]:
!pip install stable-baselines3

In [None]:
import matplotlib.pyplot as plt
import gym
from gym import spaces
import numpy as np

from stable_baselines3 import PPO
from stable_baselines3.common import logger
from stable_baselines3.common.callbacks import EvalCallback

from shutil import copyfile
import os

from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.monitor import Monitor

def transform_observation(obs, config):
    my_board = np.zeros((config.columns * config.rows * 1), dtype = np.uint8)
    their_board = np.zeros((config.columns * config.rows * 1), dtype = np.uint8)
    food_board = np.zeros((config.columns * config.rows * 1), dtype = np.uint8)
    for goose in obs[0].observation.geese[0]:
        my_board[goose] = 255
    my_board = my_board.reshape((config.rows, config.columns, 1))
    
    for goose in obs[0].observation.geese[1]:
        their_board[goose] = 255
    their_board = their_board.reshape((config.rows, config.columns, 1))
    
    for goose in obs[0].observation.food:
        food_board[goose] = 255
    food_board = food_board.reshape((config.rows, config.columns, 1))
    board = np.concatenate([my_board, their_board, food_board], axis = -1)
    return board

In [None]:
def transform_actions(actions):
    if actions == 0:
        return "NORTH"
    if actions == 1:
        return "EAST"
    if actions == 2:
        return "WEST"
    if actions == 3:
        return "SOUTH"

In [None]:
geese_env = make("hungry_geese")

In [None]:
REWARD_LOST = -1
REWARD_WON = 1
class GeeseGym(gym.Env):
    def __init__(self, debug = False):     
        self.geese_env = make("hungry_geese", debug = debug)
        self.config = self.geese_env.configuration
        self.action_space = spaces.Discrete(4)
        
        self.observation_space = spaces.Box(low=0, high=255, 
                                            shape=(self.config.rows, 
                                                   self.config.columns, 
                                                   3), 
                                            dtype=np.uint8)
        
        
        self.reward_range = (-1, 1000)
    def reset(self):
        self.obs = self.geese_env.reset(num_agents = 2)
        x_obs = transform_observation(self.obs, self.config)
        return x_obs
    
    
    def step(self, action):
        my_actions = transform_actions(action)
        opponent_action = transform_actions(0)
        self.obs = self.geese_env.step([my_actions, opponent_action])        
        x_obs = transform_observation(self.obs, self.config)
        x_reward = self.obs[0].reward
        done = (self.obs[0]["status"] != "ACTIVE")
        info = self.obs[0]["info"]
        return x_obs, x_reward, done, info

In [None]:
# Settings
SEED = 17
NUM_TIMESTEPS = int(1e7)
EVAL_FREQ = int(1e4)
EVAL_EPISODES = int(1e2)
BEST_THRESHOLD = 0.01 # must achieve a mean score above this to replace prev best self

LOGDIR = "ppo1_selfplay"

class GeeseSelfPlayEnv(GeeseGym):
  # wrapper over the normal single player env, but loads the best self play model
    def __init__(self):
        super(GeeseSelfPlayEnv, self).__init__()
        self.policy = self
        self.best_model = None
        self.best_model_filename = None
    def predict(self, obs): # the policy
        if self.best_model is None:
            return self.action_space.sample() # return a random action
        else:
            action, _ = self.best_model.predict(obs)
        return action
    def reset(self):
        # load model if it's there
        modellist = [f for f in os.listdir(LOGDIR) if f.startswith("history")]
        modellist.sort()
        if len(modellist) > 0:
            filename = os.path.join(LOGDIR, modellist[-1]) # the latest best model
            if filename != self.best_model_filename:
                print("loading model: ", filename)
                self.best_model_filename = filename
                if self.best_model is not None:
                    del self.best_model
                self.best_model = PPO.load(filename, env=self)
        return super(GeeseSelfPlayEnv, self).reset()

class SelfPlayCallback(EvalCallback):
  # hacked it to only save new version of best model if beats prev self by BEST_THRESHOLD score
  # after saving model, resets the best score to be BEST_THRESHOLD
    def __init__(self, *args, **kwargs):
        super(SelfPlayCallback, self).__init__(*args, **kwargs)
        self.best_mean_reward = BEST_THRESHOLD
        self.generation = 0
    def _on_step(self) -> bool:
        result = super(SelfPlayCallback, self)._on_step()
        if result and self.best_mean_reward > BEST_THRESHOLD:
            self.generation += 1
            print("SELFPLAY: mean_reward achieved:", self.best_mean_reward)
            print("SELFPLAY: new best model, bumping up generation to", self.generation)
            source_file = os.path.join(LOGDIR, "best_model.zip")
            backup_file = os.path.join(LOGDIR, "history_"+str(self.generation).zfill(8)+".zip")
            copyfile(source_file, backup_file)
            self.best_mean_reward = BEST_THRESHOLD
        return result

def rollout(env, policy):
    obs = env.reset()

    done = False
    total_reward = 0

    while not done:
        action, _states = policy.predict(obs)
        obs, reward, done, _ = env.step(action)
        total_reward += reward

    return total_reward

In [None]:
def make_env(rank=0):
    def _init():
        env = GeeseSelfPlayEnv()
        log_file = os.path.join(LOGDIR, str(rank))
        return env


In [None]:
logger.configure(folder=LOGDIR)
# env = GeeseSelfPlayEnv()
# env = SubprocVecEnv([GeeseGym() for i in range(4)])
model = PPO('MlpPolicy', GeeseGym(), verbose = 1, n_steps = 2048*16, batch_size = 128, n_epochs = 50, learning_rate = .01)
# eval_callback = SelfPlayCallback(env,
#     best_model_save_path=LOGDIR,
#     log_path=LOGDIR,
#     eval_freq=EVAL_FREQ,
#     n_eval_episodes=EVAL_EPISODES,
#     deterministic=False)
model.learn(total_timesteps=NUM_TIMESTEPS)
model.save(os.path.join(LOGDIR, "final_model")) # probably never get to this point.

In [None]:
def run_test(model):
    env = GeeseGym(debug = True)
    obs = env.reset()
    done = False
    while not done:
        actions = model.predict(obs)[0]
        obs, reward, done, info = env.step(actions)
        print(reward)
#         plt.imshow(obs)
#         plt.show()
run_test(model)