In [None]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* .

动态调整学习率；修改网络，支持citytile；
添加了救援（多worker救一city）和worker避难；
添加了夜晚出城采集。
特定条件下根据上述规则进行决策，其他时候根据模型的输出。二者之间是存在一定矛盾的，需要权衡。

In [None]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [None]:
def seed_everything(seed_value):
    #random will be the same if getting the same seed_value 
    random.seed(seed_value)
    np.random.seed(seed_value)
    #Sets the seed for generating random numbers for the current GPU
    torch.manual_seed(seed_value)
    #set Python Hash Seed to a specific number,making it able to appear again
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    #indicate if cuda is available 
    if torch.cuda.is_available(): 
        #Sets the seed for generating random numbers for the current GPU
        torch.cuda.manual_seed(seed_value)
        #Sets the seed for generating random numbers on all GPUs.
        torch.cuda.manual_seed_all(seed_value)
        #if True, causes cuDNN to only use deterministic convolution algorithms
        torch.backends.cudnn.deterministic = True
        # if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest.
        torch.backends.cudnn.benchmark = True
#使用一个确定的数表示本组随机数采样的结果，使结果可复现
seed = 42
seed_everything(seed)

# Preprocessing

In [None]:
def to_label(action):
    #将action按照空格分成一个长为3的字符串列表
    strs = action.split(' ')
    unit_id = strs[1]
    
    if strs[0]=="r":
        return 0,0,1,0,int(strs[1]),int(strs[2])
    if strs[0]=="bw":
        return 0,0,1,1,int(strs[1]),int(strs[2])
    
    #m代表move，说明是一个可动单位
    if strs[0] == 'm':
        #north、south、west、east
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    #return unit_id, label, is_citytile, buildworker, x, y
    return unit_id, label, 0, 0, 0, 0

#判断资源是否所有单位全部耗尽的函数（即一方全部消失在黑暗中，游戏结束）
def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name='Toad Brigade'): 
    obses = {}
    samples = []
    append = samples.append
    #从../input/lux-ai-episodes寻找训练模型，并加入episodes中
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name]
    #tqdm为一个显示进度条的库，使得处理进度GUI化
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            #json_load表示从json文件中读取的数据形成的对象
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        #返回rewards数组最大值的索引（本质上是选择最后得分更高的player进行学习）
        index = np.argmax([r or 0 for r in json_load['rewards']])
        #仅使用自己team定义的json文件
        if json_load['info']['TeamNames'][index] != team_name:
            continue
        #对一个json文件的每一步进行操作
        for i in range(len(json_load['steps'])-1):
            #如果在该步执行后存在该"Active"状态说明我方存活
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                #如果该步状态后依然我方active，load下一步状态的所有动作，放入actions中
                actions = json_load['steps'][i+1][index]['action']
                #load当前步骤后的全地图资源状态
                obs = json_load['steps'][i][0]['observation']
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                #将obs重新整合为一个字典
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}_{i}'
                #obses为所有obs字典的列表集合
                obses[obs_id] = obs
                                
                for action in actions:
                    #label为当前执行动作对应的下标
                    unit_id, label, is_citytile, buildworker, x, y = to_label(action)
                    if label is not None:
                        if (not is_citytile):
                            #如果动作有方向或者待在city中，那么添加到列表samples中
                            append((0,0,0,obs_id, unit_id, label))
                        else:
                            append((1,x,y,obs_id, 0, 5+buildworker))
                        
#最终得到全地图每一步执行后的包括资源、各player单位的字典集合obses,所有存活单位每一回合的移动samples
    return obses, samples

In [None]:
#存放train data的数据的path
episode_dir = '../input/lux-ai-top-episodes'
obses, samples = create_dataset_from_json(episode_dir)
print('obses:', len(obses), 'samples:', len(samples))

In [None]:
#大概展示所有训练集
labels = [sample[-1] for sample in samples]
actions = ['north', 'south', 'west', 'east', 'bcity','research','buildworker']
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[value]:^7}: {count:>3}')

# Input For Training

