# Write submission.py

## Imports

In [1]:
%%writefile submission.py
# MIT License

# Copyright (c) 2021 Choi Yeonung
# Copyright (c) 2020 DeNA Co., Ltd.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import base64
import pickle
import zlib
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kaggle_environments import make
from kaggle_environments.envs.hungry_geese.hungry_geese import Action, adjacent_positions

Overwriting submission.py


## Global variables

In [2]:
%%writefile -a submission.py

prev_action = None
prev_obs = None
env = None
mcts = None
state_dict = _STATE_DICT_

Appending to submission.py


## Neural network

In [3]:
%%writefile -a submission.py

class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:,:,:,-self.edge_size[1]:], x, x[:,:,:,:self.edge_size[1]]], dim=3)
        h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h


class GeeseNet(nn.Module):
    def __init__(self):
        super().__init__()
        input_shape = (17, 7, 11)
        layers, filters = 12, 32
        self.conv0 = TorusConv2d(input_shape[0], filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 4, bias=False)
        self.head_v = nn.Linear(filters * 2, 1, bias=False)

    def forward(self, x, _=None):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)
        p = self.head_p(h_head)
        v = torch.tanh(self.head_v(torch.cat([h_head, h_avg], 1)))
        return p, v

Appending to submission.py


## Utility functions

In [4]:
%%writefile -a submission.py

def find_prev_actions(curr_obs, prev_obs):
    if prev_obs is None:
        return [None] * 4

    actions = []

    prev_head_poss = [goose[0] if len(goose) else None for goose in prev_obs.geese]
    curr_head_poss = [goose[0] if len(goose) else None for goose in curr_obs.geese]

    for prev_head_pos, curr_head_pos in zip(prev_head_poss, curr_head_poss):
        if prev_head_pos is None:
            actions.append(None)
            continue

        adj_poss = adjacent_positions(prev_head_pos, 11, 7)
        breaked = False
        for action_num, adj_pos in enumerate(adj_poss, 1):
            if adj_pos == curr_head_pos:
                actions.append(Action(action_num).name)
                breaked = True
                break
        if not breaked:
            actions.append(None)

    return actions


def get_env(obs, config):
    global env, prev_obs

    if env is None:
        env = make('hungry_geese')
        env.reset(4)
        env.state[0].observation = obs
        env.steps[0] = env.state
    else:
        env.state[0].observation = obs
        env.steps.append(env.state)

    prev_actions = find_prev_actions(obs, prev_obs)
    for i, action in enumerate(prev_actions):
        env.state[i].action = action

    return env


def remove_illegal_action(pi, prev_action):
    if prev_action:
        oppo_action_value = Action.opposite(Action[prev_action]).value
        pi[oppo_action_value - 1] = 0
        try:
            pi /= pi.sum()
        except ZeroDivisionError:
            pass
    return pi


def select_action(pi):
    best_actions = np.array(np.argwhere(pi == np.max(pi))).flatten()
    action_num = np.random.choice(best_actions) + 1
    return Action(action_num).name

Appending to submission.py


## MCTS class

In [5]:
%%writefile -a submission.py

