Referense: [Train a MARIO-Playing RL Agent](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)

In [None]:
import random
import copy
from collections import deque
from pathlib import Path

import pandas as pd
import numpy as np

import torch
from torch import nn

import gym
from gym import spaces
from gym.spaces import Box

from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, adjacent_positions, row_col, translate, min_distance
from kaggle_environments import make, evaluate

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

# Supervised

I'm using greedy risk averse goose from: [Greedy-Goose](https://www.kaggle.com/victordelafuente/greedy-risk-averse-improved-dead-end-detection)

In [None]:
%%writefile greedy-goose.py
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col, translate, adjacent_positions, min_distance
import random as rand
from enum import Enum, auto


def opposite(action):
    if action == Action.NORTH:
        return Action.SOUTH
    if action == Action.SOUTH:
        return Action.NORTH
    if action == Action.EAST:
        return Action.WEST
    if action == Action.WEST:
        return Action.EAST
    raise TypeError(str(action) + " is not a valid Action.")

    

#Enconding of cell content to build states from observations
class CellState(Enum):
    EMPTY = 0
    FOOD = auto()
    GOOSE = auto()


#This class encapsulates mos of the low level Hugry Geese stuff    
class BornToNotMedalv2:    
    def __init__(self):
        self.DEBUG=False
        self.rows, self.columns = -1, -1        
        self.my_index = -1
        self.my_head, self.my_tail = -1, -1
        self.geese = []
        self.heads = []
        self.tails = []
        self.food = []
        self.cell_states = []
        self.actions = [action for action in Action]
        self.previous_action = None
        self.step = 1

        
    def _adjacent_positions(self, position):
        return adjacent_positions(position, self.columns, self.rows)
 

    def _min_distance_to_food(self, position, food=None):
        food = food if food!=None else self.food
        return min_distance(position, food, self.columns)

    
    def _row_col(self, position):
        return row_col(position, self.columns)
    
    
    def _translate(self, position, direction):
        return translate(position, direction, self.columns, self.rows)
        
        
    def preprocess_env(self, observation, configuration):
        observation = Observation(observation)
        configuration = Configuration(configuration)
        
        self.rows, self.columns = configuration.rows, configuration.columns        
        self.my_index = observation.index
        self.hunger_rate = configuration.hunger_rate
        self.min_food = configuration.min_food

        self.my_head, self.my_tail = observation.geese[self.my_index][0], observation.geese[self.my_index][-1]        
        self.my_body = [pos for pos in observation.geese[self.my_index]]

        
        self.geese = [g for i,g in enumerate(observation.geese) if i!=self.my_index  and len(g) > 0]
        self.geese_cells = [pos for g in self.geese for pos in g if len(g) > 0]
        
        self.occupied = [p for p in self.geese_cells]
        self.occupied.extend([p for p in observation.geese[self.my_index]])
        
        
        self.heads = [g[0] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 0]
        self.bodies = [pos  for i,g in enumerate(observation.geese) for pos in g[1:-1] if i!=self.my_index and len(g) > 2]
        self.tails = [g[-1] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 1]
        self.food = [f for f in observation.food]
        
        self.adjacent_to_heads = [pos for head in self.heads for pos in self._adjacent_positions(head)]
        self.adjacent_to_bodies = [pos for body in self.bodies for pos in self._adjacent_positions(body)]
        self.adjacent_to_tails = [pos for tail in self.tails for pos in self._adjacent_positions(tail)]
        self.adjacent_to_geese = self.adjacent_to_heads + self.adjacent_to_bodies
        self.danger_zone = self.adjacent_to_geese
        
        #Cell occupation
        self.cell_states = [CellState.EMPTY.value for _ in range(self.rows*self.columns)]
        for g in self.geese:
            for pos in g:
                self.cell_states[pos] = CellState.GOOSE.value
        for pos in self.heads:
                self.cell_states[pos] = CellState.GOOSE.value
        for pos in self.my_body:
            self.cell_states[pos] = CellState.GOOSE.value
                
        #detect dead-ends
        self.dead_ends = []
        for pos_i,_ in enumerate(self.cell_states):
            if self.cell_states[pos_i] != CellState.EMPTY.value:
                continue
            adjacent = self._adjacent_positions(pos_i)
            adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
            num_blocked = sum(adjacent_states)
            if num_blocked>=(CellState.GOOSE.value*3):
                self.dead_ends.append(pos_i)
        
        #check for extended dead-ends
        new_dead_ends = [pos for pos in self.dead_ends]
        while new_dead_ends!=[]:
            for pos in new_dead_ends:
                self.cell_states[pos]=CellState.GOOSE.value
                self.dead_ends.append(pos)
            
            new_dead_ends = []
            for pos_i,_ in enumerate(self.cell_states):
                if self.cell_states[pos_i] != CellState.EMPTY.value:
                    continue
                adjacent = self._adjacent_positions(pos_i)
                adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
                num_blocked = sum(adjacent_states)
                if num_blocked>=(CellState.GOOSE.value*3):
                    new_dead_ends.append(pos_i)                                    
        
                
    def strategy_random(self, observation, configuration):
        if self.previous_action!=None:
            action = rand.choice([action for action in Action if action!=opposite(self.previous_action)])
        else:
            action = rand.choice([action for action in Action])
        self.previous_action = action
        return action.name
                        
                        
    def safe_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.adjacent_to_heads) and (future_position not in self.dead_ends)
    
    
    def valid_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.dead_ends)    

    
    def free_position(self, future_position):
        return (future_position not in self.occupied) 
    
                        
    def strategy_random_avoid_collision(self, observation, configuration):
        dead_end_cell = False
        free_cell = True
        actions = [action 
                   for action in Action 
                   for future_position in [self._translate(self.my_head, action)]
                   if self.valid_position(future_position)] 
        if self.previous_action!=None:
            actions = [action for action in actions if action!=opposite(self.previous_action)] 
        if actions==[]:
            dead_end_cell = True
            actions = [action 
                       for action in Action 
                       for future_position in [self._translate(self.my_head, action)]
                       if self.free_position(future_position)]
            if self.previous_action!=None:
                actions = [action for action in actions if action!=opposite(self.previous_action)] 
            #no alternatives
            if actions==[]:
                free_cell = False
                actions = self.actions if self.previous_action==None else [action for action in self.actions if action!=opposite(self.previous_action)] 

        action = rand.choice(actions)
        self.previous_action = action
        if self.DEBUG:
            aux_pos = self._row_col(self._translate(self.my_head, self.previous_action))
            dead_ends = "" if not dead_end_cell else f', dead_ends={[self._row_col(p1) for p1 in self.dead_ends]}, occupied={[self._row_col(p2) for p2 in self.occupied]}'
            if free_cell:
                print(f'{id(self)}({self.step}): Random_ac_move {action.name} to {aux_pos} dead_end={dead_end_cell}{dead_ends}', flush=True)
            else:
                print(f'{id(self)}({self.step}): Random_ac_move {action.name} to {aux_pos} free_cell={free_cell}', flush=True)
        return action.name
    
    
    def strategy_greedy_avoid_risk(self, observation, configuration):        
        actions = {  
            action: self._min_distance_to_food(future_position)
            for action in Action 
            for future_position in [self._translate(self.my_head, action)]
            if self.safe_position(future_position)
        }
  
        if self.previous_action!=None:
            actions.pop(opposite(self.previous_action), None)
        if any(actions):
            action = min(actions.items(), key=lambda x: x[1])[0]
            self.previous_action = action
            if self.DEBUG:
                aux_pos = self._row_col(self._translate(self.my_head, self.previous_action))
                print(f'{id(self)}({self.step}): Greedy_ar_move {action.name} to {aux_pos}', flush=True)
            return action.name
        else:
            return self.strategy_random_avoid_collision(observation, configuration)
    
    
    #Redefine this method
    def agent_strategy(self, observation, configuration):
        action = self.strategy_greedy_avoid_risk(observation, configuration)
        return action
    
    
    def agent_do(self, observation, configuration):
        self.preprocess_env(observation, configuration)
        move = self.agent_strategy(observation, configuration)
        self.step += 1
        #if self.DEBUG:
        #    aux_pos = self._translate(self.my_head, self.previous_action), self._row_col(self._translate(self.my_head, self.previous_action))
        #    print(f'{id(self)}({self.step}): Move {move} to {aux_pos} internal_vars->{vars(self)}', flush=True)
        return move

    
    
