Thanks for a very interesting and fun competition and great public contributions, like the HandyRL.

I first started some training with replays and imitation learning, with my best in the ~1050 zone but wanted to learn more within the RL area so I left the imitation learning and started to use the RL/HandyRL kernel https://github.com/DeNA/HandyRL

After some practice with HandyRL and after studied the code closer it was time for the training, this was around two-three weeks prior the deadline. During this time the team behind HandyRL also released more information and a presentation from a conference https://www.kaggle.com/c/hungry-geese/discussion/247266, after this reading I implemented the https://www.kaggle.com/ilialar/risk-averse-greedy-goose as opponent, I think the implementation worked, and also used V-trace. 

The idea was to train both a small model, and use the Alpha Zero with a thought that more analyses could be done with approximal half the prediction time vs the standard, and also a large model with the though that more parameters would create a better RL. But started to feel the time pressure and ended both ideas to favor a longer training with a medium sized model instead, with the thought that a longer trained model is the best choice.

For the longer training I needed to save the episodes and other information and states to continue the learning, so did an extra code for the feature. Started with joblib lzma compress that saved ~50% of org size but took much more time so changed to standard npy after a while, time pressure vs diskspace.

I also did some other minor changes, 
-	Added Mixed Precision to double the batch size.
-	Changed Adam to Adamp https://arxiv.org/abs/2006.08217
-	Changed the relu to hardswish with inplace feature, based on that swish is a good replacement for relu in general and then using the faster version h-swish from https://arxiv.org/abs/1905.02244. After some feedback this change might not make big difference in RL. Just a bad habit changing this function from my part.

I finally trained a model for 3310 epochs (not enough for the top 50 I guess), which I posted the final day.  I used 200k episodes in the saved resumed part, shifting with new episodes after 200k as in original code after max episodes is reached. More max episodes are preferable but due to testing and time pressure I used a smaller size, also wanted to try this option to use only the 200k latest.

I have shared the changes and training in this notebook.

All credits to the original contributors and creators of HandyRL! A manual prediction is that they will win.


In [None]:

!git clone https://github.com/Dena/HandyRL.git

!cp -r HandyRL/. .

In [None]:
!pip install -r requirements.txt
!pip install catalyst
!pip install kaggle_environments
!mkdir models

In [None]:
%%writefile handyrl/envs/kaggle/hungry_geese.py
# Copyright (c) 2020 DeNA Co., Ltd.
# Licensed under The MIT License [see LICENSE for details]

# kaggle_environments licensed under Copyright 2020 Kaggle Inc. and the Apache License, Version 2.0
# (see https://github.com/Kaggle/kaggle-environments/blob/master/LICENSE for details)

#changed
# wrapper of Hungry Geese environment from kaggle

import random
import itertools

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# You need to install kaggle_environments, requests
from kaggle_environments import make

from ...environment import BaseEnvironment


class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:,:,:,-self.edge_size[1]:], x, x[:,:,:,:self.edge_size[1]]], dim=3)
        h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h


class GeeseNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32

        self.conv0 = TorusConv2d(17, filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 4, bias=False)
        self.head_v = nn.Linear(filters * 2, 1, bias=False)

    def forward(self, x, _=None):
        h = F.hardswish(self.conv0(x),inplace=True)
        for block in self.blocks:
            h = F.hardswish(h + block(h),inplace=True)
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)
        p = self.head_p(h_head)
        v = torch.tanh(self.head_v(torch.cat([h_head, h_avg], 1)))

        return {'policy': p, 'value': v}


