In [None]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* .

In [None]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [None]:
def seed_everything(seed_value):
    #random will be the same if getting the same seed_value 
    random.seed(seed_value)
    np.random.seed(seed_value)
    #Sets the seed for generating random numbers for the current GPU
    torch.manual_seed(seed_value)
    #set Python Hash Seed to a specific number,making it able to appear again
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    #indicate if cuda is available 
    if torch.cuda.is_available(): 
        #Sets the seed for generating random numbers for the current GPU
        torch.cuda.manual_seed(seed_value)
        #Sets the seed for generating random numbers on all GPUs.
        torch.cuda.manual_seed_all(seed_value)
        #if True, causes cuDNN to only use deterministic convolution algorithms
        torch.backends.cudnn.deterministic = True
        # if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest.
        torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

# Preprocessing

In [None]:
def to_label(action):
    #将action按照空格分成一个长为3的字符串列表
    strs = action.split(' ')
    unit_id = strs[1]
    #m代表move，说明是一个可动单位
    if strs[0] == 'm':
        #north、south、west、east
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    #是city
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label

#判断资源是否所有单位全部耗尽的函数（即一方全部消失在黑暗中，游戏结束）
def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    samples = []
    append = samples.append
    #从../input/lux-ai-episodes寻找训练模型，并加入episodes中
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    #tqdm为一个显示进度条的库，使得处理进度GUI化
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            #json_load表示从json文件中读取的数据形成的对象
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        #返回rewards数组最大值的索引（本质上是选择最后得分更高的player进行学习）
        index = np.argmax([r or 0 for r in json_load['rewards']])
        #仅使用自己team定义的json文件
        if json_load['info']['TeamNames'][index] != team_name:
            continue
        #对一个json文件的每一步进行操作
        for i in range(len(json_load['steps'])-1):
            #如果在该步执行后存在该"Active"状态说明我方存活
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                #如果该步状态后依然我方active，load下一步状态的所有动作，放入actions中
                actions = json_load['steps'][i+1][index]['action']
                #load当前步骤后的全地图资源状态
                obs = json_load['steps'][i][0]['observation']
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                #将obs重新整合为一个字典
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                #obses为所有obs字典的列表集合
                obses[obs_id] = obs
                                
                for action in actions:
                    #label为当前执行动作对应的下标
                    unit_id, label = to_label(action)
                    if label is not None:
                        #如果动作有方向或者待在city中，那么添加到列表samples中，定义的append类似于C的宏定义
                        append((obs_id, unit_id, label))
#最终得到全地图每一步执行后的包括资源、各palyer单位的字典集合obses,所有存活单位每一回合的移动samples
    return obses, samples

In [None]:
episode_dir = '../input/d/shoheiazuma/lux-ai-episodes'
obses, samples = create_dataset_from_json(episode_dir)
print('obses:', len(obses), 'samples:', len(samples))

In [None]:
#将所有sample的最后一个label，即移动方向制作成list
labels = [sample[-1] for sample in samples]
actions = ['north', 'south', 'west', 'east', 'bcity']
#打印各个方向动作的总次数
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[value]:^5}: {count:>3}')

# Input For Training

In [None]:
# Input for Neural Network
def make_input(obs, unit_id):
    #x,y coordiantes
    width, height = obs['width'], obs['height']
    #//为整数除法，相当于逻辑右移一位
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    #制作一个三维全为浮点数0的列表
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        #split方法通过指定分隔符分隔字符串
        strs = update.split(' ')
        input_identifier = strs[0]
        #a player unit which contains all information
        #str[1]为0时表示worker，为1时表示cart
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,#在第0层有unit的位置标上1
                    (wood + coal + uranium) / 100#在第1层对应有unit的位置，标记其装载材料的比例
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3#结果为2或者5
                #在第2层标记非unit_id的本方单位位置，第3层标记本方cooldown / 6，第4层标记其装载材料的比例
                #在第5层标记非unit_id的敌方单位位置，第6层标记本方cooldown / 6，第7层标记其装载材料的比例
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            #第8层标记本方citytile位置，第9层标记该city在黑夜能持续几天
            #第10层标记敌方citytile位置，第11层标记该city在黑夜能持续几天
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            #第12、13、14层分别放wood、coal、uranium的位置极其装载比例
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            #第15层放我方研究点数
            #第16层放敌方研究点数
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities参数计算
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


