In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

# Config 配置表

In [2]:
suffix = "120521" # 实验id
data_dir = '/home/workspace/lux-ai-2021' # 数据路径
output_dir = "/home/workspace/output" 
output_dir = f"{output_dir}/{suffix}" # 输出文件路径
DEBUG = False # debug调试
SEED = 42 # 随机种子

replays_dir = "above1900" # replays文件夹
only_someteams = True # 只关注特定的比赛队伍
LOWSCORE = 0 # replays的rank分数限制
flip_list = ["xy"] # 数据增强：图像翻转方式
add_flip_data = False # 是否进行全量的、add方式的图像翻转数据增强
flip_p = 0.3 # 使用图像翻转的概率
is_transform = False # 是否使用图像翻转数据增强

NUM_WORKERS = 8 # CPU线程
BATCH_SIZE = 128 # GPU batch size
TEST_PERCENT = 0.1 # 验证集比例

n_early_stopping = 3 # 早停法耐心值
N_EPOCHS = 25 # epoch数量

INIT_LR = 1e-3 # 初始学习率
RLP_FACTOR = 0.1 # ReduceLROnPlateau Factor
RLP_PAT = 1 # ReduceLROnPlateau Patient

SCHEDULER = "CosineAnnealingLR" # 学习率调度器
SCHEDULER_WARMUP = "GradualWarmupSchedulerV3" # 学习率升温调度器
WARMUP_EPO = 1 # 学习率升温epoch数量
WARMUP_FACTOR = 10 # 学习率升温factor倍数
T_MAX= N_EPOCHS-WARMUP_EPO-1 if SCHEDULER_WARMUP=="GradualWarmupSchedulerV3" else N_EPOCHS-1 # 学习率升温factor倍数

LOSS_FUNC = "CrossEntropy" # 损失函数
FOCAL_ALPHA = [0.4, 0.4, 0.4, 0.4, 0.8, 0.8] # Focal loss alpha
FOCAL_GAMMA = 2 # Focal loss gamma
FOCAL_BAL_IND = 5 #  Focal loss index(imbalance feature)

XAVIER_INIT = False # XAVIER初始化

# Import and Initalize

In [3]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
import time
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from warmup_scheduler import GradualWarmupScheduler

In [4]:
os.makedirs(output_dir, exist_ok=True)
simplepy_dir = f"{data_dir}/simple/*"
!cp -r $simplepy_dir $output_dir

In [5]:

def seed_everything(seed_value):
    '''
    随机种子固定
    '''
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

def get_timediff(time1,time2):
    '''
    记录时间差
    '''
    minute_,second_ = divmod(time2-time1,60)
    return f"{int(minute_):02d}:{int(second_):02d}"
    
    
def init_logger(log_file=output_dir+'train.log'):
    '''
    日志记录初始化
    '''
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger(f'{output_dir}/train.log')
seed_everything(SEED)

# Preprocessing

In [6]:
def to_label(action):
    '''
    将6种action从str转化为int
    '''
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': 5, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    return unit_id, label


def depleted_resources(obs):
    '''
    检测地图上资源是否已经枯竭
    '''
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True

def get_act_unit_id_set(actions):
    '''
    获取unit的action
    '''
    act_unit_id_list = []
    for act in actions:
        act = act.split(" ")
        if act[0] in ["m","t","bcity","p"]:
            act_unit_id_list.append(act[1])
    return set(act_unit_id_list)

def get_obs_unit_id_set(updates, team_id):
    '''
    获取unit的状态
    '''
    obs_unit_id_list = []
    for state in updates:
        if state[0] == "u":
            state = state.split(" ")
            if state[2] == str(team_id) and float(state[6]) < 1:
                obs_unit_id_list.append(state[3])
    return set(obs_unit_id_list)

def hometeam_lowscore_lose(episode_dir, ep_id, index, lowscore):
    '''
    判断replays的rank分数，若过低，则不使用
    '''
    if lowscore == 0:
        return False
    with open(f"{episode_dir}/{ep_id}_info.json") as f:
        json_load = json.load(f)
    agents_score = json_load["agents"][index]["updatedScore"]
    if agents_score >= lowscore:
        return False
    elif agents_score < lowscore:
        return True
        

