In [None]:
!pip install kaggle-environments -U > /dev/null 2>&1
# !cp -r ../input/lux-ai-2021/* .

In [None]:
from kaggle_environments import make

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import namedtuple, deque
from lux.game import Game
from lux.game_map import RESOURCE_TYPES, Position
from tqdm.notebook import tqdm
import json
from pathlib import Path
from enum import Enum
from sklearn.preprocessing import normalize
import math
import os
import sys
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.ticker import MultipleLocator
from mpl_toolkits.axes_grid1 import make_axes_locatable
import tqdm

import torch
import numpy as np
from lux.game_constants import GAME_CONSTANTS
from lux.game import Game
from lux import annotate,game
from lux.game_map import Position

In [None]:
!mkdir ./models
!mkdir ./models/qeval
!mkdir ./models/qnext
!mkdir ./models/supervised_model

In [None]:
class config:
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # No noticable speed improvement with GPU
    epochs = 20
    lr = 1e-5
    weight_decay = 1e-5
    epsilon = .1
    eps_decay = .999
    eps_min = 0.00
    gamma = 0.95
    num_layers = 19
    num_actions = 9
    mem_len = 480
    replace_cntr = 100
    base_size = 32
    phase = 'supervise'
    debug_print = False

In [None]:
class Callback():
    def __init__(self):
        self.loss = 0.
        
    def on_train_begin(self):
        print('Begin Training')
    
    def on_train_end(self):
        luxagent.save_model()
    
    def on_epoch_begin(self, epoch):
        print(f'EPOCH {epoch}/{config.epochs-1}')
        
    def on_epoch_end(self):
        if config.phase == 'reinforce':
            luxagent.qeval.optimizer.param_groups[0]['lr'] *= .5
        
    def on_game_begin(self):
        self.loss = 0.
        
    def on_game_end(self, idx):
        if config.phase == 'supervise':
            steps = json_steps[-1][0]['observation']['step']
            print(f"{' ':3s}{game_state.map.width}x{game_state.map.height}: Loss = {self.loss/(steps):.3e}  |  Game {idx}")
            luxagent.model.optimizer.param_groups[0]['lr'] *= 0.95
        elif config.phase == 'reinforce':
            steps = obs['step']
            print(f"{' ':3s}{game_state.map.width}x{game_state.map.height}: Loss = {self.loss/(steps):.3e}  |  Reward = {reward:.4f}  |  Game {idx}")
        elif config.phase == 'test':
            print(f"{' ':3s}{game_state.map.width}x{game_state.map.height}: Game {idx}")
        
        for team,player in enumerate(game_state.players):
            workers = 0
            carts = 0
            for unit in player.units:
                if unit.type==UNIT_TYPE['WORKER']:
                    workers+=1
                if unit.type==UNIT_TYPE['CART']:
                    carts+=1
            citytiles = player.city_tile_count
            research = player.research_points
            print(f"{' ':6s}Team:{team}  Citytiles: {citytiles:>2d}  Workers: {workers:>2d}  Carts: {carts:>2d}  RP: {research:>3d}")
#         luxagent.model.initialize_hidden()
        
    def on_loss_begin(self):
        pass
    
    def on_loss_end(self):
        pass
    
    def on_step_begin(self):
        pass
    
    def on_step_end(self, loss):
        self.loss += loss

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=1,kernel_size=(4,4),stride=1)
        self.conv2 = nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(4,4),stride=1)
        self.conv3 = nn.Conv2d(in_channels=2,out_channels=1,kernel_size=(4,4),stride=1)
        self.conv4 = nn.Conv2d(in_channels=9,out_channels=1,kernel_size=(4,4),stride=1)
        self.lstm = nn.LSTM(input_size=676, hidden_size=128, num_layers=2, batch_first=True) 
        self.relu = nn.ReLU()
        self.emb = nn.Embedding(4,4)
        self.encoders = nn.TransformerEncoderLayer(d_model=260,nhead=10)
        self.enc = nn.TransformerEncoder(self.encoders,num_layers=3)
        self.dec = nn.Linear(260,config.num_actions)
        
        self.optimizer = optim.Adam(self.parameters(), lr=config.lr, weight_decay=config.weight_decay)
#         self.initialize_hidden()
        
        self.to(config.device)
    
    def initialize_hidden(self):
        self.h = torch.zeros([2, 1, 128], requires_grad=True)
        self.c = torch.zeros([2, 1, 128], requires_grad=True)
    
    def forward(self, state1, state2):
        [r_map1, c_map1, u_map1, e_types, e_info], [r_map2, c_map2, u_map2] = state1, state2
        for idx,component in enumerate([r_map1, c_map1, u_map1, e_types, e_info, r_map2, c_map2, u_map2]):
            component.to(config.device)
        
        outr1 = self.relu(self.conv1(r_map1))
        outc1 = self.relu(self.conv2(c_map1)).permute(1,0,2,3)
        outu1 = self.relu(self.conv3(u_map1)).permute(1,0,2,3)
        map1 = torch.cat([outr1, outc1, outu1],dim=1)
        map1 = self.relu(self.conv4(map1))
        
        outr2 = self.relu(self.conv1(r_map2))
        outc2 = self.relu(self.conv2(c_map2)).permute(1,0,2,3)
        outu2 = self.relu(self.conv3(u_map2)).permute(1,0,2,3)
        map2 = torch.cat([outr2, outc2, outu2],dim=1)
        map2 = self.relu(self.conv4(map2))
        
        map1 = map1.view(1,1,-1)
        map2 = map2.view(1,1,-1)
        map_ = torch.cat([map1, map2],dim=1)
        
        map_, _ = self.lstm(map_)
        map_ = map_.view(1,-1).expand(1,len(e_types),-1)
        
        e_embs = self.emb(e_types)
        embs = torch.add(e_embs, e_info).unsqueeze(0)
        
        out = torch.cat([embs, map_],dim=2)
        out = self.relu(self.enc(out))
        out = self.dec(out)
        
        invalid_actions = torch.zeros(out.shape)
        
        invalid_actions[0,e_types==UNIT_TYPE['WORKER'],:2] = -1e30
        invalid_actions[0,e_types==UNIT_TYPE['CART'],:2] = -1e30
        invalid_actions[0,e_types==2,2:] = -1e30
        
        out = torch.add(out,invalid_actions)
        if config.debug_print:
            print(f'preds\n{out}')
        out = torch.softmax(out,dim=2).squeeze(0)
        
        return out