def agent_singleton(observation, configuration):
    global gus    
    
    try:
        gus
    except NameError:
        gus = BornToNotMedalv2()
            
    action = gus.agent_do(observation, configuration)

    
    return action


    

In [None]:
%%writefile choose_random.py
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col, translate, adjacent_positions, min_distance
import random as rand
from enum import Enum, auto

import random
import copy
oppo_dict = {1:3, 2:4, 3:1, 4:2}
last_action = None
action_list = [1,2,3,4]
def agent(obs, conf):
    global last_action
    action = random.randint(0, 2)
    allow = copy.deepcopy(action_list)
    if last_action is not None:
        allow.remove(last_action)
    
    action = allow[action]
    last_action = oppo_dict[action]
    return Action(action).name

# Environment

The standard environment, plus a converting process for the observed values.

In [None]:
class GeeseEnv(gym.Env):
    def __init__(self, opponent=["random"], debug=False):
        """Hungry Geese Environment

        Args:
            opponent (list, optional): Survived Agent. Defaults to ["random"].
            debug (bool, optional): Whether to output an debug. Defaults to False.
        """
        super(GeeseEnv, self).__init__()
        
        # Number of environments in the vectorized environment.
        self.num_envs = 1

        # self.num_previous_observations = 1
        self.debug=debug

        # Permitted actions
        # ['NORTH', 'EAST', 'SOUTH', 'WEST']
        self.actions = [action.name for action in Action]
        # Defined Action Space(Must)
        self.action_space = spaces.Discrete(len(self.actions))

        # Environment and Configuration
        self.env = make("hungry_geese", debug=self.debug)
        self.rows = self.env.configuration.rows
        self.columns = self.env.configuration.columns
        self.hunger_rate = self.env.configuration.hunger_rate
        self.min_food = self.env.configuration.min_food

        # Defined Opponent
        self.trainer = self.env.train([None, *opponent])
        
        # Observation Space(Must?)
        # Defined value range and shape in output observation 
        self.observation_space = Box(low=0, high=1, shape=(13, 7, 11), dtype=np.uint8)
        
        self.length = 1
    
    def step(self, action):
        """
        Input agent action, Output observation after players action

        Args:
            action (int): 

        Returns:
            np.ndarray: observation (same GeeseNet Input shape)
            int: reward
            done: whether end game or not
            dict: env information
        """
        action = self.actions[action]
        obs, reward, done, info = self.trainer.step(action)
        conv_obs, conv_reward = self.convert(obs)
        return conv_obs, conv_reward, done, info
    
    def reset(self):
        """Reset Environment

        Returns:
            np.ndarray: observation
        """
        self.length = 1
        obs = self.trainer.reset()
        conv_obs, _ = self.convert(obs)
        return conv_obs
    
    def convert(self, observation):
        """Convert Observation

        Args:
            observation (dict): Output observation at default environment

        Returns:
            np.ndarray: converted observation (same GeeseNet Input shape)
        """
        index = observation["index"]
        step = observation["step"]
        geese = observation["geese"]
        food = observation["food"]

        # remain = len([g for g in geese if len(g) > 0])

        mappings = np.zeros((13, 77), dtype=int)

        # MY GEESE
        my_geese = geese[index]
        if len(my_geese) > 0:
            my_geese_head = my_geese[0]
            my_geese_body = my_geese
            my_geese_tail = my_geese[-1]

            mappings[0][my_geese_head] = 1
            for mgb in my_geese_body:
                mappings[1][mgb] = 1
            mappings[2][my_geese_tail] = 1

        # OP GEESE
        count = 1
        for i in range(len(geese)):
            if i == index:
                continue

            op_geese = geese[i]
            if len(op_geese) > 0:
                op_geese_head = op_geese[0]
                op_geese_body = op_geese
                op_geese_tail = op_geese[-1]

                mappings[count * 3][op_geese_head] = 1
                for ogb in op_geese_body:
                    mappings[count * 3 + 1][ogb] = 1
                mappings[count * 3 + 2][op_geese_tail] = 1

            count += 1

        # FOOD
        for f in food:
            mappings[12][f] = 1
        
        length = len(geese[index])
        if length > self.length:
            self.length = length
            reward = 50
        elif length == self.length:
            reward = 5
        else:
            self.length = length
            reward = -50
        
            
        remain = len([c for idx, c in enumerate(geese) if (len(c) > 0) and (idx != index)])