def create_dataset_from_json(episode_dir, team_name=['Toad Brigade']): 
    '''
    从replay中获取所需的 状态 和 action
    '''
    obses = {}
    samples = []
    isnot_attention_team_cnt = 0
    hometeam_lowscore_lose_cnt = 0
    episodes = [path for path in Path(episode_dir).glob('*.json') if 'output' not in path.name and "info" not in path.name]
    
    # 是否全量的加入翻转数据增强
    if flip_list != [] and add_flip_data:
        for flip_method in flip_list:
            flip_episode_dir = episode_dir.replace(replays_dir, f"{replays_dir}_{flip_method}")
            episodes += [path for path in Path(flip_episode_dir).glob('*.json') if "info" not in path.name]
    
    
    for filepath in tqdm(episodes):  # 加载json
        file_flip_method = str(filepath).split("/")[-2].split("_")[-1]
        if file_flip_method not in flip_list:
            file_flip_method = ""
        with open(filepath) as f:
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        win_teamname = json_load['info']['TeamNames'][index]
        
        # 是否只关注特定的Team
        if win_teamname not in team_name:
            isnot_attention_team_cnt += 1
            if only_someteams:
                continue
            if hometeam_lowscore_lose(episode_dir, ep_id, index, LOWSCORE):
                hometeam_lowscore_lose_cnt += 1
                continue

        # 载入replay中的每回合状态和action，并存给变量
        for i in range(len(json_load['steps'])-1):
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
                obs = dict([
                    (k,v) for k,v in obs.items() 
                    if k in ['step', 'updates', 'player', 'width', 'height']
                ])
                obs_id = f'{ep_id}{file_flip_method}_{i}'
                obses[obs_id] = obs
                
                act_unit_id_set = get_act_unit_id_set(actions)
                obs_unit_id_set = get_obs_unit_id_set(obs["updates"], index)
                center_unit_id_set = obs_unit_id_set - act_unit_id_set
                for unit_id in list(center_unit_id_set):
                    samples.append((obs_id, unit_id, 5))
                                
                for action in actions:
                    unit_id, label = to_label(action)
                    if label is not None:
                        # unit_id: unit id
                        # label: 0,1,2,3,4 -> mn, ms, mw, me, bcity
                        samples.append((obs_id, unit_id, label)) 
    if only_someteams:
        print(f"delete not attention replays: {isnot_attention_team_cnt}")
    return obses, samples

In [7]:
obses, samples = create_dataset_from_json(f'{data_dir}/{replays_dir}')
print('obses:', len(obses), 'samples:', len(samples))

  0%|          | 0/1184 [00:00<?, ?it/s]

delete not attention replays: 117
obses: 362171 samples: 6945282


In [8]:
labels = [sample[-1] for sample in samples]
actions_fullname = ['north', 'south', 'west', 'east', 'bcity', 'center']
nswe_maxcnt = 0
# 统计6个action出现的数量
for value, count in zip(*np.unique(labels, return_counts=True)):
    if value in [0,1,2,3]:
        nswe_maxcnt = max(nswe_maxcnt, count)
    elif value == 5:
        center_cnt = count
    print(f'{actions_fullname[value]:^6}: {count:>3}')

north : 651993
south : 640506
 west : 587745
 east : 560672
bcity : 152445
center: 4351921


In [9]:
# center太多，删掉一部分
c_sample_rate = nswe_maxcnt/center_cnt/2 
print(f"c_sample_rate:{c_sample_rate}")

samples_new = []
for samp in samples:
    if samp[2] == 5 and random.random() > (c_sample_rate): 
        continue
    samples_new.append(samp)
    
print("len(samples_new):", len(samples_new))
samples = samples_new

c_sample_rate:0.07490864379201737
len(samples_new): 2920047


In [10]:
# 查看平衡之后的action数量
labels = [sample[-1] for sample in samples]
actions_fullname = ['north', 'south', 'west', 'east', 'bcity', 'center']
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions_fullname[value]:^6}: {count:>3}')