In [None]:
Transitions = namedtuple('transitions',
                        ('states','actions','rewards','states_','dones'))

class Replay():
    def __init__(self):
        self.memory = deque(maxlen=config.mem_len)
        
    def store_transition(self, states, actions, rewards, states_, dones):
        self.memory.append(Transitions(states, actions, rewards, states_, dones))
        
    def sample_memory(self):
        s = random.sample(self.memory, 1)
        
        states = s[0].states
        actions = s[0].actions
        rewards = s[0].rewards
        states_ = s[0].states_
        dones = s[0].dones
        
        return states, actions, rewards, states_, dones
    
    def __len__(self):
        return len(self.memory)

In [None]:
class LuxAgent():
    def __init__(self):
        self.initialize_model()
        self.memory = Replay()
        self.epsilon = config.epsilon
        self.gamma = config.gamma
        self.step_counter = 0
    
    def initialize_model(self):
        if config.phase == 'supervise':
#             self.model = torch.load(f'../input/lux-models/supervised_model.pth')
#             self.model = torch.load(f'./models/supervised_model/supervised_model.pth')
            self.model = Model()
        elif config.phase == 'reinforce':
            self.qeval = torch.load(f'./models/supervised_model/supervised_model.pth')
            self.qnext = torch.load(f'./models/supervised_model/supervised_model.pth')
        elif config.phase == 'test':
            self.model = torch.load(f'./models/supervised_model/supervised_model.pth')
    
    def load_model(self):
        if config.phase == 'supervise':
            self.model = torch.load(f'./models/supervised_model/supervised_model.pth')
        elif config.phase == 'reinforce':
            self.qeval = torch.load(f'./models/qeval/lux_model_qeval.pth')
            self.qnext = torch.load(f'./models/qnext/lux_model_qnext.pth')
                    
    def save_model(self):
        if config.phase == 'supervise':
            torch.save(self.model, f'./models/supervised_model/supervised_model.pth')
        elif config.phase == 'reinforce':
            torch.save(self.qeval, f'./models/qeval/lux_model_qeval.pth')
            torch.save(self.qnext, f'./models/qnext/lux_model_qnext.pth')
        
    def store_memory(self, states, actions, rewards, states_, dones):
        self.memory.store_transition(states, actions, rewards, states_, dones)
    
    def get_from_memory(self):
        return self.memory.sample_memory()
    
    def replace_target_network(self):
        if self.step_counter % config.replace_cntr == 0:
            self.qnext.load_state_dict(self.qeval.state_dict())

    def decrement_epsilon(self):
        self.epsilon = self.epsilon * config.eps_decay if self.epsilon > config.eps_min else config.eps_min
    
    def rl_choose_actions(self, state1, state2):
        actions = torch.argmax(self.qeval.forward(state1, state2),dim=1)
        for idx,_ in enumerate(actions):
            if np.random.random_sample() <= self.epsilon:
                actions[idx] = np.random.choice(range(config.num_actions))
        return actions
        
    def rl_predict(self, state):
        state1 = state[0]
        state2 = state[1]
        return self.qeval.forward(state1, state2)
    
    def rl_predict_(self, state):
        state1 = state[0]
        state2 = state[1]
        return self.qnext.forward(state1, state2)
        
    def reinforcement_learn(self):
        if len(self.memory) < 1:
            return
        
        self.replace_target_network()
        
        s, a, r, s_, d = self.get_from_memory()
        
        a = a.long().unsqueeze(-1)
        q = self.rl_predict(s).gather(1,a)
        
        a_ = torch.argmax(self.rl_predict(s_).long(),1).unsqueeze(-1)
        q_ = self.rl_predict_(s_).gather(1,a_)
        q_ = r + config.gamma * q_ * (1-d)
        
        callback.on_loss_begin()
        
        criterion = nn.MSELoss()
        loss = criterion(torch.sum(q), torch.sum(q_))
        
        callback.on_loss_end()
        
        self.qeval.optimizer.zero_grad()
        loss.backward()
        
        callback.on_step_begin()
        
        self.qeval.optimizer.step()
        
        callback.on_step_end(loss.item())
        
        self.step_counter += 1
        self.decrement_epsilon()
        
    def get_entities(self):
        player = game_state.players[0]
        entities = []
        for city in player.cities.values():
            for tile in city.citytiles:
                entities.append(f'{tile.pos.x} {tile.pos.y}')
        for unit in player.units:
            entities.append(unit.id)
        return entities
    
    def convert_actions(self, actions):
        acts = []
        a = {'bw':0, 'r':1, 'm n':2, 'm s':3, 'm e':4, 'm w':5, 'm c':6, 't':7, 'bcity':8}
        for action in actions:
            acts.append(a[action])
        return acts
    
    def reorder_targ_actions(self, targ_action_list):
        entities = self.get_entities()
#         actions = [''] * len(entities)
        num_targs = []
        for targ_acts in targ_action_list:
            targ_act = targ_acts.split()
            if targ_act[0]=='m':
                num_targs.append(f'{targ_act[0]} {targ_act[2]}')
            else:
                num_targs.append(targ_act[0])
        targs = self.convert_actions(num_targs)