#         reward = length * (3 - remain) - step/200
#         if (remain != 0) and (length == 0):
#             reward = -1
        if (remain == 0) and (length > 0):
            reward = 1000
        elif (remain != 0) and (length == 0):
            reward = -1000
            
            
        return mappings.reshape(-1, 7, 11), reward

# Trainer

In [None]:
class Geese:
    def __init__(self, state_dim, action_dim, save_dir, load_weight):
        """Training Class

        Args:
            state_dim (np.ndarray): observation shape
            action_dim (int): action length 
            save_dir (str): save path (log, checkpoint)
        """
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir
        
        self.use_cuda = torch.cuda.is_available()
        
        self.net = GeeseNet(self.state_dim, self.action_dim).float()
#         self.net.load_state_dict(torch.load(load_weight))
        if self.use_cuda:
            self.net = self.net.to(device="cuda")
        
        # Rate of search to be performed
        self.exploration_rate = 1
        self.exploration_rate_decay = 0.9999
        self.exploration_rate_min = 0.005
        self.curr_step = 0
        
        self.save_every = 5e5
            
        self.memory = deque(maxlen=100000)
        self.batch_size=32

        self.gamma = 0.9

        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.025)
        self.loss_fn = torch.nn.SmoothL1Loss()
        self.burnin = 1e4  # Minimum number of steps required to train an experience.
        self.learn_every = 3  # Number of steps to indicate when to update Q_online
        self.sync_every = 1e4  # Number of steps to indicate when to synchronize Q_target & Q_online

    
    def act(self, state):
        """Predict action from observation.
        (Random with a certain probability)

        Args:
            state (np.ndarray): observation

        Returns:
            int: action
        """
        if np.random.rand() < self.exploration_rate:
            action = np.random.randint(self.action_dim)
        
        else:
            state = state.__array__()
            if self.use_cuda:
                state = torch.tensor(state).cuda()
            else:
                state = torch.tensor(state)
            state = state.unsqueeze(0).float()
            action_values = self.net(state, model="online")
            action = torch.argmax(action_values, axis=1).item()
        
        # update exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
        self.curr_step += 1
        
        return action

    def cache(self, state, next_state, action, reward, done):
        state = state.__array__()
        next_state = next_state.__array__()
        if self.use_cuda:
            state = torch.tensor(state).cuda()
            next_state = torch.tensor(next_state).cuda()
            action = torch.tensor([action]).cuda()
            reward = torch.tensor([reward]).cuda()
            done = torch.tensor([done]).cuda()
        else:
            state = torch.tensor(state)
            next_state = torch.tensor(next_state)
            action = torch.tensor([action])
            reward = torch.tensor([reward])
            done = torch.tensor([done])       

        self.memory.append((state, next_state, action, reward, done))

    def recall(self):
        batch = random.sample(self.memory, self.batch_size)
        state, next_state, action, reward, done = map(torch.stack, zip(*batch))
        return state,  next_state, action.squeeze(), reward.squeeze(), done.squeeze()
    
    def td_estimate(self, state, action):
        current_Q = self.net(state, model="online")[
            np.arange(0, self.batch_size), action
        ]  # Q_online(s,a)
        return current_Q

    @torch.no_grad()
    def td_target(self, reward, next_state, done):
        next_state_Q = self.net(next_state, model="online")
        best_action = torch.argmax(next_state_Q, axis=1)
        next_Q = self.net(next_state, model="target")[
            np.arange(0, self.batch_size), best_action
        ]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

    def update_Q_online(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self.net.target.load_state_dict(self.net.online.state_dict())

    def save(self):
        save_path = (
            self.save_dir / f"geese_net_{int(self.curr_step // self.save_every)}.chkpt"
        )
        torch.save(
            self.net.state_dict(),
            save_path,
        )
        print(f"GeeseNet saved to {save_path} at step {self.curr_step}")

    def learn(self):
        if self.curr_step % self.sync_every == 0:
            self.sync_Q_target()

        if self.curr_step % self.save_every == 0:
            self.save()

        if self.curr_step < self.burnin:
            return None, None

        if self.curr_step % self.learn_every != 0:
            return None, None

        # Sampling Memory
        state, next_state, action, reward, done = self.recall()

        # Get TD Estimator
        state = state.float()
        td_est = self.td_estimate(state, action)

        #Get TD Target
        next_state = next_state.float()
        td_tgt = self.td_target(reward, next_state, done)

        # Back propagation of loss to Q_online
        loss = self.update_Q_online(td_est, td_tgt)

        return (td_est.mean().item(), loss)

# Set Environment

In [None]:
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, adjacent_positions, row_col, translate, min_distance
from kaggle_environments import make
opponent = ['choose_random.py', 'choose_random.py', 'choose_random.py']
env = GeeseEnv(opponent=opponent)


# Model(NN)

In [None]:
class PartConv2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_my_3_1 = nn.Conv2d(3, 6, kernel_size=4)
        self.conv_my_3_2 = nn.Conv2d(6, 2, kernel_size=2)
        self.bn_my = nn.BatchNorm2d(2)
        self.linear_my = nn.Linear(7, 6)
        
        self.conv_op_3_1 = nn.Conv2d(3, 6, kernel_size=4)
        self.conv_op_3_2 = nn.Conv2d(6, 2, kernel_size=2)
        self.bn_op = nn.BatchNorm2d(2)
        self.linear_op = nn.Linear(7, 6)
        
        self.conv_fd_1_1 = nn.Conv2d(1, 6, kernel_size=4)
        self.conv_fd_1_2 = nn.Conv2d(6, 2, kernel_size=2)      
        self.bn_fd = nn.BatchNorm2d(2)
        self.linear_fd = nn.Linear(7, 6)
        
    def forward(self, input):
        my = input[:, 0:3]
        op1 = input[:, 3:6]
        op2 = input[:, 6:9]
        op3 = input[:, 9:12]
        food = input[:, 12:]
        
        my = self.conv_my_3_1(my)
        my = self.conv_my_3_2(my)
        my = self.bn_my(my)
        
        my = self.linear_my(my)
        
        opes = None
        for op in [op1, op2, op3]:
            op = self.conv_op_3_1(op)
            op = self.conv_op_3_2(op)
            op = self.bn_op(op)
            if opes is None:
                opes = op
            else:
                opes += op

        opes = self.linear_op(opes)
        
        food = self.conv_fd_1_1(food)
        food = self.conv_fd_1_2(food)
        food = self.bn_fd(food)
        food = self.linear_fd(food)
        x = torch.cat((my, opes, food), dim=1)
        return x

    

class GeeseNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        print(input_dim)
        d, r, c = input_dim

        assert r == 7
        assert c == 11

        self.online = nn.Sequential(
            PartConv2d(),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(108, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 8),
            nn.LeakyReLU(),
            nn.Linear(8, output_dim),
        )

        self.target = copy.deepcopy(self.online)

        for p in self.target.parameters():
            p.requires_grad = False
    
    def forward(self, input, model):
        if model == "online":
            return self.online(input)
        elif model == "target":
            return self.target(input)

# Train

In [None]:
import numpy as np
import time, datetime
import matplotlib.pyplot as plt


class MetricLogger:
    def __init__(self, save_dir):
        self.save_log = save_dir / "log"
        with open(self.save_log, "w") as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
        self.ep_rewards_plot = save_dir / "reward_plot.jpg"
        self.ep_lengths_plot = save_dir / "length_plot.jpg"
        self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
        self.ep_avg_qs_plot = save_dir / "q_plot.jpg"

        # 指標の履歴
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []

        # reacord()が呼び出されるたびに追加される移動平均
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []

        # 現在のエピソードの指標
        self.init_episode()

        # 時間を記録
        self.record_time = time.time()

    def log_step(self, reward, loss, q):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1

    def log_episode(self):
        "エピソード終了時の記録"
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)
        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

    def record(self, episode, epsilon, step):
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)

        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - "
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
        )

        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d}{step:8d}{epsilon:10.3f}"
                f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
                f"{time_since_last_record:15.3f}"
                f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
            )

        for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
            plt.plot(getattr(self, f"moving_avg_{metric}"))
            plt.savefig(getattr(self, f"{metric}_plot"))
            plt.clf()

