In [10]:
%%writefile submission.py
import base64
import pickle
import zlib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.container import Sequential
# from torch.distributions.categorical import Categorical
from kaggle_environments.envs.hungry_geese.hungry_geese import Action


class FlattenExtractor(nn.Module):
    """Some Information about FlattenExtractor"""
    def __init__(self):
        super(FlattenExtractor, self).__init__()
        self.flatten = nn.Flatten(start_dim=1, end_dim=-1)

    def forward(self, x):
        x = self.flatten(x)
        return x


class MlpExtractor(nn.Module):
    """Some Information about MlpExtractor"""
    def __init__(self):
        super(MlpExtractor, self).__init__()
        self.shared_net = Sequential()
        self.policy_net = Sequential(
            nn.Linear(in_features=1083, out_features=64, bias=True),
            nn.Tanh(),
            nn.Linear(in_features=64, out_features=64, bias=True),
            nn.Tanh()
        )
        self.value_net = Sequential(
            nn.Linear(in_features=1083, out_features=64, bias=True),
            nn.Tanh(),
            nn.Linear(in_features=64, out_features=64, bias=True),
            nn.Tanh()
        )

    def forward(self, x):
        p = self.policy_net(x)
        v = self.value_net(x)
        return p, v


class ActorCriticPolicy(nn.Module):
    """Some Information about ActorCriticPolicy"""
    def __init__(self):
        super(ActorCriticPolicy, self).__init__()
        self.feature_extractor = FlattenExtractor()
        self.mlp_extractor = MlpExtractor()
        self.action_net = nn.Linear(in_features=64, out_features=4, bias=True)
        self.value_net = nn.Linear(in_features=64, out_features=1, bias=True)

    def forward(self, x):
        x = self.feature_extractor(x)
        p, v = self.mlp_extractor(x)
        p = self.action_net(p)
        v = self.value_net(v)
        return p, v


state_dict = _STATE_DICT_

state_dict = pickle.loads(zlib.decompress(base64.b64decode(state_dict)))
model = ActorCriticPolicy()
model.load_state_dict(state_dict)
model.eval()

obs_prev = None
act_prev = None


# Modified from https://www.kaggle.com/yuricat/smart-geese-trained-by-reinforcement-learning
def process_obs(obs):
    global act_prev
    # my previous action
    obs_index = obs.index
    a = np.zeros(5, dtype=np.float32)
    if act_prev is not None:
        a[act_prev - 1] = 1
    a[-1] = obs.step % 40 / 40

    b = np.zeros((14, 7 * 11), dtype=np.float32)
    b[-1] = 1  # empty cells

    for p, pos_list in enumerate(obs['geese']):
        # head position
        for pos in pos_list[:1]:
            b[0 + (p - obs_index) % 4, pos] = 1
            b[-1, pos] = 0
        # tip position
        for pos in pos_list[-1:]:
            b[4 + (p - obs_index) % 4, pos] = 1
            b[-1, pos] = 0
        # whole position
        for pos in pos_list:
            b[8 + (p - obs_index) % 4, pos] = 1
            b[-1, pos] = 0

    # food
    for pos in obs['food']:
        b[-2, pos] = 1
        b[-1, pos] = 0

    c = np.concatenate((a, b.reshape(-1)))
    return c


def agent(obs, conf):
    global model, obs_prev, act_prev
    obs_backup = obs
    obs = process_obs(obs).reshape(1, -1)
    obs = torch.from_numpy(obs)
    p, v = model(obs)
    p = p.squeeze()
    p = F.softmax(p, dim=0)
    if act_prev is not None:
        act_oppo = (act_prev + 1) % 4 + 1
        p[act_oppo - 1] = 0
        p /= p.sum()
    # print(p)
    action = p.squeeze().argmax().item() + 1
    # action = Categorical(p).sample().item() + 1
    obs_prev = obs_backup
    act_prev = action
    return Action(action).name

Overwriting submission.py


In [11]:
import base64
import pickle
import zlib
from stable_baselines3 import PPO

model_path = 'models1/model_50000_steps.zip'
model = PPO.load(model_path)
print(model.policy)

state_dict = model.policy.to('cpu').state_dict()
state_dict = base64.b64encode(zlib.compress(pickle.dumps(state_dict)))

with open('submission.py', 'r') as file:
    src = file.read()
src = src.replace("_STATE_DICT_", f"{state_dict}")
with open('submission.py', 'w') as file:
    file.write(src)

ActorCriticPolicy(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (mlp_extractor): MlpExtractor(
    (shared_net): Sequential()
    (policy_net): Sequential(
      (0): Linear(in_features=1083, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=1083, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=4, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)


In [12]:
from kaggle_environments import make

env = make("hungry_geese", debug=True)

env.run(["submission.py", "submission.py", "submission.py", "submission.py"])
# env.render(mode="ipython", width=800, height=700)

Goose Collision: WEST
Goose Collision: EAST
Goose Starved: Action.WEST


  'reward': 2201,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 1},
   'status': 'DONE'},
  {'action': 'NORTH',
   'reward': 5001,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 2},
   'status': 'ACTIVE'},
  {'action': 'NORTH',
   'reward': 2201,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 3},
   'status': 'DONE'}],
 [{'action': 'EAST',
   'reward': 5101,
   'info': {},
   'observation': {'remainingOverageTime': 60,
    'step': 50,
    'geese': [[19], [], [32], []],
    'food': [11, 38],
    'index': 0},
   'status': 'ACTIVE'},
  {'action': 'NORTH',
   'reward': 2201,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 1},
   'status': 'DONE'},
  {'action': 'NORTH',
   'reward': 5101,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 2},
   'status': 'ACTIVE'},
  {'action': 'NORTH',
   'reward': 2201,
   'info': {},
   'observation': {'remainingOverageTime': 60, 'index': 

In [20]:
!kaggle competitions submit -c hungry-geese -f submission.py -m "PPO MlpPolicy only self-play"

100%|█████████████████████████████████████████| 907k/907k [00:07<00:00, 129kB/s]
Successfully submitted to Hungry Geese