# Introduction
Since the competition is now over, I want to share my (hopefully **silver**, or bronze) solution. My agent is an ensemble of 3 IL agents, I borrowed that idea from the [lux-ai-with-il-ensemble-of-models](https://www.kaggle.com/realneuralnetwork/lux-ai-with-il-ensemble-of-models) notebook but instead of choosing the most common action, I take the softmax function and choose the action with highest probability. 3 IL models are:
- [lux-ai-with-il-decreasing-learning-rate](https://www.kaggle.com/realneuralnetwork/lux-ai-with-il-decreasing-learning-rate) 
- [toad model from the orginal ensemble notebook](https://www.kaggle.com/realneuralnetwork/lux-ai-with-il-ensemble-of-models) I tried to use many different models but somehow this one gave the best performance
- [unet immitation learning](https://www.kaggle.com/bachngoh/luxai-unet-immitationlearning-lb-1100) I was inspired by [this](https://www.kaggle.com/c/lux-ai-2021/discussion/289540) amazing post by nosound.

I also use unet to train the city tile actions:
- [unet for ctiles](https://www.kaggle.com/bachngoh/luxai-unet-for-ctiles)

In [None]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* .

In [None]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

[lux-ai-with-il-decreasing-learning-rate](https://www.kaggle.com/realneuralnetwork/lux-ai-with-il-decreasing-learning-rate) amazing notebook!

In [None]:
model = torch.jit.load("../input/lux-ai-with-il-decreasing-learning-rate/model.pth")
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model_il.pth')

In [None]:
model = torch.jit.load("../input/models-lux/model_toad.pth")
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model_toad.pth')

In [None]:
model = torch.jit.load("../input/d/bachngoh/luxai-models/submission_unet_v6_updated_ctile/model.pth")
model.eval()
traced = torch.jit.trace(model.cpu(), (torch.rand(1, 14, 32, 32), torch.rand(1,14,4,4)))
traced.save('model_unet.pth')

In [None]:
model = torch.jit.load("../input/d/bachngoh/luxai-models/submission_unet_v6_updated_ctile/model_ct.pth")
model.eval()
traced = torch.jit.trace(model.cpu(), (torch.rand(1, 14, 32, 32), torch.rand(1,14,4,4)))
traced.save('model_unet_ct.pth')

# Submission

In [None]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game
from collections import Counter
from torch import nn


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'

model = torch.jit.load(f"{path}/model_il.pth")
model.eval()

model2 = torch.jit.load(f"{path}/model_unet.pth")
model2.eval()

model4 = torch.jit.load(f"{path}/model_toad.pth")
model4.eval()

model_ct = torch.jit.load(f"{path}/model_unet_ct.pth")
model_ct.eval()


def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]   
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

def make_input_unet(obs):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    global_features = np.zeros((14,4,4))
    
    b = np.zeros((14, 32, 32), dtype=np.float32)
    
    friendly_unit_cnt = 0
    opponent_unit_cnt = 0
    friendly_ctile_cnt = 0
    opponent_ctile_cnt = 0
    total_wood = 0
    total_coal = 0
    total_uranium = 0
    
    can_mine_coal = 0
    can_mine_uranium = 0
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            
            # Units
            team = int(strs[2])
            
            if (team - obs['player']) % 2 == 0:
                friendly_unit_cnt += 1
            else:
                opponent_unit_cnt += 1
            
            cooldown = float(strs[6])
            idx = (team - obs['player']) % 2 * 3
            b[idx:idx + 3, x, y] = (
                1,
                cooldown / 6,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'ct':
            # CityTiles
            
            team = int(strs[1])
            
            if (team - obs['player']) % 2 == 0:
                friendly_ctile_cnt += 1
            else:
                opponent_ctile_cnt += 1
            
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 6 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 10, 'coal': 11, 'uranium': 12}[r_type], x, y] = amt / 800
            if r_type == 'wood': total_wood += amt
            elif r_type == 'coal': total_coal += amt
            elif r_type == 'uranium': total_uranium += amt
            
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            if team - obs['player'] % 2 == 0:
                if rp >= 50:
                    can_mine_coal = 1
                if rp >= 200:
                    can_mine_uranium = 1
            
            global_features[(team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    global_features[2, :] = obs['step'] % 40 / 40
    # Turns
    global_features[3, :] = obs['step'] / 360
    # Number of friendly unit 
    global_features[4, :] = friendly_unit_cnt / 50
    # Number of opponent unit
    global_features[5, :] = opponent_unit_cnt / 50
    # Number of friendly ctiles
    global_features[6, :] = friendly_ctile_cnt / 50
    # Number of opponent unit
    global_features[7, :] = opponent_ctile_cnt / 50
    # Total Wood
    global_features[8, :] = total_wood / 24000
    # Total Coal
    global_features[9, :] = total_coal / 24000
    # Total Uranium
    global_features[10, :] = total_uranium / 12000
    global_features[11, :] = can_mine_coal
    global_features[12, :] = can_mine_uranium
    # Map Size
    global_features[13, :] = width 
    
    # Map Size
    b[13, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b, global_features


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_action(policy, unit, dest):

    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos

def get_action_unet(policy, unit, dest, shift):
    logits = nn.Softmax(policy[:, unit.pos.x + shift, unit.pos.y + shift] )
    action = unet_unit_actions[ np.argmax( policy[:, unit.pos.x + shift, unit.pos.y + shift] )]
    pos = unit.pos.translate(action[-1], 1) or unit.pos
    if pos not in dest or in_city(pos):
        return call_func(unit, *action), pos
    
    return unit.move('c'), unit.pos 

def get_shift(observation):
    width, height = observation['width'], observation['height']
    shift = (32 - width) // 2
    return shift


def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    
    shift = get_shift(observation)
    #print(shift)
    state_1, state_2 = make_input_unet(observation)
    with torch.no_grad():
        p_ct = model_ct(torch.from_numpy(state_1).unsqueeze(0).float(), torch.from_numpy(state_2).unsqueeze(0).float())
        policy_ct = p_ct.squeeze(0).numpy()
        
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                action = np.argmax( policy_ct[:, city_tile.pos.x + shift, city_tile.pos.y + shift] )
                if action == 0:
                    actions.append(city_tile.research())
                    player.research_points += 1
                elif action == 1:
                    actions.append(city_tile.build_worker())
                    unit_count += 1
    
    # Worker Actions
    dest = []

    with torch.no_grad():
        p1 = model2(torch.from_numpy(state_1).unsqueeze(0).float(), torch.from_numpy(state_2).unsqueeze(0).float())
        policy_unet = p1.squeeze(0).numpy()

    
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))
            
                p2 = model4(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()
            policy2 = policy_unet[:, unit.pos.x + shift, unit.pos.y + shift] 
            policy4 = p2.squeeze(0).numpy()

            softmax = nn.Softmax(dim=0)
            logits1 = softmax(torch.from_numpy(policy))
            logits2 = softmax(torch.from_numpy(policy2))
            logits4 = softmax(torch.from_numpy(policy4))

            ensemble_logits = np.array( logits1 * 0.4 + logits2 * 0.4 + logits4 * 0.2)

            action, pos = get_action(policy, unit, dest)

            actions.append(action)
            dest.append(pos)

    return actions

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 0, "annotations": True}, debug=True)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)

In [None]:
!tar -czf submission.tar.gz *