use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()

save_dir = Path(".")

geese = Geese(state_dim=(13, 7, 11), action_dim=env.action_space.n, save_dir=save_dir, load_weight="../input/geese-weight/geese_net_0.chkpt")

logger = MetricLogger(save_dir)

episodes = 50000
for e in range(episodes):

    state = env.reset()

    while True:

        action = geese.act(state)
        next_state, reward, done, info = env.step(action)
        geese.cache(state, next_state, action, reward, done)
        q, loss = geese.learn()
        logger.log_step(reward, loss, q)
        state = next_state
        if done:
            break

    logger.log_episode()

    if e % 5000 == 0:
        logger.record(episode=e, epsilon=geese.exploration_rate, step=geese.curr_step)
geese.save()

# Submission

In [None]:
%%writefile dqnv1.py

from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col
import torch
from torch import nn
import numpy as np
from collections import deque
import random
import copy
import gym
from gym import spaces
from gym.spaces import Box
from pathlib import Path
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, adjacent_positions, row_col, translate, min_distance
from kaggle_environments import make
import os

class PartConv2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_my_3_1 = nn.Conv2d(3, 6, kernel_size=4)
        self.conv_my_3_2 = nn.Conv2d(6, 2, kernel_size=2)
        self.bn_my = nn.BatchNorm2d(2)
        self.linear_my = nn.Linear(7, 6)
        
        self.conv_op_3_1 = nn.Conv2d(3, 6, kernel_size=4)
        self.conv_op_3_2 = nn.Conv2d(6, 2, kernel_size=2)
        self.bn_op = nn.BatchNorm2d(2)
        self.linear_op = nn.Linear(7, 6)
        
        self.conv_fd_1_1 = nn.Conv2d(1, 6, kernel_size=4)
        self.conv_fd_1_2 = nn.Conv2d(6, 2, kernel_size=2)      
        self.bn_fd = nn.BatchNorm2d(2)
        self.linear_fd = nn.Linear(7, 6)
        
    def forward(self, input):
        my = input[:, 0:3]
        op1 = input[:, 3:6]
        op2 = input[:, 6:9]
        op3 = input[:, 9:12]
        food = input[:, 12:]
        
        my = self.conv_my_3_1(my)
        my = self.conv_my_3_2(my)
        my = self.bn_my(my)
        
        my = self.linear_my(my)
        
        opes = None
        for op in [op1, op2, op3]:
            op = self.conv_op_3_1(op)
            op = self.conv_op_3_2(op)
            op = self.bn_op(op)
            if opes is None:
                opes = op
            else:
                opes += op

        opes = self.linear_op(opes)
        
        food = self.conv_fd_1_1(food)
        food = self.conv_fd_1_2(food)
        food = self.bn_fd(food)
        food = self.linear_fd(food)
        x = torch.cat((my, opes, food), dim=1)
        return x

    

class GeeseNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        print(input_dim)
        d, r, c = input_dim

        assert r == 7
        assert c == 11

        self.online = nn.Sequential(
            PartConv2d(),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(108, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 8),
            nn.LeakyReLU(),
            nn.Linear(8, output_dim),
        )

        self.target = copy.deepcopy(self.online)

        for p in self.target.parameters():
            p.requires_grad = False
    
    def forward(self, input, model):
        if model == "online":
            return self.online(input)
        elif model == "target":
            return self.target(input)
        
def convert_numpy(observation):
    index = observation["index"]
    step = observation["step"]
    geese = observation["geese"]
    food = observation["food"]

    remain = len([g for g in geese if len(g) > 0])

    mappings = np.zeros((len(geese) * 3 + 1, 77), dtype=int)

    # MY GEESE
    my_geese = geese[index]
    if len(my_geese) == 0:
        return mappings.reshape(-1, 7, 11)
    my_geese_head = my_geese[0]
    my_geese_body = my_geese
    my_geese_tail = my_geese[-1]

    mappings[0][my_geese_head] = 1
    for mgb in my_geese_body:
        mappings[1][mgb] = 1
    mappings[2][my_geese_tail] = 1

    # OP GEESE
    count = 1
    for i in range(len(geese)):
        if i == index:
            continue

        op_geese = geese[index]
        if len(op_geese) > 0:
            op_geese_head = op_geese[0]
            op_geese_body = op_geese
            op_geese_tail = op_geese[-1]

            mappings[count * 3][op_geese_head] = 1
            for ogb in op_geese_body:
                mappings[count * 3 + 1][ogb] = 1
            mappings[count * 3 + 2][op_geese_tail] = 1

        count += 1

    # FOOD
    for f in food:
        mappings[12][f] = 1

    return mappings.reshape(-1, 7, 11)