class Environment(BaseEnvironment):
    ACTION = ['NORTH', 'SOUTH', 'WEST', 'EAST']
    DIRECTION = [[-1, 0], [1, 0], [0, -1], [0, 1]]
    NUM_AGENTS = 4

    def __init__(self, args={}):
        super().__init__()
        self.env = make("hungry_geese")
        self.reset()

    def reset(self, args={}):
        obs = self.env.reset(num_agents=self.NUM_AGENTS)
        self.update((obs, {}), True)

    def update(self, info, reset):
        obs, last_actions = info
        if reset:
            self.obs_list = []
        self.obs_list.append(obs)
        self.last_actions = last_actions

    def action2str(self, a, player=None):
        return self.ACTION[a]

    def str2action(self, s, player=None):
        return self.ACTION.index(s)

    def direction(self, pos_from, pos_to):
        if pos_from is None or pos_to is None:
            return None
        x, y = pos_from // 11, pos_from % 11
        for i, d in enumerate(self.DIRECTION):
            nx, ny = (x + d[0]) % 7, (y + d[1]) % 11
            if nx * 11 + ny == pos_to:
                return i
        return None

    def __str__(self):
        # output state
        obs = self.obs_list[-1][0]['observation']
        colors = ['\033[33m', '\033[34m', '\033[32m', '\033[31m']
        color_end = '\033[0m'

        def check_cell(pos):
            for i, geese in enumerate(obs['geese']):
                if pos in geese:
                    if pos == geese[0]:
                        return i, 'h'
                    if pos == geese[-1]:
                        return i, 't'
                    index = geese.index(pos)
                    pos_prev = geese[index - 1] if index > 0 else None
                    pos_next = geese[index + 1] if index < len(geese) - 1 else None
                    directions = [self.direction(pos, pos_prev), self.direction(pos, pos_next)]
                    return i, directions
            if pos in obs['food']:
                return 'f'
            return None

        def cell_string(cell):
            if cell is None:
                return '.'
            elif cell == 'f':
                return 'f'
            else:
                index, directions = cell
                if directions == 'h':
                    return colors[index] + '@' + color_end
                elif directions == 't':
                    return colors[index] + '*' + color_end
                elif max(directions) < 2:
                    return colors[index] + '|' + color_end
                elif min(directions) >= 2:
                    return colors[index] + '-' + color_end
                else:
                    return colors[index] + '+' + color_end

        cell_status = [check_cell(pos) for pos in range(7 * 11)]

        s = 'turn %d\n' % len(self.obs_list)
        for x in range(7):
            for y in range(11):
                pos = x * 11 + y
                s += cell_string(cell_status[pos])
            s += '\n'
        for i, geese in enumerate(obs['geese']):
            s += colors[i] + str(len(geese) or '-') + color_end + ' '
        return s

    def step(self, actions):
        # state transition
        obs = self.env.step([self.action2str(actions.get(p, None) or 0) for p in self.players()])
        self.update((obs, actions), False)

    def diff_info(self, _):
        return self.obs_list[-1], self.last_actions

    def turns(self):
        # players to move
        return [p for p in self.players() if self.obs_list[-1][p]['status'] == 'ACTIVE']

    def terminal(self):
        # check whether terminal state or not
        for obs in self.obs_list[-1]:
            if obs['status'] == 'ACTIVE':
                return False
        return True

    def outcome(self):
        # return terminal outcomes
        # 1st: 1.0 2nd: 0.33 3rd: -0.33 4th: -1.00
        rewards = {o['observation']['index']: o['reward'] for o in self.obs_list[-1]}
        outcomes = {p: 0 for p in self.players()}
        for p, r in rewards.items():
            for pp, rr in rewards.items():
                if p != pp:
                    if r > rr:
                        outcomes[p] += 1 / (self.NUM_AGENTS - 1)
                    elif r < rr:
                        outcomes[p] -= 1 / (self.NUM_AGENTS - 1)
        return outcomes

    def legal_actions(self, player):
        # return legal action list
        return list(range(len(self.ACTION)))

    def action_length(self):
        # maximum action label (it determines output size of policy function)
        return len(self.ACTION)

    def players(self):
        return list(range(self.NUM_AGENTS))

    # def rule_based_action(self, player):
    #     from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, GreedyAgent
    #     action_map = {'N': Action.NORTH, 'S': Action.SOUTH, 'W': Action.WEST, 'E': Action.EAST}

    #     agent = GreedyAgent(Configuration({'rows': 7, 'columns': 11}))
    #     agent.last_action = action_map[self.ACTION[self.last_actions[player]][0]] if player in self.last_actions else None
    #     obs = {**self.obs_list[-1][0]['observation'], **self.obs_list[-1][player]['observation']}
    #     action = agent(Observation(obs))
    #     return self.ACTION.index(action)

    def rule_based_action(self, player):

        import random

        from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col

        def get_nearest_cells(x,y):
            # returns all cells reachable from the current one
            result = []
            for i in (-1,+1):
                result.append(((x+i+7)%7, y))
                result.append((x, (y+i+11)%11))
            return result

        def find_closest_food(table):
            # returns the first step toward the closest food item
            new_table = table.copy()
            
            
            # (direction of the step, axis, code)
            possible_moves = [
                (1, 0, 1),
                (-1, 0, 2),
                (1, 1, 3),
                (-1, 1, 4)
            ]
            
            # shuffle possible options to add variability
            random.shuffle(possible_moves)
            
            
            updated = False
            for roll, axis, code in possible_moves:

                shifted_table = np.roll(table, roll, axis)
                
                if (table == -2).any() and (shifted_table[table == -2] == -3).any(): # we have found some food at the first step
                    return code
                else:
                    mask = np.logical_and(new_table == 0,shifted_table == -3)
                    if mask.sum() > 0:
                        updated = True
                    new_table += code * mask
                if (table == -2).any() and shifted_table[table == -2][0] > 0: # we have found some food
                    return shifted_table[table == -2][0]
                
                # else - update new reachible cells
                mask = np.logical_and(new_table == 0,shifted_table > 0)
                if mask.sum() > 0:
                    updated = True
                new_table += shifted_table * mask

            # if we updated anything - continue reccurison
            if updated:
                return find_closest_food(new_table)
            # if not - return some step
            else:
                return table.max()

        
        #self.last_step = None

        def agent(obs_dict, config_dict):
            self.last_step = None
            #global self.
            
            observation = Observation(obs_dict)
            configuration = config_dict
            player_index = observation.index
            player_goose = observation.geese[player_index]
            player_head = player_goose[0]
            player_row, player_column = row_col(player_head, configuration.columns)


            table = np.zeros((7,11))
            # 0 - emply cells
            # -1 - obstacles
            # -4 - possible obstacles
            # -2 - food
            # -3 - head
            # 1,2,3,4 - reachable on the current step cell, number is the id of the first step direction
            
            legend = {
                1: 'SOUTH',
                2: 'NORTH',
                3: 'EAST',
                4: 'WEST'
            }
            
            # let's add food to the map
            for food in observation.food:
                x,y = row_col(food, configuration.columns)
                table[x,y] = -2 # food
                
            # let's add all cells that are forbidden
            for i in range(4):
                opp_goose = observation.geese[i]
                if len(opp_goose) == 0:
                    continue
                    
                is_close_to_food = False
                    
                if i != player_index:
                    x,y = row_col(opp_goose[0], configuration.columns)
                    possible_moves = get_nearest_cells(x,y) # head can move anywhere
                    
                    for x,y in possible_moves:
                        if table[x,y] == -2:
                            is_close_to_food = True
                    
                        table[x,y] = -4 # possibly forbidden cells
                
                # usually we ignore the last tail cell but there are exceptions
                tail_change = -1
                if obs_dict['step'] % 40 == 39:
                    tail_change -= 1
                
                # we assume that the goose will eat the food
                if is_close_to_food:
                    tail_change += 1
                if tail_change >= 0:
                    tail_change = None
                    

                for n in opp_goose[:tail_change]:
                    x,y = row_col(n, configuration.columns)
                    table[x,y] = -1 # forbidden cells
            
            # going back is forbidden according to the new rules
            x,y = row_col(player_head, configuration.columns)
            if self.last_step is not None:
                if self.last_step == 1:
                    table[(x + 6) % 7,y] = -1
                elif self.last_step == 2:
                    table[(x + 8) % 7,y] = -1
                elif self.last_step == 3:
                    table[x,(y + 10)%11] = -1
                elif self.last_step == 4:
                    table[x,(y + 12)%11] = -1
                
            # add head position
            table[x,y] = -3
            
            # the first step toward the nearest food
            step = int(find_closest_food(table))
            
            # if there is not available steps try to go to possibly dangerous cell
            if step not in [1,2,3,4]:
                x,y = row_col(player_head, configuration.columns)
                if table[(x + 8) % 7,y] == -4:
                    step = 1
                elif table[(x + 6) % 7,y] == -4:
                    step = 2
                elif table[x,(y + 12)%11] == -4:
                    step = 3
                elif table[x,(y + 10)%11] == -4:
                    step = 4
                        
            # else - do a random step and lose
                else:
                    step = np.random.randint(4) + 1
            
            self.last_step = step
            return legend[step]
        obs = {**self.obs_list[-1][0]['observation'], **self.obs_list[-1][player]['observation']}
        action = agent(obs, Configuration({'rows': 7, 'columns': 11}))
        return self.ACTION.index(action)


    def rule_based_action2(self, player):

        from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col
        import numpy as np
        import random
        import copy

        self.frame = 0
        self.opposites = {Action.EAST: Action.WEST, Action.WEST: Action.EAST, Action.NORTH: Action.SOUTH, Action.SOUTH: Action.NORTH}
        self.action_meanings = {Action.EAST: (1, 0), Action.WEST: (-1, 0), Action.NORTH: (0, -1), Action.SOUTH: (0, 1)}
        self.action_names = {(1, 0): Action.EAST, (-10, 0): Action.EAST, (-1, 0): Action.WEST, (10, 0): Action.WEST, (0, -1): Action.NORTH, (0, 6): Action.NORTH, (0, -6): Action.SOUTH, (0, 1): Action.SOUTH}
        strValue = {Action.EAST: 'EAST', Action.WEST: 'WEST', Action.NORTH: 'NORTH', Action.SOUTH: 'SOUTH'}
        self.all_last_actions = [None, None, None, None]
        self.revert_last_actions = [None, None, None, None]
        self.last_observation = None

        class Obs:
            pass

        def setLastActions(observation, configuration):
            #global frame, revert_last_actions, all_last_actions
            if not self.frame == 0:
                for i in range(4):
                    setLastAction(observation, configuration, i)
            self.revert_last_actions = copy.deepcopy(self.all_last_actions)


        def revertLastActions():
            #global revert_last_actions, all_last_actions
            self.all_last_actions = copy.deepcopy(self.revert_last_actions)


        def setLastAction(observation, configuration, gooseIndex):
            #global last_observation, all_last_actions, action_names
            if len(observation.geese[gooseIndex]) > 0:
                oldGooseRow, oldGooseCol = row_col(self.last_observation.geese[gooseIndex][0], configuration.columns)
                newGooseRow, newGooseCol = row_col(observation.geese[gooseIndex][0], configuration.columns)
                self.all_last_actions[gooseIndex] = self.action_names[
                    ((newGooseCol - oldGooseCol) % configuration.columns, (newGooseRow - oldGooseRow) % configuration.rows)]


        def getValidDirections(observation, configuration, gooseIndex):
            #global all_last_actions, opposites
            directions = [Action.EAST, Action.WEST, Action.NORTH, Action.SOUTH]
            returnDirections = []
            for direction in directions:
                row, col = getRowColForAction(observation, configuration, gooseIndex, direction)
                if not willGooseBeThere(observation, configuration, row, col) and not self.all_last_actions[gooseIndex] == self.opposites[
                    direction]:
                    returnDirections.append(direction)        
            if len(returnDirections) == 0:
                return directions
            return returnDirections


        def randomTurn(observation, configuration, actionOverrides, rewards, fr):
            newObservation = cloneObservation(observation)
            for i in range(4):
                if len(observation.geese[i]) > 0:
                    if i in actionOverrides.keys():
                        newObservation = performActionForGoose(observation, configuration, i, newObservation, actionOverrides[i])
                    else:
                        newObservation = randomActionForGoose(observation, configuration, i, newObservation)

            checkForCollisions(newObservation, configuration)
            updateRewards(newObservation, configuration, rewards, fr)
            hunger(newObservation, fr)
            return newObservation


        def hunger(observation, fr):
            if fr % 40 == 0:
                for g, goose in enumerate(observation.geese):
                    goose = goose[0:len(goose)-1]
                    


        def updateRewards(observation, configuration, rewards, fr):
            for g, goose in enumerate(observation.geese):
                if len(goose) > 0:
                    rewards[g] = 2 * fr + len(goose)

        def checkForCollisions(observation, configuration):
            killed = []
            for g, goose in enumerate(observation.geese):
                if len(goose) > 0:
                    for o, otherGoose in enumerate(observation.geese):
                        for p, part in enumerate(otherGoose):
                            if not (o == g and p == 0):
                                if goose[0] == part:
                                    killed.append(g)

            for kill in killed:
                observation.geese[kill] = []


        def cloneObservation(observation):
            newObservation = Obs()
            newObservation.index = observation.index
            newObservation.geese = copy.deepcopy(observation.geese)
            newObservation.food = copy.deepcopy(observation.food)
            return newObservation


        def randomActionForGoose(observation, configuration, gooseIndex, newObservation):
            validActions = getValidDirections(observation, configuration, gooseIndex)
            action = random.choice(validActions)
            row, col = getRowColForAction(observation, configuration, gooseIndex, action)
            newObservation.geese[gooseIndex] = [row * configuration.columns + col] + newObservation.geese[gooseIndex]
            if not isFoodThere(observation, configuration, row, col):
                newObservation.geese[gooseIndex] = newObservation.geese[gooseIndex][0:len(newObservation.geese[gooseIndex])-1]  
            return newObservation

        def performActionForGoose(observation, configuration, gooseIndex, newObservation, action):
            row, col = getRowColForAction(observation, configuration, gooseIndex, action)
            newObservation.geese[gooseIndex][:0] = [row * configuration.columns + col]
            if not isFoodThere(observation, configuration, row, col):
                newObservation.geese[gooseIndex] = newObservation.geese[gooseIndex][0:len(newObservation.geese[gooseIndex])-1]  
            return newObservation
                

        def isFoodThere(observation, configuration, row, col):
            for food in observation.food:
                foodRow, foodCol = row_col(food, configuration.columns)
                if foodRow == row and foodCol == col:
                    return True
            return False


        def willGooseBeThere(observation, configuration, row, col):
            for goose in observation.geese:
                for p, part in enumerate(goose):
                    if not p == len(goose) - 1:
                        partRow, partCol = row_col(part, configuration.columns)
                        if partRow == row and partCol == col:
                            return True
            return False


        def getRowColForAction(observation, configuration, gooseIndex, action):
            #global action_meanings
            gooseRow, gooseCol = row_col(observation.geese[gooseIndex][0], configuration.columns)
            actionRow = (gooseRow + self.action_meanings[action][1]) % configuration.rows
            actionCol = (gooseCol + self.action_meanings[action][0]) % configuration.columns
            return actionRow, actionCol


        def simulateMatch(observation, configuration, firstMove, depth):
            #global frame
            actionOverrides = {observation.index: firstMove}
            revertLastActions()
            simulationFrame = self.frame + 1
            newObservation = cloneObservation(observation)
            rewards = [0, 0, 0, 0]
            count = 0
            while count < depth:
                newObservation = randomTurn(newObservation, configuration, actionOverrides, rewards, simulationFrame)
                actionOverrides = {}
                simulationFrame += 1
                count += 1
            return rewards


        def simulateMatches(observation, configuration, numMatches, depth):
            options = getValidDirections(observation, configuration, observation.index)
            rewardTotals = []
            for o, option in enumerate(options):
                rewardsForOption = [0, 0, 0, 0]
                for i in range(numMatches):
                    matchRewards = simulateMatch(observation, configuration, option, depth)
                    for j in range(4):
                        rewardsForOption[j] += matchRewards[j]
                rewardTotals.append(rewardsForOption)
            scores = []
            for o, option in enumerate(options):
                rewards = rewardTotals[o]
                if len(rewards) <= 0:
                    mean = 0
                else:
                    mean = sum(rewards) / len(rewards)
                if mean == 0:
                    scores.append(0)
                else:
                    scores.append(rewards[observation.index] / mean)
            
            # print('frame: ', frame)
            # print('options: ', options)
            # print('scores: ', scores)
            # print('reward totals: ', rewardTotals)
            # print('lengths: ')
            # print('0: ', len(observation.geese[0]))
            # print('1: ', len(observation.geese[1]))
            # print('2: ', len(observation.geese[2]))
            # print('3: ', len(observation.geese[3]))

            return options[scores.index(max(scores))]



        def agent(obs_dict, config_dict):
            #global last_observation, all_last_actions, opposites, frame
            observation = Observation(obs_dict)
            configuration = config_dict
            setLastActions(observation, configuration)
            myLength = len(observation.geese[observation.index])
            if myLength < 5:
                my_action = simulateMatches(observation, configuration, 300, 3)
            elif myLength < 9:
                my_action = simulateMatches(observation ,configuration, 120, 6)
            else:
                my_action = simulateMatches(observation, configuration, 85, 9)
            
            self.last_observation = cloneObservation(observation)
            self.frame += 1
            return strValue[my_action]

        obs = {**self.obs_list[-1][0]['observation'], **self.obs_list[-1][player]['observation']}
        action = agent(obs, Configuration({'rows': 7, 'columns': 11}))
        return self.ACTION.index(action)    


    def net(self):
        return GeeseNet

    def observation(self, player=None):
        if player is None:
            player = 0

        b = np.zeros((self.NUM_AGENTS * 4 + 1, 7 * 11), dtype=np.float32)
        obs = self.obs_list[-1][0]['observation']

        for p, geese in enumerate(obs['geese']):
            # head position
            for pos in geese[:1]:
                b[0 + (p - player) % self.NUM_AGENTS, pos] = 1
            # tip position
            for pos in geese[-1:]:
                b[4 + (p - player) % self.NUM_AGENTS, pos] = 1
            # whole position
            for pos in geese:
                b[8 + (p - player) % self.NUM_AGENTS, pos] = 1

        # previous head position
        if len(self.obs_list) > 1:
            obs_prev = self.obs_list[-2][0]['observation']
            for p, geese in enumerate(obs_prev['geese']):
                for pos in geese[:1]:
                    b[12 + (p - player) % self.NUM_AGENTS, pos] = 1

        # food
        for pos in obs['food']:
            b[16, pos] = 1

        return b.reshape(-1, 7, 11)