In [None]:
# Input for Neural Network
def make_input(is_citytile, obs, unit_id, x, y):
    #x,y coordiantes
    width, height = obs['width'], obs['height']
    #为使得地图大小层的1数据位于坐标系中心
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}

    b = np.zeros((22, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        #str[1]为0时表示worker，为1时表示cart
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if (unit_id == strs[3] and not is_citytile):
                # Position and Cargo
                b[:2, x, y] = (
                    1,#在第0层有unit的位置标上1
                    (wood + coal + uranium) / 100#在第1层对应有unit的位置，标记其装载材料的比例
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3#结果为2或者5
                #在第2层标记非unit_id的本方单位位置，第3层标记本方cooldown / 6，第4层标记其装载材料的比例
                #在第5层标记非unit_id的敌方单位位置，第6层标记本方cooldown / 6，第7层标记其装载材料的比例
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            idx=0
            if (is_citytile and x == int(strs[3]) and y == int(strs[4])):
                idx = 20
            else:
                idx = 8 + (team - obs['player']) % 2 * 2
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            #第8层标记本方citytile位置，第9层标记该city在黑夜能持续几天
            #第10层标记敌方citytile位置，第11层标记该city在黑夜能持续几天
            #如果恰好是xy处的citytile，就放在20和21层
            #为了让网络也支持citytile的决策，此处在baseline基础上进行了较大修改
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            #第12、13、14层分别放wood、coal、uranium的位置极其装载比例
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            #第15层放我方研究点数
            #第16层放敌方研究点数
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities参数计算
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

class LuxDataset(Dataset):
    def __init__(self, obses, samples):
        self.obses = obses
        self.samples = samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        is_citytile, x, y, obs_id, unit_id, action = self.samples[idx]
        obs = self.obses[obs_id]
        state = make_input(is_citytile, obs, unit_id, x, y)
        
        return state, action

# Model for Training

In [None]:
# Neural Network for Lux AI
class BasicConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        return h


class LuxNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        #model initialization
        self.conv0 = BasicConv2d(22, filters, (3, 3), True)
        #将submodules保存在一个list中。ModuleList可以像一般的Python list一样被索引，即每一个都是用于存储一个layer的
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 7, bias=False)
        #一共7种输出，其中前5种适用于worker，后两种适用于citytile

    #神经网络的前向传播
    def forward(self, x):
        #relu函数的第二个参数inplace默认值为false，表示不改变输入的数据
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        if h_head.sum()==0:
            h_head = (h * x[:,20:21]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p

# Train

In [None]:
def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs):
    best_acc = 0.0
    #动态改变学习率的sheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.1)
    
    for epoch in range(num_epochs):
        #显式地指定model使用gpu
        model.cuda()
        #在Pytorch训练中有两种模式，train为训练模式，eval为评估模式
        for phase in ['train', 'val']:
            if phase == 'train':

                model.train()
            else:
                #测试过程中会使用model.eval()，这时神经网络会沿用batch normalization的值，并不使用dropout，即我们这里不是在训练模型
                #那么就不需要进行dropout改变模型，故调用该函数保证dropout不变
                model.eval()
                
            epoch_loss = 0.0
            epoch_acc = 0
            
            dataloader = dataloaders_dict[phase]
            #对模型进行多次训练
            for item in tqdm(dataloader, leave=False):
                states = item[0].cuda().float()
                actions = item[1].cuda().long()
                #将模型的参数梯度初始化为0
                optimizer.zero_grad()
                #在我们处于训练模式下时，允许计算局部梯度
                with torch.set_grad_enabled(phase == 'train'):
                    policy = model.forward(states)
                    #criterion为损失函数，计算损失
                    #policy对应模型预测，actions是数据集的实际动作，其相似程度体现了模仿的效果
                    loss = criterion(policy, actions)
                    #返回所有张量最大值
                    _, preds = torch.max(policy, 1)

                    #如果处于训练模式下，损失函数backward进行反向传播梯度的计算，并使用优化器的step函数来更新参数
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * len(policy)
                    epoch_acc += torch.sum(preds == actions.data)
            #归一化
            data_size = len(dataloader.dataset)
            epoch_loss = epoch_loss / data_size
            epoch_acc = epoch_acc.double() / data_size
            if phase == 'train':
                scheduler.step()
            print(f'Epoch {epoch + 1}/{num_epochs} | {phase:^5} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}')
        #产生了更好结果的情况
        if epoch_acc > best_acc:
            traced = torch.jit.trace(model.cpu(), torch.rand(1, 22, 32, 32))
            traced.save('model.pth')
            best_acc = epoch_acc

In [None]:
#真正的训练过程实现
model = LuxNet()
#通过train_test_split将原始数据按照比例划分为“测试集”“训练集”
train, val = train_test_split(samples, test_size=0.15, random_state=42, stratify=labels)
batch_size = 64
#进行训练时shuffle为true,打乱数据传输，val时无需这样的额外操作
#train与val各使用2个进程
train_loader = DataLoader(
    LuxDataset(obses, train), 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=2
)
val_loader = DataLoader(
    LuxDataset(obses, val), 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=2
)
dataloaders_dict = {"train": train_loader, "val": val_loader}
#交叉熵损失函数
criterion = nn.CrossEntropyLoss()
#利用Adam(Adaptive Moment Estimation)进行优化
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=30)

# Submission

In [None]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game
from lux.game_map import Cell, RESOURCE_TYPES, Position
from lux.constants import Constants
from lux.game_constants import GAME_CONSTANTS
from lux import annotate
import math
import sys


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
model = torch.jit.load(f'{path}/model.pth')
model.eval()


# Input for Neural Network
#使用模型的必要输入处理
def make_input(is_citytile, obs, unit_id, x, y):
    #x,y coordiantes
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    b = np.zeros((22, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        #split方法通过指定分隔符分隔字符串
        strs = update.split(' ')
        input_identifier = strs[0]
        #a player unit which contains all information
        #str[1]为0时表示worker，为1时表示cart
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if (unit_id == strs[3] and not is_citytile):
                # Position and Cargo
                b[:2, x, y] = (
                    1,#在第0层有unit的位置标上1
                    (wood + coal + uranium) / 100#在第1层对应有unit的位置，标记其装载材料的比例
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3#结果为2或者5
                #在第2层标记非unit_id的本方单位位置，第3层标记本方cooldown / 6，第4层标记其装载材料的比例
                #在第5层标记非unit_id的敌方单位位置，第6层标记本方cooldown / 6，第7层标记其装载材料的比例
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            idx=0
            if (is_citytile and x == int(strs[3]) and y == int(strs[4])):
                idx = 20
            else:
                idx = 8 + (team - obs['player']) % 2 * 2
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            #第8层标记本方citytile位置，第9层标记该city在黑夜能持续几天
            #第10层标记敌方citytile位置，第11层标记该city在黑夜能持续几天
            #如果恰好是xy处的citytile，就放在20和21层
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            #第12、13、14层分别放wood、coal、uranium的位置极其装载比例
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            #第15层放我方研究点数
            #第16层放敌方研究点数
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities参数计算
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state

def sum_fuel(unit):
    return (unit.cargo.wood + unit.cargo.coal * 10 + unit.cargo.uranium * 40)

#判断pos是否在city中的函数
def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False

def call_func(obj, method, args=[]):
    #获取对象的所有属性值
    return getattr(obj, method)(*args)

#unit的所有可能行为列表，方便输出
unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',)]

#查询当前位置有无资源，返回资源块数和燃料因数（燃料值 = 资源块数 * 该种类资源的燃料因数）
def could_get_res(pos, player):
    is_coal = (player.research_points >= 50)
    is_uranium = (player.research_points >= 200)
    fuelparam = {"coal":10,"uranium":40, "wood":1}
    try:
        res = game_state.map.get_cell_by_pos(pos).has_resource()
        if not res:
            return 0, 0
        type = game_state.map.get_cell_by_pos(pos).resource.type
        amount = game_state.map.get_cell_by_pos(pos).resource.amount
        if ((type == "coal" and is_coal) or (type == "uranium" and is_uranium) or type == "wood"):
            return amount, fuelparam[type]
        else:
            return 0, 0
    except:
        return 0, 0

#调用底层函数获取当前citytile中所有的unit形成的列表
def get_units_in_citytile(pos, player):
    unit_list = []
    for unit in player.units:
        if(unit.pos.x == pos.x and unit.pos.y == pos.y and unit.can_act()):
            unit_list.append(unit)
    return unit_list

# 若某worker在city内，判断当前pos是否靠近某可获取资源,用于让夜晚回归的unit再次出去
def step_out_city(unit, player, dest):
    #夜晚时，城内的worker出去后是先采集资源，后消耗资源，所以若worker出城后就能采集到资源，在某种条件下可以让worker出城获得更大利益。
    #需存在一个出城位置（上下左右四个方向）满足条件1、2才可出城（如果是城内的最后一个worker，需满足额外条件**）
    # **.该city内剩余燃料足够撑到黑夜结束 or 即使此worker不出城，在周围四格采集的资源也不足够city撑到黑夜结束。
    # 1.出城当天所在位置的周围三格（见图示）可采集的资源块数在2以上。（至少能活2天）
    # 2.满足条件1的位置的周围三格（见图示）可采集总燃料值（块数 * 燃料因数）在8以上。
    #若没有这样的位置（即subpos1-4都不满足要求，则不出城）
    #出城方向：满足条件1和2的位置中，总燃料值最高的方向。
    incitys = get_units_in_citytile(unit.pos, player)
    x = unit.pos.x
    y = unit.pos.y
    subpos1 = Position(x, y - 1) #下
    subpos2 = Position(x, y + 1) #上
    subpos3 = Position(x - 1, y) #左
    subpos4 = Position(x + 1, y) #右
    subpossub1 = Position(x, y - 2)  # for subpos1
    subpossub2 = Position(x + 1, y - 1)  # for sp1 sp4
    subpossub3 = Position(x - 1, y - 1)  # for sp1 sp3
    subpossub4 = Position(x, y + 2)  # for sp2
    subpossub5 = Position(x - 1, y + 1)  # for sp2 sp3
    subpossub6 = Position(x + 1, y + 1)  # for sp2 sp4
    subpossub7 = Position(x - 2, y)  # for sp3
    subpossub8 = Position(x + 2, y)  # for sp4
    try:
        #获取city周围四格资源情况，用于计算额外条件**
        subresamount1, subresparam1 = could_get_res(subpos1, player)
        subresamount2, subresparam2 = could_get_res(subpos2, player)
        subresamount3, subresparam3 = could_get_res(subpos3, player)
        subresamount4, subresparam4 = could_get_res(subpos4, player)
        cityid = game_state.map.get_cell_by_pos(unit.pos).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        #当前城内燃料
        fuel = check_fuel(player, cityid)
        #剩余黑夜天数
        nightdays = (40 - game_state.turn) % 40
        #度过黑夜要消耗掉的燃料数
        fuelneed = nightdays * light_up_keep
        #周围四格资源块数和燃料值
        aroundamount = subresamount1 + subresamount2 + subresamount3 + subresamount4
        aroundfuel = subresamount1 * subresparam1 + subresamount2 * subresparam2 + \
                     subresamount3 * subresparam3 + subresamount4 * subresparam4
        #若为城内最后一个worker且不满足额外条件**，不出城。
        if (len(incitys) == 1 and fuel < fuelneed and aroundamount > nightdays - (fuelneed - fuel)/light_up_keep
                and aroundfuel > fuelneed - fuel):
            return unit.move('c'), unit.pos
        if(len(incitys) > 1 and incitys[0].id == unit.id):
            return unit.move('c'), unit.pos
        #计算各subpos的周围三格内资源情况，根据条件1、2决定是否出城即出城位置。
        subresamountsub1, subresparamsub1 = could_get_res(subpossub1, player)
        subresamountsub2, subresparamsub2 = could_get_res(subpossub2, player)
        subresamountsub3, subresparamsub3 = could_get_res(subpossub3, player)
        if (subpos1 not in dest and ((subresamount1 and y > 0) or (subresamountsub1 and y > 1) or (
                subresamountsub2 and y > 0 and x < game_state.map.width - 1) or (
                                             subresamountsub3 and y > 0 and x > 0))):
            act = unit_actions[0]
            action1, pos1 = call_func(unit, *act), subpos1
        else:
            action1, pos1 = unit.move('c'), unit.pos
        subresamountsub4, subresparamsub4 = could_get_res(subpossub4, player)
        subresamountsub5, subresparamsub5 = could_get_res(subpossub5, player)
        subresamountsub6, subresparamsub6 = could_get_res(subpossub6, player)
        if (subpos2 not in dest and ((subresamount2 and y < game_state.map.height - 1) or (
                subresamountsub4 and y < game_state.map.height - 2) or (
                                             subresamountsub5 and y < game_state.map.height - 1 and x > 0) or (
                                             subresamountsub6 and y < game_state.map.height - 1 and x < game_state.map.width - 1))):
            act = unit_actions[1]
            action2, pos2 = call_func(unit, *act), subpos2
        else:
            action2, pos2 = unit.move('c'), unit.pos
        subresamountsub7, subresparamsub7 = could_get_res(subpossub7, player)
        if (subpos3 not in dest and ((subresamount3 and x > 0) or (subresamountsub3 and y > 0 and x > 0) or (
                subresamountsub5 and y < game_state.map.height - 1 and x > 0) or (subresamountsub7 and x > 1))):
            act = unit_actions[2]
            action3, pos3 = call_func(unit, *act), subpos3
        else:
            action3, pos3 = unit.move('c'), unit.pos
        subresamountsub8, subresparamsub8 = could_get_res(subpossub8, player)
        if (subpos4 not in dest and ((subresamount4 and x < game_state.map.width - 1) or (
                subresamountsub2 and y > 0 and x < game_state.map.width - 1) or (
                                             subresamountsub6 and y < game_state.map.height - 1 and x < game_state.map.width - 1) or (
                                             subresamountsub8 and x < game_state.map.width - 2))):
            act = unit_actions[3]
            action4, pos4 = call_func(unit, *act), subpos4
        else:
            action4, pos4 = unit.move('c'), unit.pos
        amount1 = subresamount1 + subresamountsub1 + subresamountsub2 + subresamountsub3
        #若不满足条件1，则将燃料值直接置为0
        if(amount1 > 1):
            fuel1 = subresamount1 * subresparam1 + subresamountsub1 * subresparamsub1 \
                    + subresamountsub2 * subresparamsub2 + subresamountsub3 * subresparamsub3
        else:
            fuel1 = 0

        amount2 = subresamount2 + subresamountsub4 + subresamountsub5 + subresamountsub6
        if (amount2 > 1):
            fuel2 = subresamount2 * subresparam2 + subresamountsub4 * subresparamsub4 \
                    + subresamountsub5 * subresparamsub5 + subresamountsub6 * subresparamsub6
        else:
            fuel2 = 0

        amount3 = subresamount3 + subresamountsub3 + subresamountsub5 + subresamountsub7
        if(amount3 > 1):
            fuel3 = subresamount3 * subresparam3 + subresamountsub3 * subresparamsub3 \
                    + subresamountsub5 * subresparamsub5 + subresamountsub7 * subresparamsub7
        else:
            fuel3 = 0

        amount4 = subresamount4 + subresamountsub2 + subresamountsub6 + subresamountsub8
        if(amount4 > 1):
            fuel4 = subresamount4 * subresparam4 + subresamountsub2 * subresparamsub2 \
                    + subresamountsub6 * subresparamsub6 + subresamountsub8 * subresparamsub8
        else:
            fuel4 = 0
        
        #找到获取燃料值最高的出城位置
        if(fuel1 >= 8 and fuel1 >= fuel2 and fuel1 >= fuel3 and fuel1 >= fuel4):
            return action1, pos1
        if(fuel2 >= 8 and fuel2 >= fuel1 and fuel2 >= fuel3 and fuel2 >= fuel4):
            return action2, pos2
        if(fuel3 >= 8 and fuel3 >= fuel1 and fuel3 >= fuel2 and fuel3 >= fuel4):
            return action3, pos3
        if(fuel4 >= 8 and fuel4 >= fuel1 and fuel4 >= fuel2 and fuel4 >= fuel3):
            return action4, pos4

        return unit.move('c'), unit.pos
    except:
        return unit.move('c'), unit.pos
    
#调用底层函数获取当前cityid对应的夜晚一天消耗的fuel
def check_light_up_keep(player, cityid):
    for city in player.cities.values():
        if(city.cityid == cityid):
            return city.get_light_upkeep()
    return False

#调用底层函数获取当前cityid对应的存储燃料数
def check_fuel(player, cityid):
    for city in player.cities.values():
        if(city.cityid == cityid):
            return city.fuel
    return False

#一个简单的单步救援函数，当unit内部有燃料且恰好临近4格有自己资源不足的city时为其补充燃料的函数，输出对应的动作
def rescue_city(unit, player, dest):
    x = unit.pos.x
    y = unit.pos.y
    subpos1 = Position(x,y-1)
    subpos2 = Position(x,y+1)
    subpos3 = Position(x-1,y)
    subpos4 = Position(x+1,y)
    if(y>0 and in_city(subpos1)):
        cityid = game_state.map.get_cell_by_pos(subpos1).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if ((40 - game_state.turn % 40) * light_up_keep > fuel):
            act = unit_actions[0]
            return call_func(unit, *act), subpos1
    if(y<game_state.map.height-1 and in_city(subpos2)):
        cityid = game_state.map.get_cell_by_pos(subpos2).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if ((40 - game_state.turn % 40) * light_up_keep > fuel):
            act = unit_actions[1]
            return call_func(unit, *act), subpos2
    if(x>0 and in_city(subpos3)):
        cityid = game_state.map.get_cell_by_pos(subpos3).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if ((40 - game_state.turn % 40) * light_up_keep > fuel):
            act = unit_actions[2]
            return call_func(unit, *act), subpos3
    if(x<game_state.map.width-1 and in_city(subpos4)):
        cityid = game_state.map.get_cell_by_pos(subpos4).citytile.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        if (40 - game_state.turn % 40) * light_up_keep > fuel:
            act = unit_actions[3]
            return call_func(unit, *act), subpos4
    return False     

#fill_city_way的前期准备函数，获取一个city与对其进行救援的unit的组队，分别存在两个列表的相同位置
#最终获得两个列表，列表中的worker与citytile相互绑定送资源
#可能存在一个city被多个unit救援，但是最终到达的近邻citytile不同，但是都同属于一个cityid
def fill_city_fuel_list(player, max_dis):
    fill_worker_list = []
    fill_citytile_list = []
    city_list = player.cities.values()
    #规模大的城市优先
    sorted(city_list, key=(lambda x:len(x.citytiles)),reverse=True)

    for city in city_list:
        cityid = city.cityid
        light_up_keep = check_light_up_keep(player, cityid)
        fuel = check_fuel(player, cityid)
        extra_need = min(10, 40 - (game_state.turn % 40)) * light_up_keep - fuel
        extra_need = 1.3*extra_need
        if extra_need <= 0:
            continue
        #饱和救援
        available_list=[]
        for unit in player.units:
            #这只是求这一步的动作，所以必须可以act
            if(sum_fuel(unit) <= 40 or not unit.can_act()):
                continue
            min_distance = 100
            citytile = None
            for city_tile in city.citytiles:
                tmp = abs(city_tile.pos.x-unit.pos.x) + abs(city_tile.pos.y-unit.pos.y)
                if tmp > max_dis:
                    continue
                if tmp<min_distance:
                    min_distance=tmp
                    citytile=city_tile
            if(not citytile == None):
                available_list.append([tmp, unit, citytile])
                
        #按距离排序
        sorted(available_list,key=(lambda x:x[0]))
        for available in available_list:
            if extra_need <=0:
                break
            extra_need = extra_need-available[0]
            fill_worker_list.append(available[1])
            fill_citytile_list.append(available[2])
        #if extra_need>=0:
            #print("turn city rescue fail"+str(game_state.turn)+' '+str(cityid))
            
    #借用补充燃料的函数，实现研究点数不到200时worker找城市避难功能
    if player.research_points >= 200:
        return fill_worker_list, fill_citytile_list
    
    for unit in player.units:
        #保证和前面救援不冲突。
        #cart的功能并没有实现。
        if unit.is_cart():
            continue
        if (in_city(Position(unit.pos.x,unit.pos.y)) or not unit.can_act()):
            continue
        tmp=0
        if unit.is_worker():
            tmp=4
        else:
            tmp=10
        if sum_fuel(unit)>=tmp*10:
            continue
        unit_corr = None
        citytile = None
        min_distance = 100
        for city in player.cities.values():
            for city_tile in city.citytiles:
                dis=abs(city_tile.pos.x-unit.pos.x) + abs(city_tile.pos.y-unit.pos.y)
                #假设city一定会有人rescue，这里就不必判断了
                #能活着到达才有效。0.7是考虑到意外情况
                if(dis <= min_distance and tmp*(2*dis-max(0,max_dis*2-10)) <= 0.7*sum_fuel(unit)):
                    min_distance = abs(city_tile.pos.x-unit.pos.x)+abs(city_tile.pos.y-unit.pos.y)
                    citytile = city_tile
                    unit_corr = unit
        #已经为需要避难的worker找到了最近的citytile，但还要判断：如果附近有资源，不必避难
        if(not unit_corr == None and not citytile == None):
            pos_x = unit.pos.x
            pos_y = unit.pos.y
            tmp_range=max(0,max_dis-5)+1
            left_border=int(max(0,pos_x-tmp_range))
            right_border=int(min(game_state.map.width-1,pos_x+tmp_range))
            down_border=int(max(0,pos_y-tmp_range))
            up_border=int(min(game_state.map.height-1,pos_y+tmp_range))
            flag=False
            for x in range(left_border,right_border+1):
                for y in range(down_border, up_border+1):
                    if abs(x-pos_x)+abs(y-pos_y)>tmp_range:
                        continue
                    if(could_get_res(Position(x,y),player)[0]!=0):
                        res=game_state.map.get_cell_by_pos(Position(x,y)).resource
                        #附近有大量资源时不必逃生。理想情况下worker通过模型的决策，会去采集资源的
                        if ((res.type=="uranium" and res.amount*40>250) or (res.type=="coal" and res.amount*10>250) or (res.type=="wood" and res.amount>250)):
                            flag=True
                            break
                if flag:
                    break
            if flag:
                continue
            fill_worker_list.append(unit_corr)
            fill_citytile_list.append(citytile)
            #print("turn worker citytile"+str(game_state.turn)+' '+str(unit_corr.id)+' '+str(citytile.cityid))
                    
    return fill_worker_list, fill_citytile_list       

#为城市补充燃料的寻路的实现函数
#每一个unit被存入先前列表中的unit都会被调用此函数，dest用于保证己方unit不发生碰撞
#get_fill_way函数是通过指定的unit和citytile获取靠近的正确action
def get_fill_way(unit, citytile, dest):
    target_x = citytile.pos.x
    target_y = citytile.pos.y
    pos_x = unit.pos.x
    pos_y = unit.pos.y
    if(target_x == pos_x and target_y > pos_y and Position(pos_x,pos_y+1) not in dest):
        act = unit_actions[1]
        return call_func(unit, *act), Position(pos_x,pos_y+1)
    if(target_x == pos_x and target_y < pos_y and Position(pos_x,pos_y-1) not in dest):
        act = unit_actions[0]
        return call_func(unit, *act), Position(pos_x,pos_y-1)
    if(target_y == pos_y and target_x > pos_x and Position(pos_x+1,pos_y) not in dest):
        act = unit_actions[3]
        return call_func(unit, *act), Position(pos_x+1,pos_y)
    if(target_y == pos_y and target_x < pos_x and Position(pos_x-1,pos_y) not in dest):
        act = unit_actions[2]
        return call_func(unit, *act), Position(pos_x-1,pos_y)
    if(target_x > pos_x and target_y > pos_y):
        if(Position(pos_x,pos_y+1) not in dest and Position(pos_x+1,pos_y) not in dest):
            label = np.random.randint(0, 2) * 2 + 1
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos 
        if(Position(pos_x,pos_y+1) not in dest):
            act = unit_actions[1]
            return call_func(unit, *act), Position(pos_x,pos_y+1)
        if(Position(pos_x+1,pos_y) not in dest):
            act = unit_actions[3]
            return call_func(unit, *act), Position(pos_x+1,pos_y)
    if(target_x < pos_x and target_y > pos_y):
        if(Position(pos_x,pos_y+1) not in dest and Position(pos_x-1,pos_y) not in dest):
            label = np.random.randint(1, 3)
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos 
        if(Position(pos_x,pos_y+1) not in dest):
            act = unit_actions[1]
            return call_func(unit, *act), Position(pos_x,pos_y+1)
        if(Position(pos_x-1,pos_y) not in dest):
            act = unit_actions[2]
            return call_func(unit, *act), Position(pos_x-1,pos_y)
    if(target_x > pos_x and target_y < pos_y):
        if(Position(pos_x,pos_y-1) not in dest and Position(pos_x+1,pos_y) not in dest):
            label = np.random.randint(0, 2) * 3
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos 
        if(Position(pos_x,pos_y-1) not in dest):
            act = unit_actions[0]
            return call_func(unit, *act), Position(pos_x,pos_y-1)
        if(Position(pos_x+1,pos_y) not in dest):
            act = unit_actions[3]
            return call_func(unit, *act), Position(pos_x+1,pos_y)
    if(target_x < pos_x and target_y < pos_y):
        if(Position(pos_x,pos_y-1) not in dest and Position(pos_x-1,pos_y) not in dest):
            label = np.random.randint(0, 2) * 2
            act = unit_actions[label]
            pos = unit.pos.translate(act[-1], 1)
            return call_func(unit, *act), pos 
        if(Position(pos_x,pos_y-1) not in dest):
            act = unit_actions[0]
            return call_func(unit, *act), Position(pos_x,pos_y-1)
        if(Position(pos_x-1,pos_y) not in dest):
            act = unit_actions[2]
            return call_func(unit, *act), Position(pos_x-1,pos_y)
    return False            
    
#根据网络的输出值决定worker的动作
def get_action(policy, unit, player, dest):
    #argsort函数将列表中的元素从小到大排列，提取其对应的index(索引)形成一个新列表
    for label in np.argsort(policy)[::-1]:
        if (label==5 or label==6):
            continue
        if label == 4 and not unit.can_build(game_state.map):
            return step_out_city(unit, player, dest)
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos
            
    return unit.move('c'), unit.pos
    
#对战调用的主函数
def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    tmpcnt=0
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            #如果能够行动
            if city_tile.can_act():
                can_build=0
                if unit_count+tmpcnt < player.city_tile_count:
                    can_build=1
                can_research=0
                if not player.researched_uranium():
                    can_research=1
                if can_build and can_research:
                    state = make_input(1,observation,0,city_tile.pos.x, city_tile.pos.y)
                    #如果citytile既可以buildworker也可以research，需要调用模型进行决策
                    with torch.no_grad():
                        p = model(torch.from_numpy(state).unsqueeze(0))

                    policy = p.squeeze(0).numpy()
                    for label in np.argsort(policy)[::-1]:
                        if not (label==5 or label==6):
                            continue
                        if label==5:
                            actions.append(city_tile.research())
                            player.research_points += 1
                            break
                        else:
                            actions.append(city_tile.build_worker())
                            tmpcnt=tmpcnt+1
                            break

                elif can_build:
                    actions.append(city_tile.build_worker())
                    tmpcnt=tmpcnt+1
                elif can_research:
                    actions.append(city_tile.research())
                    player.research_points += 1
                print()

    # Worker Actions
    dest = []
    fill_worker_list = []
    fill_citytile_list = []
    #为燃料不足的城市准备过夜燃料的逻辑函数调用
    #下面的20是需要调整的超参数，控制了从何时开始准备救援和避难
    if(game_state.turn % 40 >= 20 and game_state.turn % 40 <= 39):
            fill_worker_list, fill_citytile_list = fill_city_fuel_list(player, (39 - game_state.turn % 40) // 2)    
    for unit in player.units:
        flag = False
        if(game_state.turn % 40 >= 20 and game_state.turn % 40 <= 39):
            for i in range(len(fill_worker_list)):
                if(unit == fill_worker_list[i]):
                    unit = fill_worker_list[i]
                    citytile = fill_citytile_list[i]
                    if(not get_fill_way(unit, citytile, dest) == False):
                        action, pos = get_fill_way(unit, citytile, dest)
                        actions.append(action)
                        if not in_city(pos):
                            dest.append(pos)
                        flag = True
                        break
        if(flag):
            continue
        if unit.can_act() and (game_state.turn % 40 >= 30 and not in_city(unit.pos)):
            if not(rescue_city(unit, player, dest) == False):
                action, pos = rescue_city(unit, player, dest)
                #添加action与dest记录
                actions.append(action)
                if not in_city(pos):
                    dest.append(pos)
                #print(action, game_state.turn)
                continue
        #worker可以行动且处于白天或者不呆在城市里，这段时间全权由模型进行控制与输出
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            #if unit.is_cart():
                #todo cart logic
            #else:
            if unit.is_worker():
                state = make_input(0,observation, unit.id,0, 0)
                #torch.no_grad()是一个上下文管理器，被该语句wrap起来的部分将不会track梯度，即不会改变之前grad属性
                with torch.no_grad():
                    #从nparray中建立一个张量 unsqueeze是升高一个维度，将首个维度置为1
                    p = model(torch.from_numpy(state).unsqueeze(0))

                policy = p.squeeze(0).numpy()

                action, pos = get_action(policy, unit, player, dest)
                #添加action与dest记录
                actions.append(action)
                if not in_city(pos):
                    dest.append(pos)
        #黑夜回到city的worker迈出城市继续收集资源，step_out_city函数的调用处
        elif (unit.can_act() and game_state.turn % 40 >= 30 and in_city(unit.pos)):
                action, pos = step_out_city(unit, player, dest)
                #添加action与dest记录
                actions.append(action)
                if not in_city(pos):
                    dest.append(pos)
                #print(action, game_state.turn)
    return actions

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 24, "height": 24, "loglevel": 2, "annotations": True}, debug=True)
steps = env.run(['agent.py', 'agent.py'])
env.render(mode="ipython", width=1200, height=800)