In [None]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* .

In [None]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [None]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

In [None]:
def to_label_ct(action, obs):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2   
    strs = action.split(' ')
    if strs[0] == 'r':
        label = 0
        unit_pos = (int(strs[1]) + x_shift, int(strs[2]) + y_shift)
    elif strs[0] == 'bw': 
        label = 1
        unit_pos = (int(strs[1]) + x_shift, int(strs[2]) + y_shift)
    else:
        label = None
        unit_pos = None
    
    return label, unit_pos

def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True

def create_dataset_from_json_ctile(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    samples = []
    append = samples.append
    
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
#     episodes = episodes[:len(episodes)//2]
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        for i in range(len(json_load['steps'])-1):
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                
                obs_id = f'{ep_id}_{i}'
                obses[obs_id] = obs
                
                action_map = np.zeros((2,32,32))
                mask = np.zeros((2,32,32))
                
                for action in actions:
                    label, unit_pos = to_label_ct(action, obs)
                    if label is not None:
                        action_map[label, unit_pos[0], unit_pos[1]] = 1
                        mask[:, unit_pos[0], unit_pos[1]] = 1
                mask = mask.astype('bool')
                action_map = action_map.astype('bool')
                append((obs_id, action_map, mask))

    return obses, samples

In [None]:
episode_dir = '../input/simulations-episode-scraper-match-downloader/'
obses, samples = create_dataset_from_json_ctile(episode_dir)
print('obses:', len(obses), 'samples:', len(samples))

In [None]:
# Input for Neural Network 
# Feature map size [14,32,32] and global features size [14,4,4]
def make_input(obs):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    global_features = np.zeros((14,4,4))
    
    b = np.zeros((14, 32, 32), dtype=np.float32)
    
    friendly_unit_cnt = 0
    opponent_unit_cnt = 0
    friendly_ctile_cnt = 0
    opponent_ctile_cnt = 0
    total_wood = 0
    total_coal = 0
    total_uranium = 0
    
    can_mine_coal = 0
    can_mine_uranium = 0
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            
            # Units
            team = int(strs[2])
            
            if (team - obs['player']) % 2 == 0:
                friendly_unit_cnt += 1
            else:
                opponent_unit_cnt += 1
            
            cooldown = float(strs[6])
            idx = (team - obs['player']) % 2 * 3
            b[idx:idx + 3, x, y] = (
                1,
                cooldown / 6,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'ct':
            # CityTiles
            
            team = int(strs[1])
            
            if (team - obs['player']) % 2 == 0:
                friendly_ctile_cnt += 1
            else:
                opponent_ctile_cnt += 1
            
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 6 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 10, 'coal': 11, 'uranium': 12}[r_type], x, y] = amt / 800
            if r_type == 'wood': total_wood += amt
            elif r_type == 'coal': total_coal += amt
            elif r_type == 'uranium': total_uranium += amt
            
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            if team - obs['player'] % 2 == 0:
                if rp >= 50:
                    can_mine_coal = 1
                if rp >= 200:
                    can_mine_uranium = 1
            
            global_features[(team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    global_features[2, :] = obs['step'] % 40 / 40
    # Turns
    global_features[3, :] = obs['step'] / 360
    # Number of friendly unit 
    global_features[4, :] = friendly_unit_cnt / 50
    # Number of opponent unit
    global_features[5, :] = opponent_unit_cnt / 50
    # Number of friendly ctiles
    global_features[6, :] = friendly_ctile_cnt / 50
    # Number of opponent unit
    global_features[7, :] = opponent_ctile_cnt / 50
    # Total Wood
    global_features[8, :] = total_wood / 24000
    # Total Coal
    global_features[9, :] = total_coal / 24000
    # Total Uranium
    global_features[10, :] = total_uranium / 12000
    global_features[11, :] = can_mine_coal
    global_features[12, :] = can_mine_uranium
    # Map Size
    global_features[13, :] = width 
    
    # Map Size
    b[13, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b, global_features


class LuxDataset(Dataset):
    def __init__(self, obses, samples):
        self.obses = obses
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs_id, action_map, mask = self.samples[idx]
        obs = self.obses[obs_id]
        state_1, state_2 = make_input(obs)
        
        return state_1, state_2, action_map, mask

# UNet

In [None]:
# copied from https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py

""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, n_channels_b, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64,128)
        self.down2 = Down(128,256)
#         self.down3 = Down(256,512)
        
        factor = 2 if bilinear else 1
        self.down3 = Down(256, 512//factor)
        self.up1 = Up(512 + n_channels_b, 256, bilinear)
        self.up2 = Up(256+128, 128, bilinear)
        self.up3 = Up(128+64, 64, bilinear)
#         self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
        
    def forward(self, x, x_features):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        #print(x4.size())
        x = torch.cat((x4, x_features),1)
        #print(x.size())
        x = self.up1(x,x3)
        #print(x.size())
        x = self.up2(x,x2)
        x = self.up3(x,x1)
        logits = self.outc(x)
        
        return logits

In [None]:
def train_model(model, dataloaders_dict, optimizer, num_epochs):
    
    for epoch in range(num_epochs):
        model.cuda()
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            epoch_loss = 0.0
            
            dataloader = dataloaders_dict[phase]
            for item in tqdm(dataloader, leave=False):
                states_1 = item[0].cuda().float()
                states_2 = item[1].cuda().float()
                actions = item[2].cuda().float()
                mask = item[3].cuda().float()
                
                optimizer.zero_grad()
                criterion = nn.BCEWithLogitsLoss(weight=mask)
                
                with torch.set_grad_enabled(phase == 'train'):
                    policy = model(states_1, states_2)
                    loss = criterion(policy, actions)
                    _, preds = torch.max(policy, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    epoch_loss += loss.item() * len(policy) * mask[mask==0].size()[0]/mask[mask==1].size()[0]


            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size

            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f}')
        

In [None]:
model = UNet(14, 2, 14)
train, val = train_test_split(samples, test_size = 0.1, random_state = 42)
batch_size = 256
train_loader = DataLoader(
    LuxDataset(obses, train),
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val),
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}
#criterion = nn.BCEWithLogitsLoss()
#optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
# model = torch.jit.load(f'../input/luxai-models/submission_unet_with_ctile/model_ct.pth')

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
train_model(model, dataloaders_dict, optimizer, num_epochs=5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
train_model(model, dataloaders_dict, optimizer, num_epochs=5)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
train_model(model, dataloaders_dict, optimizer, num_epochs=2)

In [None]:
traced = torch.jit.trace(model.cpu(), (torch.rand(1, 14, 32, 32), torch.rand(1,14,4,4)))
traced.save('model_ct.pth')

In [None]:
model = torch.jit.load(f'../input/luxai-models/submission_unet_v6_updated/model.pth')
traced = torch.jit.trace(model.cpu(), (torch.rand(1, 14, 32, 32), torch.rand(1,14,4,4)))
traced.save('model.pth')

# Submission

In [None]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game

path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
#path = '../input/luxai-unet-immitationlearning/'
model = torch.jit.load(f'{path}/model.pth')
model.eval()

model_ct = torch.jit.load(f'{path}/model_ct.pth')
model_ct.eval()

# Input for Neural Network 
# Feature map size [14,32,32] and global features size [4,4,4]
def make_input(obs):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    global_features = np.zeros((14,4,4))
    
    b = np.zeros((14, 32, 32), dtype=np.float32)
    
    friendly_unit_cnt = 0
    opponent_unit_cnt = 0
    friendly_ctile_cnt = 0
    opponent_ctile_cnt = 0
    total_wood = 0
    total_coal = 0
    total_uranium = 0
    
    can_mine_coal = 0
    can_mine_uranium = 0
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            
            # Units
            team = int(strs[2])
            
            if (team - obs['player']) % 2 == 0:
                friendly_unit_cnt += 1
            else:
                opponent_unit_cnt += 1
            
            cooldown = float(strs[6])
            idx = (team - obs['player']) % 2 * 3
            b[idx:idx + 3, x, y] = (
                1,
                cooldown / 6,
                (wood + coal + uranium) / 100
            )
        elif input_identifier == 'ct':
            # CityTiles
            
            team = int(strs[1])
            
            if (team - obs['player']) % 2 == 0:
                friendly_ctile_cnt += 1
            else:
                opponent_ctile_cnt += 1
            
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 6 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 10, 'coal': 11, 'uranium': 12}[r_type], x, y] = amt / 800
            if r_type == 'wood': total_wood += amt
            elif r_type == 'coal': total_coal += amt
            elif r_type == 'uranium': total_uranium += amt
            
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            if team - obs['player'] % 2 == 0:
                if rp >= 50:
                    can_mine_coal = 1
                if rp >= 200:
                    can_mine_uranium = 1
            
            global_features[(team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    global_features[2, :] = obs['step'] % 40 / 40
    # Turns
    global_features[3, :] = obs['step'] / 360
    # Number of friendly unit 
    global_features[4, :] = friendly_unit_cnt / 50
    # Number of opponent unit
    global_features[5, :] = opponent_unit_cnt / 50
    # Number of friendly ctiles
    global_features[6, :] = friendly_ctile_cnt / 50
    # Number of opponent unit
    global_features[7, :] = opponent_ctile_cnt / 50
    # Total Wood
    global_features[8, :] = total_wood / 24000
    # Total Coal
    global_features[9, :] = total_coal / 24000
    # Total Uranium
    global_features[10, :] = total_uranium / 12000
    global_features[11, :] = can_mine_coal
    global_features[12, :] = can_mine_uranium
    # Map Size
    global_features[13, :] = width 
    
    # Map Size
    b[13, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b, global_features


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state

def get_shift(observation):
    width, height = observation['width'], observation['height']
    shift = (32 - width) // 2
    return shift

def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]
def get_action(policy, unit, dest, shift):
    action = unit_actions[ np.argmax( policy[:, unit.pos.x + shift, unit.pos.y + shift] )]
    pos = unit.pos.translate(action[-1], 1) or unit.pos
    if pos not in dest or in_city(pos):
        return call_func(unit, *action), pos
    
    return unit.move('c'), unit.pos


def agent(observation, configuration):
    global game_state
    game_state = get_game_state(observation)    
    shift = get_shift(observation)
    player = game_state.players[observation.player]
    actions = []
    
    state_1, state_2 = make_input(observation)
    with torch.no_grad():
        p_ct = model_ct(torch.from_numpy(state_1).unsqueeze(0).float(), torch.from_numpy(state_2).unsqueeze(0).float())
        policy_ct = p_ct.squeeze(0).numpy()
        
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                action = np.argmax( policy_ct[:, city_tile.pos.x + shift, city_tile.pos.y + shift] )
                if action == 0:
                    actions.append(city_tile.research())
                    player.research_points += 1
                elif action == 1:
                    actions.append(city_tile.build_worker())
                    unit_count += 1
#                 if unit_count < player.city_tile_count: 
#                     actions.append(city_tile.build_worker())
#                     unit_count += 1
#                 elif not player.researched_uranium():
#                     actions.append(city_tile.research())
#                     player.research_points += 1
    
    # Worker Actions
    state_1, state_2 = make_input(observation)
    dest = []
    with torch.no_grad():
        p = model(torch.from_numpy(state_1).unsqueeze(0).float(), torch.from_numpy(state_2).unsqueeze(0).float())
        policy = p.squeeze(0).numpy()
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            action, pos = get_action(policy, unit, dest, shift)
            actions.append(action)
            dest.append(pos)

    return actions

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "annotations": True}, debug=True)
steps = env.run(['./agent.py', './agent.py'])

In [None]:
env.render(mode="ipython", width=1200, height=800)

In [None]:
!tar -czf submission.tar.gz *