class LuxDataset(Dataset):
    def __init__(self, obses, samples):
        self.obses = obses
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs_id, unit_id, action = self.samples[idx]
        obs = self.obses[obs_id]
        state = make_input(obs, unit_id)
        
        return state, action

# Model for Training

In [None]:
# Neural Network for Lux AI
#class torch.nn.Module是所有网络的基类，所有模型也应该继承这个类。
#此为基础的卷积层模型
class BasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        #padding(int or tuple, optional) - 输入的每一条边补充0的层数,这是为了防止卷积核的大小与步长使得一些边缘数据无法被遍历，所以需要进行填充
        #kerner_size(int or tuple) - 卷积核的尺寸
        #输入的尺度是(N, C_in,H,W)，输出尺度（N,C_out,H_out,W_out）
        #这里padding填充这个数量的目的是保证在一次卷积操作后其尺寸与原尺寸保持相同
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        
        #对小批量(mini-batch)3d数据组成的4d输入进行批标准化(Batch Normalization)操作
        #在每一个小批量（mini-batch）数据中，计算输入各个维度的均值和标准差
        #在训练时，该层计算每次输入的均值与方差，并进行移动平均。移动平均默认的动量值为0.1。
        #输入输出相同
        #BatchNorm2d函数是在使用卷积计算后进行归一化，防止relu前的数据过大导致网络不稳定
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        return h


class LuxNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        #model initialization
        self.conv0 = BasicConv2d(20, filters, (3, 3), True)
        #将submodules保存在一个list中。ModuleList可以像一般的Python list一样被索引，即每一个都是用于存储一个layer的
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        #nn.Linear（）是用于设置网络中的全连接层的，需要注意在二维图像处理的任务中
        #全连接层的输入与输出一般都设置为二维张量，形状通常为[batch_size, size]
        #其中前两个参数分别为in的size，5为输出的size
        self.head_p = nn.Linear(filters, 5, bias=False)

    #神经网络的前向传播
    def forward(self, x):
        #此为激励层的内容，relu_是一种非线性激活函数，本质上在深度神经网络上可以逼近任何类型的函数
        #relu函数的第二个参数inplace默认值为false，表示不改变输入的数据
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        #view函数相当于resize的功能，将原来的tensor变换成指定的维度
        #这里是将激励值乘入输入x，再将其重新张成h.size(0)*h.size(1)*剩余的形式并求和
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        #调用head_p函数张成二维向量
        p = self.head_p(h_head)
        return p

# Train

In [None]:
def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs):
    best_acc = 0.0
    #动态改变学习率的sheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.1)
    
    for epoch in range(num_epochs):
        #显式地指定model使用gpu
        model.cuda()
        #在Pytorch训练中有两种模式，train为训练模式，eval为评估模式
        for phase in ['train', 'val']:
            if phase == 'train':
                #在使用pytorch构建神经网络的时候，训练过程中会在程序上方添加一句model.train()，作用是启用batch normalization和dropout
                model.train()
            else:
                #测试过程中会使用model.eval()，这时神经网络会沿用batch normalization的值，并不使用dropout，即我们这里不是在训练模型
                #那么就不需要进行dropout改变模型，故调用该函数保证dropout不变
                model.eval()
                
            epoch_loss = 0.0
            epoch_acc = 0
            
            dataloader = dataloaders_dict[phase]
            #对模型进行多次训练
            for item in tqdm(dataloader, leave=False):
                states = item[0].cuda().float()
                actions = item[1].cuda().long()
                #将模型的参数梯度初始化为0
                optimizer.zero_grad()
                #在我们处于训练模式下时，允许计算局部梯度
                with torch.set_grad_enabled(phase == 'train'):
                    policy = model.forward(states)
                    #criterion为损失函数，计算损失
                    loss = criterion(policy, actions)
                    #返回所有张量最大值
                    _, preds = torch.max(policy, 1)

                    #如果处于训练模式下，损失函数backward进行反向传播梯度的计算，并使用优化器的step函数来更新参数
                    if phase == 'train':
                        #当前Variable对leaf variable求偏导，计算好梯度
                        loss.backward()
                        #根据前面求出的梯度更新参数
                        optimizer.step()

                    epoch_loss += loss.item() * len(policy)
                    epoch_acc += torch.sum(preds == actions.data)
            #归一化
            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size
            if phase == 'train':
                scheduler.step()
            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        #产生了更好结果的情况
        if epoch_acc > best_acc:
            traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
            traced.save('model.pth')
            best_acc = epoch_acc

