# Off policy reinforcement learning approach 


* please check [my kaggle post](https://www.kaggle.com/c/lux-ai-2021/discussion/294601) or my [github](https://github.com/Fkaneko/kaggle_lux_ai) for more detail.
![model](https://user-images.githubusercontent.com/61892693/145679319-db3b0f4e-e1eb-449d-94f8-c048d1bf2f79.png)
* Learning result
  - evaluate every 100 epochs
  - match 50 episodes with [imitation learning baseline model](https://www.kaggle.com/realneuralnetwork/lux-ai-with-il-decreasing-learning-rate) (score ~1350) without internal unit move resolution
  
  
| epoch | win rate | tile diff |
| --- | --- | --- |
|100 |  0.058 |   -28.4|
|200 |  0.212 |   -13.0|
|300 |  0.442 |    -0.0|
|400 |  0.615 |    10.4|
|500 |  0.731 |    15.4|
|600 |  0.673 |     9.7|
|700 |  0.712 |    14.1|
|800 |  0.750 |    15.1|

In [None]:
!pip install kaggle-environments -U > /dev/null 2>&1
!cp -r ../input/lux-ai-2021/* .

In [None]:
# !cp  ../input/lux-jit-models/sub_models/*.pth .
!cp ../input/lux-ai-off-policy-models/lux_ai_off_policy_models/*.onnx .
!cp ../input/onnxruntime-wheel-for-lux/onnxruntime_wheel/onnxruntime-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl .

# Load Opponent: Imitation Learning model(score ~ 1350)
* use imitation baseline -> https://www.kaggle.com/huikang/lux-ai-agent-evaluation

In [None]:
!cp ../input/lux-ai-with-il-decreasing-learning-rate/model.pth .
!cp ../input/lux-ai-with-il-decreasing-learning-rate/agent.py  ./agent_imitation_baseline.py

# 

In [None]:
%%writefile agent.py
import copy
import json
import math
import os
import random
import subprocess
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import numpy as np
import skimage.measure
import skimage.transform
import torch
from lux.game import Game

path = "/kaggle_simulations/agent" if os.path.exists("/kaggle_simulations") else "."
# from https://stackoverflow.com/questions/4256107/running-bash-commands-in-python
onnxruntime_wheel_path = os.path.join(
    path, "onnxruntime-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
)
bashCommand = "pip install --disable-pip-version-check" + " " + onnxruntime_wheel_path
process = subprocess.Popen(
    bashCommand.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
output, error = process.communicate()
import onnxruntime as ort

model_path = os.path.join(path, "model_rl.onnx")  # 
model = ort.InferenceSession(model_path)

ACTION_DIRECTIONS = [(0, -1), (0, +1), (-1, 0), (1, 0)]


class UnitOrder(NamedTuple):
    step: int
    order_strings: List[str]


def get_cargo_fuel_value(unit):
    """
    Returns the fuel-value of all the cargo this unit has.
    """
    wood_rate = 1
    coal_rate = 10
    uranium_rate = 40
    if hasattr(unit, "get_cargo_fuel_value"):
        return unit.get_cargo_fuel_value()
    else:
        return (
            unit.cargo.wood * wood_rate
            + unit.cargo.coal * coal_rate
            + unit.cargo.uranium * uranium_rate
        )


def to_numpy(tensor):
    if isinstance(tensor, torch.Tensor):
        return (
            tensor.detach().cpu().numpy()
            if tensor.requires_grad
            else tensor.cpu().numpy()
        )
    if isinstance(tensor, np.ndarray):
        return tensor



def get_unit_sequence_obs(
    game,
    player: int,
    b_active: np.ndarray,
    can_act_units: Dict[str, Any],
    turn: int = 0,
    unit_length: int = 64,
    action_dim: int = 5,
    input_dim=4,
):
    """
    Implements getting a observation from the current game for this unit or city
    """

    if len(can_act_units) == 0:
        obs = {
            "image": b_active.astype(np.float32),
            "input_sequence": np.zeros((unit_length, input_dim), dtype=np.float32),
            # "orig_length": 0,
            "sequence_mask": np.zeros((unit_length, 1), dtype=np.int64),
            "rule_mask": np.zeros((unit_length, action_dim), dtype=np.int64),
        }
        unit_order = UnitOrder(step=turn, order_strings=[])
        return obs, unit_order

    ordered_unit_points, ordered_units = order_unit(
        units=list(can_act_units.values()),
        is_debug=False,
        random_start=False,
    )

    (
        orig_length,
        sequence_mask,
        input_sequence,
        _,
        unit_order,
        action_masks,
    ) = generate_sequence(
        game=game,
        state=b_active,
        can_act_units=can_act_units,
        ordered_units=ordered_units,
        max_sequence=unit_length,  # or len(can_act_units)
        input_size=b_active.shape[1:],
        # input_size=input_size,
        # no_action=no_action,
        # in_features=in_features,
        # ignore_class_index=ignore_class_index,
        actions=None,
        action_length=action_dim,
    )

    obs = {
        "image": b_active.astype(np.float32),
        "input_sequence": input_sequence.astype(np.float32),
        # "orig_length": orig_length,
        "sequence_mask": sequence_mask.astype(np.int64),
        "rule_mask": action_masks.astype(np.int64),
    }
    unit_order = UnitOrder(step=turn, order_strings=unit_order)
    return obs, unit_order


def order_points(points, units, ind):
    points_new = [
        points.pop(ind)
    ]  # initialize a new list of points with the known first point
    units_new = [
        units.pop(ind)
    ]  # initialize a new list of points with the known first point
    pcurr = points_new[-1]  # initialize the current point (as the known point)
    while len(points) > 0:
        d = np.linalg.norm(
            np.array(points) - np.array(pcurr), axis=1
        )  # distances between pcurr and all other remaining points
        ind = d.argmin()  # index of the closest point
        points_new.append(points.pop(ind))  # append the closest point to points_new
        units_new.append(units.pop(ind))  # append the closest point to points_new
        pcurr = points_new[-1]  # update the current point
    return points_new, units_new


def order_unit(
    units: List[Any],
    is_debug: bool = False,
    random_start: bool = False,
    split_proba: float = 0.1,
) -> List[Tuple[int, int]]:
    """
    from
    https://stackoverflow.com/questions/37742358/sorting-points-to-form-a-continuous-line
    """
    # assemble the x and y coordinates into a list of (x,y) tuples:
    ordered_units = sorted(units, key=lambda x: (x.pos.y, x.pos.x))
    points = [(unit.pos.x, unit.pos.y) for unit in ordered_units]

    # order the points based on the known first point:
    start_ind = 0
    if random_start:
        start_ind = random.choice(range(len(points)))
    points_new, units_new = order_points(
        points.copy(), units=ordered_units.copy(), ind=start_ind
    )

    if random_start:
        points_new = points_new[::-1]
        if random.random() < split_proba:
            split_ind = random.choice(range(len(points)))
            points_new = points_new[split_ind:] + points_new[:split_ind]

    if is_debug:
        fig, ax = plt.subplots(1, 2, figsize=(10, 4))
        points = np.stack(points, axis=0)
        xn, yn = np.array(points_new).T
        ax[0].plot(points[:, 0], points[:, 1], "-o")  # original (shuffled) points
        ax[1].plot(xn, yn, "-o")  # new (ordered) points
        ax[0].set_title("Original")
        ax[1].set_title("Ordered")
        ax[0].grid()
        ax[1].grid()
        plt.tight_layout()
        plt.show()
    return points_new, units_new


def generate_sequence(
    game: Any,
    state: np.ndarray,
    can_act_units: Dict[str, Any],
    # ordered_unit_points: List[Tuple[int, int]],
    ordered_units: List[Any],
    input_size: Tuple[int, int] = (32, 32),
    max_sequence: int = 128,
    no_action: int = 6,
    ignore_class_index: int = 20,
    in_features: int = 4,
    actions: Optional[Dict[str, int]] = None,
    action_length: int = 5,
):

    # unit feature
    # (N_u, 4), x, y, act_tpye, dirction/None
    if actions is None:
        actions = {}

    unit_feature = []
    unit_order = []
    action_masks = []

    ban_map = generate_pos_ban_map(b_active=state)
    for unit in ordered_units:
        action_mask = np.ones((action_length,), dtype=np.int64)
        unit_id = unit.id
        y_shift, x_shift = calc_game_map_shift(input_size=input_size, game=game)
        unit_feature.append(
            [
                unit.pos.x + x_shift,
                unit.pos.y + y_shift,
                can_act_units[unit_id].get_cargo_space_left(),
                get_cargo_fuel_value(can_act_units[unit_id]),
                actions.pop(unit_id, no_action),
            ]
        )
        assert (
            # state[CH_MAP["UnitPos"], unit_feature[-1][1], unit_feature[-1][0]]
            state[2, unit_feature[-1][1], unit_feature[-1][0]]
            == 1.0
        )
        unit_order.append(unit_id)
        action_mask[4] = int(can_act_units[unit_id].can_build(game.map))
        action_mask[:4] = not_go_pos(
            pos_y=unit_feature[-1][1], pos_x=unit_feature[-1][0], ban_map=ban_map
        )
        action_masks.append(action_mask)

    assert len(actions) == 0
    unit_feature = np.stack(unit_feature, axis=0)
    action_masks = np.stack(action_masks, axis=0)
    unit_feature = unit_feature[:max_sequence]
    action_masks = action_masks[:max_sequence]
    # debug
    # fig, axes = plt.subplots(1, 3)
    # state_map = state[CH_MAP["UnitPos"]] -  (state[CH_MAP["UnitCooldown"]] > 0).astype(int)
    # cell_map = np.zeros_like(state_map)
    # for x, y in unit_feature[:, :2]:
    #     cell_map[y, x] = 1.
    # axes[0].imshow(state_map)
    # axes[1].imshow(cell_map)
    # axes[2].imshow((state_map - cell_map) * 0.5)

    orig_length = unit_feature.shape[0]
    pad_len = max_sequence - orig_length
    if pad_len > 0:
        unit_feature = np.pad(unit_feature, [[0, pad_len], [0, 0]], constant_values=-1)
        action_masks = np.pad(action_masks, [[0, pad_len], [0, 0]], constant_values=-1)
    sequence_mask = np.zeros_like(unit_feature[:, 0:1])
    sequence_mask[:orig_length, :] = 1.0

    input_sequence = unit_feature[:, :in_features] / 100.0
    output_sequence = unit_feature[:, in_features:]
    output_sequence[output_sequence == -1] = ignore_class_index

    return (
        orig_length,
        sequence_mask,
        input_sequence,
        output_sequence,
        unit_order,
        action_masks,
    )


def calc_game_map_shift(input_size: List[int], game: Any) -> Tuple[int, int]:
    x_shift = (input_size[1] - game.map.width) // 2
    y_shift = (input_size[0] - game.map.height) // 2
    return (y_shift, x_shift)


def crop_state(state: np.ndarray, game: Any, input_size=[32, 32]):
    y_shift, x_shift = calc_game_map_shift(input_size=input_size, game=game)

    if state.ndim == 2:
        state = state[
            y_shift : input_size[1] - y_shift, x_shift : input_size[0] - x_shift
        ]
    elif state.ndim == 3:
        state = state[
            :, y_shift : input_size[1] - y_shift, x_shift : input_size[0] - x_shift
        ]
    return state


def pred_with_onnx(model: ort.InferenceSession, obs: Dict[str, np.ndarray]):
    ort_inputs = {}
    for ort_in in model.get_inputs():
        ort_inputs[ort_in.name] = to_numpy(obs[ort_in.name])[
            np.newaxis,
        ]
    out = model.run(
        None,
        ort_inputs,
    )
    action_logit = out[0]
    action_logit = (action_logit - (obs["rule_mask"] == 0) * 1e32).squeeze(0)
    return action_logit


def get_resource_distribution(b_active: np.ndarray, game: Any) -> np.ndarray:
    # resoure_map = b_active[[12, 13, 14]].transpose(1, 2, 0)
    # resoure_map = b_active[[12, 13, 14]]
    input_size = (32, 32)
    y_shift, x_shift = calc_game_map_shift(input_size=input_size, game=game)

    research_point = int(b_active[[15]].max() * 200)
    research_mask = [True, research_point >= 50, research_point >= 200]

    fuel_rate = [1, 10, 40]
    resoure_ch = [12, 13, 14]
    resoure_map = np.zeros_like(b_active[[12]])

    for resouce_index, mask in enumerate(research_mask):
        if mask:
            resoure_map += (
                b_active[[resoure_ch[resouce_index]]] * fuel_rate[resouce_index]
            )

    resoure_map = resoure_map[
        :, y_shift : input_size[1] - y_shift, x_shift : input_size[0] - x_shift
    ]

    res_avg_map = torch.nn.functional.avg_pool2d(
        input=torch.from_numpy(resoure_map), kernel_size=5, stride=1, padding=2
    ).numpy()

    return res_avg_map, resoure_map


def get_act_cities_map(player_cities: list, game_map: Any):
    act_cities_map = np.zeros((1, game_map.height, game_map.width), dtype=np.float32)
    posyx2tile = {}

    for city in player_cities:
        if hasattr(city, "city_cells"):
            for cell in city.city_cells:
                city_tile = cell.city_tile
                if city_tile.can_act():
                    act_cities_map[:, city_tile.pos.y, city_tile.pos.x] = 1.0
                    if city_tile.pos.y in posyx2tile.keys():
                        posyx2tile[city_tile.pos.y].update({city_tile.pos.x: city_tile})
                    else:
                        posyx2tile[city_tile.pos.y] = {city_tile.pos.x: city_tile}

        else:
            for city_tile in city.citytiles:
                if city_tile.can_act():
                    act_cities_map[:, city_tile.pos.y, city_tile.pos.x] = 1.0
                    if city_tile.pos.y in posyx2tile.keys():
                        posyx2tile[city_tile.pos.y].update({city_tile.pos.x: city_tile})
                    else:
                        posyx2tile[city_tile.pos.y] = {city_tile.pos.x: city_tile}

    return act_cities_map, posyx2tile


def decide_worker_gen_place(
    res_avg_map: np.ndarray,
    act_cities_map_max: np.ndarray,
    resource_map: np.ndarray,
    num_units: int,
    num_city_tiles: int,
):

    places = serarch_with_pooled_feats(res_avg_map, act_cities_map_max)
    if len(places) == 0:
        squeeze_factor = 3
        orig_shape = resource_map.shape[1:]
        kernel_size = (
            resource_map.shape[1] // squeeze_factor,
            resource_map.shape[2] // squeeze_factor,
        )
        pooled_res = skimage.measure.block_reduce(
            resource_map.squeeze(), kernel_size, np.mean
        )
        pooled_res = skimage.transform.resize(
            pooled_res,
            tuple(orig_shape),
            order=None,
            mode="constant",
            clip=True,
            preserve_range=False,
        )

        places = serarch_with_pooled_feats(pooled_res, act_cities_map_max)
    return places[: num_city_tiles - num_units]


def serarch_with_pooled_feats(res_avg_map, act_cities_map_max):
    intersection = (res_avg_map * act_cities_map_max).squeeze()
    places = np.where(intersection > 0)
    if places[0].shape[0] > 0:
        value = intersection[places]
        places = np.stack(places, axis=-1)
        return places[np.argsort(value)[::-1]]
    else:
        return []


def check_action_plan(
    action_code: int,
    our_city: np.ndarray,
    pos_x: int,
    pos_y: int,
    current_plan: np.ndarray,
    is_center: bool = False,
):
    assert type(action_code) == int or np.int64
    is_center = is_center or (action_code == 4)

    use_cooldown_as_center = False
    if is_center:
        pos_y_next = pos_y
        pos_x_next = pos_x
    else:
        direc = ACTION_DIRECTIONS[action_code]
        pos_y_next = direc[1] + pos_y
        pos_x_next = direc[0] + pos_x
        # for random agent
        if (
            (pos_x_next >= our_city.shape[1])
            or (pos_y_next >= our_city.shape[0])
            or (pos_x_next < 0)
            or (pos_y_next < 0)
        ):
            pos_y_next = pos_y
            pos_x_next = pos_x

    is_no_unit = current_plan[pos_y_next, pos_x_next]
    is_citytile = our_city[pos_y_next, pos_x_next]
    is_ok = is_no_unit or is_citytile
    if is_ok:
        if not is_citytile:
            current_plan[pos_y_next, pos_x_next] = False
    else:
        current_plan[pos_y, pos_x] = False
        if not is_center:
            use_cooldown_as_center = True

    return current_plan, use_cooldown_as_center


def check_is_center_action(action_code: int):
    return action_code == 5


def generate_pos_ban_map(b_active: np.ndarray):
    our_units = b_active[3] > 0
    our_city = b_active[8] > 0

    opp_units = b_active[6] > 0
    unit_stack_map = np.logical_or(our_units, opp_units)
    unit_stack_map[our_city] = False

    map_range = b_active[19]
    opp_city = b_active[10]
    ban_map = np.logical_or((map_range == 0), (opp_city > 0))
    ban_map = np.logical_or(ban_map, unit_stack_map)
    return ban_map


class StateHist:
    def __init__(self, input_size: List[int] = [32, 32], unit_hist_length: int = 4):
        self.input_size = input_size
        self.unit_hist_length = unit_hist_length
        self.reset()

    def reset(self):
        self.unit_hists = np.zeros(
            [self.unit_hist_length] + self.input_size, dtype=np.float32
        )

        self.initial_resource_map = np.zeros(self.input_size, dtype=np.float32)
        self.resource_reduction_map = np.zeros(self.input_size, dtype=np.float32)
        self.resource_delta = np.zeros(self.input_size, dtype=np.float32)

        self.opponent_unit_track = np.zeros(self.input_size, dtype=np.float32)
        self.player_unit_track = np.zeros(self.input_size, dtype=np.float32)
        self.step = 0

    def initialize(self, initial_state: np.ndarray):
        self.initial_resource_map = self._extract_resource(b_active=initial_state)
        self.resource_reduction_map = self.initial_resource_map / (
            self.initial_resource_map + 1e-6
        )

    def _extract_resource(self, b_active: np.ndarray):
        return b_active[[12, 13, 14]].sum(axis=0)

    def _get_unit_hist(self, b_active: np.ndarray):
        our_units = b_active[2]
        opp_units = b_active[5]
        unit_hist = our_units + opp_units * 0.5
        return our_units, opp_units, unit_hist

    def get_hist(self):
        targets = [
            self.resource_reduction_map,
            self.resource_delta,
            self.player_unit_track,
            self.opponent_unit_track,
        ]
        targets = np.stack(targets, axis=0)
        targets = np.concatenate([self.unit_hists, targets], axis=0)
        return targets[:-2], targets[-2:]

    def update(self, current_state: np.ndarray, env_step: int):
        if self.step == 0:
            self.initialize(initial_state=current_state)

        assert self.step == env_step, (self.step, env_step)

        player_units, opp_units, unit_hist = self._get_unit_hist(b_active=current_state)

        self.unit_hists[0:-1] = self.unit_hists[1:]
        self.unit_hists[-1] = unit_hist

        self.opponent_unit_track += opp_units / 360
        self.player_unit_track += player_units / 360
        self.opponent_unit_track = self.opponent_unit_track.clip(0, 1)
        self.player_unit_track = self.player_unit_track.clip(0, 1)

        current_res_stack = (self._extract_resource(b_active=current_state)) / (
            self.initial_resource_map + 1e-6
        )
        self.resource_delta = self.resource_reduction_map - current_res_stack
        self.resource_reduction_map = current_res_stack
        self.step += 1


def not_go_pos(pos_y: int, pos_x: int, ban_map: np.ndarray):
    # n , s, w, e
    # (x, y)
    # action_directions = [(0, -1), (0, +1), (-1, 0), (1, 0)]
    # possible_yx_places = [(direc[1] + pos_y, direc[0] + pos_x) for direc in action_directions]
    action_mask = []
    for direc in ACTION_DIRECTIONS:
        pos_y_next = direc[1] + pos_y
        pos_x_next = direc[0] + pos_x
        if (pos_x_next > 31) or (pos_y_next > 31):
            action_mask.append(False)
        elif (pos_x_next < 0) or (pos_y_next < 0):
            action_mask.append(False)
        else:
            action_mask.append(~ban_map[pos_y_next, pos_x_next])
    return np.array(action_mask).astype(int)


# snippet to find the closest city tile to a position
def find_closest_city_tile(pos, player):
    closest_city_tile = None
    if len(player.cities) > 0:
        closest_dist = math.inf
        # the cities are stored as a dictionary mapping city id to the city object, which has a citytiles field that
        # contains the information of all citytiles in that city
        for k, city in player.cities.items():
            for city_tile in city.citytiles:
                dist = city_tile.pos.distance_to(pos)
                if dist < closest_dist:
                    closest_dist = dist
                    closest_city_tile = city_tile
    return closest_city_tile


def make_input(obs, unit_id, is_xy_order: bool = False):
    width, height = obs["width"], obs["height"]
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}

    b = np.zeros((20, 32, 32), dtype=np.float32)

    for update in obs["updates"]:
        strs = update.split(" ")
        input_identifier = strs[0]

        if input_identifier == "u":
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            if unit_id == strs[3]:
                # Position and Cargo
                b[:2, x, y] = (1, (wood + coal + uranium) / 100)
            else:
                # Units
                team = int(strs[2])
                cooldown = float(strs[6])
                idx = 2 + (team - obs["player"]) % 2 * 3
                b[idx : idx + 3, x, y] = (
                    1,
                    cooldown / 6,
                    (wood + coal + uranium) / 100,
                )
        elif input_identifier == "ct":
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 8 + (team - obs["player"]) % 2 * 2
            b[idx : idx + 2, x, y] = (1, cities[city_id])
        elif input_identifier == "r":
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(float(strs[4]))
            b[{"wood": 12, "coal": 13, "uranium": 14}[r_type], x, y] = amt / 800
        elif input_identifier == "rp":
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[15 + (team - obs["player"]) % 2, :] = min(rp, 200) / 200
        elif input_identifier == "c":
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, 10) / 10

    # Day/Night Cycle
    b[17, :] = obs["step"] % 40 / 40
    # Turns
    b[18, :] = obs["step"] / 360
    # Map Size
    b[19, x_shift : 32 - x_shift, y_shift : 32 - y_shift] = 1

    if not is_xy_order:
        b = b.transpose(0, 2, 1)

    return (
        b,
        b[[8, 10], x_shift : 32 - x_shift, y_shift : 32 - y_shift] == 1,
        b[[12, 13, 14], x_shift : 32 - x_shift, y_shift : 32 - y_shift].sum(axis=0) > 0,
    )


game_state = None


def get_game_state(observation):
    global game_state

    if observation["step"] == 0:
        game_state = Game()
        game_state._initialize(observation["updates"])
        game_state._update(observation["updates"][2:])
        game_state.id = observation["player"]
    else:
        game_state._update(observation["updates"])
    return game_state


def in_city(pos):
    try:
        city = game_state.map.get_cell_by_pos(pos).citytile
        return city is not None and city.team == game_state.id
    except:
        return False


def call_func(obj, method, args=[]):
    return getattr(obj, method)(*args)


unit_actions = [
    ("move", "n"),
    ("move", "s"),
    ("move", "w"),
    ("move", "e"),
    ("build_city",),
    #     ("move", "c"),
]


ACTION_DIM = len(unit_actions)


hist_folder = {0: StateHist(), 1: StateHist()}
hist_folder[0].reset()
hist_folder[1].reset()
USE_HIST = True if ACTION_DIM == 6 else False
unit_length = 128


def agent(observation, configuration):
    global game_state
    global hist_folder

    game_state = get_game_state(observation)
    player = game_state.players[observation.player]
    actions = []

    state, citymap, resource_map = make_input(observation, "u_-00", is_xy_order=False)
    state[:2] = 0.0

    # City Actions

    city_tile_count = (
        player.city_tile_count
    )  # :sum([len(city.city_cells) for city in player_cities])
    player_cities = list(player.cities.values())
    # units = player.units  # game.state["teamStates"][team]["units"].values()
    unit_count = len(player.units)  # unit_count = len(units)
    game_map = game_state.map

    player_tiles = None
    act_cities_map_max, posyx2tile = get_act_cities_map(
        player_cities=player_cities, game_map=game_map
    )

    if (len(posyx2tile) > 0) and (unit_count < city_tile_count):
        res_avg_map, resource_dist = get_resource_distribution(
            b_active=state, game=game_state
        )
        places = decide_worker_gen_place(
            res_avg_map=res_avg_map,
            resource_map=resource_dist,
            act_cities_map_max=act_cities_map_max,
            num_units=unit_count,
            num_city_tiles=city_tile_count,
        )
        if len(places) > 0:
            player_tiles = []
            for y, x in places:
                player_tiles.append(posyx2tile[y][x])

    acted_cities = []
    if player_tiles is not None:
        for city_tile in player_tiles:
            assert city_tile.can_act()
            actions.append(city_tile.build_worker())
            unit_count += 1
            acted_cities.append(city_tile)
        assert unit_count <= player.city_tile_count

    for city_id, city in player.cities.items():
        for city_tile in city.citytiles:
            if city_tile.can_act() and (city_tile not in acted_cities):
                if unit_count < player.city_tile_count:
                    actions.append(city_tile.build_worker())
                    unit_count += 1
                elif not player.researched_uranium():
                    actions.append(city_tile.research())
                    player.research_points += 1

    # Worker Actions
    dest = []
    friendly_cites = None
    is_near_end = game_state.turn > 400
    can_act_units = {unit.id: unit for unit in player.units if unit.can_act()}
    team = observation.player

    current_action_plan = np.ones(
        (game_state.map.height, game_state.map.width), dtype=bool
    )
    if len(can_act_units) > 0:
        unit = next(iter(can_act_units.values()))
        obs, unit_order = get_unit_sequence_obs(
            game=game_state,
            player=observation.player,
            b_active=state,
            can_act_units=can_act_units,
            turn=game_state.turn,
            unit_length=unit_length,
            action_dim=ACTION_DIM,
            input_dim=4,
        )
        hist_right, hist_left = hist_folder[team].get_hist()
        if USE_HIST:
            obs["image"][:2] = hist_left
            obs["image"] = np.concatenate([obs["image"], hist_right], axis=0)

        action_logit = pred_with_onnx(model=model, obs=obs)
        our_city = crop_state(state[8] > 0, game=game_state)

        for seq_ind, unit_id in enumerate(unit_order.order_strings):
            if seq_ind >= unit_length:
                break
            unit = can_act_units[unit_id]
            action_code = np.argmax(action_logit[seq_ind])
            is_center = check_is_center_action(action_code=action_code)
            current_action_plan, use_cooldown_as_center = check_action_plan(
                action_code=action_code,
                our_city=our_city,
                pos_x=unit.pos.x,
                pos_y=unit.pos.y,
                current_plan=current_action_plan,
                is_center=is_center,
            )
            if use_cooldown_as_center:
                actions.append(unit.move("c"))
            elif not is_center:
                act = unit_actions[action_code]
                actions.append(call_func(unit, *act))

    env_step = game_state.turn
    if USE_HIST:
        hist_folder[team].update(current_state=state, env_step=env_step)

    return actions


# Score ~1200 rl model
* this model training with ~500 epochs with V-trace and then ~500 epochs with UPGO


In [None]:
!cp ./model_vtrace500_upgo500.onnx ./model_rl.onnx

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 32, "height": 32, "loglevel": 2, "annotations": True, "seed":1}, debug=False)
steps = env.run(['agent.py', 'agent_imitation_baseline.py'])
env.render(mode="ipython", width=1200, height=800)

# UPGO models vs imitation model
* UPGO: The convergence of UPGO is faster than V-trace
* visualize 100, 300 and 800 epoch match results here.

## Epoch 100 result (500 x 100 episodes)

In [None]:
!cp ./model_100.onnx ./model_rl.onnx

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 32, "height": 32, "loglevel": 2, "annotations": True, "seed":1}, debug=False)
steps = env.run(['agent.py', 'agent_imitation_baseline.py'])
env.render(mode="ipython", width=1200, height=800)

## Epoch 300 result (500 x 300 episodes)

In [None]:
!cp ./model_300.onnx ./model_rl.onnx

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 32, "height": 32, "loglevel": 2, "annotations": True, "seed":1}, debug=False)
steps = env.run(['agent.py', 'agent_imitation_baseline.py'])
env.render(mode="ipython", width=1200, height=800)

## Epoch 800 result (500 x 800 episodes)

In [None]:
!cp ./model_800.onnx ./model_rl.onnx

In [None]:
from kaggle_environments import make

env = make("lux_ai_2021", configuration={"width": 32, "height": 32, "loglevel": 2, "annotations": True, "seed":1}, debug=False)
steps = env.run(['agent.py', 'agent_imitation_baseline.py'])
env.render(mode="ipython", width=1200, height=800)

In [None]:
!tar -cvzf submission.tar.gz *