# Hungry geese: Reinforcement Learning with stable-baselines3


This notebook is meant as a lab for testing Reinforcement Learning solutions on Kaggle's "hungry-geese" environment using stable-baselines3 for training but only pytorch later for submission agent. Consider this notebook as a general guide on how to work with stable-baselines3 and torch to implement a custom policy network for "hungry-geese". Play with it and modify to your liking ;-)

Check: 
* stable-baselines3 docs: https://stable-baselines3.readthedocs.io/en/master/
* DQN paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
* Thanks for the example on how to use pytorch custom policy to: https://www.kaggle.com/toshikazuwatanabe/connect4-make-submission-with-stable-baselines3


## Important note

Training this method for enough steps on kaggle's notebook environment sometimes fails on game render (if you see error 137 after a failed run means "out of memory", try changing replay buffer size). I think DQN needs long trainings on this environment so train on your own server for better results.

Enjoy and please notify me any mistakes you detect, thanks!


## Changelog
* v125: Fixed bug on observation space had wrong data type after changed to [-1,1] range. Thanks to "Mahesh Abnave" (@maheshabnave999) for reporting it!
* v122: Fixed bug after observation space changed to [-1,1] range. Thanks to "Mahesh Abnave" (@maheshabnave999) for reporting it!
* v119:
    * Removed negative reward for just moving, ends up turning into suicidal agent. Now reward for move is 1.1, eat is 2
    * Added to observation cells marked with RISK (-0.5) for: dead_ends and adjacent to heads
* v116:
    * Changed observation to properly mark as visited previous position when agent's lenght is only one
    * Fixed bug on previous move negative reward, it was checked against the wrong action
* v100: fixed some bug on submission agent
    * fixed final weight adjutment
    * fixed main.py used old version of observation processor
* v95: fixed bug on final agent output weights adjustment
* v92: observation input for policy is been simplified to grid of values [0,1] see environment description for encoding
* v90: Major bugfix, fixed missing save/load of policy network (only feature_extractor net was saved/loaded and used).
    * Now code is simpler since you just specify layers and sizes on creation