def my_dqn(observation, configuration):
    global model, obs_prep, last_action, last_observation, previous_observation

    tgz_agent_path = '/kaggle_simulations/agent/'
    normal_agent_path = '.'
    model_name = "geese_net_0"
    num_previous_observations = 1
    epsilon = .1
    init = False
    debug = False

    try:
        model
    except NameError:
        init=True
    else:
        if model==None:
            init = True 
            initializing
    if init:
        #initializations
        defaults = [configuration.rows,
                    configuration.columns,
                    configuration.hunger_rate,
                    configuration.min_food]

        model = GeeseNet((13, 7, 11), 4)
        last_action = -1
        
        file_name = os.path.join(normal_agent_path, f'{model_name}.chkpt')
        if not os.path.exists(file_name):
            file_name = os.path.join(tgz_agent_path, f'{model_name}.chkpt')
            
        model.load_state_dict(torch.load(file_name))

    conv_obs = convert_numpy(observation)
    tensor_obs = torch.tensor(conv_obs).unsqueeze(0).float()
    n_out = model(tensor_obs, "online") 
    pred = torch.argmax(n_out, axis=1).item()
    actions = [action.name for action in Action]
        
    last_action = actions[pred]
    return last_action #return action

In [None]:
import kaggle_environments
from kaggle_environments import make, evaluate, utils

env = make("hungry_geese", debug=True)

env.reset()
env.run(["dqnv1.py", 'choose_random.py', 'choose_random.py', 'choose_random.py'])
env.render(mode="ipython", width=800, height=700)

In [None]:


result = evaluate(
    "hungry_geese",
    ["dqnv1.py", 'choose_random.py', 'choose_random.py', 'choose_random.py'],
    num_episodes=100,
)

In [None]:
result_df = pd.DataFrame(result, columns=["Submission", "Opponent1", "Opponent2", "Opponent3"])
result_rank_df = result_df.rank(ascending=False, axis=1, method="min")

sns.heatmap(pd.concat([
    result_rank_df["Submission"].value_counts(),
    result_rank_df["Opponent1"].value_counts(),
    result_rank_df["Opponent2"].value_counts(),
    result_rank_df["Opponent3"].value_counts()
], axis=1), cmap='Oranges', annot=True)

In [None]:
!tar cvzf submission.tar.gz dqnv1.py geese_net_0.chkpt