if __name__ == '__main__':
    e = Environment()
    for _ in range(100):
        e.reset()
        while not e.terminal():
            print(e)
            actions = {p: e.legal_actions(p) for p in e.turns()}
            print([[e.action2str(a, p) for a in alist] for p, alist in actions.items()])
            e.step({p: random.choice(alist) for p, alist in actions.items()})
        print(e)
        print(e.outcome())


In [None]:
%%writefile handyrl/evaluation.py
# Copyright (c) 2020 DeNA Co., Ltd.
# Licensed under The MIT License [see LICENSE for details]

# evaluation of policies or planning algorithms

import random
import time
import multiprocessing as mp

from .environment import prepare_env, make_env
from .connection import send_recv, accept_socket_connections, connect_socket_connection
from .agent import RandomAgent, RuleBasedAgent, Agent, EnsembleAgent, SoftAgent


network_match_port = 9876


def view(env, player=None):
    if hasattr(env, 'view'):
        env.view(player=player)
    else:
        print(env)


def view_transition(env):
    if hasattr(env, 'view_transition'):
        env.view_transition()
    else:
        pass


class NetworkAgentClient:
    def __init__(self, agent, env, conn):
        self.conn = conn
        self.agent = agent
        self.env = env

    def run(self):
        while True:
            command, args = self.conn.recv()
            if command == 'quit':
                break
            elif command == 'outcome':
                print('outcome = %f' % args[0])
            elif hasattr(self.agent, command):
                if command == 'action' or command == 'observe':
                    view(self.env)
                ret = getattr(self.agent, command)(self.env, *args, show=True)
                if command == 'action':
                    player = args[0]
                    ret = self.env.action2str(ret, player)
            else:
                ret = getattr(self.env, command)(*args)
                if command == 'update':
                    reset = args[1]
                    if reset:
                        self.agent.reset(self.env, show=True)
                    view_transition(self.env)
            self.conn.send(ret)


class NetworkAgent:
    def __init__(self, conn):
        self.conn = conn

    def update(self, data, reset):
        return send_recv(self.conn, ('update', [data, reset]))

    def outcome(self, outcome):
        return send_recv(self.conn, ('outcome', [outcome]))

    def action(self, player):
        return send_recv(self.conn, ('action', [player]))

    def observe(self, player):
        return send_recv(self.conn, ('observe', [player]))