* v84: added observation centered on head, various bugfixes, changed 2 of the opponents to vanilla greedy agents and some other minor modifications
* v75: FIXED HUGE BUG on environment reward that led to the agent assuming every move was rewarded, sorry for not detecting it earlier :-(
* v59: removed -100 rewards (seemed to prevent from learning)
* v54: added policy action weight modification based on common sense (collision is bad, same for dead ends and reduce weight of dangerous moves)
* v46: modified reward values to
* v41: added back opposite action rewards for opposite move
* v40: added reward based on final ranking on game end
* v39:
    * KNOWN ISSUES: policy still seems to converge on training to predicting almost always the same action :-( 
    * added positive reward when surviving a game
    * added extra -1 reward when losing
    * removed negative reward for opposite move
    * removed negative reward for moving away from previously nearest food (might have been eaten or be a bad move to go there)
    * (fixed bug) removed a mistake on a negative reward 
* v37: fixed bug on negative cs_reward incorporation
* v25: added previous n grid observations to dqp policy (n=2), total input dimensions are n-previous + 1 for current observation
* v23: final agent now avoids performing the opposite move to the last taken
* v21: added dropout and changed action choice on agent to probabilistic based on policy outputs
* v17: fixes on policy training


## Preliminary preparations

* Install library
* Save our improved training opponent to disk. Here I'm using greedy risk averse goose from: https://www.kaggle.com/victordelafuente/hungry-geese-template-class-greedy-risk-averse
* Part of the same template class code is used, with some minor modifications, at the "utility functions" section



In [None]:
!pip install stable-baselines3 kaggle-environments > /dev/null 2>&1

In [None]:
%%writefile greedy-goose.py
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col, translate, adjacent_positions, min_distance
import random as rand
from enum import Enum, auto


def opposite(action):
    if action == Action.NORTH:
        return Action.SOUTH
    if action == Action.SOUTH:
        return Action.NORTH
    if action == Action.EAST:
        return Action.WEST
    if action == Action.WEST:
        return Action.EAST
    raise TypeError(str(action) + " is not a valid Action.")

    

#Enconding of cell content to build states from observations
class CellState(Enum):
    EMPTY = 0
    FOOD = auto()
    GOOSE = auto()


#This class encapsulates mos of the low level Hugry Geese stuff    
#This class encapsulates mos of the low level Hugry Geese stuff    
class BornToNotMedalv2:    
    def __init__(self):
        self.DEBUG=False
        self.rows, self.columns = -1, -1        
        self.my_index = -1
        self.my_head, self.my_tail = -1, -1
        self.geese = []
        self.heads = []
        self.tails = []
        self.food = []
        self.cell_states = []
        self.actions = [action for action in Action]
        self.previous_action = None
        self.step = 1

        
    def _adjacent_positions(self, position):
        return adjacent_positions(position, self.columns, self.rows)
 

    def _min_distance_to_food(self, position, food=None):
        food = food if food!=None else self.food
        return min_distance(position, food, self.columns)

    
    def _row_col(self, position):
        return row_col(position, self.columns)
    
    
    def _translate(self, position, direction):
        return translate(position, direction, self.columns, self.rows)
        
        
    def preprocess_env(self, observation, configuration):
        observation = Observation(observation)
        configuration = Configuration(configuration)
        
        self.rows, self.columns = configuration.rows, configuration.columns        
        self.my_index = observation.index
        self.hunger_rate = configuration.hunger_rate
        self.min_food = configuration.min_food

        self.my_head, self.my_tail = observation.geese[self.my_index][0], observation.geese[self.my_index][-1]        
        self.my_body = [pos for pos in observation.geese[self.my_index][1:-1]]

        
        self.geese = [g for i,g in enumerate(observation.geese) if i!=self.my_index  and len(g) > 0]
        self.geese_cells = [pos for g in self.geese for pos in g if len(g) > 0]
        
        self.occupied = [p for p in self.geese_cells]
        self.occupied.extend([p for p in observation.geese[self.my_index]])
        
        
        self.heads = [g[0] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 0]
        self.bodies = [pos  for i,g in enumerate(observation.geese) for pos in g[1:-1] if i!=self.my_index and len(g) > 2]
        self.tails = [g[-1] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 1]
        self.food = [f for f in observation.food]
        
        self.adjacent_to_heads = [pos for head in self.heads for pos in self._adjacent_positions(head)]
        self.adjacent_to_bodies = [pos for body in self.bodies for pos in self._adjacent_positions(body)]
        self.adjacent_to_tails = [pos for tail in self.tails for pos in self._adjacent_positions(tail)]
        self.adjacent_to_geese = self.adjacent_to_heads + self.adjacent_to_bodies
        self.danger_zone = self.adjacent_to_geese
        
        #Cell occupation
        self.cell_states = [CellState.EMPTY.value for _ in range(self.rows*self.columns)]
        for g in self.geese:
            for pos in g:
                self.cell_states[pos] = CellState.GOOSE.value
        for pos in self.heads:
                self.cell_states[pos] = CellState.GOOSE.value
        for pos in self.my_body:
            self.cell_states[pos] = CellState.GOOSE.value
                
        #detect dead-ends
        self.dead_ends = []
        for pos_i,_ in enumerate(self.cell_states):
            if self.cell_states[pos_i] != CellState.EMPTY.value:
                continue
            adjacent = self._adjacent_positions(pos_i)
            adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
            num_blocked = sum(adjacent_states)
            if num_blocked>=(CellState.GOOSE.value*3):
                self.dead_ends.append(pos_i)
        
        #check for extended dead-ends
        new_dead_ends = [pos for pos in self.dead_ends]
        while new_dead_ends!=[]:
            for pos in new_dead_ends:
                self.cell_states[pos]=CellState.GOOSE.value
                self.dead_ends.append(pos)
            
            new_dead_ends = []
            for pos_i,_ in enumerate(self.cell_states):
                if self.cell_states[pos_i] != CellState.EMPTY.value:
                    continue
                adjacent = self._adjacent_positions(pos_i)
                adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
                num_blocked = sum(adjacent_states)
                if num_blocked>=(CellState.GOOSE.value*3):
                    new_dead_ends.append(pos_i)                                    
        
                
    def strategy_random(self, observation, configuration):
        if self.previous_action!=None:
            action = rand.choice([action for action in Action if action!=opposite(self.previous_action)])
        else:
            action = rand.choice([action for action in Action])
        self.previous_action = action
        return action.name
                        
                        
    def safe_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.adjacent_to_heads) and (future_position not in self.dead_ends)
    
    
    def valid_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.dead_ends)    

    
    def free_position(self, future_position):
        return (future_position not in self.occupied) 
    
                        
    def strategy_random_avoid_collision(self, observation, configuration):
        dead_end_cell = False
        free_cell = True
        actions = [action 
                   for action in Action 
                   for future_position in [self._translate(self.my_head, action)]
                   if self.valid_position(future_position)] 
        if self.previous_action!=None:
            actions = [action for action in actions if action!=opposite(self.previous_action)] 
        if actions==[]:
            dead_end_cell = True
            actions = [action 
                       for action in Action 
                       for future_position in [self._translate(self.my_head, action)]
                       if self.free_position(future_position)]
            if self.previous_action!=None:
                actions = [action for action in actions if action!=opposite(self.previous_action)] 
            #no alternatives
            if actions==[]:
                free_cell = False
                actions = self.actions if self.previous_action==None else [action for action in self.actions if action!=opposite(self.previous_action)] 

        action = rand.choice(actions)
        self.previous_action = action
        if self.DEBUG:
            aux_pos = self._row_col(self._translate(self.my_head, self.previous_action))
            dead_ends = "" if not dead_end_cell else f', dead_ends={[self._row_col(p1) for p1 in self.dead_ends]}, occupied={[self._row_col(p2) for p2 in self.occupied]}'
            if free_cell:
                print(f'{id(self)}({self.step}): Random_ac_move {action.name} to {aux_pos} dead_end={dead_end_cell}{dead_ends}', flush=True)
            else:
                print(f'{id(self)}({self.step}): Random_ac_move {action.name} to {aux_pos} free_cell={free_cell}', flush=True)
        return action.name
    
    
    def strategy_greedy_avoid_risk(self, observation, configuration):        
        actions = {  
            action: self._min_distance_to_food(future_position)
            for action in Action 
            for future_position in [self._translate(self.my_head, action)]
            if self.safe_position(future_position)
        }
  
        if self.previous_action!=None:
            actions.pop(opposite(self.previous_action), None)
        if any(actions):
            action = min(actions.items(), key=lambda x: x[1])[0]
            self.previous_action = action
            if self.DEBUG:
                aux_pos = self._row_col(self._translate(self.my_head, self.previous_action))
                print(f'{id(self)}({self.step}): Greedy_ar_move {action.name} to {aux_pos}', flush=True)
            self.previous_action = action
            return action.name
        else:
            return self.strategy_random_avoid_collision(observation, configuration)
    
    
    #Redefine this method
    def agent_strategy(self, observation, configuration):
        action = self.strategy_greedy_avoid_risk(observation, configuration)
        return action
    
    
    def agent_do(self, observation, configuration):
        self.preprocess_env(observation, configuration)
        move = self.agent_strategy(observation, configuration)
        self.step += 1
        #if self.DEBUG:
        #    aux_pos = self._translate(self.my_head, self.previous_action), self._row_col(self._translate(self.my_head, self.previous_action))
        #    print(f'{id(self)}({self.step}): Move {move} to {aux_pos} internal_vars->{vars(self)}', flush=True)
        return move

    
    
def agent_singleton(observation, configuration):
    global gus    
    
    try:
        gus
    except NameError:
        gus = BornToNotMedalv2()
            
    action = gus.agent_do(observation, configuration)

    
    return action

