In [None]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* .

In [None]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [None]:
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
class ResBlock(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.conv1 = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.conv2 = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn1 = nn.BatchNorm2d(output_dim) if bn else None
        self.bn2 = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv1(x)
        h = self.bn1(h) if self.bn1 is not None else h
        h = F.relu_(h)
        h = self.conv2(h)
        h = self.bn2(h) if self.bn2 is not None else h
        h = F.relu_(x + h)
        return h
    
class AlphaNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 19, 64
        neck = 32
        self.conv0 = BasicConv2d(20, filters, (3, 3), True)
        self.blocks = nn.ModuleList([ResBlock(filters, filters, (3, 3), True) for _ in range(layers)])
        self.conv_neck = BasicConv2d(filters, neck, (1, 1), True)
        self.head_p = nn.Linear(neck, 5, bias=False)
        self.global_pool = gem
#         self.head_p = nn.Sequential(
#             nn.Dropout(0.3),
#             nn.Linear(neck, 5, bias=True),
#             nn.BatchNorm1d(5),
#             torch.nn.PReLU()
#         )

    def forward(self, x):
        bs = x.size(0)
        
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = block(h)
        h = self.conv_neck(h) # torch.Size([bs, 32, 32, 32])bs c h w
        h = (h * x[:,:1])
        h = self.global_pool(h).view(bs,-1) # チャネルごとの代表値
        #h_head = (h * x[:,:1]).view(bs, h.size(1), -1).sum(-1)
        p = self.head_p(h)
        return p

class BasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        return h
    
class LuxNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = BasicConv2d(20, filters, (3, 3), True)
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 5, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p

In [None]:
model =  LuxNet()
state_dict = torch.load( f'../input/luxmodel/last_weight.pth')
model.load_state_dict(state_dict)
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model1.pth')

model =  LuxNet()
state_dict = torch.load( f'../input/luxmodel/last_weight2.pth')
model.load_state_dict(state_dict)
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model2.pth')

model =  LuxNet()
state_dict = torch.load( f'../input/luxmodel/exp_last_best_acc.pth')
model.load_state_dict(state_dict)
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model3.pth')


model = torch.jit.load("../input/luxmodel/model1350.pth")
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model4.pth')

model = torch.jit.load("../input/models-lux/model_toad.pth")
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model5.pth')

model = AlphaNet()
state_dict = torch.load( f'../input/luxmodel/exp012_bestacc_fold0.pth')
model.load_state_dict(state_dict)
model.eval()
traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
traced.save('model6.pth')

In [None]:
def to_label(action):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label
def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir,
                             step_from=0,
                             step_to=361,
                             team_name='Toad Brigade'): 
    """
    epospde_dir:episodeが入ってるフォルダ
    step_from,step_to : step_from <= i <= step_toを満たすstepだけ取得
    team_name:
    """
    obses = {}
    samples = []
    episodes = []
    append = samples.append
    
    # 複数フォルダから取込
    episode_files = []
    episode_files = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    random.shuffle(episode_files)
    episode_files = episode_files[:10]
    for filepath in tqdm(episode_files): 
        with open(filepath) as f:
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue
            
        episodes.append(ep_id)

        for i in range(len(json_load['steps'])-1):
            if not (step_from <= i <= step_to):
                continue
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                obses[obs_id] = obs
                                
                for action in actions:
                    unit_id, label = to_label(action)
                    if label is not None:
                        append((obs_id, unit_id,label))

    return obses, samples, episodes

In [None]:
obses, samples, episodes = create_dataset_from_json("../input/lux-ai-episodes-score1800")
print('obses:', len(obses), 'samples:', len(samples), 'episodes:', len(episodes))
labels = [sample[-1] for sample in samples]
def print_num_of_label(samples):
    labels = [sample[-1] for sample in samples]
    actions = ['north', 'south', 'west', 'east', 'bcity']
    for value, count in zip(*np.unique(labels, return_counts=True)):
        print(f'{actions[value]:^5}: {count:>3}')
print_num_of_label(samples)

In [None]:
# Input for Neural Network
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


class LuxDataset(Dataset):
    def __init__(self, obses, samples):
        self.obses = obses
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs_id, unit_id, action = self.samples[idx]
        obs = self.obses[obs_id]
        state = make_input(obs, unit_id)
        
        return state, action

In [None]:
def evaluator(model_path):
    N_DATA = 500

    model = torch.jit.load(model_path)
    model.eval()
    model = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))

    ds = LuxDataset(obses, samples)
    collect = 0 
    for i in tqdm(range(N_DATA)):
        data, label = ds[i][0], ds[i][1]
        with torch.no_grad():
            out = model(torch.from_numpy(data).unsqueeze(0))
        if i < 5:
            print(f"pred:{out.argmax().item()}, gt:{label}, logit:{out}")
        if out.argmax().item() == label:
            collect +=1
    print(f"acc{collect/N_DATA}")

In [None]:
evaluator("./model1.pth")
evaluator("./model2.pth")
evaluator("./model3.pth")
evaluator("./model4.pth")
evaluator("./model5.pth")
evaluator("./model6.pth")

# Submission

In [None]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game
from collections import Counter


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
# model = torch.jit.load(f"{path}/model1.pth")
# model.eval()

# model2 = torch.jit.load(f"{path}/model2.pth")
# model2.eval()

model3 = torch.jit.load(f"{path}/model3.pth")
model3.eval()

# model4 = torch.jit.load(f"{path}/model4.pth")
# model4.eval()

model5 = torch.jit.load(f"{path}/model5.pth")
model5.eval()

model6 = torch.jit.load(f"{path}/model6.pth")
model6.eval()



def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos


def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                if unit_count < player.city_tile_count: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
    
    # Worker Actions
    """
    unit数が60以上になったら推論モデルを1つにする
    """
    dest = []
    for unit in player.units:
        if len(player.units) < 50:
            if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
                state = make_input(observation, unit.id)
                with torch.no_grad():
#                     p = model(torch.from_numpy(state).unsqueeze(0))
#                     p2 = model2(torch.from_numpy(state).unsqueeze(0))
                    p3 = model3(torch.from_numpy(state).unsqueeze(0))
#                     p4 = model4(torch.from_numpy(state).unsqueeze(0))
                    p5 = model5(torch.from_numpy(state).unsqueeze(0))
                    p6 = model6(torch.from_numpy(state).unsqueeze(0))

#                 policy = p.squeeze(0).numpy()
#                 policy2 = p2.squeeze(0).numpy()
                policy3 = p3.squeeze(0).numpy()
#                 policy4 = p4.squeeze(0).numpy()
                policy5 = p5.squeeze(0).numpy()
                policy6 = p6.squeeze(0).numpy()

#                 action, pos = get_action(policy, unit, dest)
#                 action2, pos = get_action(policy2, unit, dest)
                action3, pos = get_action(policy3, unit, dest)
#                 action4, pos = get_action(policy4, unit, dest)
                action5, pos = get_action(policy5, unit, dest)
                action6, pos = get_action(policy6, unit, dest)
                
                action = Counter([action3,action5,action3,action5,action6,action6,action6]).most_common(1)[0][0]
                actions.append(action)
                dest.append(pos)
        else:
            if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
                state = make_input(observation, unit.id)
                with torch.no_grad():
                    p = model6(torch.from_numpy(state).unsqueeze(0))

                policy = p.squeeze(0).numpy()

                action, pos = get_action(policy, unit, dest)
                actions.append(action)
                dest.append(pos)
            

    return actions

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=False)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)

In [None]:
!tar -czf submission.tar.gz *