def exec_match(env, agents, critic, show=False, game_args={}):
    ''' match with shared game environment '''
    if env.reset(game_args):
        return None
    for agent in agents.values():
        agent.reset(env, show=show)
    while not env.terminal():
        if show:
            view(env)
        if show and critic is not None:
            print('cv = ', critic.observe(env, None, show=False)[0])
        turn_players = env.turns()
        actions = {}
        for p, agent in agents.items():
            if p in turn_players:
                actions[p] = agent.action(env, p, show=show)
            else:
                agent.observe(env, p, show=show)
        if env.step(actions):
            return None
        if show:
            view_transition(env)
    outcome = env.outcome()
    if show:
        print('final outcome = %s' % outcome)
    return outcome


def exec_network_match(env, network_agents, critic, show=False, game_args={}):
    ''' match with divided game environment '''
    if env.reset(game_args):
        return None
    for p, agent in network_agents.items():
        info = env.diff_info(p)
        agent.update(info, True)
    while not env.terminal():
        if show:
            view(env)
        if show and critic is not None:
            print('cv = ', critic.observe(env, None, show=False)[0])
        turn_players = env.turns()
        actions = {}
        for p, agent in network_agents.items():
            if p in turn_players:
                action = agent.action(p)
                actions[p] = env.str2action(action, p)
            else:
                agent.observe(p)
        if env.step(actions):
            return None
        for p, agent in network_agents.items():
            info = env.diff_info(p)
            agent.update(info, False)
    outcome = env.outcome()
    for p, agent in network_agents.items():
        agent.outcome(outcome[p])
    return outcome


def build_agent(raw, env):
    if raw == 'random':
        return RandomAgent()
    elif raw == 'rulebase':
        return RuleBasedAgent()
    return None


class Evaluator:
    def __init__(self, env, args):
        self.env = env
        self.args = args
        self.default_opponent = 'rulebase'

    def execute(self, models, args):
        opponents = self.args.get('eval', {}).get('opponent', [])
        if len(opponents) == 0:
            opponent = self.default_opponent
        else:
            #opponent = random.choice(opponents)
            opponent = self.default_opponent

        agents = {}
        for p, model in models.items():
            if model is None:
                agents[p] = build_agent(opponent, self.env)
            else:
                agents[p] = Agent(model, self.args['observation'])

        outcome = exec_match(self.env, agents, None)
        if outcome is None:
            print('None episode in evaluation!')
            return None
        return {'args': args, 'result': outcome, 'opponent': opponent}


def wp_func(results):
    games = sum([v for k, v in results.items() if k is not None])
    win = sum([(k + 1) / 2 * v for k, v in results.items() if k is not None])
    if games == 0:
        return 0.0
    return win / games


def eval_process_mp_child(agents, critic, env_args, index, in_queue, out_queue, seed, show=False):
    random.seed(seed + index)
    env = make_env({**env_args, 'id': index})
    while True:
        args = in_queue.get()
        if args is None:
            break
        g, agent_ids, pat_idx, game_args = args
        print('*** Game %d ***' % g)
        agent_map = {env.players()[p]: agents[ai] for p, ai in enumerate(agent_ids)}
        if isinstance(list(agent_map.values())[0], NetworkAgent):
            outcome = exec_network_match(env, agent_map, critic, show=show, game_args=game_args)
        else:
            outcome = exec_match(env, agent_map, critic, show=show, game_args=game_args)
        out_queue.put((pat_idx, agent_ids, outcome))
    out_queue.put(None)


def evaluate_mp(env, agents, critic, env_args, args_patterns, num_process, num_games, seed):
    in_queue, out_queue = mp.Queue(), mp.Queue()
    args_cnt = 0
    total_results, result_map = [{} for _ in agents], [{} for _ in agents]
    print('total games = %d' % (len(args_patterns) * num_games))
    time.sleep(0.1)
    for pat_idx, args in args_patterns.items():
        for i in range(num_games):
            if len(agents) == 2:
                # When playing two player game,
                # the number of games with first or second player is equalized.
                first_agent = 0 if i < (num_games + 1) // 2 else 1
                tmp_pat_idx, agent_ids = (pat_idx + '-F', [0, 1]) if first_agent == 0 else (pat_idx + '-S', [1, 0])
            else:
                tmp_pat_idx, agent_ids = pat_idx, random.sample(list(range(len(agents))), len(agents))
            in_queue.put((args_cnt, agent_ids, tmp_pat_idx, args))
            for p in range(len(agents)):
                result_map[p][tmp_pat_idx] = {}
            args_cnt += 1

    network_mode = agents[0] is None
    if network_mode:  # network battle mode
        agents = network_match_acception(num_process, env_args, len(agents), network_match_port)
    else:
        agents = [agents] * num_process

    for i in range(num_process):
        in_queue.put(None)
        args = agents[i], critic, env_args, i, in_queue, out_queue, seed
        if num_process > 1:
            mp.Process(target=eval_process_mp_child, args=args).start()
            if network_mode:
                for agent in agents[i]:
                    agent.conn.close()
        else:
            eval_process_mp_child(*args, show=True)

    finished_cnt = 0
    while finished_cnt < num_process:
        ret = out_queue.get()
        if ret is None:
            finished_cnt += 1
            continue
        pat_idx, agent_ids, outcome = ret
        if outcome is not None:
            for idx, p in enumerate(env.players()):
                agent_id = agent_ids[idx]
                oc = outcome[p]
                result_map[agent_id][pat_idx][oc] = result_map[agent_id][pat_idx].get(oc, 0) + 1
                total_results[agent_id][oc] = total_results[agent_id].get(oc, 0) + 1

    for p, r_map in enumerate(result_map):
        print('---agent %d---' % p)
        for pat_idx, results in r_map.items():
            print(pat_idx, {k: results[k] for k in sorted(results.keys(), reverse=True)}, wp_func(results))
        print('total', {k: total_results[p][k] for k in sorted(total_results[p].keys(), reverse=True)}, wp_func(total_results[p]))


def network_match_acception(n, env_args, num_agents, port):
    waiting_conns = []
    accepted_conns = []

    for conn in accept_socket_connections(port):
        if len(accepted_conns) >= n * num_agents:
            break
        waiting_conns.append(conn)

        if len(waiting_conns) == num_agents:
            conn = waiting_conns[0]
            accepted_conns.append(conn)
            waiting_conns = waiting_conns[1:]
            conn.send(env_args)  # send accept with environment arguments

    agents_list = [
        [NetworkAgent(accepted_conns[i * num_agents + j]) for j in range(num_agents)]
        for i in range(n)
    ]

    return agents_list


def get_model(env, model_path):
    import torch
    from .model import ModelWrapper
    model = env.net()()
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return ModelWrapper(model)


def client_mp_child(env_args, model_path, conn):
    env = make_env(env_args)
    model = get_model(env, model_path)
    NetworkAgentClient(Agent(model), env, conn).run()


def eval_main(args, argv):
    env_args = args['env_args']
    prepare_env(env_args)
    env = make_env(env_args)

    model_path = argv[0] if len(argv) >= 1 else 'models/latest.pth'
    num_games = int(argv[1]) if len(argv) >= 2 else 100
    num_process = int(argv[2]) if len(argv) >= 3 else 1

    agent1 = Agent(get_model(env, model_path))
    critic = None

    print('%d process, %d games' % (num_process, num_games))

    seed = random.randrange(1e8)
    print('seed = %d' % seed)

    agents = [agent1] + [RandomAgent() for _ in range(len(env.players()) - 1)]

    evaluate_mp(env, agents, critic, env_args, {'default': {}}, num_process, num_games, seed)


def eval_server_main(args, argv):
    print('network match server mode')
    env_args = args['env_args']
    prepare_env(env_args)
    env = make_env(env_args)

    num_games = int(argv[0]) if len(argv) >= 1 else 100
    num_process = int(argv[1]) if len(argv) >= 2 else 1

    print('%d process, %d games' % (num_process, num_games))

    seed = random.randrange(1e8)
    print('seed = %d' % seed)

    evaluate_mp(env, [None] * len(env.players()), None, env_args, {'default': {}}, num_process, num_games, seed)


def eval_client_main(args, argv):
    print('network match client mode')
    while True:
        try:
            host = argv[1] if len(argv) >= 2 else 'localhost'
            conn = connect_socket_connection(host, network_match_port)
            env_args = conn.recv()
        except EOFError:
            break

        model_path = argv[0] if len(argv) >= 1 else 'models/latest.pth'
        mp.Process(target=client_mp_child, args=(env_args, model_path, conn)).start()
        conn.close()