## Creating our custom environment for sample-baselines

Create our custom training environment connecting kaggle's hugry-geese env with gym.Env.

Main processes are done by:
* **_preprocess_env()**: basic preprocessing of observation (copied from base template class)
* **process_env_obs()**: processing raw environment observation to matrix-like spaces.Box, each cell has an int value in [-1,1]
    * food -> 1
    * empty -> 0.4
    * tails -> -.5
    * bodies -> -.8
    * heads -> -1
    * my_head -> 0, observation is always centered on it
    * if agent's length is 1, previous position is marked as BODY*.5 (-.4) otherwise without previous observations learning of oppositve moves for length 1 is confusing...
    * adjacent to heads and dead ends -> -0.5
    
* **common_sense_rewards()**: common sense rewards and penalties (avoid collisions, search food...)
* step: environment step

You can edit the number of previous observations (if set to zero remember to set move to eat common sense reward to 1 because lenght can't be compared with previuos move to detect growth).

In [None]:
import gym
from gym import spaces

from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, adjacent_positions, row_col, translate, min_distance
from kaggle_environments import make

from enum import Enum, auto
import numpy as np


class CellState(Enum):
    EMPTY = 0
    FOOD = auto()
    HEAD = auto()
    BODY = auto()
    TAIL = auto()
    MY_HEAD = auto()
    MY_BODY = auto()
    MY_TAIL = auto()
    ANY_GOOSE = auto()
    

class ObservationProcessor:
    
    def __init__(self, rows, columns, hunger_rate, min_food, debug=False, center_head=True):
        self.debug = debug
        self.rows, self.columns = rows, columns
        self.hunger_rate = hunger_rate
        self.min_food = min_food
        self.previous_action = -1
        self.last_action = -1
        self.last_min_distance_to_food = self.rows*self.columns #initial max value to mark no food seen so far
        self.center_head = center_head

    #***** BEGIN: utility functions ******   
    
    def opposite(self, action):
        if action == Action.NORTH:
            return Action.SOUTH
        if action == Action.SOUTH:
            return Action.NORTH
        if action == Action.EAST:
            return Action.WEST
        if action == Action.WEST:
            return Action.EAST
        raise TypeError(str(action) + " is not a valid Action.")
        
        
    def _adjacent_positions(self, position):
        return adjacent_positions(position, self.columns, self.rows)

    
    def _min_distance_to_food(self, position, food=None):
        food = food if food!=None else self.food
        return min_distance(position, food, self.columns)
    
    
    def _row_col(self, position):
        return row_col(position, self.columns)

    
    def _translate(self, position, direction):
        return translate(position, direction, self.columns, self.rows)     

    
    def _preprocess_env(self, obs):
        observation = Observation(obs)
        
        self.my_index = observation.index

        if len (observation.geese[self.my_index])>0:
            self.my_head = observation.geese[self.my_index][0]
            self.my_tail = observation.geese[self.my_index][-1]        
            self.my_body = [pos for pos in observation.geese[self.my_index][1:-1]]
        else:
            self.my_head = -1
            self.my_tail = -1
            self.my_body = []

        
        self.geese = [g for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 0]
        self.geese_cells = [pos for g in self.geese for pos in g if len(g) > 0]
        
        self.occupied = [p for p in self.geese_cells]
        self.occupied.extend([p for p in observation.geese[self.my_index]])
        
        
        self.heads = [g[0] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 0]
        self.bodies = [pos  for i,g in enumerate(observation.geese) for pos in g[1:-1] if i!=self.my_index and len(g) > 2]
        self.tails = [g[-1] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 1]
        self.food = [f for f in observation.food]
        
        self.adjacent_to_heads = [pos for head in self.heads for pos in self._adjacent_positions(head)]
        self.adjacent_to_bodies = [pos for body in self.bodies for pos in self._adjacent_positions(body)]
        self.adjacent_to_tails = [pos for tail in self.tails for pos in self._adjacent_positions(tail)]
        self.adjacent_to_geese = self.adjacent_to_heads + self.adjacent_to_bodies
        self.danger_zone = self.adjacent_to_geese
        
        #Cell occupation
        self.cell_states = [CellState.EMPTY.value for _ in range(self.rows*self.columns)]
        for g in self.geese:
            for pos in g:
                self.cell_states[pos] = CellState.ANY_GOOSE.value
        for pos in self.heads:
                self.cell_states[pos] = CellState.ANY_GOOSE.value
        for pos in self.my_body:
            self.cell_states[pos] = CellState.ANY_GOOSE.value
        self.cell_states[self.my_tail] = CellState.ANY_GOOSE.value
                
        #detect dead-ends
        self.dead_ends = []
        for pos_i,_ in enumerate(self.cell_states):
            if self.cell_states[pos_i] != CellState.EMPTY.value:
                continue
            adjacent = self._adjacent_positions(pos_i)
            adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
            num_blocked = sum(adjacent_states)
            if num_blocked>=(CellState.ANY_GOOSE.value*3):
                self.dead_ends.append(pos_i)
        
        #check for extended dead-ends
        new_dead_ends = [pos for pos in self.dead_ends]
        while new_dead_ends!=[]:
            for pos in new_dead_ends:
                self.cell_states[pos]=CellState.ANY_GOOSE.value
                self.dead_ends.append(pos)
            
            new_dead_ends = []
            for pos_i,_ in enumerate(self.cell_states):
                if self.cell_states[pos_i] != CellState.EMPTY.value:
                    continue
                adjacent = self._adjacent_positions(pos_i)
                adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
                num_blocked = sum(adjacent_states)
                if num_blocked>=(CellState.ANY_GOOSE.value*3):
                    new_dead_ends.append(pos_i)    
                    
                        
    def safe_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.adjacent_to_heads) and (future_position not in self.dead_ends)
    
    
    def valid_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.dead_ends)    

    
    def free_position(self, future_position):
        return (future_position not in self.occupied)  
    
    #***** END: utility functions ******
    
    
    def process_env_obs(self, obs):
        self._preprocess_env(obs)
        
        EMPTY = .4
        HEAD = -1
        BODY = MY_BODY = -.8
        TAIL = MY_TAIL = -.5
        MY_HEAD = 0
        FOOD = 1
        RISK = -.5
        
        #Example: {'remainingOverageTime': 12, 'step': 0, 'geese': [[62], [50]], 'food': [7, 71], 'index': 0}
        #observation = [[CellState.EMPTY.value for _ in range(self.columns)] for _ in range(self.rows)]
        observation = [[EMPTY for _ in range(self.columns)] for _ in range(self.rows)]
        
        #Other agents
        for pos in self.heads:
            r, c = self._row_col(pos)
            observation[r][c] = HEAD #CellState.HEAD.value
        for pos in self.bodies:
            r, c = self._row_col(pos)
            observation[r][c] = BODY #CellState.BODY.value
        for pos in self.tails:
            r, c = self._row_col(pos)
            observation[r][c] = TAIL #CellState.TAIL.value

        #Me
        r, c = self._row_col(self.my_head)
        observation[r][c] = MY_HEAD #-1 #CellState.MY_HEAD.value
        if self.my_head != self.my_tail:
            r, c = self._row_col(self.my_tail)
            observation[r][c] = MY_TAIL #CellState.MY_TAIL.value
        for pos in self.my_body:
            r, c = self._row_col(pos)
            observation[r][c] = MY_BODY #CellState.MY_BODY.value
            
        #Food
        for pos in self.food:
            r, c = self._row_col(pos)
            observation[r][c] = FOOD #CellState.FOOD.value
        
        
        if (self.previous_action!=-1):
            aux_previous_pos = self._translate(self.my_head, self.opposite(self.previous_action))
            r, c = self._row_col(aux_previous_pos)
            if observation[r][c]>0:
                observation[r][c] = MY_BODY * .5 #Marked to avoid opposite moves
        
        #Add risk mark
        for pos in self.adjacent_to_heads:
            r, c = self._row_col(pos)
            if observation[r][c] > 0:
                    observation[r][c] = RISK

        #Add risk mark
        for pos in self.dead_ends:
            r, c = self._row_col(pos)
            if observation[r][c] > 0:
                    observation[r][c] = RISK/2
        
        
        if self.center_head:
            #NOTE: assumes odd number of rows and columns
            head_row, head_col = self._row_col(self.my_head)
            v_center = (self.columns // 2) # col 5 on 0-10 (11 columns)
            v_roll = v_center - head_col
            h_center = (self.rows // 2) # row 3 on 0-7 (7 rows)
            h_roll = h_center - head_row
            observation = np.roll(observation, v_roll, axis=1)
            observation = np.roll(observation, h_roll, axis=0)

        return np.array([observation])
    
    
    def common_sense_rewards(self, action):
        if self.my_head==-1:
            if self.debug:
                print("DIED!!")
            return -2
        
        reward = 0
        future_position = self._translate(self.my_head, action)
        check_opposite = (self.previous_action!=-1)
        
        if future_position in self.occupied:
            if self.debug:
                print("Move to occupied")
            reward = -2 #this action meant death        
        elif check_opposite and (self.previous_action==self.opposite(action)): #opposite is currently a patch until Action.opposite works...
            if self.debug:
                print("Move to opposite direction, previous", self.previous_action, "vs now",action)
            reward = -2 #this action meant death
        elif (future_position in self.food) and (future_position not in self.adjacent_to_heads):
            if self.debug:
                print("Safe move to EAT!")
            reward = 2 #eating is good! 
        elif future_position in self.dead_ends:
            if self.debug:
                print("Move to dead end")
            reward = 0
        else:
            min_distance_to_food = self._min_distance_to_food(future_position)
            
            if min_distance_to_food<=self.last_min_distance_to_food:
                if self.debug:
                    print("Move to food")
                #Removed positive rewards here, eating reward will be considered via gamma (future rewards) if agent gets to food
                if future_position in self.danger_zone:
                    reward = 0 #0.1 
                else:
                    reward = 0 #0.2 
            else:
                #ignore might be moving away, but also the nearest food could have been eaten... NO PENALTY HERE!
                reward = 0 
                
            self.last_min_distance_to_food=min_distance_to_food
                
        self.previous_action = self.last_action
        self.last_action = action
        return reward
    
    
#Initial template from: https://stable-baselines.readthedocs.io/en/master/guide/custom_env.html
class HungryGeeseEnv(gym.Env):
    
    def __init__(self, dummy_env=False, opponent=['greedy','greedy','greedy-goose.py'], action_offset=1, debug=False, defaults=[7,11,10,2]):
        super(HungryGeeseEnv, self).__init__()
        self.num_envs = 1
        self.num_previous_observations = 0
        self.debug=debug
        self.actions = [action for action in Action]
        self.action_offset=action_offset
        if not dummy_env:
            self.env = make("hungry_geese", debug=self.debug)
            self.rows = self.env.configuration.rows
            self.columns = self.env.configuration.columns
            self.hunger_rate = self.env.configuration.hunger_rate
            self.min_food = self.env.configuration.min_food
            self.trainer = self.env.train([None, *opponent])
        else:
            self.env = None
            self.rows = defaults[0]
            self.columns = defaults[1]
            self.hunger_rate = defaults[2]
            self.min_food = defaults[3]

        # Define action and observation space
        # They must be gym.spaces objects        
        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=-1, high=1,
                                            shape=(self.num_previous_observations+1, self.rows, self.columns), dtype=np.int8)
        self.reward_range = (-4, 1)
        self.step_num=1        
        self.observation_preprocessor = ObservationProcessor(self.rows, self.columns, self.hunger_rate, self.min_food, debug=self.debug, center_head=True)
        self.observation = []
        self.previous_observation = []
    
    
    def step(self, action):        
        action += self.action_offset
        action = Action(action)        
        cs_rewards = self.observation_preprocessor.common_sense_rewards(action)
        if self.debug:
            if cs_rewards!=0:
                print("CS reward", action.name, self.observation, cs_rewards)
            else:
                print("CS ok", action.name)
        

        obs, reward, done, _ = self.trainer.step(action.name)

        if len(self.observation)>0:
            #Not initial step, t=0
            self.previous_observation.append(self.observation)
            #Keep list constrained to max length
            if len(self.previous_observation)>self.num_previous_observations:
                del self.previous_observation[0]
            
        self.observation = self.observation_preprocessor.process_env_obs(obs)
        
        if len(self.previous_observation)==0:
            #Initial step, t=0
            self.previous_observation = [self.observation for _ in range(self.num_previous_observations)]
        
        info = {}
        #if self.debug:
        #    print(action, reward, cs_rewards, done, "\n"+"\n".join([str(o) for o in self.observation]))
        
        env_reward = reward
        if len(self.previous_observation)>0:
            unique_before, counts_before = np.unique(self.previous_observation[-1], return_counts=True)
            unique_now, counts_now = np.unique(self.observation, return_counts=True)
            before = dict(zip(unique_before, counts_before))
            now = dict(zip(unique_now, counts_now))
            count_length = lambda d: d.get(CellState.MY_HEAD.value, 0) + d.get(CellState.MY_BODY.value, 0) + d.get(CellState.MY_TAIL.value, 0)
            if count_length(now)>count_length(before):
                reward = 2 #Ate
            else:
                reward = 0 #Just moving
            if self.debug:
                print(f'{self.step_num} {count_length(now)} {count_length(before)} R {reward}')
        else:
            reward = 0 # no way to check previuos length use common sense reward on move to food instead ;-)
        if done:
            #game ended            
            if self.observation_preprocessor.my_head == -1:
                #DIED, but what final ranking?
                rank = len(self.observation_preprocessor.geese)+1
                if self.debug:
                    print("Rank on end", rank, "geese", self.observation_preprocessor.geese)
                if rank == 4:
                    reward = -2
                elif rank == 3:
                    reward = 0
                elif rank == 2:
                    reward = 0
                else:
                    reward = 100
            else:
                reward = 1 #survived the game!?
        elif reward<1:
            reward=1.1 #0 #set to 0 if staying alive is not enough
        elif reward>1:
            #ate something!!! :-)
            reward = 1
        
        if self.debug and done:
            print("DONE!", self.observation, env_reward, reward, cs_rewards)
        
        reward = cs_rewards if cs_rewards<0 else cs_rewards+reward #if cs_reward<0 use it only      
        self.step_num += 1
        
        if self.num_previous_observations>0:
            observations = np.concatenate((*self.previous_observation, self.observation), axis=0)
            return observations, reward, done, info
        else:
            return self.observation, reward, done, info

        
    def reset(self):
        self.observation_preprocessor = ObservationProcessor(self.rows, self.columns, self.hunger_rate, self.min_food, debug=self.debug, center_head=True)
        obs = self.trainer.reset()
        self.observation = self.observation_preprocessor.process_env_obs(obs)
        self.previous_observation = [self.observation for _ in range(self.num_previous_observations)]
        return self.observation
    
    
    def render(self,**kwargs):
        self.env.render(**kwargs)

# Using our custom environment

1. First we instantiate the environment

In [None]:
env = HungryGeeseEnv(opponent=['greedy','greedy','greedy-goose.py'], debug=False) #Train vs 2 basic greedy and 1 improved greedy averse

In [None]:
"""
from stable_baselines3.common.env_checker import check_env
check_env(env)
"""

## Training DQN

From stable_baselines3.DQN documentation
* > class stable_baselines3.dqn.DQN(policy, env, learning_rate=0.0001, buffer_size=1000000, learning_starts=50000, batch_size=32, tau=1.0, gamma=0.99, train_freq=4, gradient_steps=1, n_episodes_rollout=- 1, optimize_memory_usage=False, target_update_interval=10000, exploration_fraction=0.1, exploration_initial_eps=1.0, exploration_final_eps=0.05, max_grad_norm=10, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)
    * > learning_rate (Union[float, Callable[[float], float]]) – The learning rate, it can be a function of the current progress remaining (from 1 to 0)
    * > learning_starts (int) – how many steps of the model to collect transitions for before learning starts
    * > batch_size (Optional[int]) – Minibatch size for each gradient update
    * > buffer_size (int) – size of the replay buffer
    * > gamma (float) – the discount factor
    * > target_update_interval (int) – update the target network every target_update_interval environment steps.
    * > exploration_fraction (float) – fraction of entire training period over which the exploration rate is reduced
    * > exploration_initial_eps (float) – initial value of random action probability
    * > exploration_final_eps (float) – final value of random action probability
    
### Custom Policy

From: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html

* using policy_kwargs use net_arch=[l1,..,ln]
* > custom feature extractor (e.g. custom CNN when using images), you can define class that derives from BaseFeaturesExtractor and then pass it to the model when training

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

import torch as th
import torch.nn as nn
import torch.nn.functional as F


model_name = "dqnv1"
m_env = Monitor(env, model_name, allow_early_resets=True) 


policy_kwargs = dict(
    net_arch = [2000, 1000, 500, 1000, 500, 100]
)

TRAIN_STEPS = 1e6
alpha_0 = 1e-6
alpha_end = 1e-9

def learning_rate_f(process_remaining):
    #default =  1e-4
    initial = alpha_0
    final = alpha_end
    interval = initial-final
    return final+interval*process_remaining

params ={
    'gamma': .9,
    'batch_size': 100,
     #'train_freq': 500,
    'target_update_interval': 10000,
    'learning_rate': learning_rate_f,
    'learning_starts': 1000,
    'exploration_fraction': .2,
    'exploration_initial_eps': .05,
    'tau': 1,
    'exploration_final_eps': .01,
    'buffer_size': 100000,
    #'verbose': 1,
}

#coment **params for default parameters
trainer = DQN('MlpPolicy', m_env, policy_kwargs=policy_kwargs, **params)

#You can check policy architecture with:
#print(trainer.policy.net_arch) #prints: [64, 64] for default DQN policy
#Or check model.policy
print(trainer.policy)

In [None]:
trainer.learn(total_timesteps=TRAIN_STEPS, callback=None)

# Results

Lets see how the training goes

In [None]:
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

df = pd.read_csv(f'{model_name}.monitor.csv', header=1, index_col='t')

df.rename(columns = {'r':'Episode Reward', 'l':'Episode Length'}, inplace = True) 
plt.figure(figsize=(20,5))
sns.regplot(data=df, y='Episode Reward', x=df.index)

In [None]:
plt.figure(figsize=(20,5))
sns.regplot(data=df, y='Episode Length', x=df.index, color="orange")

# Save our DQN model

Here I use a little trick to avoid writting by-hand all the key mappings for model save/load ;-)