#         for idx,entity in enumerate(entities):
#             for targ in targs:
#                 if targ<=1:
#                     actions[idx] = targ
#                 else:
#                     actions[idx] = targ
        return targs # actions, directions
    
    def convert_targ_actions(self, preds, targ_action_list):
        targs = torch.zeros((preds.shape))
        actions = self.reorder_targ_actions(targ_action_list)
        for idx,action in enumerate(actions):
                targs[idx,action] = 1.
        return targs
   
    def sl_choose_actions(self, state1, state2):
        actions = self.model.forward(state1, state2)
        if config.debug_print:
            print(actions)
        return torch.argmax(actions,dim=1)
            
    def supervise_learn(self, state, last_state, targ_actions):
        pred = self.model.forward(state, last_state)
        targ = self.convert_targ_actions(pred, targ_actions)
        
        if config.debug_print:
            print(f'preds\n{pred}\ntargs\n{targ}\n')
        callback.on_loss_begin()
        criterion = nn.MSELoss()
        loss = criterion(pred, targ)
        callback.on_loss_end()
        
        self.model.optimizer.zero_grad()
        loss.backward()
        
        callback.on_step_begin()
        
        self.model.optimizer.step()
        
        callback.on_step_end(loss.item())
        
        self.decrement_epsilon()

# David White's Rule-Based Agent

In [None]:
class WorkerObjective(Enum):
    GatherFuel = "gather fuel"
    BuildCity = "build city"
    Rest = "rest"

# Wrapper class for worker unit objects
class WorkerAgent:
    def __init__(self, worker_obj, debug):
        self.debug = debug
        self.worker = worker_obj
        self.objective = WorkerObjective.BuildCity
        self.objective_changed = False
        self.mine = None
        self.destination = None
        
    # Reset worker object with latest version each turn
    def update(self, worker_obj):
        self.worker = worker_obj
    
    def _find_open_mine(self, resource_type, mines):
        mine = mines.place_in_mine(self.worker, resource_type)
        if mine is not None:
            self.mine = mine
        
    def _get_best_fuel(self, player):
        resource_type = RESOURCE_TYPES.WOOD
        if player.researched_coal():
            resource_type = RESOURCE_TYPES.COAL
        if player.researched_uranium():
            resource_type = RESOURCE_TYPES.URANIUM
            
        return resource_type
    
    def _at_mining_spot(self):
        mining_spot = self.get_mining_spot()
        if mining_spot is None:
            return False
        
        return self.worker.pos == mining_spot
    
    def _update_mining(self, controller):
        if self._at_mining_spot():
            cell = controller.map.get_cell_by_pos(self.worker.pos)
            if not cell.has_resource():
                self.mine.report_resource_depleted(self.worker.pos, self.worker)
                self.mine = None
         
        if self.objective == WorkerObjective.GatherFuel:
            best_fuel = self._get_best_fuel(controller.player)
            if self.mine is not None and self.mine.resource_type != best_fuel:
                self.mine.release_worker(self.worker)
                self._find_open_mine(best_fuel, controller.mines)
                
    def _handle_objective_change(self):
        if self.objective_changed:
            if self.mine != None:
                self.mine.release_worker(self.worker)
                self.mine = None
            self.destination = None
            self.objective_changed = False
            if self.debug:
                print("Worker", self.worker.id, "assigned new objective", file=sys.stderr)
                
    def _handle_mine_assignment(self, controller):
        if self.destination is None and self.mine is None and self.objective != WorkerObjective.Rest:
            resource_type = RESOURCE_TYPES.WOOD
            if self.objective == WorkerObjective.GatherFuel:
                resource_type = self._get_best_fuel(controller.player)
            self._find_open_mine(resource_type, controller.mines)
            if self.debug:
                if self.mine is not None:
                    print("Worker", self.worker.id, "assigned to mining spot", self.get_mining_spot(), file=sys.stderr)
                else:
                    print("Unable to place worker", self.worker.id, "in mine", file=sys.stderr)
                
    def _handle_destination_arrival(self):
        if self.destination is not None and self.worker.pos == self.destination:
            self.destination = None  
            if self.debug:
                print("Worker", self.worker.id, "arrived at their destination", file=sys.stderr)
        
    def _handle_destination_assignment(self, controller):
        if self.destination is not None:
            return
        
        if self.objective == WorkerObjective.Rest and not self.on_city_tile(controller.map):
            closest_city_tile = controller.cities.get_nearest_city_tile(self.worker.pos)
            if closest_city_tile is not None:
                self.destination = closest_city_tile.pos
            if self.debug:
                print("Worker", self.worker.id, "destination set to city tile", (self.destination.x, self.destination.y), file=sys.stderr)
            return
            
        if self.worker.get_cargo_space_left() == 0: 
            if self.mine is not None:
                self.mine.release_worker(self.worker)
                self.mine = None

            if self.debug:
                print("Worker", self.worker.id, "is at max cargo", file=sys.stderr)
            if self.objective == WorkerObjective.GatherFuel:
                nearest_city_tile = controller.cities.get_nearest_city_tile(self.worker.pos)
                if nearest_city_tile is not None:
                    self.destination = nearest_city_tile.pos
            elif self.objective == WorkerObjective.BuildCity:
                # nearest_periph = controller.cities.get_nearest_periph_pos(self.worker.pos, controller.map)
                # if nearest_periph is not None:
                #     self.destination = nearest_periph
                # else:
                self.destination = self.find_nearest_empty_tile(self.worker.pos, controller.map)
            if self.debug:
                print("Worker", self.worker.id, "destination changed to", self.destination, file=sys.stderr)
            return
                   
        if not self._at_mining_spot():
            self.destination = self.get_mining_spot()
            if self.debug:
                print("Worker", self.worker.id, "returning to minining spot", self.get_mining_spot(), file=sys.stderr)
            return

    # Converts directions to degrees
    def _to_degrees(self, direction):
        directions = ["w", "s", "e", "n"]
        return 90 * directions.index(direction)

    # Converts degrees to directions
    def _to_dir(self, degrees):
        directions = ["w", "s", "e", "n"]
        return directions[int((degrees % 360) / 90)]
        

    # Returns the direction 90 degrees * times clockwise of direction
    def _rotate_dir(self, direction, times):
        if direction == "c":
            return "c"
        return self._to_dir(self._to_degrees(direction) + 90 * times)
                
        
    def get_mining_spot(self):
        if self.mine == None:
            return None
        
        tile = self.mine.get_assigned_spot(self.worker)
        if tile is not None:
            return Position(tile[0], tile[1])
        return None
        
    def set_objective(self, objective):
        if self.objective == objective:
            return
        self.objective = objective
        self.objective_changed = True
        if self.debug:
            print("Worker", self.worker.id, "has new objective", self.objective, file=sys.stderr)
        
    def get_step_direction(self, game_map, steps, avoid_city=False):
        direction = self.worker.pos.direction_to(self.destination)
        step = step = self.worker.pos.translate(direction, 1)

        if direction == "c" and (step.x, step.y) in steps:              # If the worker plans to stay put but is blocking the step of another worker
            for i in range(4):
                new_dir = self._rotate_dir("w", i)
                step = self.worker.pos.translate(new_dir, 1)
                if (step.x, step.y) not in steps:
                    return new_dir
        
        if avoid_city or (step.x, step.y) in steps:                     # Get best detour
            cell = game_map.get_cell(step.x, step.y)
            if cell.citytile is None and (step.x, step.y) not in steps:
                return direction
            
            shortest_dist = float("inf")
            best_dir = None
            for i in range(4):
                new_dir = self._rotate_dir(direction, i)
                step = self.worker.pos.translate(new_dir, 1)
                if step.x < 0 or step.x >= game_map.width or step.y < 0 or step.y >= game_map.height or (step.x, step.y) in steps:
                    continue
                cell = game_map.get_cell(step.x, step.y)
                dist = step.distance_to(self.destination) 
                if cell.citytile is None and dist < shortest_dist:
                    best_dir = new_dir
                    
            if best_dir is not None:
                return best_dir
              
        return direction
            
    
    def on_city_tile(self, game_map):
        tile = game_map.get_cell_by_pos(self.worker.pos)
        return tile.citytile is None
    
    def find_nearest_empty_tile(self, loc, game_map):
        if self.tile_is_empty(loc, game_map):
            return loc
        
        searched = set()
        q = [loc]
        
        while len(q) > 0:
            p = q.pop(0)
            searched.add((p.x, p.y))
            
            if self.tile_is_empty(p, game_map):
                return p
            
            for direction in [(1, 0), (0, 1), (-1, 0), (0, -1)]:
                neighbor = Position(p.x + direction[0], p.y + direction[1])
                if neighbor.x >= 0 and neighbor.x < game_map.width and neighbor.y >= 0 and neighbor.y < game_map.height and (neighbor.x, neighbor.y) not in searched:
                    q.append(neighbor)
            
            
    def tile_is_empty(self, pos, game_map):
        cell = game_map.get_cell(pos.x, pos.y)
        return cell.citytile is None and not cell.has_resource()
    
    def get_action(self, controller, steps):
        self._update_mining(controller)
        
        if not self.worker.can_act():
            return None, (self.worker.pos.x, self.worker.pos.y)
        
        self._handle_objective_change()
        self._handle_mine_assignment(controller)
        self._handle_destination_arrival()
        
        if self.destination is None and self.objective == WorkerObjective.BuildCity and self.worker.can_build(controller.map):
            if self.debug:
                print("Worker", self.worker.id, "building city tile at", self.worker.pos)
            return self.worker.build_city(), (self.worker.pos.x, self.worker.pos.y)
        
        self._handle_destination_assignment(controller)
                
        if self.destination is not None:
            avoid_city = self.objective == WorkerObjective.BuildCity and self.worker.get_cargo_space_left() == 0
            step_dir = self.get_step_direction(controller.map, steps, avoid_city)
            if self.debug:
                print("Worker", self.worker.id, "step direction:", step_dir, file=sys.stderr)
            # step_dir = self.worker.pos.direction_to(self.destination)
            step = self.worker.pos.translate(step_dir, 1)
            return self.worker.move(step_dir), (step.x, step.y)
        
        return None, (self.worker.pos.x, self.worker.pos.y)
    