In [None]:
model = LuxNet()
#train_test_split将原始数据按照比例划分为“测试集”“训练集”
# test_size：样本占比，如果是整数的话就是样本的数量

# random_state：是随机数的种子。
# 随机数种子：其实就是该组随机数的编号，在需要重复试验的时候，保证得到一组一样的随机数。
#stratify是为了保持split前类的分布，保证其在训练集中的分布比例不变
train, val = train_test_split(samples, test_size=0.15, random_state=42, stratify=labels)
batch_size = 64
#shuffle为false表示不打乱传入的数据
#num_workers表示了会有多少进程共同执行
train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}
#交叉熵损失函数，一种使用对数的损失函数
criterion = nn.CrossEntropyLoss()
#利用Adam(Adaptive Moment Estimation)进行优化
#它的优点主要在于经过偏置校正后，每一次迭代学习率都有个确定范围，使得参数比较平稳
#lr为学习率，过高会导致不稳定与过拟合
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=25)

# Submission

In [None]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game
from lux.game_map import Cell, RESOURCE_TYPES, Position
from lux.constants import Constants
from lux.game_constants import GAME_CONSTANTS
from lux import annotate
import math
import sys


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
model = torch.jit.load(f'{path}/model.pth')
model.eval()


def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state

def sum_fuel(unit):
    return (unit.cargo.wood + unit.cargo.coal * 10 + unit.cargo.uranium * 40)

#判断pos是否在city中的函数
def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False

def call_func(obj, method, args=[]):
    #获取对象的所有属性值
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]

def could_get_res(pos, player):
    is_coal = (player.research_points >= 50)
    is_uranium = (player.research_points >= 200)
    fuelparam = {"coal":10,"uranium":40, "wood":1}
    try:
        res = game_state.map.get_cell_by_pos(pos).has_resource()
        if not res:
            return 0, 0
        type = game_state.map.get_cell_by_pos(pos).resource.type
        amount = game_state.map.get_cell_by_pos(pos).resource.amount
        if ((type == "coal" and is_coal) or (type == "uranium" and is_uranium) or type == "wood"):
            return amount, fuelparam[type]
        else:
            return 0, 0
    except:
        return 0, 0


# 若某车在city内，判断当前pos是否靠近某可获取资源,用于让夜晚回归的unit再次出去
def could_get_res(pos, player):
    is_coal = (player.research_points >= 50)
    is_uranium = (player.research_points >= 200)
    fuelparam = {"coal":10,"uranium":40, "wood":1}
    try:
        res = game_state.map.get_cell_by_pos(pos).has_resource()
        if not res:
            return 0, 0
        type = game_state.map.get_cell_by_pos(pos).resource.type
        amount = game_state.map.get_cell_by_pos(pos).resource.amount
        if ((type == "coal" and is_coal) or (type == "uranium" and is_uranium) or type == "wood"):
            return amount, fuelparam[type]
        else:
            return 0, 0
    except:
        return 0, 0

