### Reference
- [Smart Geese Trained by Reinforcement Learning](https://www.kaggle.com/yuricat/smart-geese-trained-by-reinforcement-learning)
- [Alpha Zero General](https://github.com/suragnair/alpha-zero-general)

In [None]:
%%writefile submission.py
import pickle
import bz2
import base64
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import functools, collections
from copy import deepcopy
from kaggle_environments.envs.hungry_geese.hungry_geese import Action, translate
from kaggle_environments.helpers import histogram

sigmoid = lambda x: 1 / (1 + np.exp(-x))
eps = 10**(-6)
DEBUG = True

In [None]:
%%writefile -a submission.py
# The modelâ€™s parameters from https://www.kaggle.com/yuricat/smart-geese-trained-by-reinforcement-learning
PARAM = b'XXXXX'
PARAM_SELF = b'YYYYY'

In [None]:
%%writefile -a submission.py
class MCTS():
    def __init__(self, game, nn_agent_self, nn_agent_pubhrl, eps=1e-8, cpuct_self=1.0, cpuct_other=1.0):
        self.game = game
        self.nn_agent_self = nn_agent_self
        self.nn_agent_pubhrl = nn_agent_pubhrl
        self.eps = eps
        self.cpuct_self = cpuct_self
        self.cpuct_other = cpuct_other
        
        self.Qsa = {}  # stores Q values for s,a (as defined in the paper)
        self.Nsa = {}  # stores #times edge s,a was visited
        self.Ns = {}  # stores #times board s was visited
        self.Ps = {}  # stores initial policy (returned by neural net)
        self.Pm = {}  # masked initial policy (returned by neural net times masking)

        self.Vs = {}  # stores game.getValidMoves for board s
        
        self.last_obs = None

    def getActionProb(self, obs, timelimit=1.0):
        extra_time = obs.remainingOverageTime
        obs_step = obs.step
        remaining_steps = 220 - obs.step
        print(obs)
        print(len(obs.geese[obs.index]), [len(goose) for goose in obs.geese])
        
        s = self.game.stringRepresentation(obs)
        i = obs.index

        start_time = time.time()
        while time.time() - start_time < timelimit + extra_time/(remaining_steps/4):
            self.search(obs, self.last_obs)

            counts = [
                self.Nsa[(s, i, a)] if (s, i, a) in self.Nsa else 0
                for a in range(self.game.getActionSize())
            ]
            prob = counts / (np.sum(counts)+eps)
            
            target_prob = max(self.Ps[s, i])
            if time.time() - start_time > timelimit and (extra_time < 10 or max(prob) >= target_prob):
                break
                        
        self.last_obs = obs
        
        a = np.argmax(prob)
        
        if DEBUG:
            print(s,i,a)
            print(len(self.Qsa), len(self.Nsa), len(self.Ns), len(self.Ps), len(self.Vs))
            print("self.Qsa", self.Qsa[s,i,a])
            print("self.Nsa", self.Nsa[s,i,a])
            print("self.Ns",  self.Ns[s])
            print("self.Ps",  " ".join(f"{x:.4f}" for x in self.Ps[s,i]))
            print("self.Vs",  self.Vs[s,i])
            print("prob   ",  " ".join(f"{x:.4f}" for x in prob))
            print()
        
        return prob

    def search(self, obs, last_obs, prev_v=0):
        s = self.game.stringRepresentation(obs)
        
        if obs.step >= 200:
            lengths = sorted(len(goose) for goose in obs.geese)[::-1]
            position = lengths.index(len(obs.geese[obs.index]))
            scores = {0:1, 1:0.5, 2:-0.5, 3:-1}
            return [scores[position]]*4

        if s not in self.Ns:
            values = [-10] * 4
            for i in range(4):
                if len(obs.geese[i]) == 0:
                    continue
                    
                valids = self.game.getValidMoves(obs, last_obs, i)
                # leaf node
                if sum(v == 0 for v in valids) >= 3:
                    self.Ps[(s, i)], values[i] = valids, prev_v
                elif obs.step >= 192:  # random rollouts
                    self.Ps[(s, i)], values[i] = [0.25, 0.25, 0.25, 0.25], prev_v
                elif i == obs.index:
                    self.Ps[(s, i)], values[i] = self.nn_agent_self.predict(obs, last_obs, i)
                else:
                    self.Ps[(s, i)], values[i] = self.nn_agent_pubhrl.predict(obs, last_obs, i)                    
                
                self.Pm[s, i] = (valids + self.Ps[s, i]) * valids  # masking invalid moves
                sum_Ps_s = np.sum(self.Pm[s, i])
                if sum_Ps_s > 0:
                    self.Pm[(s, i)] /= sum_Ps_s  # renormalize

                self.Vs[(s, i)] = valids
                self.Ns[s] = 0
            return values

        best_acts = [None] * 4
        for i in range(4):
            if len(obs.geese[i]) == 0:
                continue
            
            valids = self.Vs[(s, i)]
            cur_best = -float('inf')
            best_act = self.game.actions[-1]

            # pick the action with the highest upper confidence bound
            for a in range(self.game.getActionSize()):
                if i == obs.index:
                    cpuct = self.cpuct_self
                else:
                    cpuct = self.cpuct_other
                if valids[a]:
                    if (s, i, a) in self.Qsa:
                        u = self.Qsa[(s, i, a)] + cpuct * self.Ps[(s, i)][a] * math.sqrt(
                                self.Ns[s]) / (1 + self.Nsa[(s, i, a)])
                    else:
                        u = cpuct * self.Ps[(s, i)][a] * math.sqrt(
                            self.Ns[s] + self.eps)  # Q = 0 ?

                    if u > cur_best:
                        cur_best = u
                        best_act = self.game.actions[a]
                        
            best_acts[i] = best_act
        
        next_obs = self.game.getNextState(obs, last_obs, best_acts)
        values = self.search(next_obs, obs)

        for i in range(4):
            if len(obs.geese[i]) == 0:
                continue
                
            a = self.game.actions.index(best_acts[i])
            v = values[i]
            if (s, i, a) in self.Qsa:
                self.Qsa[(s, i, a)] = (self.Nsa[(s, i, a)] * self.Qsa[
                    (s, i, a)] + v) / (self.Nsa[(s, i, a)] + 1)
                self.Nsa[(s, i, a)] += 1

            else:
                self.Qsa[(s, i, a)] = v
                self.Nsa[(s, i, a)] = 1 + sigmoid(v)  # to tie break when needed

        self.Ns[s] += 1
        return values

In [None]:
%%writefile -a submission.py
class HungryGeese(object):
    def __init__(self,
                 rows=7,
                 columns=11,
                 actions=[Action.NORTH, Action.SOUTH, Action.WEST, Action.EAST],
                 hunger_rate=40):
        self.rows = rows
        self.columns = columns
        self.actions = actions
        self.hunger_rate = hunger_rate

    def getActionSize(self):
        return len(self.actions)

    def getNextState(self, obs, last_obs, directions):
        next_obs = deepcopy(obs)
        next_obs.step += 1
        geese = next_obs.geese
        food = next_obs.food
                
        for i in range(4):
            goose = geese[i]
            
            if len(goose) == 0: 
                continue
            
            head = translate(goose[0], directions[i], self.columns, self.rows)
            
            # Check action direction
            if last_obs is not None and head == last_obs.geese[i][0]:
                geese[i] = []
                continue

            # Consume food or drop a tail piece.
            if head in food:
                food.remove(head)
            else:
                goose.pop()
            
            # Add New Head to the Goose.
            goose.insert(0, head)

            # If hunger strikes remove from the tail.
            if next_obs.step % self.hunger_rate == 0:
                if len(goose) > 0:
                    goose.pop()

        goose_positions = histogram(
            position
            for goose in geese
            for position in goose
        )

        # Check for collisions.
        for i in range(4):
            if len(geese[i]) > 0:
                head = geese[i][0]
                if goose_positions[head] > 1:
                    geese[i] = []
        
        return next_obs

    def getValidMoves(self, obs, last_obs, index):        
        foods = obs.food        
        geese = deepcopy(obs.geese)        
        pos = geese[index][0]
        
        maxlen_goose = max(len(goose) for goose in geese)
        num_goose = sum(len(goose) > 0 for goose in geese)
        
        potential_tail_strike = collections.defaultdict(lambda: 1)
        potential_head_collision = collections.defaultdict(lambda: 1)
        for goose_idx, goose in enumerate(geese):
            if goose_idx == index or not goose:
                continue
            for action in self.actions:
                nex_loc = translate(goose[0], action, self.columns, self.rows)
                head_collision_factor = 1
                if len(geese[index]) < len(goose):
                    potential_head_collision[nex_loc] = 0.111  # avoid because of definite loss
                elif num_goose == 2 and len(geese[index]) >= maxlen_goose:
                    potential_head_collision[nex_loc] = 3.333  # secure first place
                else:
                    potential_head_collision[nex_loc] = 0.888  # would prefer higher placing
                if nex_loc in foods:
                    potential_tail_strike[goose[-1]] = 0.101
        
        next_poss = [translate(pos, action, self.columns, self.rows) for action in self.actions]
        
        mask_head_collision = np.array([potential_head_collision[next_pos] for next_pos in next_poss])
        mask_tail_strike    = np.array([potential_tail_strike[next_pos] for next_pos in next_poss])

        obstacles = {position for goose in geese for position in goose[:-1]}
        if last_obs:
            obstacles.add(last_obs.geese[index][0])            
    
        mask_valid = np.array([1.0 if next_pos not in obstacles else 0 
                               for next_pos in next_poss])
    
        return mask_valid * mask_tail_strike * mask_head_collision

    def stringRepresentation(self, obs):      
        return str(obs.geese + obs.food)

In [None]:
%%writefile -a submission.py
# Neural Network for Hungry Geese
class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:,:,:,-self.edge_size[1]:], x, x[:,:,:,:self.edge_size[1]]], dim=3)
        h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h

class GeeseNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = TorusConv2d(17, filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 4, bias=False)
        self.head_v = nn.Linear(filters * 2, 1, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)
        p = torch.softmax(self.head_p(h_head), 1)
        v = torch.tanh(self.head_v(torch.cat([h_head, h_avg], 1)))

        return p, v

class NNAgent():
    def __init__(self, state_dict):
        self.model = GeeseNet()
        self.model.load_state_dict(state_dict)
        self.model.eval()
        
    def predict(self, obs, last_obs, index):
        x = self._make_input(obs, last_obs, index)
        with torch.no_grad():
            xt = torch.from_numpy(x).unsqueeze(0)
            p, v = self.model(xt)
            
        return p.squeeze(0).detach().numpy(), v.item()
        
    # Input for Neural Network
    def _make_input(self, obs, last_obs, index):
        b = np.zeros((17, 7 * 11), dtype=np.float32)
        
        for p, pos_list in enumerate(obs.geese):
            # head position
            for pos in pos_list[:1]:
                b[0 + (p - index) % 4, pos] = 1
            # tip position
            for pos in pos_list[-1:]:
                b[4 + (p - index) % 4, pos] = 1
            # whole position
            for pos in pos_list:
                b[8 + (p - index) % 4, pos] = 1

        # previous head position
        if last_obs is not None:
            for p, pos_list in enumerate(last_obs.geese):
                for pos in pos_list[:1]:
                    b[12 + (p - index) % 4, pos] = 1

        # food
        for pos in obs.food:
            b[16, pos] = 1

        return b.reshape(-1, 7, 11)