class Workers:
    def __init__(self, worker_list, debug):
        self.debug = debug
        self.workers = {}                              # Maps worker ids to WorkerAgent objs
        self.task_proportions = [0.5, 0.5, 0.0]
        
        for worker in worker_list:
            self.workers[worker.id] = WorkerAgent(worker, self.debug)
            
        if self.debug:
            print("Workers object initialized")

        self._reassign_objectives()
            
    def _reassign_objectives(self):
        num_city_builders = math.ceil(self.task_proportions[0] * len(self.workers))
        num_fuel_gatherers = math.ceil(self.task_proportions[1] * len(self.workers))
        worker_ids = self.workers.keys()
        
        for i, worker_id in enumerate(worker_ids):
            if i < num_city_builders:
                self.workers[worker_id].set_objective(WorkerObjective.BuildCity)
                continue
            if i < num_city_builders + num_fuel_gatherers:
                self.workers[worker_id].set_objective(WorkerObjective.GatherFuel)
                continue
            self.workers[worker_id].set_objective(WorkerObjective.Rest)
        
    def update(self, worker_list):
        if self.debug:
            print("Updating workers object")
           
        # Remove workers that were lost last turn
        lost_workers = set(self.workers.keys()).difference(set([worker.id for worker in worker_list]))
        for lost_worker in lost_workers:
            self.workers.pop(lost_worker)
            
        for worker in worker_list:
            if worker.id in self.workers:
                self.workers[worker.id].update(worker)
                continue
                
            self.workers[worker.id] = WorkerAgent(worker, self.debug)
            self._reassign_objectives()
            if self.debug:
                print("Worker added", file=sys.stderr)
                
    def update_task_proportions(self, proportions):
        self.task_proportions = proportions
        self._reassign_objectives()
            
            
    def get_actions(self, controller):
        actions = []
        steps = set()
        
        for worker in self.workers.values():
            action, step = worker.get_action(controller, steps)
            steps.add(step)
            if action is not None:
                actions.append(action)
                
        return actions