north : 651993
south : 640506
 west : 587745
 east : 560672
bcity : 152445
center: 326686


# Training

In [11]:
# Input for Neural Network
def make_input(obs, unit_id):
    width, height = obs['width'], obs['height'] # 12, 16, 24, or 32
    # 让小棋盘移动到32*32的正中间玩
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                # unit自己: 0,1
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                # 我方的unit: 2,3,4 对方的unit: 5,6,7
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                ) # todo 是不是没有考虑重叠的unit?
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            # 我方的citytiles:8,9, 对方的citytiles:10,11 
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id] # 城市燃料够用多少天（最大10天），且除以10归一化(0,1]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            # 我方的研究点数: 15, 我方的研究点数: 16
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle # 日夜循环中的几点，除以40归一化
    b[17, :] = obs['step'] % 40 / 40
    # Turns # 第几个turn，除以360归一化
    b[18, :] = obs['step'] / 360 
    # Map Size # 地图范围
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


new_action_dict = {0:1, 1:0, 2:3, 3:2, 4:4, 5:5}

# 数据增强时的action更新
def modify_updates(mapsize, updates_list, flip="x"):
    if updates_list == []:
        return []
    elif updates_list == None:
        return None
    
    new_updates_list = []
    for status in updates_list:
        status_split = status.split(" ")
        if status_split[0] == "r":
            if flip in ["xy", "x"]:
                status_split[2] = str(mapsize - int(status_split[2]))
            if flip in ["xy", "y"]:
                status_split[3] = str(mapsize - int(status_split[3]))
            new_updates_list.append(" ".join(status_split))
        elif status_split[0] == "u":
            if flip in ["xy", "x"]:
                status_split[4] = str(mapsize - int(status_split[4]))
            if flip in ["xy", "y"]:
                status_split[5] = str(mapsize - int(status_split[5]))
            new_updates_list.append(" ".join(status_split))
        elif status_split[0] == "ct":
            if flip in ["xy", "x"]:
                status_split[3] = str(mapsize - int(status_split[3]))
            if flip in ["xy", "y"]:
                status_split[4] = str(mapsize - int(status_split[4]))
            new_updates_list.append(" ".join(status_split))
        elif status_split[0] == "ccd":
            if flip in ["xy", "x"]:
                status_split[1] = str(mapsize - int(status_split[1]))
            if flip in ["xy", "y"]:
                status_split[2] = str(mapsize - int(status_split[2]))
            new_updates_list.append(" ".join(status_split))                            
        else:
            new_updates_list.append(status)
    return new_updates_list

In [12]:
class LuxDataset(Dataset):
    def __init__(self, obses, samples, transform=False):
        self.obses = obses
        self.samples = samples
        self.transform = transform
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs_id, unit_id, action = self.samples[idx] # 抽样
        obs = self.obses[obs_id] # 获取地图状态
        
        # 数据增强
        if self.transform:
            if random.random() <= flip_p:
                map_size = obs["height"]-1
                obs["updates"] = modify_updates(map_size, obs["updates"], flip_list[0])
                action = new_action_dict[action]

        # 将地图状态转化成image形式        
        state = make_input(obs, unit_id)
        return state, action