In [None]:
%%writefile -a submission.py
game = HungryGeese()

state_dict_self = pickle.loads(bz2.decompress(base64.b64decode(PARAM)))
agent_self = NNAgent(state_dict_self)

state_dict_pubhrl = pickle.loads(bz2.decompress(base64.b64decode(PARAM)))
agent_pubhrl = NNAgent(state_dict_pubhrl)

mcts = MCTS(game, agent_self, agent_pubhrl)

def alphageese_agent(obs, config):
    action = game.actions[np.argmax(
        mcts.getActionProb(obs, timelimit=config.actTimeout))]
    return action.name

In [None]:
%%writefile -a submission.py

# class Struct(object):
#     # convert dictionary into object to allow instance.attribute notation
#     def __init__(self, data):
#         for name, value in data.items():
#             setattr(self, name, self._wrap(value))

#     def _wrap(self, value):
#         if isinstance(value, (tuple, list, set, frozenset)): 
#             return type(value)([self._wrap(v) for v in value])
#         else:
#             return Struct(value) if isinstance(value, dict) else value

# ## test code
# config = {'episodeSteps': 200, 'actTimeout': 1, 'runTimeout': 1200, 
#           'columns': 11, 'rows': 7, 'hunger_rate': 40, 'min_food': 2, 'max_length': 99}