class CityWrapper:
    def __init__(self, city_obj, debug):
        self.city = city_obj
        self.debug = debug
    
    def get_nearest_periph_pos(self, loc, game_map):
        if self.debug:
            print("Searching for city build location", file=sys.stderr)
        # Return periphery tile obj closest to loc (Only works if loc is not inside city)
        
        # Sort tiles in city according to distance from loc
        sorted_tiles = sorted(self.city.citytiles, key=lambda tile: tile.pos.distance_to(loc))

        for tile in sorted_tiles:
            for direction in [(1, 0), (0, 1), (-1, 0), (0, -1)]:
                neighbor = Position(tile.pos.x + direction[0], tile.pos.y + direction[1])
                if neighbor.x >= 0 and neighbor.x < game_map.width and neighbor.y >= 0 and neighbor.y < game_map.height:
                    cell = game_map.get_cell(neighbor.x, neighbor.y)
                    if cell.citytile == None and not cell.has_resource():
                        return neighbor
                    
        return None
    
    def get_nearest_city_tile(self, loc):
        # Return city tile closest to loc
        shortest_dist = float("inf")
        closest = None
        for tile in self.city.citytiles:
            dist = tile.pos.distance_to(loc)
            if dist < shortest_dist:
                shortest_dist = dist
                closest = tile
        return closest
    
    def get_actions(self, controller, workers_needed):
        actions = []
        workers_built = 0
        for tile in self.city.citytiles:
            if tile.can_act():
                if workers_needed - workers_built > 0:
                    actions.append(tile.build_worker())
                    workers_built += 1
                    if self.debug:
                        print("City tile", tile.pos, "creating new worker.", file=sys.stderr)
                    continue
                actions.append(tile.research())
        return actions, workers_built
    
class CitiesWrapper:
    def __init__(self, cities_list, debug):
        self.cities = [CityWrapper(city, debug) for city in cities_list]
        self.debug = debug
        
    def update(self, cities_list):
        self.cities = [CityWrapper(city, self.debug) for city in cities_list]
    
    def get_nearest_city(self, loc):
        # Return CityWrapper obj closest to loc
        shortest_dist = float("inf")
        closest = None
        
        for city in self.cities:
            dist = city.get_nearest_city_tile(loc).pos.distance_to(loc)
            if dist < shortest_dist:
                shortest_dist = dist
                closest = city
                
        return closest
    
    def get_nearest_city_tile(self, loc):
        # Return CityTile obj closest to loc
        shortest_dist = float("inf")
        closest = None
        
        for city in self.cities:
            tile = city.get_nearest_city_tile(loc)
            dist = tile.pos.distance_to(loc)
            if dist < shortest_dist:
                shortest_dist = dist
                closest = tile
                
        return closest
    
    def get_nearest_periph_pos(self, loc, game_map):
        sorted_cities = sorted(self.cities, key=lambda city: city.get_nearest_city_tile(loc).pos.distance_to(loc))
        
        for city in sorted_cities:
            periph = city.get_nearest_periph_pos(loc, game_map)
            if periph is not None:
                return periph
            
        return None
    
    def get_actions(self, controller):
        actions = []
        workers_needed = max(controller.state.num_city_tiles - controller.state.num_workers, 0)
        
        for city in self.cities:
            city_actions, workers_built = city.get_actions(controller, workers_needed)
            actions += city_actions
            workers_needed -= workers_built
            
        return actions


class Mine:
    def __init__(self, game_state, resource_tile_set, resource_type, debug):
        self.resource_type = resource_type
        self.resource_tiles = resource_tile_set
        self.assigned_workers = {}                                      # Maps worker IDs to assigned worker_tile
        #self.available_resources = 0
        #self.cart_loc = self.get_cart_loc()
        #self.available_work_tiles = len(self.worker_tiles)              # Number of available worker tiles
        self.debug = debug
    
    def _find_cart_loc(self):
        # Find and return the best location to park the cart
        pass
    
    def _get_open_worker_tile(self, worker_pos):
        available_tiles = list(filter(lambda tile: tile not in self.assigned_workers.values(), self.resource_tiles))
        available_tiles = sorted(available_tiles, key=lambda tile: Position(tile[0], tile[1]).distance_to(worker_pos))
        return available_tiles[0]
    
    def get_resource_tiles(self):
        return self.resource_tiles
    
    def worker_assigned(self, worker_id):                               # Checks if a given worker is assigned to mine
        return worker_id in self.assigned_workers
    
    def get_dist(self, loc):                                            # Returns the shortest distance between loc and all spots in mine
        shortest_dist = float("inf")
        
        for tile in self.resource_tiles:
            tile_pos = Position(tile[0], tile[1])
            dist = tile_pos.distance_to(loc)
            if dist < shortest_dist:
                shortest_dist = dist
                
        return shortest_dist
    
    def has_opening(self):                                              # Checks if there are any available spots in mine
        return len(self.resource_tiles) > len(self.assigned_workers)
    
    def assign_worker(self, worker):
        self.assigned_workers[worker.id] = self._get_open_worker_tile(worker.pos)
        
    def release_worker(self, worker):
        if worker.id in self.assigned_workers:
            self.assigned_workers.pop(worker.id)
        
    def get_assigned_spot(self, worker):
        if worker.id in self.assigned_workers:
            return self.assigned_workers[worker.id]
        return None
    
    def report_resource_depleted(self, pos, assigned_worker):
        self.resource_tiles.remove((pos.x, pos.y))
        self.release_worker(assigned_worker)

    