In [None]:
%%writefile handyrl/train.py
# Copyright (c) 2020 DeNA Co., Ltd.
# Licensed under The MIT License [see LICENSE for details]

# training

import os
import time
import copy
import threading
import random
import bz2
import base64
import pickle
import warnings
from collections import deque

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import torch.optim as optim
import psutil

from .environment import prepare_env, make_env
from .util import map_r, bimap_r, trimap_r, rotate
from .model import to_torch, to_gpu, RandomModel, ModelWrapper
from .losses import compute_target
from .connection import MultiProcessJobExecutor
from .connection import accept_socket_connections
from .worker import WorkerCluster, WorkerServer
from torch.cuda.amp import autocast
from joblib import dump, load

def make_batch(episodes, args):
    """Making training batch

    Args:
        episodes (Iterable): list of episodes
        args (dict): training configuration

    Returns:
        dict: PyTorch input and target tensors

    Note:
        Basic data shape is (B, T, P, ...) .
        (B is batch size, T is time length, P is player count)
    """

    obss, datum = [], []

    def replace_none(a, b):
        return a if a is not None else b

    for ep in episodes:
        moments_ = sum([pickle.loads(bz2.decompress(ms)) for ms in ep['moment']], [])
        moments = moments_[ep['start'] - ep['base']:ep['end'] - ep['base']]
        players = list(moments[0]['observation'].keys())
        if not args['turn_based_training']:  # solo training
            players = [random.choice(players)]

        obs_zeros = map_r(moments[0]['observation'][moments[0]['turn'][0]], lambda o: np.zeros_like(o))  # template for padding
        p_zeros = np.zeros_like(moments[0]['policy'][moments[0]['turn'][0]])  # template for padding

        # data that is chainge by training configuration
        if args['turn_based_training'] and not args['observation']:
            obs = [[m['observation'][m['turn'][0]]] for m in moments]
            p = np.array([[m['policy'][m['turn'][0]]] for m in moments])
            act = np.array([[m['action'][m['turn'][0]]] for m in moments], dtype=np.int64)[..., np.newaxis]
            amask = np.array([[m['action_mask'][m['turn'][0]]] for m in moments])
        else:
            obs = [[replace_none(m['observation'][player], obs_zeros) for player in players] for m in moments]
            p = np.array([[replace_none(m['policy'][player], p_zeros) for player in players] for m in moments])
            act = np.array([[replace_none(m['action'][player], 0) for player in players] for m in moments], dtype=np.int64)[..., np.newaxis]
            amask = np.array([[replace_none(m['action_mask'][player], p_zeros + 1e32) for player in players] for m in moments])

        # reshape observation
        obs = rotate(rotate(obs))  # (T, P, ..., ...) -> (P, ..., T, ...) -> (..., T, P, ...)
        obs = bimap_r(obs_zeros, obs, lambda _, o: np.array(o))

        # datum that is not changed by training configuration
        v = np.array([[replace_none(m['value'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
        rew = np.array([[replace_none(m['reward'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
        ret = np.array([[replace_none(m['return'][player], [0]) for player in players] for m in moments], dtype=np.float32).reshape(len(moments), len(players), -1)
        oc = np.array([ep['outcome'][player] for player in players], dtype=np.float32).reshape(1, len(players), -1)

        emask = np.ones((len(moments), 1, 1), dtype=np.float32)  # episode mask
        tmask = np.array([[[m['policy'][player] is not None] for player in players] for m in moments], dtype=np.float32)
        omask = np.array([[[m['value'][player] is not None] for player in players] for m in moments], dtype=np.float32)

        progress = np.arange(ep['start'], ep['end'], dtype=np.float32)[..., np.newaxis] / ep['total']

        # pad each array if step length is short
        if len(tmask) < args['forward_steps']:
            pad_len = args['forward_steps'] - len(tmask)
            obs = map_r(obs, lambda o: np.pad(o, [(0, pad_len)] + [(0, 0)] * (len(o.shape) - 1), 'constant', constant_values=0))
            p = np.pad(p, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0)
            v = np.concatenate([v, np.tile(oc, [pad_len, 1, 1])])
            act = np.pad(act, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0)
            rew = np.pad(rew, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0)
            ret = np.pad(ret, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0)
            emask = np.pad(emask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0)
            tmask = np.pad(tmask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0)
            omask = np.pad(omask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=0)
            amask = np.pad(amask, [(0, pad_len), (0, 0), (0, 0)], 'constant', constant_values=1e32)
            progress = np.pad(progress, [(0, pad_len), (0, 0)], 'constant', constant_values=1)

        obss.append(obs)
        datum.append((p, v, act, oc, rew, ret, emask, tmask, omask, amask, progress))

    p, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = zip(*datum)

    obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o)))
    p = to_torch(np.array(p))
    v = to_torch(np.array(v))
    act = to_torch(np.array(act))
    oc = to_torch(np.array(oc))
    rew = to_torch(np.array(rew))
    ret = to_torch(np.array(ret))
    emask = to_torch(np.array(emask))
    tmask = to_torch(np.array(tmask))
    omask = to_torch(np.array(omask))
    amask = to_torch(np.array(amask))
    progress = to_torch(np.array(progress))

    return {
        'observation': obs,
        'policy': p, 'value': v,
        'action': act, 'outcome': oc,
        'reward': rew, 'return': ret,
        'episode_mask': emask,
        'turn_mask': tmask, 'observation_mask': omask,
        'action_mask': amask,
        'progress': progress,
    }


def forward_prediction(model, hidden, batch, args):
    """Forward calculation via neural network

    Args:
        model (torch.nn.Module): neural network
        hidden: initial hidden state (..., B, P, ...)
        batch (dict): training batch (output of make_batch() function)

    Returns:
        tuple: batch outputs of neural network
    """

    observations = batch['observation']  # (B, T, P, ...)

    if hidden is None:
        # feed-forward neural network
        obs = map_r(observations, lambda o: o.view(-1, *o.size()[3:]))
        outputs = model(obs, None)
    else:
        # sequential computation with RNN
        outputs = {}
        for t in range(batch['turn_mask'].size(1)):
            obs = map_r(observations, lambda o: o[:, t].reshape(-1, *o.size()[3:]))  # (..., B * P, ...)
            omask_ = batch['observation_mask'][:, t]
            omask = map_r(hidden, lambda h: omask_.view(*h.size()[:2], *([1] * (len(h.size()) - 2))))
            hidden_ = bimap_r(hidden, omask, lambda h, m: h * m)  # (..., B, P, ...)
            if args['turn_based_training'] and not args['observation']:
                hidden_ = map_r(hidden_, lambda h: h.sum(1))  # (..., B * 1, ...)
            else:
                hidden_ = map_r(hidden_, lambda h: h.view(-1, *h.size()[2:]))  # (..., B * P, ...)
            outputs_ = model(obs, hidden_)
            for k, o in outputs_.items():
                if k == 'hidden':
                    next_hidden = outputs_['hidden']
                else:
                    outputs[k] = outputs.get(k, []) + [o]
            next_hidden = bimap_r(next_hidden, hidden, lambda nh, h: nh.view(h.size(0), -1, *h.size()[2:]))  # (..., B, P or 1, ...)
            hidden = trimap_r(hidden, next_hidden, omask, lambda h, nh, m: h * (1 - m) + nh * m)
        outputs = {k: torch.stack(o, dim=1) for k, o in outputs.items() if o[0] is not None}

    for k, o in outputs.items():
        o = o.view(*batch['turn_mask'].size()[:2], -1, o.size(-1))
        if k == 'policy':
            # gather turn player's policies
            outputs[k] = o.mul(batch['turn_mask']).sum(2, keepdim=True) - batch['action_mask']
        else:
            # mask valid target values and cumulative rewards
            outputs[k] = o.mul(batch['observation_mask'])

    return outputs


def compose_losses(outputs, log_selected_policies, total_advantages, targets, batch, args):
    """Caluculate loss value

    Returns:
        tuple: losses and statistic values and the number of training data
    """

    tmasks = batch['turn_mask']
    omasks = batch['observation_mask']

    losses = {}
    dcnt = tmasks.sum().item()
    turn_advantages = total_advantages.mul(tmasks).sum(2, keepdim=True)

    losses['p'] = (-log_selected_policies * turn_advantages).sum()
    if 'value' in outputs:
        losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2
    if 'return' in outputs:
        losses['r'] = F.smooth_l1_loss(outputs['return'], targets['return'], reduction='none').mul(omasks).sum()

    entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(tmasks.sum(-1))
    losses['ent'] = entropy.sum()

    base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0)
    entropy_loss = entropy.mul(1 - batch['progress'] * (1 - args['entropy_regularization_decay'])).sum() * -args['entropy_regularization']
    losses['total'] = base_loss + entropy_loss

    return losses, dcnt


def compute_loss(batch, model, hidden, args):
    outputs = forward_prediction(model, hidden, batch, args)
    actions = batch['action']
    emasks = batch['episode_mask']
    clip_rho_threshold, clip_c_threshold = 1.0, 1.0

    log_selected_b_policies = F.log_softmax(batch['policy']  , dim=-1).gather(-1, actions) * emasks
    log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks

    # thresholds of importance sampling
    log_rhos = log_selected_t_policies.detach() - log_selected_b_policies
    rhos = torch.exp(log_rhos)
    clipped_rhos = torch.clamp(rhos, 0, clip_rho_threshold)
    cs = torch.clamp(rhos, 0, clip_c_threshold)
    outputs_nograd = {k: o.detach() for k, o in outputs.items()}

    if 'value' in outputs_nograd:
        values_nograd = outputs_nograd['value']
        if args['turn_based_training'] and values_nograd.size(2) == 2:  # two player zerosum game
            values_nograd_opponent = -torch.stack([values_nograd[:, :, 1], values_nograd[:, :, 0]], dim=2)
            values_nograd = (values_nograd + values_nograd_opponent) / (batch['observation_mask'].sum(dim=2, keepdim=True) + 1e-8)
        outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks)

    # compute targets and advantage
    targets = {}
    advantages = {}

    value_args = outputs_nograd.get('value', None), batch['outcome'], None, args['lambda'], 1, clipped_rhos, cs
    return_args = outputs_nograd.get('return', None), batch['return'], batch['reward'], args['lambda'], args['gamma'], clipped_rhos, cs

    targets['value'], advantages['value'] = compute_target(args['value_target'], *value_args)
    targets['return'], advantages['return'] = compute_target(args['value_target'], *return_args)

    if args['policy_target'] != args['value_target']:
        _, advantages['value'] = compute_target(args['policy_target'], *value_args)
        _, advantages['return'] = compute_target(args['policy_target'], *return_args)

    # compute policy advantage
    total_advantages = clipped_rhos * sum(advantages.values())

    return compose_losses(outputs, log_selected_t_policies, total_advantages, targets, batch, args)


class Batcher:
    def __init__(self, args, episodes):
        self.args = args
        self.episodes = episodes
        self.shutdown_flag = False

        self.executor = MultiProcessJobExecutor(self._worker, self._selector(), self.args['num_batchers'], num_receivers=2)

    def _selector(self):
        while True:
            yield [self.select_episode() for _ in range(self.args['batch_size'])]

    def _worker(self, conn, bid):
        print('started batcher %d' % bid)
        while not self.shutdown_flag:
            episodes = conn.recv()
            batch = make_batch(episodes, self.args)
            conn.send(batch)
        print('finished batcher %d' % bid)

    def run(self):
        self.executor.start()

    def select_episode(self):
        while True:
            ep_idx = random.randrange(min(len(self.episodes), self.args['maximum_episodes']))
            accept_rate = 1 - (len(self.episodes) - 1 - ep_idx) / self.args['maximum_episodes']
            if random.random() < accept_rate:
                break
        ep = self.episodes[ep_idx]
        turn_candidates = 1 + max(0, ep['steps'] - self.args['forward_steps'])  # change start turn by sequence length
        st = random.randrange(turn_candidates)
        ed = min(st + self.args['forward_steps'], ep['steps'])
        st_block = st // self.args['compress_steps']
        ed_block = (ed - 1) // self.args['compress_steps'] + 1
        ep_minimum = {
            'args': ep['args'], 'outcome': ep['outcome'],
            'moment': ep['moment'][st_block:ed_block],
            'base': st_block * self.args['compress_steps'],
            'start': st, 'end': ed, 'total': ep['steps']
        }
        return ep_minimum

    def batch(self):
        return self.executor.recv()

    def shutdown(self):
        self.shutdown_flag = True
        self.executor.shutdown()


class Trainer:
    def __init__(self, args, model):
        self.args = args
        if self.args['restart_with_saved_epochs']:
            print('start load saved episodes....')
            self.episodes = load('geese2genepisodes.job', mmap_mode=None)
            print('loaded saved episodes', len(self.episodes))
        if not self.args['restart_with_saved_epochs']:
            self.episodes = deque()
        restart_epoch = self.args['restart_epoch']
        self.gpu = torch.cuda.device_count()
        self.model = model
        
        if self.args['restart_with_saved_env_states']:
          self.batch_cnt = load('batch_cntload.job', mmap_mode=None)
          print('loaded saved batch_cntload')
        if not self.args['restart_with_saved_env_states']:
          self.batch_cnt = 0
        if self.args['restart_with_saved_env_states']:
          self.data_cnt = load('data_cntload.job', mmap_mode=None)
          print('loaded saved data_cntload')
        if not self.args['restart_with_saved_env_states']:
          self.data_cnt = 0
        
        self.default_lr = 3e-8
        if self.args['restart_with_saved_env_states']:
          self.data_cnt_ema = load('data_cnt_emaload.job', mmap_mode=None)
          print('loaded saved data_cnt_ema')
        if not self.args['restart_with_saved_env_states']:
          self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps']
        self.params = list(self.model.parameters())
        lr = self.default_lr * self.data_cnt_ema
        self.optimizer = optim.Adam(self.params, lr=lr, weight_decay=1e-5) if len(self.params) > 0 else None
        if self.args['restart_with_saved_env_states']:
          self.steps = load('stepsload.job', mmap_mode=None)
          self.steps_old = load('stepsload.job', mmap_mode=None)
          print('loaded saved steps')
        if not self.args['restart_with_saved_env_states']:
          self.steps = 0
          self.steps_old = 0
        self.lock = threading.Lock()
        self.batcher = Batcher(self.args, self.episodes)
        self.updated_model = None, 0
        self.update_flag = False
        self.shutdown_flag = False
        if self.args['restart_with_saved_env_states']:
          self.scaler = torch.cuda.amp.GradScaler()
          self.scaler.load_state_dict(pickle.loads(bz2.decompress(base64.b64decode(load('scalerload.job', mmap_mode=None)))))
          print('loaded saved GradScaler')
        if not self.args['restart_with_saved_env_states']:
          self.scaler = torch.cuda.amp.GradScaler()
        self.wrapped_model = ModelWrapper(self.model)
        self.once = True
        self.trained_model = self.wrapped_model
        if self.gpu > 1:
            self.trained_model = nn.DataParallel(self.wrapped_model)

    def update(self):
        if len(self.episodes) < self.args['minimum_episodes']:
            return None, 0  # return None before training
        self.update_flag = True
        while True:
            time.sleep(0.1)
            model, steps = self.recheck_update()
            if model is not None:
                break
        return model, steps

    def report_update(self, model, steps):
        self.lock.acquire()
        self.update_flag = False
        self.updated_model = model, steps
        self.lock.release()

    def recheck_update(self):
        self.lock.acquire()
        flag = self.update_flag
        self.lock.release()
        return (None, -1) if flag else self.updated_model

    def shutdown(self):
        self.shutdown_flag = True
        self.batcher.shutdown()

    def train(self):
        if self.optimizer is None:  # non-parametric model
            print()
            return

        # if self.once:
        #   self.optimizer.load_state_dict(pickle.loads(bz2.decompress(base64.b64decode(load('optimload.job', mmap_mode=None)))))
        #   print('loaded saved optimizer')
        #   self.once = False    

        loss_sum = {}
        if self.gpu > 0:
            self.trained_model.cuda()
        self.trained_model.train()

        while self.data_cnt == 0 or not (self.update_flag or self.shutdown_flag):
            batch = self.batcher.batch()
            batch_size = batch['value'].size(0)
            player_count = batch['value'].size(2)
            hidden = self.wrapped_model.init_hidden([batch_size, player_count])
            if self.gpu > 0:
                batch = to_gpu(batch)
                hidden = to_gpu(hidden)
            with autocast():
                losses, dcnt = compute_loss(batch, self.trained_model, hidden, self.args)
		
            self.optimizer.zero_grad()
            self.scaler.scale(losses['total']).backward()
            #losses['total'].backward()
            self.scaler.unscale_(self.optimizer)
            nn.utils.clip_grad_norm_(self.params, 4.0)
            #self.optimizer.step()
            self.scaler.step(self.optimizer)

            self.batch_cnt += 1
            self.data_cnt += dcnt
            for k, l in losses.items():
                loss_sum[k] = loss_sum.get(k, 0.0) + l.item()

            self.steps += 1
            self.scaler.update()

        print('loss = %s' % ' '.join([k + ':' + '%.3f' % (l / self.data_cnt) for k, l in loss_sum.items()]))

        self.data_cnt_ema = self.data_cnt_ema * 0.8 + self.data_cnt / (1e-2 + self.batch_cnt) * 0.2
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.default_lr * self.data_cnt_ema / (1 + self.steps * 1e-5)
        self.model.cpu()
        self.model.eval()
        return copy.deepcopy(self.model)

    def run(self):
        print('waiting training')
        while not self.shutdown_flag:
            if len(self.episodes) < self.args['minimum_episodes']:
                time.sleep(1)
                continue
            if self.steps == 0 or self.steps == self.steps_old:
                self.batcher.run()
                print('started training')
            model = self.train()
            self.report_update(model, self.steps)
        print('finished training')


class Learner:
    def __init__(self, args, net=None, remote=False):
        train_args = args['train_args']
        env_args = args['env_args']
        train_args['env'] = env_args
        args = train_args

        self.args = args
        random.seed(args['seed'])

        self.env = make_env(env_args)
        eval_modify_rate = (args['update_episodes'] ** 0.85) / args['update_episodes']
        self.eval_rate = max(args['eval_rate'], eval_modify_rate)
        self.shutdown_flag = False
        self.flags = set()

        # trained datum
        self.model_epoch = self.args['restart_epoch']
        self.model_class = net if net is not None else self.env.net()
        train_model = self.model_class()
        if self.model_epoch == 0:
            self.model = RandomModel(self.env)
        else:
            self.model = train_model
            try:
              print('loaded prev trained saved model')
              self.model.load_state_dict(pickle.loads(bz2.decompress(base64.b64decode(load('modelload.job', mmap_mode=None)))), strict=False)
              #self.model.load_state_dict(torch.load(self.model_path(self.model_epoch))['model'], strict=False)
            except:
              self.model.load_state_dict(torch.load(self.model_path(self.model_epoch)), strict=False)

        # generated datum
        self.generation_results = {}
        self.num_episodes = 0

        # evaluated datum
        self.results = {}
        self.results_per_opponent = {}
        self.num_results = 0

        # multiprocess or remote connection
        self.worker = WorkerServer(args) if remote else WorkerCluster(args)

        # thread connection
        self.trainer = Trainer(args, train_model)

        self.runs = 0

    def shutdown(self):
        self.shutdown_flag = True
        self.trainer.shutdown()
        self.worker.shutdown()
        self.thread.join()

    def model_path(self, model_id):
        return os.path.join('models', str(model_id) + '.pth')

    def latest_model_path(self):
        return os.path.join('models', 'latest.pth')


    def update_model(self, model, steps):
        # get latest model and save it
        print('updated model(%d)' % steps)
        self.model_epoch += 1
        self.model = model
        os.makedirs('models', exist_ok=True)
        torch.save({'model': model.state_dict(), 
            'optim': self.trainer.optimizer.state_dict(),
            'scaler': self.trainer.scaler.state_dict(),
            'data_cnt_ema': self.trainer.data_cnt_ema,
            'batch_cnt': self.trainer.batch_cnt,
            'data_cnt': self.trainer.data_cnt,
            'steps': self.trainer.steps},
            self.model_path(self.model_epoch))
        #torch.save(model.state_dict(), self.latest_model_path())
        torch.save({'model': model.state_dict(), 
            'optim': self.trainer.optimizer.state_dict(),
            'scaler': self.trainer.scaler.state_dict(),
            'data_cnt_ema': self.trainer.data_cnt_ema,
            'batch_cnt': self.trainer.batch_cnt,
            'data_cnt': self.trainer.data_cnt,
            'steps': self.trainer.steps},
            self.latest_model_path())


        
    def feed_episodes(self, episodes):
        # analyze generated episodes
        for episode in episodes:
            self.runs = self.runs + 1
            if episode is None:
                continue
            for p in episode['args']['player']:
                model_id = episode['args']['model_id'][p]
                outcome = episode['outcome'][p]
                n, r, r2 = self.generation_results.get(model_id, (0, 0, 0))
                self.generation_results[model_id] = n + 1, r + outcome, r2 + outcome ** 2

        # store generated episodes
        self.trainer.episodes.extend([e for e in episodes if e is not None])

        if self.runs>self.args['save_no_episodes']:
          
            # with open('geese2genepisodes.job', 'wb') as f:
            #       dump(self.trainer.episodes, f, compress=('lzma', 3))
            
            dump(self.trainer.episodes, "geese2genepisodes.job")
            
            no_epi = len(self.trainer.episodes)

            np.save('no_ep.npy', no_epi)

            self.runs = 0


        mem_percent = psutil.virtual_memory().percent
        mem_ok = mem_percent <= 95
        maximum_episodes = self.args['maximum_episodes'] if mem_ok else int(len(self.trainer.episodes) * 95 / mem_percent)

        if not mem_ok and 'memory_over' not in self.flags:
            warnings.warn("memory usage %.1f%% with buffer size %d" % (mem_percent, len(self.trainer.episodes)))
            self.flags.add('memory_over')

        while len(self.trainer.episodes) > maximum_episodes:
            self.trainer.episodes.popleft()

    def feed_results(self, results):
        # store evaluation results
        for result in results:
            if result is None:
                continue
            for p in result['args']['player']:
                model_id = result['args']['model_id'][p]
                res = result['result'][p]
                n, r, r2 = self.results.get(model_id, (0, 0, 0))
                self.results[model_id] = n + 1, r + res, r2 + res ** 2

                if model_id not in self.results_per_opponent:
                    self.results_per_opponent[model_id] = {}
                opponent = result['opponent']
                n, r, r2 = self.results_per_opponent[model_id].get(opponent, (0, 0, 0))
                self.results_per_opponent[model_id][opponent] = n + 1, r + res, r2 + res ** 2

    def update(self):
        # call update to every component
        print()
        print('epoch %d' % self.model_epoch)

        if self.model_epoch not in self.results:
            print('win rate = Nan (0)')
        else:
            def output_wp(name, results):
                n, r, r2 = results
                mean = r / (n + 1e-6)
                name_tag = ' (%s)' % name if name != '' else ''
                print('win rate%s = %.3f (%.1f / %d)' % (name_tag, (mean + 1) / 2, (r + n) / 2, n))

            if len(self.args.get('eval', {}).get('opponent', [])) <= 1:
                output_wp('', self.results[self.model_epoch])
            else:
                output_wp('total', self.results[self.model_epoch])
                for key in sorted(list(self.results_per_opponent[self.model_epoch])):
                    output_wp(key, self.results_per_opponent[self.model_epoch][key])

        if self.model_epoch not in self.generation_results:
            print('generation stats = Nan (0)')
        else:
            n, r, r2 = self.generation_results[self.model_epoch]
            mean = r / (n + 1e-6)
            std = (r2 / (n + 1e-6) - mean ** 2) ** 0.5
            print('generation stats = %.3f +- %.3f' % (mean, std))

        model, steps = self.trainer.update()
        if model is None:
            model = self.model
        self.update_model(model, steps)

        # clear flags
        self.flags = set()

    def server(self):
        # central conductor server
        # returns as list if getting multiple requests as list
        print('started server')
        prev_update_episodes = self.args['minimum_episodes']
        while self.model_epoch < self.args['epochs'] or self.args['epochs'] < 0:
            # no update call before storing minimum number of episodes + 1 age
            next_update_episodes = prev_update_episodes + self.args['update_episodes']
            while not self.shutdown_flag and self.num_episodes < next_update_episodes:
                conn, (req, data) = self.worker.recv()
                multi_req = isinstance(data, list)
                if not multi_req:
                    data = [data]
                send_data = []

                if req == 'args':
                    for _ in data:
                        args = {'model_id': {}}

                        # decide role
                        if self.num_results < self.eval_rate * self.num_episodes:
                            args['role'] = 'e'
                        else:
                            args['role'] = 'g'

                        if args['role'] == 'g':
                            # genatation configuration
                            args['player'] = self.env.players()
                            for p in self.env.players():
                                if p in args['player']:
                                    args['model_id'][p] = self.model_epoch
                                else:
                                    args['model_id'][p] = -1
                            self.num_episodes += 1
                            if self.num_episodes % 100 == 0:
                                print(self.num_episodes, end=' ', flush=True)

                        elif args['role'] == 'e':
                            # evaluation configuration
                            args['player'] = [self.env.players()[self.num_results % len(self.env.players())]]
                            for p in self.env.players():
                                if p in args['player']:
                                    args['model_id'][p] = self.model_epoch
                                else:
                                    args['model_id'][p] = -1
                            self.num_results += 1

                        send_data.append(args)

                elif req == 'episode':
                    # report generated episodes
                    self.feed_episodes(data)
                    send_data = [None] * len(data)

                elif req == 'result':
                    # report evaluation results
                    self.feed_results(data)
                    send_data = [None] * len(data)

                elif req == 'model':
                    for model_id in data:
                        model = self.model
                        if model_id != self.model_epoch:
                            try:
                                model = self.model_class()
                                model.load_state_dict(torch.load(self.model_path(model_id))['model'], strict=False)
                            except:
                                # return latest model if failed to load specified model
                                pass
                        send_data.append(pickle.dumps(model))

                if not multi_req and len(send_data) == 1:
                    send_data = send_data[0]
                self.worker.send(conn, send_data)
            prev_update_episodes = next_update_episodes
            self.update()
        print('finished server')

    def run(self):
        try:
            # open training thread
            self.thread = threading.Thread(target=self.trainer.run)
            self.thread.start()
            # open generator, evaluator
            self.worker.run()
            self.server()

        finally:
            self.shutdown()


def train_main(args):
    prepare_env(args['env_args'])  # preparing environment is needed in stand-alone mode
    learner = Learner(args=args)
    learner.run()


def train_server_main(args):
    learner = Learner(args=args, remote=True)
    learner.run()


In [None]:
%%writefile handyrl/agent.py
# Copyright (c) 2020 DeNA Co., Ltd.
# Licensed under The MIT License [see LICENSE for details]
# changed
# agent classes

import random

import numpy as np

from .util import softmax


class RandomAgent:
    def reset(self, env, show=False):
        pass

    def action(self, env, player, show=False):
        actions = env.legal_actions(player)
        return random.choice(actions)

    def observe(self, env, player, show=False):
        return [0.0]


class RuleBasedAgent(RandomAgent):
    def action(self, env, player, show=False):
        #if hasattr(env, 'rule_based_action'):
        return env.rule_based_action(player)
        #return random.choices([env.rule_based_action(player), env.rule_based_action2(player)], k=1, weights=[1, 1])[0]
        # else:
        #     return random.choice(env.legal_actions(player))


def print_outputs(env, prob, v):
    if hasattr(env, 'print_outputs'):
        env.print_outputs(prob, v)
    else:
        print('v = %f' % v)
        print('p = %s' % (prob * 1000).astype(int))


class Agent:
    def __init__(self, model, observation=False, temperature=0.0):
        # model might be a neural net, or some planning algorithm such as game tree search
        self.model = model
        self.hidden = None
        self.observation = observation
        self.temperature = temperature

    def reset(self, env, show=False):
        self.hidden = self.model.init_hidden()

    def plan(self, obs):
        outputs = self.model.inference(obs, self.hidden)
        self.hidden = outputs.pop('hidden', None)
        return outputs

    def action(self, env, player, show=False):
        outputs = self.plan(env.observation(player))
        actions = env.legal_actions(player)
        p = outputs['policy']
        v = outputs.get('value', None)
        mask = np.ones_like(p)
        mask[actions] = 0
        p = p - mask * 1e32

        if show:
            print_outputs(env, softmax(p), v)

        if self.temperature == 0:
            ap_list = sorted([(a, p[a]) for a in actions], key=lambda x: -x[1])
            return ap_list[0][0]
        else:
            return random.choices(np.arange(len(p)), weights=softmax(p / self.temperature))[0]

    def observe(self, env, player, show=False):
        v = None
        if self.observation:
            outputs = self.plan(env.observation(player))
            v = outputs.get('value', None)
            if show:
                print_outputs(env, None, v)
        return v if v is not None else [0.0]


class EnsembleAgent(Agent):
    def reset(self, env, show=False):
        self.hidden = [model.init_hidden() for model in self.model]

    def plan(self, obs):
        outputs = {}
        for i, model in enumerate(self.model):
            o = model.inference(obs, self.hidden[i])
            for k, v in o:
                if k == 'hidden':
                    self.hidden[i] = v
                else:
                    outputs[k] = outputs.get(k, []) + [o]
        for k, vl in outputs:
            outputs[k] = np.mean(vl, axis=0)
        return outputs


class SoftAgent(Agent):
    def __init__(self, model, observation=False):
        super().__init__(model, observation=observation, temperature=1.0)


In [None]:
%%writefile config.yaml

env_args:
    env: 'HungryGeese'
    source: 'handyrl.envs.kaggle.hungry_geese'


train_args:
    turn_based_training: False  # always False for Hungry Geese
    observation: False
    gamma: 0.8
    forward_steps: 32
    compress_steps: 4
    entropy_regularization: 2.0e-3
    entropy_regularization_decay: 0.3
    update_episodes: 500
    batch_size: 800
    minimum_episodes: 50000
    maximum_episodes: 200000
    eval_rate: 0.1
    epochs: 5000
    num_batchers: 28
    worker:
        num_parallel: 36
    lambda: 0.7
    policy_target: 'VTRACE'
    value_target: 'VTRACE'
    seed: 0
    restart_epoch: 0
    save_no_episodes: 200000
    restart_with_saved_epochs: False
    restart_with_saved_env_states: False


worker_args:
    server_address: ''
    num_parallel: 48

In [None]:
import os
import sys
import yaml

from handyrl.train import train_main 
from handyrl.train import train_server_main
from handyrl.worker import worker_main
from handyrl.evaluation import eval_main
from handyrl.evaluation import eval_server_main
from handyrl.evaluation import eval_client_main
with open('config.yaml') as f:
    args = yaml.safe_load(f)
print(args)

In [None]:
import pickle
import bz2
import base64
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from joblib import dump, load

if args['train_args']['restart_with_saved_env_states']:
    restart_epoch = args['train_args']['restart_epoch']
    modelload = base64.b64encode(bz2.compress(pickle.dumps(torch.load(f'models/{restart_epoch}.pth')['model']))) 
    optimload = base64.b64encode(bz2.compress(pickle.dumps(torch.load(f'models/{restart_epoch}.pth')['optim'])))
    scalerload = base64.b64encode(bz2.compress(pickle.dumps(torch.load(f'models/{restart_epoch}.pth')['scaler'])))
    data_cnt_emaload = torch.load(f'models/{restart_epoch}.pth')['data_cnt_ema']
    stepsload = torch.load(f'models/{restart_epoch}.pth')['steps']#
    batch_cntload = torch.load(f'models/{restart_epoch}.pth')['batch_cnt']
    data_cntload = torch.load(f'models/{restart_epoch}.pth')['data_cnt']#
    print('loaded model, opt, scaler, steps, data_cnt_ema...')
    with open('modelload.job', 'wb') as f:
      dump(modelload, f, compress=('lzma', 3))
    with open('optimload.job', 'wb') as f:
      dump(optimload, f, compress=('lzma', 3))
    with open('scalerload.job', 'wb') as f:
      dump(scalerload, f, compress=('lzma', 3))
    with open('data_cnt_emaload.job', 'wb') as f:
      dump(data_cnt_emaload, f, compress=('lzma', 3))
    with open('stepsload.job', 'wb') as f:
      dump(stepsload, f, compress=('lzma', 3))
    with open('batch_cntload.job', 'wb') as f:
      dump(batch_cntload, f, compress=('lzma', 3))
    with open('data_cntload.job', 'wb') as f:
      dump(data_cntload, f, compress=('lzma', 3))
    del modelload, optimload, scalerload, data_cnt_emaload, stepsload
    import gc
    gc.collect()
    print('saved model, opt, scaler, steps, data_cnt_ema...')


In [None]:
import warnings
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

train_main(args)