This is due to the fack that stable-baselines uses different names than the ones pytorch will look for when loading later outside of stable-baselines

In [None]:
state_dict = trainer.policy.to('cpu').state_dict()
print("\n".join(state_dict.keys())) #use this to check keys ;-)

In [None]:
""" If yout don't do this an Error will be raised due to different naming when loading model outside stable_baselines3
Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias". 
Key names have format:
q_net.q_net.0.weight
q_net.q_net.0.bias
...
q_net_target.q_net.0.weight
q_net_target.q_net.0.bias
...
"""
    
adapted_state_dict ={
    new_key : state_dict[old_key]
    for old_key in state_dict.keys()
    for new_key in ["layer"+".".join(old_key.split(".")[-2:])] #use last 3 components of name
    if old_key.find("q_net_target.") != -1 #we only want the policy weights
}
print(adapted_state_dict.keys())
th.save(adapted_state_dict, f'{model_name}.pt')


# Our agent for submission and test game

Now let's put it all together. (See file for full source)

I've tried to minimize necesary changes to the policy network to make copy to submission code easier. Our submission agent will use:
* CellState & ObservationProcessor: to process the observation into a grid of states to use as input for our policy network
* the MyNN pytorch network with architecture as inspected from "trainer.policy" with adapted layer names
* my_dqn: main loop, chose action probabilisticaly based on outputs (could apply epsilon-greedy or other selection) that's done with the "choices()" with epsilon chance of fully random choice