class Mines:
    def __init__(self, game_state, debug):
        self.mines = []
        self.debug = debug
        
        self._build_mines(game_state)
        
    def _is_valid_tile(self, game_state, x, y, w, h, resource_type, searched):
        if x < 0 or x >= w or y < 0 or y >= h or (x, y) in searched:
            return False
        
        tile = game_state.map.get_cell(x, y)
        if not tile.has_resource() or tile.resource.type != resource_type:
            return False
        
        return True
        
    def _get_resource_cluster(self, game_state, x, y, w, h, resource_type, cluster_tiles=set(), searched=set()):
        # Given x, y of a starting tile, search game map to find tiles of resource cluster
        searched.add((x, y))
        tile = game_state.map.get_cell(x, y)
        
        if not tile.has_resource():                             # Add tile to border set and make no recursive calls
            return cluster_tiles, searched
        
        cluster_tiles.add((x, y))
        
        for direction in [(1, 0), (0, 1), (-1, 0), (0, -1)]:  # Call function recursively on surrounding tiles
            new_x, new_y = x + direction[0], y + direction[1]
            if self._is_valid_tile(game_state, new_x, new_y, w, h, resource_type, searched):
                new_cluster_tiles, new_searched = self._get_resource_cluster(game_state, new_x, new_y, w, h, resource_type, cluster_tiles, searched)
                cluster_tiles = cluster_tiles.union(new_cluster_tiles)
                searched = searched.union(new_searched)
            
        return cluster_tiles, searched
        
    def _build_mines(self, game_state): 
        # Iterate over map to find clusters of resource tiles
        w, h = game_state.map.width, game_state.map.height
        searched = set()
        clusters = []
        resource_types = []
        
        for x in range(w):
            for y in range(h):
                if (x, y) in searched:
                    continue
                tile = game_state.map.get_cell(x, y)
                if tile.has_resource():
                    resource_types.append(tile.resource.type)
                    cluster, new_searched = self._get_resource_cluster(game_state, x, y, w, h, tile.resource.type, set(), set())
                    searched = searched.union(new_searched)
                    clusters.append(cluster)
        
        # ToDo: Merge mines of same resource type that share borders
        
        # Build Mine objs from clusters and borders
        for cluster, resource_type in zip(clusters, resource_types):
            self.mines.append(Mine(game_state, cluster, resource_type, self.debug))
        
        if self.debug:
            print("Clusters:", clusters, file=sys.stderr)
                    
    def update(self, game_state):
        # Check mines to see if they need updated
        for mine in self.mines:
            needs_update = mine.update_mine(game_state)
            
            if needs_update:
                # Find viable cell from mine to seed cluster search
                mine_tiles = mine.get_resource_tiles()
                new_cluster = None
                new_border = None
                
                for tile in mine_tiles:
                    cell = game_state.map.get_cell(tile[0], tile[1])
                
                    if cell.has_resource():
                        new_cluster, new_border, searched = self._get_resource_cluster(game_state, tile[0], tile[1], gamestate.map.width, gamestate.map.height, cell.resource.type, set(), set())
                        break
                
                self.mines.remove(mine)
                
                if new_cluster is not None:
                    self.mines.append(Mine(game_state, new_cluster, new_border, self.debug))
    
    def get_closest_mine(self, loc, resource_type):
        closest_mine = None
        shortest_dist = float("inf")
        
        for mine in self.mines:
            if mine.resource_type != resource_type:
                continue
            for tile in mine.resource_tiles:
                tile_pos = Position(tile[0], tile[1])
                dist = tile_pos.distance_to(loc)
                
                if dist < shortest_dist:
                    shortest_dist = dist
                    closest_mine = mine
                    
        return closest_mine
    
    def place_in_mine(self, worker, resource_type):
        sorted_mines = sorted([mine for mine in self.mines if mine.resource_type == resource_type], key=lambda mine: mine.get_dist(worker.pos))
        
        for mine in sorted_mines:
            if mine.has_opening():
                mine.assign_worker(worker)
                return mine 
            
        return None


class State:
    def __init__(self, game_state, player, opponent):
        self._update_state(game_state, player, opponent)
        
    def _update_state(self, game_state, player, opponent):
        self.num_workers = sum([1 if unit.is_worker() else 0 for unit in player.units])
        self.num_carts = sum([1 if unit.is_cart() else 0 for unit in player.units])
        self.num_city_tiles = player.city_tile_count
        self.num_opponent_workers = sum([1 if unit.is_worker() else 0 for unit in opponent.units])
        self.num_opponent_carts = sum([1 if unit.is_cart() else 0 for unit in opponent.units])
        self.num_opponent_city_tiles = opponent.city_tile_count
        self.research_points = player.research_points
        self.opponent_research_points = opponent.research_points
        self.turn = game_state.turn
        
    def get_state_vector(self):
        return np.array([
            self.num_workers,
            self.num_carts,
            self.num_city_tiles,
            self.num_opponent_workers,
            self.num_opponent_carts,
            self.num_opponent_city_tiles,
            self.research_points,
            self.opponent_research_points,
            self.turn
        ]).reshape(1, -1)
        

class Controller:
    def __init__(self, game_state, player, opponent, debug):
        self.debug = debug
        self.game_state = game_state
        self.map = game_state.map
        self.state = State(game_state, player, opponent)
        self.player = player
        self.opponent = opponent
        self.mines = Mines(game_state, debug)
        self.workers = Workers([unit for unit in player.units if unit.is_worker()], debug)
#         self.carts = []
        self.cities = CitiesWrapper(self.player.cities.values(), debug)
        
    def update(self, game_state, player, opponent):
        self.game_state = game_state
        self.state._update_state(game_state, player, opponent)
        self.map = game_state.map
        self.player = player
        self.opponent = opponent
        #self.mines.update(game_state)
        self.cities.update(self.player.cities.values())
        self.workers.update([unit for unit in player.units if unit.is_worker()])
        
    def get_state_vector(self):
        return self.state.get_state_vector()
    
    def apply_agent_action(self, action):
        self.workers.update_task_proportions(action)    
        
    def get_actions(self):
        if self.debug:
            print("Turn", self.game_state.turn, file=sys.stderr)
        worker_actions = self.workers.get_actions(self)
        city_actions = self.cities.get_actions(self)
        return worker_actions + city_actions


def calculate_reward(s, s_prime, reward_weights):
        reward_vec = (s_prime[0] - s[0]) * np.array(reward_weights)
        return np.sum(reward_vec)