class MCTS:
    def __init__(self, nnet, num_simuls=50, max_depth=10):
        self.nnet = nnet
        self.cuda = torch.cuda.is_available()
        if self.cuda:
            self.nnet.cuda()
        self.nnet.eval()

        self.num_simuls = num_simuls
        self.max_depth = max_depth
        self.c_puct = 1

        self.Qsa = {}  # stores Q values for s,a (as defined in the paper)
        self.Nsa = {}  # stores #times edge s,a was visited
        self.Ns = {}  # stores #times board s was visited
        self.Ps = {}  # stores initial policy (returned by neural net)
        self.Es = {}  # stores game.getGameEnded ended for board s
        self.Vs = {}  # stores game.getValidMoves for board s

    def get_policy(self, env, prev_obs):
        for _ in range(self.num_simuls):
            self.search(env.clone(), prev_obs, self.max_depth)

        curr_obs = env.state[0].observation

        board = self.get_board(curr_obs, prev_obs)
        board = self.get_player_board(board, curr_obs.index, 4)
        s = board.tostring()

        counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(4)]

        counts = [x for x in counts]
        counts_sum = float(sum(counts))
        probs = np.array([x / counts_sum for x in counts])

        return probs

    def search(self, env, prev_obs, remaining):
        curr_obs = env.state[0].observation

        board = self.get_board(curr_obs, prev_obs)
        board = self.get_player_board(board, curr_obs.index, 4)
        s = board.tostring()

        if s not in self.Es:
            self.Es[s] = self.get_reward(curr_obs, curr_obs.index, 4)

        if self.Es[s] is not None:  # terminal node
            return self.Es[s]

        boards = [self.get_player_board(board, player, 4) for player in range(4)]
        boards = [torch.FloatTensor(board.astype(np.float64)) for board in boards]
        boards = torch.stack(boards)
        if self.cuda:
            boards = boards.contiguous().cuda()

        with torch.no_grad():
            pis, vs = self.nnet(boards)
            pis = F.softmax(pis, 1).data.cpu().numpy()
            vs = vs.data.cpu().numpy()

        actions = []
        for i, pi in enumerate(pis):
            prev_action = env.state[i].action
            pi = remove_illegal_action(pi, prev_action)
            action = select_action(pi)
            actions.append(action)
    
        v = vs[curr_obs.index]

        if s not in self.Ps:  # leaf node
            self.Ps[s], v = pis[curr_obs.index], v
            policy = remove_illegal_action(self.Ps[s], env.state[curr_obs.index].action)
            valids = np.where(policy > 0, 1, 0)
            self.Ps[s] = self.Ps[s] * valids  # masking invalid moves
            sum_Ps_s = np.sum(self.Ps[s])
    
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s  # renormalize
            else:
                # log.error("All valid moves were masked, doing a workaround.")
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0

            return v

        if remaining == 0:
            return v

        # pick the action with the highest upper confidence bound
        valids = self.Vs[s]
        cur_best = -float('inf')
        best_act = -1
        for a in range(4):
            if valids[a]:
                if (s, a) in self.Qsa:
                    u = self.Qsa[(s, a)] + self.c_puct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
                else:
                    u = self.c_puct * self.Ps[s][a] * math.sqrt(self.Ns[s] + 1e-8)  # Q = 0 ?
                if u > cur_best:
                    cur_best = u
                    best_act = a
        a = best_act

        actions[curr_obs.index] = Action(a + 1).name

        env.step(actions)

        v = self.search(env, curr_obs, remaining-1)

        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1
        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1

        self.Ns[s] += 1

        return v

    def get_board(self, obs, prev_obs):
        board = np.zeros((4 * 4 + 1, 7 * 11), np.uint8)

        for i, goose in enumerate(obs.geese):
            # head position
            for head_pos in goose[:1]:
                board[0 + (i - obs.index) % 4, head_pos] = 1

            # tip position
            for tip_pos in goose[-1:]:
                board[4 + (i - obs.index) % 4, tip_pos] = 1

            # whole position
            for body_pos in goose[1:]:
                board[4 * 2 + (i - obs.index) % 4, body_pos] = 1

        # previous head position
        if prev_obs is not None:
            for i, goose in enumerate(prev_obs.geese):
                for pos in goose[:1]:
                    board[4 * 3 + (i - obs.index) % 4, pos] = 1

        for food_pos in obs.food:
            board[-1, food_pos] = 1

        return board.reshape(-1, 7, 11)

    def get_player_board(self, board, player, num_agents):
        new_board = board.copy()
        indices = np.arange(0, num_agents * 4, num_agents)
        tmp = new_board[indices]
        new_board[indices] = new_board[indices + player]
        new_board[indices + player] = tmp
        return new_board


    def get_reward(self, obs, player, num_agents):
        alive = 0
        for goose in obs.geese:
            alive += 1 if len(goose) > 0 else 0

        if len(obs.geese[player]) > 0 and alive >= 2 and obs.step < 199:
            return None

        rank = 1
        for i, goose in enumerate(obs.geese):
            if i == player:
                continue
            if len(goose) > len(obs.geese[player]):
                rank += 1
            elif len(goose) == len(obs.geese[player]):
                rank += 0.5
        return (num_agents + 1 - 2 * rank) / (num_agents - 1)

    

Appending to submission.py


## Agent function

In [6]:
%%writefile -a submission.py

def agent(obs, config):
    global state_dict, prev_action, prev_obs, mcts

    if mcts is None:
        nnet = GeeseNet()
        state_dict = pickle.loads(zlib.decompress(base64.b64decode(state_dict)))
        nnet.load_state_dict(state_dict)
        mcts = MCTS(nnet)

    env = get_env(obs, config)
    policy = mcts.get_policy(env, prev_obs)
    policy = remove_illegal_action(policy, prev_action)
    action = select_action(policy)

    prev_action = action
    prev_obs = obs

    return action

Appending to submission.py


# Write the weights on submission.py

In [7]:
import base64
import pickle
import zlib
import torch
from kaggle_environments import make

loaded_model = torch.load('temp/best.pth.tar', map_location='cpu')
state_dict = loaded_model['state_dict']

new_state_dict = {}
for key, value in state_dict.items():
    key = key.replace("module.", "")
    new_state_dict[key] = value

state_dict = base64.b64encode(zlib.compress(pickle.dumps(new_state_dict)))

with open('submission.py', 'r') as f:
    src = f.read()
src = src.replace("_STATE_DICT_", f"{state_dict}")
with open('submission.py', 'w') as f:
    f.write(src)

In [None]:
env = make("hungry_geese", debug=True)
env.run(["submission.py"] * 4)  # white, blue, green, red
env.render(mode="ipython", width=700, height=550)