def get_units_in_citytile(pos, player):
    unit_list = []
    for unit in player.units:
        if(unit.pos.x == pos.x and unit.pos.y == pos.y and unit.can_act()):
            unit_list.append(unit)
    return unit_list

# 若某车在city内，判断当前pos是否靠近某可获取资源,用于让夜晚回归的unit再次出去
def step_out_city(unit, player, dest):
    incitys = get_units_in_citytile(unit.pos, player)
    x = unit.pos.x
    y = unit.pos.y
    subpos1 = Position(x, y - 1)
    subpos2 = Position(x, y + 1)
    subpos3 = Position(x - 1, y)
    subpos4 = Position(x + 1, y)
    subpossub1 = Position(x, y - 2)  # for 1
    subpossub2 = Position(x + 1, y - 1)  # for 1 4
    subpossub3 = Position(x - 1, y - 1)  # for 1 3
    subpossub4 = Position(x, y + 2)  # for 2
    subpossub5 = Position(x - 1, y + 1)  # for 2 3
    subpossub6 = Position(x + 1, y + 1)  # for 2 4
    subpossub7 = Position(x - 2, y)  # for 3
    subpossub8 = Position(x + 2, y)  # for 4
    try:
        subresamount1, subresparam1 = could_get_res(subpos1, player)
        subresamount2, subresparam2 = could_get_res(subpos2, player)
        subresamount3, subresparam3 = could_get_res(subpos3, player)
        subresamount4, subresparam4 = could_get_res(subpos4, player)
        cityid = game_state.map.get_cell_by_pos(unit.pos).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        nightdays = (40 - game_state.turn) % 40
        fuelneed = nightdays * light_up_keep
        aroundamount = subresamount1 + subresamount2 + subresamount3 + subresamount4
        aroundfuel = subresamount1 * subresparam1 + subresamount2 * subresparam2 + \
                     subresamount3 * subresparam3 + subresamount4 * subresparam4
        if (len(incitys) == 1 and fuel < fuelneed and aroundamount > nightdays - (fuelneed - fuel)/light_up_keep
                and aroundfuel > fuelneed - fuel):
            return unit.move('c'), unit.pos
        if(len(incitys) > 1 and incitys[0].id == unit.id):
            return unit.move('c'), unit.pos
        subresamountsub1, subresparamsub1 = could_get_res(subpossub1, player)
        subresamountsub2, subresparamsub2 = could_get_res(subpossub2, player)
        subresamountsub3, subresparamsub3 = could_get_res(subpossub3, player)
        if (subpos1 not in dest and ((subresamount1 and y > 0) or (subresamountsub1 and y > 1) or (
                subresamountsub2 and y > 0 and x < game_state.map.width - 1) or (
                                             subresamountsub3 and y > 0 and x > 0))):
            act = unit_actions[0]
            action1, pos1 = call_func(unit, *act), subpos1
        else:
            action1, pos1 = unit.move('c'), unit.pos
        subresamountsub4, subresparamsub4 = could_get_res(subpossub4, player)
        subresamountsub5, subresparamsub5 = could_get_res(subpossub5, player)
        subresamountsub6, subresparamsub6 = could_get_res(subpossub6, player)
        if (subpos2 not in dest and ((subresamount2 and y < game_state.map.height - 1) or (
                subresamountsub4 and y < game_state.map.height - 2) or (
                                             subresamountsub5 and y < game_state.map.height - 1 and x > 0) or (
                                             subresamountsub6 and y < game_state.map.height - 1 and x < game_state.map.width - 1))):
            act = unit_actions[1]
            action2, pos2 = call_func(unit, *act), subpos2
        else:
            action2, pos2 = unit.move('c'), unit.pos
        subresamountsub7, subresparamsub7 = could_get_res(subpossub7, player)
        if (subpos3 not in dest and ((subresamount3 and x > 0) or (subresamountsub3 and y > 0 and x > 0) or (
                subresamountsub5 and y < game_state.map.height - 1 and x > 0) or (subresamountsub7 and x > 1))):
            act = unit_actions[2]
            action3, pos3 = call_func(unit, *act), subpos3
        else:
            action3, pos3 = unit.move('c'), unit.pos
        subresamountsub8, subresparamsub8 = could_get_res(subpossub8, player)
        if (subpos4 not in dest and ((subresamount4 and x < game_state.map.width - 1) or (
                subresamountsub2 and y > 0 and x < game_state.map.width - 1) or (
                                             subresamountsub6 and y < game_state.map.height - 1 and x < game_state.map.width - 1) or (
                                             subresamountsub8 and x < game_state.map.width - 2))):
            act = unit_actions[3]
            action4, pos4 = call_func(unit, *act), subpos4
        else:
            action4, pos4 = unit.move('c'), unit.pos
        amount1 = subresamount1 + subresamountsub1 + subresamountsub2 + subresamountsub3
        if(amount1 > 1):
            fuel1 = subresamount1 * subresparam1 + subresamountsub1 * subresparamsub1 \
                    + subresamountsub2 * subresparamsub2 + subresamountsub3 * subresparamsub3
        else:
            fuel1 = 0

        amount2 = subresamount2 + subresamountsub4 + subresamountsub5 + subresamountsub6
        if (amount2 > 1):
            fuel2 = subresamount2 * subresparam2 + subresamountsub4 * subresparamsub4 \
                    + subresamountsub5 * subresparamsub5 + subresamountsub6 * subresparamsub6
        else:
            fuel2 = 0

        amount3 = subresamount3 + subresamountsub3 + subresamountsub5 + subresamountsub7
        if(amount3 > 1):
            fuel3 = subresamount3 * subresparam3 + subresamountsub3 * subresparamsub3 \
                    + subresamountsub5 * subresparamsub5 + subresamountsub7 * subresparamsub7
        else:
            fuel3 = 0

        amount4 = subresamount4 + subresamountsub2 + subresamountsub6 + subresamountsub8
        if(amount4 > 1):
            fuel4 = subresamount4 * subresparam4 + subresamountsub2 * subresparamsub2 \
                    + subresamountsub6 * subresparamsub6 + subresamountsub8 * subresparamsub8
        else:
            fuel4 = 0

        if(fuel1 >= 8 and fuel1 >= fuel2 and fuel1 >= fuel3 and fuel1 >= fuel4):
            return action1, pos1
        if(fuel2 >= 8 and fuel2 >= fuel1 and fuel2 >= fuel3 and fuel2 >= fuel4):
            return action2, pos2
        if(fuel3 >= 8 and fuel3 >= fuel1 and fuel3 >= fuel2 and fuel3 >= fuel4):
            return action3, pos3
        if(fuel4 >= 8 and fuel4 >= fuel1 and fuel4 >= fuel2 and fuel4 >= fuel3):
            return action4, pos4

        return unit.move('c'), unit.pos
    except:
        return unit.move('c'), unit.pos