game_state = None
controller = None

def base_agent(observation, configuration):
    global game_state
    global controller

    ### Do not edit ###
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation.player
        
        player = game_state.players[observation.player]
        opponent = game_state.players[(observation.player + 1) % 2]
        controller = Controller(game_state, player, opponent, False)
        
    else:
        game_state._update(observation["updates"])
        player = game_state.players[observation.player]
        opponent = game_state.players[(observation.player + 1) % 2]
        controller.update(game_state, player, opponent)
    
    return controller.get_actions()

In [None]:
RES = GAME_CONSTANTS['RESOURCE_TYPES']
RP_REQ = GAME_CONSTANTS['PARAMETERS']['RESEARCH_REQUIREMENTS']
DIR = GAME_CONSTANTS['DIRECTIONS']
UNIT_TYPE = GAME_CONSTANTS['UNIT_TYPES']
FUEL_RATE = GAME_CONSTANTS['PARAMETERS']['RESOURCE_TO_FUEL_RATE']

def normal(state):
    for idx,channel in enumerate(state):
        if idx in [1,9,14]: continue
        state[idx] = torch.tensor(normalize(channel))
    return state

def get_input():
    w,h = game_state.map_width, game_state.map_height
    s = config.base_size
    pad = s - w
    rp = game_state.players[0].research_points
    
    res_token = {'wood': 1, 'coal': 2, 'uranium': 3}
    
    coal_rp_rem = 0 if RP_REQ['COAL'] - rp < 0 else RP_REQ['COAL'] - rp
    uranium_rp_rem = 0 if RP_REQ['URANIUM'] - rp < 0 else RP_REQ['URANIUM'] - rp
    
    rp_req = {'wood': 0, 'coal': coal_rp_rem, 'uranium': uranium_rp_rem}
    
    # Energy of Resources
    resources = torch.zeros([s,s,3])
    for row in range(h):
        for col in range(w):
            resource = game_state.map.map[row][col].resource
            if resource==None:
                continue
            else:
                resources[row+pad,col+pad] = torch.tensor([resource.amount,res_token[resource.type],rp_req[resource.type]])

    cities = torch.zeros([2,s,s,3])
    units = torch.zeros([2,s,s,5])
    entity_type = []
    entity_info = []
    extra = []
    for idx,player in enumerate(game_state.players):
        for city in player.cities.values():
            for tile in city.citytiles:
                cities[idx,tile.pos.y+pad,tile.pos.x+pad] = torch.tensor([tile.cooldown,city.fuel,city.light_upkeep])
                if idx==0:
                    entity_info.append([tile.cooldown,city.fuel,tile.pos.x,tile.pos.y])
                    if tile.cooldown<=0.:
                        entity_type.append(2)
                    else:
                        extra.append(3)

        for unit in player.units:
            units[idx,unit.pos.y+pad,unit.pos.x+pad] = torch.tensor([unit.type+1,unit.cooldown,unit.cargo.wood,unit.cargo.coal,unit.cargo.uranium])
            fuel = sum([unit.cargo.wood*FUEL_RATE['WOOD'], unit.cargo.coal*FUEL_RATE['COAL'], unit.cargo.uranium*FUEL_RATE['URANIUM']])
            if idx==0:
                entity_info.append([unit.cooldown,fuel,unit.pos.x,unit.pos.y])
                if unit.cooldown<=0.:
                    entity_type.append(unit.type)
                else:
                    extra.append(3)
    
    resources = resources.permute(2,0,1).unsqueeze(0)
    cities = cities.permute(3,0,1,2)
    units = units.permute(3,0,1,2)
    entity_type = torch.tensor(entity_type + extra)
    entity_info = torch.tensor(entity_info)
    return resources, cities, units, entity_type, entity_info

def get_nearest_cart(player, unit):
    nearest_unit = None
    if len(player.units) > 1:
        for potential_unit in player.units:
            if potential_unit.id != unit.id:
                dist = potential_cart.pos.distance_to(unit.pos)
                if dist == 1:
                    nearest_unit = potential_unit
    return nearest_unit

def get_res(unit):
    res = RES['WOOD']
    max_amount = unit.cargo.wood
    if unit.cargo.coal > max_amount:
        res = RES['COAL']
        max_amount = unit.cargo.coal
    if unit.cargo.uranium > max_amount:
        res = RES['URANIUM']
        max_amount = unit.cargo.uranium
    return res, max_amount

def action_strings(player, acts):
    actions = []
    done = False

    if game_state.turn==359 or len(player.cities)==0 and len(player.units)==0:
        done = True
    
    actors = []
    for city in player.cities.values():
        for citytile in city.citytiles:
            actors.append(citytile)
    for unit in player.units:
        actors.append(unit)

    for act,actor in zip(acts,actors):
        if actor.can_act():
            try:
                if act==0:
                    action = actor.build_worker()
                    actions.append(action)
                if act==1:
                    action = actor.research()
                    actions.append(action)
                if act==2:
                    action = actor.move(DIR['NORTH'])
                    actions.append(action)
                if act==3:
                    action = actor.move(DIR['SOUTH'])
                    actions.append(action)
                if act==4:
                    action = actor.move(DIR['EAST'])
                    actions.append(action)
                if act==5:
                    action = actor.move(DIR['WEST'])
                    actions.append(action)
                if act==6:
                    action = actor.move(DIR['CENTER'])
                    actions.append(action)
                if act==7 and actor.type==UNIT_TYPE['WORKER']:
                    res, amount = get_res(actor)
                    unit = get_nearest_unit(player, actor)
                    if unit:
                        action = unit.transfer(unit.id, res, amount)
                        actions.append(action)
                if act==8 and actor.type==UNIT_TYPE['WORKER']:
                    action = actor.build_city()
                    actions.append(action)
#                 if act==9:
#                     action = actor.build_cart()
#                     actions.append(action)
#                 if act==10 and actor.type==UNIT_TYPE['WORKER']:
#                     action = actor.pillage()
#                     actions.append(action)
            except Exception:
                continue
    return actions, done