In [13]:
# Neural Network for Lux AI
class BasicConv2d(nn.Module):
    '''
    2D CNN层: Conv2d + BatchNorm2d
    '''
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.conv = nn.Conv2d(
            input_dim, output_dim, 
            kernel_size=kernel_size, 
            padding=(kernel_size[0] // 2, kernel_size[1] // 2)
        )
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        return h


class LuxNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = BasicConv2d(20, filters, (3, 3), True) # 20 * 32
        self.blocks = nn.ModuleList([BasicConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, len(actions_fullname), bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        p = self.head_p(h_head)
        return p

In [14]:
# Focal Loss 代码（最终结果没有使用Focal Loss）
class MultiFocalLoss(nn.Module):
    """
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
    :param num_class:
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
    """

    def __init__(self, num_class=6, alpha=0.3, gamma=2, balance_index=5, smooth=None, size_average=True):
        super(MultiFocalLoss, self).__init__()
        self.num_class = num_class
        self.alpha = alpha
        self.gamma = gamma
        self.smooth = smooth
        self.size_average = size_average

        if self.alpha is None:
            self.alpha = torch.ones(self.num_class, 1)
        elif isinstance(self.alpha, (list, np.ndarray)):
            assert len(self.alpha) == self.num_class
            self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
            self.alpha = self.alpha / self.alpha.sum()
        elif isinstance(self.alpha, float):
            alpha = torch.ones(self.num_class, 1)
            alpha = alpha * (1 - self.alpha)
            alpha[balance_index] = self.alpha
            self.alpha = alpha
        else:
            raise TypeError('Not support alpha type')
        print(self.alpha)
        
        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, input, target):
        logit = F.softmax(input, dim=1)

        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            logit = logit.permute(0, 2, 1).contiguous()
            logit = logit.view(-1, logit.size(-1))
        target = target.view(-1, 1)

        # N = input.size(0)
        # alpha = torch.ones(N, self.num_class)
        # alpha = alpha * (1 - self.alpha)
        # alpha = alpha.scatter_(1, target.long(), self.alpha)
        epsilon = 1e-10
        alpha = self.alpha
        if alpha.device != input.device:
            alpha = alpha.to(input.device)

        idx = target.cpu().long()
        one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        if one_hot_key.device != logit.device:
            one_hot_key = one_hot_key.to(logit.device)

        if self.smooth:
            one_hot_key = torch.clamp(
                one_hot_key, self.smooth, 1.0 - self.smooth)
        pt = (one_hot_key * logit).sum(1) + epsilon
        logpt = pt.log()

        gamma = self.gamma

        alpha = alpha[idx]
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss
    
# 升温调度器 代码
class GradualWarmupSchedulerV3(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV3, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                self.after_scheduler.step()
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

In [15]:
model = LuxNet()


if XAVIER_INIT:
    # XAVIER 初始化
    nn.init.xavier_uniform_(model.conv0.conv.weight)
    nn.init.constant_(model.conv0.conv.bias, 0.0)
    for i in range(12):
        nn.init.xavier_uniform_(model.blocks[i].conv.weight)
        nn.init.constant_(model.blocks[i].conv.bias, 0.0)
    nn.init.xavier_uniform_(model.head_p.weight)

## 数据加载 和 loss, optimizer, scheduler

In [16]:
if DEBUG:
    samples = samples[:1000]
    labels = labels[:1000]

train, val = train_test_split(samples, test_size=TEST_PERCENT, random_state=SEED, stratify=labels)
train_dataset = LuxDataset(obses, train, is_transform)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS
)

val_loader = DataLoader(
    LuxDataset(obses, val, is_transform), 
    batch_size=BATCH_SIZE*2, 
    shuffle=False, 
    num_workers=NUM_WORKERS
)
dataloaders_dict = {"train": train_loader, "val": val_loader}

if LOSS_FUNC == "CrossEntropy":
    criterion = nn.CrossEntropyLoss()
elif LOSS_FUNC == "FocalLoss":
    criterion = MultiFocalLoss(num_class=len(actions_fullname),
                               alpha=FOCAL_ALPHA,
                               gamma=FOCAL_GAMMA,
                               balance_index=FOCAL_BAL_IND)
    
    
optimizer = torch.optim.AdamW(model.parameters(), lr=INIT_LR/WARMUP_FACTOR if SCHEDULER_WARMUP == "GradualWarmupSchedulerV3" else INIT_LR)

if SCHEDULER == "CosineAnnealingLR":
    scheduler = CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=1e-7)
elif SCHEDULER == "ReduceLROnPlateau":
    scheduler = ReduceLROnPlateau(optimizer, 'max', factor=RLP_FACTOR, patience=RLP_PAT, min_lr=1e-7)

if SCHEDULER_WARMUP =="GradualWarmupSchedulerV3":
    scheduler_warmup = GradualWarmupSchedulerV3(optimizer, multiplier=WARMUP_FACTOR, total_epoch=WARMUP_EPO, after_scheduler=scheduler)

In [17]:
%%time
best_acc = 0.0
valid_acc_max_cnt = 0
LOGGER.info(f'epoch|times|   lr   |tloss vloss| t_acc   n    s    w    e    b    c | v_acc   n     s     w     e     b     c   ')

for epoch in range(N_EPOCHS):
    start_time = time.time()
    model.cuda()

    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        epoch_loss = 0.0
        epoch_acc = 0
        acc_true_list = np.zeros(len(actions_fullname))
        acc_cnt_list  = np.zeros(len(actions_fullname))

        dataloader = dataloaders_dict[phase]
        for item in tqdm(dataloader, leave=False):
            # item存入gpu
            states = item[0].cuda().float()
            actions = item[1].cuda().long()

            optimizer.zero_grad() # 优化器置零
            with torch.set_grad_enabled(phase == 'train'):
                policy = model(states) # 模型训练
                loss = criterion(policy, actions) # 计算loss
                _, preds = torch.max(policy, 1) # 获取预测的action

                if phase == 'train':
                    loss.backward() # 反向传播
                    optimizer.step() # 优化器迭代

                epoch_loss += loss.item() * len(policy) # loss记录
                # accuracy记录
                epoch_true = preds == actions.data 
                epoch_acc += torch.sum(epoch_true)

                # 全局accuracy记录
                for i in range(len(actions_fullname)):
                    acc_true_list[i] += torch.sum((actions.data == i)  & (epoch_true))
                    acc_cnt_list[i]  += torch.sum((actions.data == i))

        data_size = len(dataloader.dataset)

        if phase == 'train':
            train_loss = epoch_loss / data_size
            train_acc_l = [epoch_acc.double() / data_size]
            train_acc_l.extend(acc_true_list/acc_cnt_list)
        elif phase == 'val':
            valid_loss = epoch_loss / data_size
            valid_acc_l = [epoch_acc.double() / data_size]
            valid_acc_l.extend(acc_true_list/acc_cnt_list)


    cur_lr = optimizer.param_groups[0]["lr"] 
    elapsed_time = get_timediff(start_time, time.time())
    LOGGER.info(f'{epoch+1:>2}/{N_EPOCHS}|{elapsed_time}|{cur_lr:.2e}|{train_loss:.3f} {valid_loss:.3f}|{train_acc_l[0]:.4f} {train_acc_l[1]:.2f} {train_acc_l[2]:.2f} {train_acc_l[3]:.2f} {train_acc_l[4]:.2f} {train_acc_l[5]:.2f} {train_acc_l[6]:.2f}|{valid_acc_l[0]:.4f} {valid_acc_l[1]:.3f} {valid_acc_l[2]:.3f} {valid_acc_l[3]:.3f} {valid_acc_l[4]:.3f} {valid_acc_l[5]:.3f} {valid_acc_l[6]:.3f} {"*" if valid_acc_l[0] > best_acc else ""}')

    # 调度器更新
    if SCHEDULER_WARMUP == "GradualWarmupSchedulerV3":
        scheduler_warmup.step()
    elif SCHEDULER_WARMUP in ["CosineAnnealingLR","CosineAnnealingWarmRestarts"]:
        scheduler.step()
    elif SCHEDULER_WARMUP == "ReduceLROnPlateau":
        scheduler.step(epoch_acc)

    # 保存模型和earlystopping
    if valid_acc_l[0] > best_acc:
        traced = torch.jit.trace(model.cpu(), torch.rand(1, 20, 32, 32))
        traced.save(f'{output_dir}/model.pth')
        best_acc = valid_acc_l[0]
        valid_acc_max_cnt=0
    else:
        valid_acc_max_cnt+=1
    
    if valid_acc_max_cnt >= n_early_stopping and not DEBUG:
        LOGGER.info("EarlyStop")
        break

epoch|times|   lr   |tloss vloss| t_acc   n    s    w    e    b    c | v_acc   n     s     w     e     b     c   


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 1/25|12:12|1.00e-04|0.980 0.834|0.6154 0.64 0.64 0.63 0.61 0.59 0.51|0.6756 0.709 0.652 0.699 0.695 0.745 0.545 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 2/25|12:08|1.00e-03|0.819 0.743|0.6809 0.71 0.70 0.70 0.68 0.67 0.56|0.7120 0.769 0.707 0.728 0.683 0.781 0.598 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 3/25|12:09|9.91e-04|0.725 0.706|0.7176 0.74 0.74 0.73 0.72 0.72 0.59|0.7252 0.754 0.748 0.709 0.740 0.739 0.620 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 4/25|12:09|9.77e-04|0.694 0.687|0.7294 0.75 0.75 0.75 0.73 0.74 0.61|0.7325 0.716 0.777 0.737 0.782 0.740 0.583 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 5/25|12:09|9.54e-04|0.676 0.670|0.7362 0.76 0.76 0.75 0.74 0.75 0.61|0.7383 0.721 0.750 0.768 0.803 0.664 0.617 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 6/25|12:09|9.23e-04|0.664 0.665|0.7413 0.76 0.76 0.76 0.75 0.75 0.62|0.7408 0.765 0.776 0.722 0.731 0.766 0.663 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 7/25|12:10|8.84e-04|0.654 0.661|0.7447 0.77 0.76 0.76 0.75 0.76 0.62|0.7422 0.799 0.745 0.728 0.781 0.718 0.594 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 8/25|12:10|8.37e-04|0.646 0.669|0.7478 0.77 0.77 0.76 0.75 0.76 0.63|0.7393 0.736 0.785 0.753 0.721 0.836 0.617 


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

 9/25|12:09|7.85e-04|0.639 0.645|0.7506 0.77 0.77 0.77 0.76 0.77 0.63|0.7472 0.781 0.724 0.771 0.763 0.769 0.646 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

10/25|12:10|7.27e-04|0.631 0.636|0.7535 0.77 0.77 0.77 0.76 0.77 0.63|0.7514 0.759 0.751 0.782 0.764 0.780 0.649 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

11/25|12:09|6.64e-04|0.625 0.638|0.7563 0.78 0.77 0.77 0.76 0.77 0.64|0.7502 0.770 0.767 0.776 0.751 0.810 0.602 


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

12/25|12:09|5.99e-04|0.618 0.630|0.7588 0.78 0.78 0.78 0.77 0.77 0.64|0.7528 0.776 0.765 0.770 0.735 0.831 0.645 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

13/25|12:02|5.32e-04|0.610 0.623|0.7619 0.78 0.78 0.78 0.77 0.78 0.64|0.7565 0.766 0.797 0.774 0.725 0.783 0.668 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

14/25|12:09|4.64e-04|0.603 0.617|0.7648 0.78 0.78 0.78 0.77 0.78 0.65|0.7592 0.775 0.773 0.779 0.765 0.756 0.659 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

15/25|12:09|3.96e-04|0.595 0.615|0.7679 0.79 0.78 0.78 0.78 0.78 0.65|0.7602 0.792 0.740 0.778 0.782 0.782 0.657 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

16/25|12:10|3.31e-04|0.587 0.607|0.7712 0.79 0.79 0.79 0.78 0.79 0.65|0.7635 0.794 0.774 0.761 0.788 0.796 0.629 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

17/25|12:10|2.69e-04|0.579 0.601|0.7745 0.79 0.79 0.79 0.78 0.79 0.66|0.7659 0.762 0.782 0.774 0.792 0.767 0.683 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

18/25|12:10|2.11e-04|0.570 0.596|0.7780 0.79 0.79 0.79 0.79 0.79 0.66|0.7674 0.789 0.784 0.793 0.766 0.775 0.646 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

19/25|12:10|1.58e-04|0.562 0.590|0.7812 0.80 0.80 0.80 0.79 0.80 0.67|0.7710 0.793 0.780 0.769 0.782 0.802 0.679 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

20/25|12:10|1.12e-04|0.555 0.584|0.7843 0.80 0.80 0.80 0.79 0.80 0.67|0.7733 0.796 0.791 0.792 0.778 0.759 0.658 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

21/25|12:10|7.25e-05|0.548 0.580|0.7873 0.80 0.80 0.80 0.79 0.81 0.67|0.7748 0.784 0.795 0.793 0.778 0.802 0.666 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

22/25|12:10|4.13e-05|0.542 0.577|0.7894 0.81 0.80 0.80 0.80 0.81 0.68|0.7758 0.789 0.796 0.796 0.777 0.797 0.662 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

23/25|12:10|1.86e-05|0.538 0.576|0.7912 0.81 0.81 0.81 0.80 0.81 0.68|0.7764 0.797 0.789 0.788 0.783 0.797 0.667 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

24/25|12:10|4.73e-06|0.535 0.575|0.7923 0.81 0.81 0.81 0.80 0.81 0.68|0.7766 0.794 0.788 0.790 0.779 0.803 0.678 *


  0%|          | 0/20532 [00:00<?, ?it/s]

  0%|          | 0/1141 [00:00<?, ?it/s]

25/25|12:10|1.00e-07|0.534 0.574|0.7928 0.81 0.81 0.81 0.80 0.81 0.68|0.7770 0.795 0.790 0.792 0.783 0.799 0.669 *


CPU times: user 4h 54min 32s, sys: 10min 48s, total: 5h 5min 20s
Wall time: 5h 4min 8s


# Submission

## 生成 agent.py 用于submission

In [18]:
%%writefile agent.py
import os
import numpy as np
import torch
from lux.game import Game


path = '/kaggle_simulations/agent' if os.path.exists('/kaggle_simulations') else '.'
model = torch.jit.load(f'{path}/model.pth')
model.eval()


def make_input(obs, unit_id):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    
    b = np.zeros((20, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (
                    1,
                    (wood + coal + uranium) / 100
                )
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs['player']) % 2 * 3
                b[idx:idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100
                )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{'wood': 12, 'coal': 13, 'uranium': 14}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10
    
    # Day/Night Cycle
    b[17, :] = obs['step'] % 40 / 40
    # Turns
    b[18, :] = obs['step'] / 360
    # Map Size
    b[19, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b


game_state = None
def get_game_state(observation):
    global game_state
    
    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):    
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [('move', 'n'), ('move', 's'), ('move', 'w'), ('move', 'e'), ('build_city',), ('move', 'c')]
def get_action(policy, unit, dest):
    for label in np.argsort(policy)[::-1]:
        act = unit_actions[label]
        pos = unit.pos.translate(act[-1], 1) or unit.pos
        if pos not in dest or in_city(pos):
            return call_func(unit, *act), pos 
            
    return unit.move('c'), unit.pos


def agent(observation, configuration):
    global game_state
    
    game_state = get_game_state(observation)    
    player = game_state.players[observation.player]
    actions = []
    
    # City Actions
    unit_count = len(player.units)
    for city in player.cities.values():
        for city_tile in city.citytiles:
            if city_tile.can_act():
                if unit_count < player.city_tile_count: 
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1
    
    # Worker Actions
    dest = []
    for unit in player.units:
        if unit.can_act() and (game_state.turn % 40 < 30 or not in_city(unit.pos)):
            state = make_input(observation, unit.id)
            with torch.no_grad():
                p = model(torch.from_numpy(state).unsqueeze(0))

            policy = p.squeeze(0).numpy()

            action, pos = get_action(policy, unit, dest)
            actions.append(action)
            dest.append(pos)

    return actions

Writing agent.py


## 测试是否work，并打包

In [19]:
!mv agent.py $output_dir
os.chdir(output_dir)

from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 12, "height": 12, "loglevel": 2, "annotations": True}, debug=False)
steps = env.run([f'agent.py', f'agent.py'])
env.render(mode="ipython", width=950, height=800)

Loading environment football failed: No module named 'gfootball'


In [20]:
zipname = f"{suffix}_submission.tar.gz"
!tar -czf $zipname *