def check_light_up_keep(player, cityid):
    for city in player.cities.values():
        if (city.cityid == cityid):
            return city.get_light_upkeep()
    return False


def check_fuel(player, cityid):
    for city in player.cities.values():
        if (city.cityid == cityid):
            return city.fuel
    return False


def rescue_city(unit, player, dest):
    x = unit.pos.x
    y = unit.pos.y
    subpos1 = Position(x, y - 1)
    subpos2 = Position(x, y + 1)
    subpos3 = Position(x - 1, y)
    subpos4 = Position(x + 1, y)
    if (y > 0 and in_city(subpos1)):
        cityid = game_state.map.get_cell_by_pos(subpos1).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if (subpos1 not in dest and ((40 - game_state.turn % 40) * light_up_keep > fuel)):
            act = unit_actions[0]
            return call_func(unit, *act), subpos1
    if (y < game_state.map.height - 1 and in_city(subpos2)):
        cityid = game_state.map.get_cell_by_pos(subpos2).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if (subpos2 not in dest and ((40 - game_state.turn % 40) * light_up_keep > fuel)):
            act = unit_actions[1]
            return call_func(unit, *act), subpos2
    if (x > 0 and in_city(subpos3)):
        cityid = game_state.map.get_cell_by_pos(subpos3).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if (subpos3 not in dest and ((40 - game_state.turn % 40) * light_up_keep > fuel)):
            act = unit_actions[2]
            return call_func(unit, *act), subpos3
    if (x < game_state.map.width - 1 and in_city(subpos4)):
        cityid = game_state.map.get_cell_by_pos(subpos4).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if (subpos4 not in dest and ((40 - game_state.turn % 40) * light_up_keep > fuel)):
            act = unit_actions[3]
            return call_func(unit, *act), subpos4
    return False