Only changes to MyNN (apart from removing dropouts):
> class MyNN(nn.Module):
>
>     def __init__(self):
>         super(MyNN, self).__init__()


As for the "main-loop", general structure is as follows:
         
> def my_dqn(observation, configuration):
>
>      [initializations]
>      ...    
>      [maintaint list of  last observations]
>      ...    
>      #Convert to grid encoded with CellState     
>      aux_observation = obs_prep.process_env_obs(observation)  values
>      #predict with aux_observation.shape = (last_observations x rows x cols)
>      ...    
>      tensor_obs = th.Tensor([aux_observation])
>      n_out = model(tensor_obs) #Example: tensor([[0.2742, 0.2653, 0.2301, 0.2303]], ...) 
>      ...    
>      [edit action weights to take into account common sense, collide is bad, etc...]
>      [choose probabilistic next move based on prediction outputs
>      with epsilon probability of fully random, always avoid opposite of last move]
>      ...   
>      return action_predicted

In [None]:
%%writefile main.py
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, adjacent_positions, row_col, translate, min_distance
from kaggle_environments import make

import gym
from gym import spaces

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from enum import Enum, auto
import numpy as np
import os
import random as rand


class CellState(Enum):
    EMPTY = 0
    FOOD = auto()
    HEAD = auto()
    BODY = auto()
    TAIL = auto()
    MY_HEAD = auto()
    MY_BODY = auto()
    MY_TAIL = auto()
    ANY_GOOSE = auto()
    