# Supervised Learning

In [None]:
def supervising_agent(observation, configuration):
    global game_state, last_state, reward
    
    ### Do not edit ###
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation.player
    else:
        game_state._update(observation["updates"])
        
    actions = []
    
    ### AI Code goes down here! ### 
    player = game_state.players[observation.player]
    opponent = game_state.players[(observation.player + 1) % 2]
    width, height = game_state.map.width, game_state.map.height
    
    true_actions = []
    state = get_input()
    if observation['step']==0:
        last_state = state[:3]
    if len(state[3])>0:
        pred = luxagent.sl_choose_actions(state, last_state)
        true_actions = json_steps[observation["step"]+1][0]['action']
        acts, done = action_strings(player, pred)
        luxagent.supervise_learn(state, last_state, true_actions)
        if config.debug_print:
            print(f'{acts}   {true_actions}')
    last_state = state[:3]
    return true_actions

In [None]:
def action_agent(observation, configuration):
    global game_state
    
        ### Do not edit ###
    if observation["step"] == 0:
        index = 0
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation.player
    else:
        game_state._update(observation["updates"])
        
    actions = []
    
    ### AI Code goes down here! ### 
    player = game_state.players[observation.player]
    opponent = game_state.players[(observation.player + 1) % 2]
    width, height = game_state.map.width, game_state.map.height
    
    try:
        actions = json_steps[observation["step"]+1][observation.player]["action"]
    except Exception:
        pass
    return actions

In [None]:
def supervise_train():
    global json_steps
    episode_dir = '../input/lux-ai-toad-brigade-episodes'
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    
    callback.on_train_begin()
    list_of_episodes = episodes[:50]
    for idx, filepath in enumerate(list_of_episodes):
        with open(filepath) as f:
#             if len(list_of_episodes)-1 == idx:
#                 config.debug_print=True
            json_load = json.load(f)
            seed = json_load['configuration']['seed']
            json_steps = json_load['steps']
            callback.on_game_begin()
            env = make("lux_ai_2021", configuration={"loglevel": 0, "annotations": True, "seed": seed}, debug=True)
            steps = env.run([supervising_agent, action_agent])
            callback.on_game_end(idx)
#         env.render(mode="ipython", width=1200, height=800)
#         raise SystemExit("Stop right there!")
    callback.on_train_end()

config.debug_print = False
config.phase = 'supervise'
callback = Callback()
luxagent = LuxAgent()
supervise_train()

# Reinforcement Learning

In [None]:
config.lr = 1e-5
config.debug_print = False

In [None]:
class Rewards:
    def __init__(self):
        self.citytiles = 1
        self.workers = 1
        self.carts = 0
        self.rp = 0
    
    def reset(self):
        self.citytiles = 1
        self.workers = 1
        self.carts = 0
        self.rp = 0

def calc_reward(player):
    total_fuel, citytiles, workers, carts, rp = 0,0,0,0,0
    player_id = player.team
    
    citytiles = player.city_tile_count
    rp = player.research_points
    for city in player.cities.values():
        total_fuel += city.fuel
    for unit in player.units:
        workers += 1 if unit.type == UNIT_TYPE['WORKER'] else 0
        carts += 1 if unit.type == UNIT_TYPE['CART'] else 0
        total_fuel += unit.cargo.wood * FUEL_RATE['WOOD']
        total_fuel += unit.cargo.coal * FUEL_RATE['COAL']
        total_fuel += unit.cargo.uranium * FUEL_RATE['URANIUM']
    fuel_reward = math.log(total_fuel,100) if total_fuel>0 else -1
    city_reward = citytiles - rewards[player_id].citytiles
    work_reward = workers - rewards[player_id].workers
    cart_reward = carts - rewards[player_id].carts
    rp___reward = rp - rewards[player_id].rp if rp < RP_REQ['URANIUM'] else 0

    rewards[player_id].citytiles = citytiles
    rewards[player_id].workers = workers
    rewards[player_id].carts = carts
    rewards[player_id].rp = rp
    
    return city_reward + fuel_reward + work_reward + cart_reward + rp___reward

def reinforce_agent(observation, configuration):
    global game_state, obs, last_state, mem_last_state, reward
    obs = observation
    
    ### Do not edit ###
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation.player
    else:
        game_state._update(observation["updates"])
        
    actions = []
    
    ### AI Code goes down here! ### 
    player = game_state.players[observation.player]
    opponent = game_state.players[(observation.player + 1) % 2]
    width, height = game_state.map.width, game_state.map.height

    state = get_input()
    if observation['step']==0:
        last_state = state[:3]
        mem_last_state = (state, last_state)
    pred = luxagent.rl_choose_actions(state, last_state)
    actions, done = action_strings(player, pred)
    
    reward = calc_reward(player)
    
    mem_state = (state, last_state)
    luxagent.store_memory(mem_state, pred, reward, mem_last_state, done)
    luxagent.reinforcement_learn()
    mem_last_state = (state, last_state)
    last_state = state[:3]
    return actions

In [None]:
def reinforce_train():
    global env
    sizes = [12,16,24,32]
    sizes = [12]
    game_num = 0
    
    callback.on_train_begin()
    for epoch in range(config.epochs):
        callback.on_epoch_begin(epoch)
        for idx,size in enumerate(sizes):
            callback.on_game_begin()
            env = make("lux_ai_2021", configuration={"loglevel": 0, "width": size, "height": size, "annotations": True}, debug=True)
            steps = env.run([reinforce_agent, "random_agent"])
            rewards[0].reset()
            rewards[1].reset()
            callback.on_game_end(game_num)
            game_num += 1
        callback.on_epoch_end()
    callback.on_train_end()
    env.render(mode="ipython", width=1200, height=800)

rewards = {0:Rewards(), 1:Rewards()}
config.phase = 'reinforce'
callback = Callback()
luxagent = LuxAgent()
reinforce_train()