# 最终获得两个列表，列表中的worker与citytile相互绑定送资源
def fill_city_fuel_list(player, max_dis):
    fill_worker_list = []
    fill_citytile_list = []
    for city in player.cities.values():
        cityid = city.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        extra_need = min(10, 40 - (game_state.turn % 40)) * light_up_keep - fuel
        if (extra_need <= 0):
            continue
        min_distance = 100
        flag = False
        unit_corr = None
        citytile = None
        for city_tile in city.citytiles:
            for unit in player.units:
                tmp = abs(city_tile.pos.x - unit.pos.x) + abs(city_tile.pos.y - unit.pos.y)
                if (sum_fuel(unit) < 40 or not unit.can_act()):
                    continue
                if (tmp <= min_distance and tmp <= max_dis):
                    min_distance = abs(city_tile.pos.x - unit.pos.x) + abs(city_tile.pos.y - unit.pos.y)
                    citytile = city_tile
                    unit_corr = unit
        if (not unit_corr == None and not citytile == None):
            fill_worker_list.append(unit_corr)
            fill_citytile_list.append(citytile)
    return fill_worker_list, fill_citytile_list


# 为城市补充燃料的寻路
def get_fill_way(unit, citytile, dest):
    target_x = citytile.pos.x
    target_y = citytile.pos.y
    pos_x = unit.pos.x
    pos_y = unit.pos.y
    if (target_x == pos_x and target_y > pos_y and Position(pos_x, pos_y + 1) not in dest):
        act = unit_actions[1]
        return call_func(unit, *act), Position(pos_x, pos_y + 1)
    if (target_x == pos_x and target_y < pos_y and Position(pos_x, pos_y - 1) not in dest):
        act = unit_actions[0]
        return call_func(unit, *act), Position(pos_x, pos_y - 1)
    if (target_y == pos_y and target_x > pos_x and Position(pos_x + 1, pos_y) not in dest):
        act = unit_actions[3]
        return call_func(unit, *act), Position(pos_x + 1, pos_y)
    if (target_y == pos_y and target_x < pos_x and Position(pos_x - 1, pos_y) not in dest):
        act = unit_actions[2]
        return call_func(unit, *act), Position(pos_x - 1, pos_y)
    if (target_x > pos_x and target_y > pos_y):
        if (Position(pos_x, pos_y + 1) not in dest and Position(pos_x + 1, pos_y) not in dest):
            label = np.random.randint(0, 2) * 2 + 1
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos
        if (Position(pos_x, pos_y + 1) not in dest):
            act = unit_actions[1]
            return call_func(unit, *act), Position(pos_x, pos_y + 1)
        if (Position(pos_x + 1, pos_y) not in dest):
            act = unit_actions[3]
            return call_func(unit, *act), Position(pos_x + 1, pos_y)
    if (target_x < pos_x and target_y > pos_y):
        if (Position(pos_x, pos_y + 1) not in dest and Position(pos_x - 1, pos_y) not in dest):
            label = np.random.randint(1, 3)
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos
        if (Position(pos_x, pos_y + 1) not in dest):
            act = unit_actions[1]
            return call_func(unit, *act), Position(pos_x, pos_y + 1)
        if (Position(pos_x - 1, pos_y) not in dest):
            act = unit_actions[2]
            return call_func(unit, *act), Position(pos_x - 1, pos_y)
    if (target_x > pos_x and target_y < pos_y):
        if (Position(pos_x, pos_y - 1) not in dest and Position(pos_x + 1, pos_y) not in dest):
            label = np.random.randint(0, 2) * 3
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos
        if (Position(pos_x, pos_y - 1) not in dest):
            act = unit_actions[0]
            return call_func(unit, *act), Position(pos_x, pos_y - 1)
        if (Position(pos_x + 1, pos_y) not in dest):
            act = unit_actions[3]
            return call_func(unit, *act), Position(pos_x + 1, pos_y)
    if (target_x < pos_x and target_y < pos_y):
        if (Position(pos_x, pos_y - 1) not in dest and Position(pos_x - 1, pos_y) not in dest):
            label = np.random.randint(0, 2) * 2
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos
        if (Position(pos_x, pos_y - 1) not in dest):
            act = unit_actions[0]
            return call_func(unit, *act), Position(pos_x, pos_y - 1)
        if (Position(pos_x - 1, pos_y) not in dest):
            act = unit_actions[2]
            return call_func(unit, *act), Position(pos_x - 1, pos_y)
    return False