class ObservationProcessor:
    
    def __init__(self, rows, columns, hunger_rate, min_food, debug=False, center_head=True):
        self.debug = debug
        self.rows, self.columns = rows, columns
        self.hunger_rate = hunger_rate
        self.min_food = min_food
        self.previous_action = -1
        self.last_action = -1
        self.last_min_distance_to_food = self.rows*self.columns #initial max value to mark no food seen so far
        self.center_head = center_head

    #***** BEGIN: utility functions ******   
    
    def opposite(self, action):
        if action == Action.NORTH:
            return Action.SOUTH
        if action == Action.SOUTH:
            return Action.NORTH
        if action == Action.EAST:
            return Action.WEST
        if action == Action.WEST:
            return Action.EAST
        raise TypeError(str(action) + " is not a valid Action.")
        
        
    def _adjacent_positions(self, position):
        return adjacent_positions(position, self.columns, self.rows)

    
    def _min_distance_to_food(self, position, food=None):
        food = food if food!=None else self.food
        return min_distance(position, food, self.columns)
    
    
    def _row_col(self, position):
        return row_col(position, self.columns)

    
    def _translate(self, position, direction):
        return translate(position, direction, self.columns, self.rows)     

    
    def _preprocess_env(self, obs):
        observation = Observation(obs)
        
        self.my_index = observation.index

        if len (observation.geese[self.my_index])>0:
            self.my_head = observation.geese[self.my_index][0]
            self.my_tail = observation.geese[self.my_index][-1]        
            self.my_body = [pos for pos in observation.geese[self.my_index][1:-1]]
        else:
            self.my_head = -1
            self.my_tail = -1
            self.my_body = []

        
        self.geese = [g for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 0]
        self.geese_cells = [pos for g in self.geese for pos in g if len(g) > 0]
        
        self.occupied = [p for p in self.geese_cells]
        self.occupied.extend([p for p in observation.geese[self.my_index]])
        
        
        self.heads = [g[0] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 0]
        self.bodies = [pos  for i,g in enumerate(observation.geese) for pos in g[1:-1] if i!=self.my_index and len(g) > 2]
        self.tails = [g[-1] for i,g in enumerate(observation.geese) if i!=self.my_index and len(g) > 1]
        self.food = [f for f in observation.food]
        
        self.adjacent_to_heads = [pos for head in self.heads for pos in self._adjacent_positions(head)]
        self.adjacent_to_bodies = [pos for body in self.bodies for pos in self._adjacent_positions(body)]
        self.adjacent_to_tails = [pos for tail in self.tails for pos in self._adjacent_positions(tail)]
        self.adjacent_to_geese = self.adjacent_to_heads + self.adjacent_to_bodies
        self.danger_zone = self.adjacent_to_geese
        
        #Cell occupation
        self.cell_states = [CellState.EMPTY.value for _ in range(self.rows*self.columns)]
        for g in self.geese:
            for pos in g:
                self.cell_states[pos] = CellState.ANY_GOOSE.value
        for pos in self.heads:
                self.cell_states[pos] = CellState.ANY_GOOSE.value
        for pos in self.my_body:
            self.cell_states[pos] = CellState.ANY_GOOSE.value
        self.cell_states[self.my_tail] = CellState.ANY_GOOSE.value
                
        #detect dead-ends
        self.dead_ends = []
        for pos_i,_ in enumerate(self.cell_states):
            if self.cell_states[pos_i] != CellState.EMPTY.value:
                continue
            adjacent = self._adjacent_positions(pos_i)
            adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
            num_blocked = sum(adjacent_states)
            if num_blocked>=(CellState.ANY_GOOSE.value*3):
                self.dead_ends.append(pos_i)
        
        #check for extended dead-ends
        new_dead_ends = [pos for pos in self.dead_ends]
        while new_dead_ends!=[]:
            for pos in new_dead_ends:
                self.cell_states[pos]=CellState.ANY_GOOSE.value
                self.dead_ends.append(pos)
            
            new_dead_ends = []
            for pos_i,_ in enumerate(self.cell_states):
                if self.cell_states[pos_i] != CellState.EMPTY.value:
                    continue
                adjacent = self._adjacent_positions(pos_i)
                adjacent_states = [self.cell_states[adj_pos] for adj_pos in adjacent if adj_pos!=self.my_head]
                num_blocked = sum(adjacent_states)
                if num_blocked>=(CellState.ANY_GOOSE.value*3):
                    new_dead_ends.append(pos_i)    
                    
                        
    def safe_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.adjacent_to_heads) and (future_position not in self.dead_ends)
    
    
    def valid_position(self, future_position):
        return (future_position not in self.occupied) and (future_position not in self.dead_ends)    

    
    def free_position(self, future_position):
        return (future_position not in self.occupied)  
    
    #***** END: utility functions ******
    
    
    def process_env_obs(self, obs):
        self._preprocess_env(obs)
        
        EMPTY = .4
        HEAD = -1
        BODY = MY_BODY = -.8
        TAIL = MY_TAIL = -.5
        MY_HEAD = 0
        FOOD = 1
        RISK = -.5
        
        #Example: {'remainingOverageTime': 12, 'step': 0, 'geese': [[62], [50]], 'food': [7, 71], 'index': 0}
        #observation = [[CellState.EMPTY.value for _ in range(self.columns)] for _ in range(self.rows)]
        observation = [[EMPTY for _ in range(self.columns)] for _ in range(self.rows)]
        
        #Other agents
        for pos in self.heads:
            r, c = self._row_col(pos)
            observation[r][c] = HEAD #CellState.HEAD.value
        for pos in self.bodies:
            r, c = self._row_col(pos)
            observation[r][c] = BODY #CellState.BODY.value
        for pos in self.tails:
            r, c = self._row_col(pos)
            observation[r][c] = TAIL #CellState.TAIL.value

        #Me
        r, c = self._row_col(self.my_head)
        observation[r][c] = MY_HEAD #-1 #CellState.MY_HEAD.value
        if self.my_head != self.my_tail:
            r, c = self._row_col(self.my_tail)
            observation[r][c] = MY_TAIL #CellState.MY_TAIL.value
        for pos in self.my_body:
            r, c = self._row_col(pos)
            observation[r][c] = MY_BODY #CellState.MY_BODY.value
            
        #Food
        for pos in self.food:
            r, c = self._row_col(pos)
            observation[r][c] = FOOD #CellState.FOOD.value
        
        
        if (self.previous_action!=-1):
            aux_previous_pos = self._translate(self.my_head, self.opposite(self.previous_action))
            r, c = self._row_col(aux_previous_pos)
            if observation[r][c]>0:
                observation[r][c] = MY_BODY * .5 #Marked to avoid opposite moves
        
        #Add risk mark
        for pos in self.adjacent_to_heads:
            r, c = self._row_col(pos)
            if observation[r][c] > 0:
                    observation[r][c] = RISK

        #Add risk mark
        for pos in self.dead_ends:
            r, c = self._row_col(pos)
            if observation[r][c] > 0:
                    observation[r][c] = RISK/2        
        
        if self.center_head:
            #NOTE: assumes odd number of rows and columns
            head_row, head_col = self._row_col(self.my_head)
            v_center = (self.columns // 2) # col 5 on 0-10 (11 columns)
            v_roll = v_center - head_col
            h_center = (self.rows // 2) # row 3 on 0-7 (7 rows)
            h_roll = h_center - head_row
            observation = np.roll(observation, v_roll, axis=1)
            observation = np.roll(observation, h_roll, axis=0)

        return np.array([observation])
    
    
class MyNN(nn.Module):
    def __init__(self):
        super(MyNN, self).__init__()
        """use names generated on adapted saved_dict
        dict_keys(['layer0.weight', 'layer0.bias', 'layer2.weight', 'layer2.bias', ...])

        net_arch as seen before:
          (q_net): QNetwork(
            (features_extractor): FlattenExtractor(
              (flatten): Flatten(start_dim=1, end_dim=-1)
            )
            (q_net): Sequential(
              (0): Linear(...)
              (1): ReLU()
              ...
            )
          )
        """
        #net_arch = [2000, 1000, 500, 1000, 500, 100]
        self.layer0 = nn.Linear(77, 2000)
        self.layer2 = nn.Linear(2000, 1000)
        self.layer4 = nn.Linear(1000, 500)
        self.layer6 = nn.Linear(500, 1000)
        self.layer8 = nn.Linear(1000, 500)
        self.layer10 = nn.Linear(500, 100)
        self.layer12 = nn.Linear(100, 4)

    def forward(self, x):
        x = nn.Flatten()(x)  # no feature extractor means flatten (check policy arch on DQN creation)
        for layer in [self.layer0, self.layer2, self.layer4, self.layer6, self.layer8, self.layer10]:
            x = F.relu(layer(x))
        x = self.layer12(x)
        return x
            
        
def my_dqn(observation, configuration):
    global model, obs_prep, last_action, last_observation, previous_observation

    tgz_agent_path = '/kaggle_simulations/agent/'
    normal_agent_path = '/kaggle/working'
    model_name = "dqnv1"
    num_previous_observations = 0
    epsilon = 0
    init = False
    debug = False

    try:
        model
    except NameError:
        init=True
    else:
        if model==None:
            init = True 
            initializing
    if init:
        #initializations
        defaults = [configuration.rows,
                    configuration.columns,
                    configuration.hunger_rate,
                    configuration.min_food]

        model = MyNN()
        last_action = -1
        last_observation = []
        previous_observation = []
        
        file_name = os.path.join(normal_agent_path, f'{model_name}.pt')
        if not os.path.exists(file_name):
            file_name = os.path.join(tgz_agent_path, f'{model_name}.pt')
            
        model.load_state_dict(th.load(file_name))
        obs_prep = ObservationProcessor(configuration.rows, configuration.columns, configuration.hunger_rate, configuration.min_food)
    
    #maintaint list of  last observations
    if num_previous_observations>0 and len(last_observation)>0:
        #Not initial step, t=0
        previous_observation.append(last_observation)
        #Keep list constrained to max length
        if len(previous_observation)>num_previous_observations:
            del previous_observation[0]
            
    #Convert to grid encoded with CellState values
    aux_observation = [obs_prep.process_env_obs(observation)] 
    last_observation = aux_observation

    if num_previous_observations>0 and len(previous_observation)==0:
        #Initial step, t=0
        previous_observation = [last_observation for _ in range(num_previous_observations)]

    if num_previous_observations>0:
        aux_observation = np.concatenate((*previous_observation, last_observation), axis=0)
    else:
        aux_observation = last_observation
        
    #predict with aux_observation.shape = (last_observations x rows x cols)
    tensor_obs = th.Tensor([aux_observation])
    n_out = model(tensor_obs) #Example: tensor([[0.2742, 0.2653, 0.2301, 0.2303]], grad_fn=<SoftmaxBackward>) 
    
    #choose probabilistic next move based on prediction outputs
    #with epsilon probability of fully random, always avoid opposite of last move
    actions = [action.value for action in Action]
    weights = list(n_out[0].detach().numpy())
    if last_action!=-1:
        #Avoid dying by stupidity xD
        remove_index = actions.index(obs_prep.opposite(Action(last_action)).value)
        del actions[remove_index]
        del weights[remove_index]    
    random=False

    min_value = abs(min(weights))
    weights = [min_value+w+1e-5 for w in weights] #Total of weights must be greater than zero  

    
    #Reduce weight to penalize bad moves (collisions, etc...)
    weights_changed = False
    weights_before = [w for w in weights]
    for index, action in enumerate(actions):
        future_position = obs_prep._translate(obs_prep.my_head, Action(action))
        if not obs_prep.free_position(future_position):
            weights[index] = min(weights[index], 1e-8) #Collision is worst case
            weights_changed = True
        elif future_position in obs_prep.dead_ends:
            weights[index] = min(weights[index],1e-2) #dead ends
            weights_changed = True
        elif future_position in obs_prep.adjacent_to_heads:
            weights[index] = min(weights[index],1e-8) #adjacent to heads
            weights_changed = True
    
    
    
    if debug and weights_changed:
        print(aux_observation)
        print(f'Adapted weights: before {weights_before} and after {weights} for actions {[Action(a).name for a in actions]}')
    #elif debug and not weights_changed:
    #    print(f'Action weights {weights}')

    if rand.random() < epsilon:
        prediction = rand.choice(actions)
        random=True
    else:
        prediction = rand.choices(actions, weights=weights)[0] 
    action_predicted = Action(prediction).name
    
    #print(observation) #Uncomment to debug a bit too much...
    #if (last_action!=-1) and debug:
    #    print(last_observation)
    #    print(f'valid_actions={actions}, w={weights}, chose={Action(prediction).name}, rand={random}',
    #          f'previous={Action(last_action).name}, opposite={Action(obs_prep.opposite(Action(last_action)).value).name}') 
    
    last_action = prediction
    return action_predicted #return action

Prepare the upload: submission.tar.gz

In [None]:
!tar cvzf submission.tar.gz main.py dqnv1.pt

Lets see a game, so far not learning much :-(

**EDIT:** Since game render is problematic with most browsers under certain scenarios, the code is commented on long trainings

In [None]:
"""
import kaggle_environments
from kaggle_environments import make, evaluate, utils

env = make("hungry_geese", debug=False) #set debug to True to see agent internals each step

env.reset()
env.run(["main.py","greedy","greedy", "greedy-goose.py"])
env.render(mode="ipython", width=700, height=500)
"""

Results are far from good but... keep working!!

In [None]:
import kaggle_environments
from kaggle_environments import evaluate

evaluate(
    "hungry_geese",
    ["main.py", "greedy","greedy", "greedy-goose.py"],
    num_episodes=10
)