# # [????] better to get stuck because game is ending
# obs = {'remainingOverageTime': 60, 'index': 1, 'step': 197, 'geese': [[], 
#     [36,35,24,25,14,3,4,15,16,27,38,39,40,29,28,17,18,7,6,5], 
#     [56,45,46,57,68,2,13,12,23,34,33,43,42,31,20,21,10,9,75,64,65], 
#     [30,41,52,63,62,51,50,49,48,59,60,61,72,73,74,8,19]], 'food': [26, 69]}  

# alphageese_agent(Struct(obs), Struct(config))

# # [0100] https://www.kaggle.com/c/hungry-geese/submissions?dialog=episodes-episode-24354313
# obs = {'remainingOverageTime': 36.25855599999999, 'index': 3, 'step': 195, 'geese': [
#     [76, 75, 74, 73, 72, 6, 7, 8, 9, 20, 21, 10], 
#     [3, 2, 13, 14, 25, 24, 23, 22, 11, 0, 66, 67, 68, 69, 70, 71, 5], 
#     [65, 64, 63, 52, 53, 42, 31, 32, 43, 54, 44, 45, 46, 57, 56], 
#     [36, 37, 38, 27, 26, 15, 16, 17, 28, 29, 18, 19, 30, 41, 40, 51, 62, 61, 50, 49, 48, 47]], 'food': [34, 39]}

# alphageese_agent(Struct(obs), Struct(config))

# # [0001] https://www.kaggle.com/c/hungry-geese/submissions?dialog=episodes-episode-24354751
# obs = {'remainingOverageTime': 8.744749000000029, 'index': 3, 'step': 159, 'geese': [
#     [28, 17, 18, 7, 6, 5, 71, 70, 59, 48, 49, 38, 37], 
#     [57, 46, 47, 58, 69, 3, 14, 25, 36, 35, 24, 23, 12, 13, 2], 
#     [29, 30, 19, 20, 9, 8, 74, 73, 62, 61, 50, 51, 52, 63, 64, 53, 42, 41], 
#     [21, 32, 22, 33, 44, 45, 56, 67, 66, 76, 10, 0]], 'food': [60, 54]}

# alphageese_agent(Struct(obs), Struct(config))

In [None]:
url = "https://tonghuikang.github.io/analysis_HandyRL/strings/pubhrl.txt"
url_self = "https://tonghuikang.github.io/analysis_HandyRL/strings/pubhrl-trained-on-boiler-adverse.txt"

import urllib
params = next(urllib.request.urlopen(url)).decode("utf-8")
params_self = next(urllib.request.urlopen(url_self)).decode("utf-8")

with open("submission.py", "r") as f:
    s = f.read()
s = s.replace("YYYYY", params)
s = s.replace("XXXXX", params_self)
with open("submission.py", "w") as f:
    f.write(s)

In [None]:
# check if the code runs
!python3 submission.py

In [None]:
from kaggle_environments import make
env = make("hungry_geese", debug=True)
env.reset()
env.run(['submission.py', "greedy", "greedy", "greedy"])
env.render(mode="ipython", width=700, height=600)