def get_action(policy, unit, player, dest):
    # argsort函数将列表中的元素从小到大排列，提取其对应的index(索引)形成一个新列表
    for label in np.argsort(policy)[::-1]:
        if label == 4 and not unit.can_build(game_state.map):
            # 模型行动不符合行动原则

            return step_out_city(unit, player, dest)
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos

        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos

    return unit.move('c'), unit.pos


def agent(observation, configuration):
    global game_state

    game_state = get_game_state(observation)
    player = game_state.players[observation.player]
    actions = []

    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            # 如果能够行动
            if city_tile.can_act():
                # 如果可以创建新的工人
                if unit_count < player.city_tile_count:
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                # 不能创建新的工人但可行动，那么提升研究点数,达到uranium要求不再提升，以便随时创造unit
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1

    # Worker Actions
    dest = []
    fill_worker_list = []
    fill_citytile_list = []
    if (game_state.turn % 40 >= 24 and game_state.turn % 40 <= 34):
        fill_worker_list, fill_citytile_list = fill_city_fuel_list(player, (34 - game_state.turn % 40) // 2)
    for unit in player.units:
        flag = False
        if (game_state.turn % 40 >= 24 and game_state.turn % 40 <= 34):
            for i in range(len(fill_worker_list)):
                if (unit == fill_worker_list[i]):
                    unit = fill_worker_list[i]
                    citytile = fill_citytile_list[i]
                    if (not get_fill_way(unit, citytile, dest) == False):
                        action, pos = get_fill_way(unit, citytile, dest)
                        actions.append(action)
                        if not in_city(pos):
                            dest.append(pos)
                        flag = True
                        break
        if flag:
            continue
        if unit.can_act() and (game_state.turn % 40 >= 30 and not in_city(unit.pos)):
            if not (rescue_city(unit, player, dest) == False):
                action, pos = step_out_city(unit, player, dest)
                # 添加action与dest记录
                actions.append(action)
                if not in_city(pos):
                    dest.append(pos)
                # print(action, game_state.turn)
                continue
        # worker可以行动且处于白天或者不呆在城市里
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            # torch.no_grad()是一个上下文管理器，被该语句wrap起来的部分将不会track梯度，即不会改变之前grad属性
            with torch.no_grad():
                # 从nparray中建立一个张量 unsqueeze是升高一个维度，将首个维度置为1
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_action(policy, unit, player, dest)
            # 添加action与dest记录
            actions.append(action)
            if not in_city(pos):
                dest.append(pos)
        elif (unit.can_act() and game_state.turn % 40 >= 30 and in_city(unit.pos)):
            action, pos = step_out_city(unit, player, dest)
            # 添加action与dest记录
            actions.append(action)
            if not in_city(pos):
                dest.append(pos)
            # print(action, game_state.turn)
    return actions

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=True)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)

In [None]:
!tar -czf submission.tar.gz *