In [None]:
import os
import gc
import math
import random
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import torch
import torch.nn as nn

import matplotlib.pyplot as plt

In [None]:
# run this if using kaggle notebooks
!cp -r ../input/lux-ai-2021/* .
# if working locally, download the `simple/lux` folder from here https://github.com/Lux-AI-Challenge/Lux-Design-2021/tree/master/kits/python
# and we recommend following instructions in there for local development with python bots

# for kaggle-environments
from lux.game import Game
from lux.game_map import Cell, RESOURCE_TYPES, Position
from lux.constants import Constants
from lux.game_constants import GAME_CONSTANTS
from lux import annotate
import math
import sys
from kaggle_environments import make

# Config

In [None]:
class Config:
    DAY_LENGTH  = GAME_CONSTANTS["PARAMETERS"]["DAY_LENGTH"]
    NIGHT_LENGTH = GAME_CONSTANTS["PARAMETERS"]["NIGHT_LENGTH"]
    MAX_DAYS= GAME_CONSTANTS["PARAMETERS"]["MAX_DAYS"]
    
    WORKER_CAPACITY = GAME_CONSTANTS["PARAMETERS"]["RESOURCE_CAPACITY"]["WORKER"]
    CART_CAPACITY = GAME_CONSTANTS["PARAMETERS"]["RESOURCE_CAPACITY"]["CART"]
    
    CITY_LIGHT_UPKEEP= GAME_CONSTANTS["PARAMETERS"]["LIGHT_UPKEEP"]["CITY"]
    WORKER_LIGHT_UPKEEP= GAME_CONSTANTS["PARAMETERS"]["LIGHT_UPKEEP"]["WORKER"]
    CART_LIGHT_UPKEEP= GAME_CONSTANTS["PARAMETERS"]["LIGHT_UPKEEP"]["CART"]
    
    
    MAX_WOOD_AMOUNT= GAME_CONSTANTS["PARAMETERS"]["MAX_WOOD_AMOUNT"]
    CITY_BUILD_COST= GAME_CONSTANTS["PARAMETERS"]["CITY_BUILD_COST"]
    CITY_ADJACENCY_BONUS= GAME_CONSTANTS["PARAMETERS"]["CITY_ADJACENCY_BONUS"]
    
    WORKER_COLLECTION_RATE_WOOD= GAME_CONSTANTS["PARAMETERS"]["WORKER_COLLECTION_RATE"]["WOOD"]
    WORKER_COLLECTION_RATE_COAL= GAME_CONSTANTS["PARAMETERS"]["WORKER_COLLECTION_RATE"]["COAL"]
    
    RESOURCE_TO_FUEL_RATE_WOOD= GAME_CONSTANTS["PARAMETERS"]["RESOURCE_TO_FUEL_RATE"]["WOOD"]
    
    
    #Cooldown Constants
    WORKER_ACTION_COOLDOWN= GAME_CONSTANTS["PARAMETERS"]["UNIT_ACTION_COOLDOWN"]["WORKER"]

# Utility Functions

In [None]:
def get_relative_position_from_cell(source_cell, target_cell):
    (xsource, ysource)=(source_cell.pos.x, source_cell.pos.y)
    (xtarget, ytarget)=(target_cell.pos.x, target_cell.pos.y)
    dist = np.sqrt( (xsource-xtarget)**2 + (ysource-ytarget)**2 )
    left=right=up=down=center=0
    
    if (xsource == xtarget) and (ysource == ytarget):
        center=1
    if xtarget < xsource:
        left=1
    if xtarget > xsource:
        right=1
    if ytarget < ysource:
        down=1
    if ytarget > ysource:
        up=1
    
    return (dist, left, right, up, down, center)


def get_resource_cells(game_state):
    game_map=game_state.map
    (width, height)=(game_map.width, game_map.height)
    resource_cells=[]
    for i in range(width):
        for j in range(height):
            cell=game_map.get_cell(i, j)
            if cell.has_resource():
                resource_cells.append(cell)
    return resource_cells

def get_citytile_cells(game_state):
    game_map=game_state.map
    (width, height)=(game_map.width, game_map.height)
    citytiles_cells=[]
    for i in range(width):
        for j in range(height):
            cell=game_map.get_cell(i, j)
            citytile=cell.citytile
            if citytile is None:
                continue
            citytiles_cells.append(citytile)
    return citytiles_cells

In [None]:
def get_nearest_resources(unit, game_state, topk=5):
    (x, y) = (unit.pos.x, unit.pos.y)
    game_map=game_state.map
    cell=game_map.get_cell_by_pos(unit.pos)
    (width, height)=(game_map.width, game_map.height)
    resource_cells=get_resource_cells(game_state)
    nearest_resources=[]
    
    max_distance=np.sqrt(width**2 + height**2)
    for rcell in resource_cells:
        if rcell.resource.type != RESOURCE_TYPES.WOOD:
            continue
        (dist, left, right, up, down, center) = get_relative_position_from_cell( cell, rcell )
        amount=rcell.resource.amount/Config.MAX_WOOD_AMOUNT
        dist=dist/max_distance
        
        nearest_resources.append({"dist":  dist, 
                                  'amount': amount,
                                  "left": left, "right": right, "up": up, "down": down, "center": center})
    nearest_resources=nearest_resources.sort(key = lambda r: r['dist'])
    return nearest_resources

def get_game_state_worker_observations(game_state):
    team=game_state.id
    players=game_state.players
    unit=players[team].units[0]
    
    resource_obs=get_nearest_resources(unit, game_state)

In [None]:
class Convmap:
    def __init__(self, game_state):
        self.game_state=game_state
        self.game_map=game_state.map
        self.width=self.game_map.width
        self.height = self.game_map.height
        
        self.players=self.game_state.players
        self.team=self.game_state.id
        self.turn=self.game_state.turn
    
    def get_city_features(self):
        city_feats=np.zeros( (self.width, self.height, 3) )
        for player in self.players:
            for city_id, city in player.cities.items():
                num_tiles=len(city.citytiles)
                fuel=city.fuel
                light_upkeep=city.get_light_upkeep()
                nights_can_survive = min(10, fuel/light_upkeep)/10 # if can survive 3 nights value 0.3
                
                for citytile in city.citytiles:
                    (x, y) = (citytile.pos.x, citytile.pos.y)
                    if city.team == self.team:
                        city_feats[x][y][0] = 1
                    else:
                        city_feats[x][y][1] = 1
                    city_feats[x][y][2] = nights_can_survive
        return city_feats
    
    def get_resource_features(self):
        rfeats=np.zeros((self.width,self.height, 2))
        for i in range(self.width):
            for j in range(self.height):
                cell=self.game_map.get_cell(i, j)
                if not cell.has_resource():
                    continue
                rfeats[i][j][0] = 1
                resource=cell.resource
                if resource and (RESOURCE_TYPES.WOOD == resource.type):
                    rfeats[i][j][1]= resource.amount/Config.MAX_WOOD_AMOUNT
        return rfeats   
    
    
    def get_unit_features(self):
        ufeats=np.zeros((self.width, self.height, 4))
        for player in self.players:
            for unit in player.units:
                (x, y) = (unit.pos.x, unit.pos.y)
                opp_team = 0 if (unit.team==self.team) else 1
                
                wood = unit.cargo.wood
                coal = unit.cargo.coal
                uranium = unit.cargo.uranium
                total_cargo = (wood + coal + uranium)/Config.WORKER_CAPACITY
                
                cooldown=unit.cooldown/Config.WORKER_ACTION_COOLDOWN
                
                ufeats[x][y][0] = 1-opp_team
                ufeats[x][y][1] = opp_team
                ufeats[x][y][2] = total_cargo
                ufeats[x][y][3] = cooldown
        return ufeats
    
    def get_road_features(self):
        feats=np.zeros((self.width,self.height, 1))
        for i in range(self.width):
            for j in range(self.height):
                cell=self.game_map.get_cell(i, j)
                road = cell.road
                if (not road) or (road == 0):
                    continue
                feats[i][j][0] = road/6
        return feats
                
    def get_conv_features(self):
        unit_feats=self.get_unit_features()
        city_feats=self.get_city_features()
        resource_feats = self.get_resource_features()
        road_feats=self.get_road_features()
        
        feats=np.concatenate([unit_feats,
                              city_feats,
                              resource_feats,
                              road_feats
                             ], axis=2)
        
        pad_feats=np.zeros((32, 32, 10))
        W = (32 - self.width)//2
        H = (32 - self.height)//2
        
        pad_feats[W:-W, H:-H] = feats
        return feats

In [None]:
def get_unit_relative_position(x, y, width, height):
    dist_feats=np.zeros((width, height, 5))
    max_dist=width *np.sqrt(2)
    left=right=up=down=0
    
    dists=[]
    for i in range(width):
        for j in range(height):
            dist=np.sqrt( ((x-i)**2)  + ((y-j)**2))
            dists.append(dist)
            if i < x:
                left=1
            if i > x:
                right=1
            if j < y:
                down=1
            if j > y:
                up=1
            dist_feats[i][j][0]=dist/max_dist
            dist_feats[i][j][1]=left
            dist_feats[i][j][2]=right
            dist_feats[i][j][3]=up
            dist_feats[i][j][4]=down
    dist_mean=np.mean(dists)
    dist_std=np.std(dists)
    
    dist_feats[:, :, 0] = (dist_feats[:, :, 0] - dist_mean)/dist_std
    return dist_feats

def get_allunit_states(game_state, conv_map):
    team=game_state.id
    game_map = game_state.map
    players=game_state.players
    
    game_completion = (1+game_state.turn)/360
    days_completed = min(1, (game_state.turn%40)/30)
    nights_completed = max(0, ((1+game_state.turn%40)-30) )/10
    
    (width, height) = (game_map.width, game_map.height)
    units_map={
        "cur_team":[],
        "opp_team":[]
    }
    
    all_units=[]
    #(xshift, yshift) = ((32-width)//2, (32-height)//2)
    xshift=0; yshift=0;
    for i, player in enumerate(players):
        for unit in player.units:
            if unit.team != game_state.id:
                continue
            unit_conv_map = conv_map.copy()
            unit_feats = np.zeros(5)
            
            (x, y) = (unit.pos.x, unit.pos.y)
            (xnew, ynew) = (x+xshift, y+yshift)
            
            cooldown=unit.cooldown/6
            wood=unit.cargo.wood/100
            dist_feats = get_unit_relative_position(xnew, ynew, width, height)
            unit_conv_feats = np.concatenate([unit_conv_map,dist_feats], axis=2)
            
            unit_feats[0] = game_completion
            unit_feats[1] = days_completed
            unit_feats[2] = nights_completed
            unit_feats[3] = wood
            unit_feats[4] = cooldown
            
            all_units.append({
                'unit': unit,
                'unit_conv_feats': unit_conv_feats,
                'unit_feats': unit_feats
            })
    return all_units

# model

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1=nn.Conv2d(in_channels, out_channels, 3,  padding=1)
        self.relu1=nn.ReLU6()
        self.bn1=nn.BatchNorm2d( out_channels )
        self.dropout1=nn.Dropout2d(0.1)
        
    def forward(self, x):
        x=self.conv1(x)
        x=self.relu1(x)
        x=self.bn1(x)
        x=self.dropout1(x)
        return x


class ConvBackbone(nn.Module):
    def __init__(self):
        super(ConvBackbone, self).__init__()
        self.pre_bn = nn.BatchNorm2d(15)
        self.conv0  = ConvBlock(15, 64)
        nlayers=6
        nchannels=64
        self.blocks = nn.ModuleList([
            ConvBlock(nchannels, nchannels) for _ in range(nlayers)
        ])
        self.bn=nn.BatchNorm1d(nchannels)
        self.dropout=nn.Dropout(0.2)
        self.fc=nn.Linear(nchannels, 128)
    def forward(self, x):
        batch_size=x.shape[0]
        x=self.pre_bn(x)
        x=self.conv0(x)
        for block in self.blocks:
            x=x+block(x)
        x=x.view(batch_size, x.shape[1], -1).sum(dim=-1)
        x=self.bn(x)
        x=self.dropout(x)
        x=self.fc(x)
        return x

class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.backbone=ConvBackbone()
        self.ffn=nn.Sequential(
            nn.BatchNorm1d(5),
            
            nn.Linear(5, 32),
            nn.BatchNorm1d(32),
            nn.ReLU6(),
            nn.Dropout(0.1),
            
            nn.Linear(32, 32),
            nn.BatchNorm1d(32),
            nn.ReLU6(),
            nn.Dropout(0.2),
            
            nn.Linear(32, 16)
        )
        self._policy_net=nn.Sequential(
            nn.BatchNorm1d(128 + 16),
            nn.Linear(128+16 ,128),
            nn.ReLU6(),
            nn.Dropout(0.1),
            
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU6(),
            nn.Dropout(0.1),
            
            nn.Linear(256, 6)
        )
        self._value_net= nn.Sequential(
            nn.BatchNorm1d(128+16),
            nn.Linear(128 + 16, 128),
            nn.ReLU6(),
            nn.Dropout(0.1),
            
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU6(),
            nn.Dropout(0.1),
            
            nn.Linear(256, 1)
        )
        
    def forward(self, x1, x2):
        x1=self.backbone(x1)
        x2=self.ffn(x2)
        x=torch.cat([x1, x2], dim=-1)
        
        pis = self._policy_net(x).softmax(dim=-1)
        values=self._value_net(x)
        
        return (pis, values)

# play game

In [None]:
def has_boundary_violation(action, x, y, width, height):
    has_violation=0
    if (action == 0) and (y == 0):
        has_violation=1
    elif (action==1) and (y+1 == height):
        has_violation=True
    elif (action==2) and (x+1 == width):
        has_violation=1
    elif (action == 3) and (x==0):
        has_violation=1
    return has_violation

def has_opposite_collision(boundary_violation, unit, action, game_state):
    if boundary_violation or (action==5):
        return 0
    violation = 0
    (x, y) = (unit.pos.x, unit.pos.y)
    unit_team=unit.team
    x_next=-1; y_next=-1;
    if action == 0:
        y_next = y-1
    elif action==1:
        y_next=y+1
    elif action==2:
        x_next = x+1
    elif action==3:
        x_next = x-1
    
    
    next_team=unit_team
    next_cell=game_state.map.get_cell(x_next, y_next)
    next_citytile=next_cell.citytile
    if (next_citytile):
        next_team = next_citytile.team
    
    for player in game_state.players:
        for next_unit in player.units:
            if (next_unit.pos.x == x_next) and (next_unit.pos.y == y_next):
                next_team=next_unit.team
                break
    if next_team!=unit_team:
        violation=1
    return violation
    

def has_build_violation(action, unit, game_state):
    violation=0
    if (action==5) and (not unit.can_build(game_state.map)):
        violation=1
    return violation

def get_game_metrics(unit, game_state, action):
    num_tiles=0
    team=game_state.id
    player=game_state.players[team]
    num_units=len(player.units)
    unit=player.units[0]
    cities=player.cities
    
    (x, y) = (unit.pos.x, unit.pos.y)
    
    (width, height) = (game_state.map.width, game_state.map.height)
    boundary_violation=has_boundary_violation(action, x, y, width, height)
    collision_violation = has_opposite_collision(boundary_violation, unit, action, game_state)
    build_violation = has_build_violation(action, unit, game_state)
    
    
    num_cities=0
    for _, city in cities.items():
        num_tiles+=len(city.citytiles)
        num_cities+=1
    return {
        "num_cities": num_cities,
        "num_tiles": num_tiles,
        "num_units": num_units,
        "boundary_violation": boundary_violation,
        "build_violation": build_violation,
        "collision_violation": collision_violation,
        "cargo_wood": unit.cargo.wood
    }

# calculate rewards

In [None]:
def calculate_rewards(game_memory, evaluate):
    num_move_actions=0
    num_build_actions=0
    max_city_count=0
    
    num_iterations=len(game_memory)
    for i in range(num_iterations):
        if i==num_iterations-1:
            break
        
        action=game_memory[i]['action']
        cur_metrics=game_memory[i]["game_metrics"]
        nxt_metrics=game_memory[i+1]["game_metrics"]
        
        ncities = nxt_metrics['num_tiles'] - cur_metrics['num_tiles']
        nunits  = nxt_metrics['num_units'] - cur_metrics['num_units']
        
        boundary_violation=cur_metrics["boundary_violation"]
        build_violation = cur_metrics["build_violation"]
        collision_violation = cur_metrics["collision_violation"]
        
        game_memory[i]["total_reward"] = (50 * nxt_metrics['num_tiles']) + (20 * nxt_metrics['num_units'])
        if boundary_violation:
            game_memory[i]["total_reward"] -= 10
        if build_violation:
            game_memory[i]["total_reward"] -= 10
        if collision_violation:
            game_memory[i]["total_reward"] -= 10
        
        wood_reward=(nxt_metrics['cargo_wood'] - cur_metrics['cargo_wood'])
        game_memory[i]["total_reward"] += (2 * wood_reward/20)
        if (not boundary_violation) and (not collision_violation) and (action!=4):
            game_memory[i]["total_reward"]+=1/40
        game_memory[i]["rts"] = game_memory[i]["total_reward"]
        
    last_metrics = game_memory[num_iterations-1]["game_metrics"]
    if (last_metrics["num_units"] == 0) or (last_metrics["num_tiles"]==0):
        game_memory[num_iterations-1]["total_reward"] = -50
        game_memory[num_iterations-1]["rts"] = game_memory[num_iterations-1]["total_reward"]
    for mem in game_memory:
        mem['rts'] = mem.get('total_reward', 0.0)
    
    num_iterations=len(game_memory)
    for i in range(num_iterations-2, -1, -1):
        game_memory[i]['rts'] = game_memory[i]['total_reward'] + 0.9*game_memory[i+1]['rts'] 

# running policy networks

In [None]:
def run_policy_network(unit, unit_conv_feats, unit_feats,
                       worker_policy,
                       game_state,
                       evaluate):
    game_map = game_state.map
    unit_conv_feats=torch.tensor(unit_conv_feats, dtype=torch.float32).transpose(0, 2).unsqueeze(0)
    unit_feats=torch.tensor(unit_feats, dtype=torch.float32).unsqueeze(0)
    
    action=0
    worker_policy.eval()
    with torch.no_grad():
        (actions, values)=worker_policy(unit_conv_feats, unit_feats)
    
    actions=actions.view(-1).numpy()
    values=values.view(-1).numpy()
    
    if evaluate:
        if (unit.can_build(game_state.map)) and (unit.cargo.wood==100):
            action=np.argmax(actions)
        else:
            action=np.argmax(actions[:5])
    else:
        days_completed=min(30, (game_state.turn%40))
        if (unit.cargo.wood == 100) and (unit.can_build(game_state.map)):
            if actions[-1] < 0.17:
                s=actions[:5].sum()
                actions[:5] = 0.83 * actions[:5]/(1e-8+s)
                actions[-1]=1-actions[:5].sum()
            action = np.random.choice(6, p=actions)
        else:
            s=actions[:5].sum()
            action = np.random.choice(5, p=actions[:5]/(s))
    return (action, values)

In [None]:
def play_game(worker_policy, evaluate=False, showui=False):
    game_state=None
    game_memory=[]
    game_over=False
    def get_worker_actions(unit, game_state, actions):
        nonlocal worker_policy
        nonlocal game_memory
        
        (x, y) = (unit.pos.x, unit.pos.y)
        conv_map = Convmap(game_state).get_conv_features()
        all_units_feats=get_allunit_states(game_state, conv_map)
        for unit_states in all_units_feats:
            unit=unit_states['unit']
            unit_conv_feats= unit_states['unit_conv_feats']
            unit_feats=unit_states['unit_feats']
            
            (action, values) = run_policy_network(unit, unit_conv_feats, unit_feats,
                                                  worker_policy,
                                                  game_state, evaluate)
        
            if action == 0:
                actions.append( unit.move('n') )
            elif action==1:
                actions.append( unit.move('s') )
            elif action==2:
                actions.append( unit.move('e') )
            elif action==3:
                actions.append( unit.move('w') )
            elif action==4:
                pass
            elif action == 5:
                actions.append(unit.build_city())
        
            game_metrics = get_game_metrics(unit, game_state, action)
            game_memory.append({
                'conv_map': unit_conv_feats,
                'unit_feats': unit_feats,
                'action': action,
                'game_metrics': game_metrics
            })
        
    def agent(observation, configuration):
        nonlocal game_state
        nonlocal game_over

        ### Do not edit ###
        if observation["step"] == 0:
            game_state = Game()
            game_state._initialize(observation["updates"])
            game_state._update(observation["updates"][2:])
            game_state.id = observation.player

        else:
            game_state._update(observation["updates"])

        ### AI Code goes down here! ###
        actions=[]
        player=game_state.players[game_state.id]
        for unit in player.units:
            if unit.can_act():
                get_worker_actions(unit, game_state, actions)
        # add debug statements like so!
        return actions

    
    env=make("lux_ai_2021",
             configuration={"seed": random.randint(0, 100000000), 
                            "loglevel": 0,
                            "annotations": False,
                            "width": 16,
                            "height": 16
                           },
             debug=True)
    
    env.run([ agent,  "simple_agent"])
    if showui:
        env.render(mode='ipython' , width=800, height=600)
    
    if game_state.turn < 359:
        game_memory.append({
            'total_reward':0,
            'game_metrics': {
                "num_units": 0,
                "num_cities": 0,
                "num_tiles": 0,
                "boundary_violation": 0,
                "cargo_wood": 0
            }
        })
        
    calculate_rewards(game_memory, evaluate)
    for i in range(len(game_memory)-2):
        game_memory[i]['next_conv_map'] = game_memory[i+1]['conv_map']
        game_memory[i]['next_unit_feats'] = game_memory[i+1]['unit_feats']        
    game_memory=game_memory[:-2]
    return game_memory

# collect training samples

In [None]:
def collect_train_samples(worker_policy):
    batch_size=300
    memory=[]
    max_city_count=0
    num_games=0
    while True:
        num_games+=1
        game_memory=play_game(worker_policy, False, False)
        for data in game_memory:
            conv_map = data['conv_map']
            unit_feats = data['unit_feats']
            
            next_conv_map=data['next_conv_map']
            next_unit_feats=data['next_unit_feats']
            
            action=data['action']
            total_reward=data['total_reward']
            num_cities=data["game_metrics"]["num_cities"]
            max_city_count=max(max_city_count, num_cities)
            
            memory.append({
                'conv_map': conv_map,
                'unit_feats':unit_feats,
                'next_conv_map': next_conv_map,
                'next_unit_feats': next_unit_feats,
                'action': action,
                'reward': total_reward,
                'rts': data['rts']
            })
        if len(memory) > batch_size:
            break
    return (num_games, max_city_count, memory)

# evaluate

In [None]:
def evaluate(worker_policy):
    num_games=50
    total_tiles=0
    total_iterations=0
    eval_mean_reward=0
    for _ in range(num_games):
        game_memory=play_game(worker_policy, True, False)
        max_tiles=0
        num_iterations=len(game_memory)
        total_reward=0
        for data in game_memory:
            num_tiles=data["game_metrics"]["num_tiles"]
            max_tiles=max(max_tiles, num_tiles)
            total_reward+=data["total_reward"]
        
        total_reward/=num_iterations
        total_tiles+=max_tiles
        total_iterations += num_iterations
        eval_mean_reward+=total_reward
    
    total_tiles/=num_games
    total_iterations/=num_games
    eval_mean_reward/=num_games
    return (total_tiles, total_iterations, eval_mean_reward)

# training model

In [None]:
def get_batched_data(worker_policy):
    conv_map=[]; next_conv_map=[];
    unit_feats=[]; next_unit_feats=[];
    rewards=[];
    actions=[]
    rts=[]
    
    (num_games, max_city_count, game_memory) = collect_train_samples(worker_policy)
    for data in game_memory:
        conv_map.append(data['conv_map'])
        unit_feats.append(data['unit_feats'])

        next_conv_map.append( data['next_conv_map'] )
        next_unit_feats.append(data['next_unit_feats'])

        actions.append( data['action'] )
        rewards.append( data['reward'] )
        rts.append(data['rts'])

    conv_map=np.array(conv_map)
    next_conv_map=np.array(next_conv_map);
    unit_feats=np.array(unit_feats) 
    next_unit_feats=np.array(next_unit_feats)
    rewards=np.array(rewards);
    rts=np.array(rts)
    actions=np.array(actions);

    
    conv_map=torch.tensor(conv_map, dtype=torch.float32).transpose(1, 3)
    next_conv_map=torch.tensor(next_conv_map, dtype=torch.float32).transpose(1, 3)
    unit_feats=torch.tensor(unit_feats, dtype=torch.float32) 
    next_unit_feats=torch.tensor(next_unit_feats, dtype=torch.float32)
    rewards=torch.tensor(rewards, dtype=torch.float32)
    rts=torch.tensor(rts, dtype=torch.float32)
    actions=torch.tensor(actions, dtype=torch.long)
    
    return (conv_map, next_conv_map, 
            unit_feats, next_unit_feats,
            rewards, actions, rts)

In [None]:
def get_entropy(pi):
    H = -pi * torch.log(pi)
    H = H.sum(dim=-1)
    return H.mean()

In [None]:
def get_old_policy_evaluation(model, batch):
    conv_map=batch[0]
    next_conv_map=batch[1]
    unit_feats=batch[2] 
    next_unit_feats=batch[3]
    rewards=batch[4]
    actions=batch[5]
    rts=batch[6]
    
    model.eval()
    with torch.no_grad():
        (pis, values) = worker_policy(conv_map, unit_feats)
        (_, next_values) = worker_policy(next_conv_map, next_unit_feats)
    
    values=values.view(-1)
    next_values=next_values.view(-1)
    rewards=rewards.view(-1)
    
    
    
    Aold = (rts - values)
    old_pis = torch.gather(pis, 1,  actions.unsqueeze(-1)).view(-1)
    old_log_pis=torch.log(old_pis + 1e-8)
    return (old_log_pis, Aold)

def get_loss(model, optimizer, it):
    clamp=0.1
    batch= get_batched_data(worker_policy)
    
    conv_map=batch[0]
    next_conv_map=batch[1]
    unit_feats=batch[2] 
    next_unit_feats=batch[3]
    rewards=batch[4]
    actions=batch[5]
    rts=batch[6]
    
    train_actor_loss=None;train_critic_loss=None;train_entropy=None
    (old_log_pis, Aold) = get_old_policy_evaluation(model, batch)
    
    num_move_actions=(actions!=5).sum().item()
    num_build_actions=(actions==5).sum().item()
        
        
    for i in range(4):
        model.train()
        (pis, values) = worker_policy(conv_map, unit_feats)
        (_, next_values) = worker_policy(next_conv_map, next_unit_feats)
        
        values=values.view(-1)
        next_values=next_values.view(-1)
        rewards=rewards.view(-1)
        
        A = ( rts - values)
        prob_values = torch.gather(pis, 1,  actions.unsqueeze(-1)).view(-1)
        new_log_pis=torch.log(prob_values + 1e-8)
        
        r = torch.exp( new_log_pis - old_log_pis )
        surr1 = r * Aold.detach()
        surr2 = torch.clamp(r, 1-clamp, 1+clamp) * Aold.detach()
        
        actor_loss = torch.min(surr1, surr2).mean()
        critic_loss = torch.abs(A).mean()
        
        H = get_entropy(pis)
        
        loss=actor_loss - 0.8 * critic_loss + 2 * H
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if it%10==0 and i==0:
            print(pis[0].detach().numpy())
            
        train_actor_loss=actor_loss.item();
        train_critic_loss=critic_loss.item();
        train_entropy=H.item()
    return (loss.item(), train_actor_loss, train_critic_loss, train_entropy, num_move_actions, num_build_actions)

In [None]:
def train_model(worker_policy, optimizer):
    eval_iterations=[]
    eval_tiles=[]
    eval_rewards=[]
    
    for it in range(300):
        (train_loss, train_actor_loss, train_critic_loss, train_entropy,
         num_move_actions, num_build_actions) = get_loss(worker_policy, optimizer, it)
        if it%10 == 0:
            torch.save(worker_policy, "policy_network_{}.pt".format(it+1))
            
            print()
            print("======"*10)
            print("Iteration:{} | Loss:{:.3f}".format(it+1, train_loss))
            print("Entropy:", train_entropy)
            print("Actor Loss:{:.3f} | Critic Loss:{:.3f}".format(train_actor_loss, train_critic_loss))
            print("Number Of Move Actions:", num_move_actions)
            print("Number Of Build Actions:",num_build_actions)
            print()
            print("==========="*10)
            
            print()
            print("Evaluation")
            (avg_tiles, avg_iterations, avg_reward)=evaluate(worker_policy)
            print("Avg tiles:", avg_tiles)
            print("Avg iterations:", avg_iterations)
            print("Aveg reward:", avg_reward)
            
            eval_iterations.append(avg_iterations)
            eval_tiles.append(avg_tiles)
            eval_rewards.append(avg_reward)
            print()
            print()
        gc.collect()
        
    return (eval_iterations, eval_tiles, eval_rewards)

In [None]:
%%time

worker_policy=ActorCritic()
optimizer=torch.optim.Adam(worker_policy.parameters(), lr=1e-4)
(eval_iterations, eval_tiles, eval_rewards) = train_model(worker_policy, optimizer)


In [None]:
game_memory = play_game(worker_policy, True, True)

In [None]:
collision_violations=0
boundary_violations=0

for mem in game_memory:
    game_metrics=mem['game_metrics']
    boundary_violations+=game_metrics["boundary_violation"]
    collision_violations += game_metrics["collision_violation"]
    print(mem['total_reward'])
print(boundary_violations, collision_violations)

In [None]:
len(game_memory)

In [None]:
plt.plot(eval_iterations)
plt.show()

In [None]:
plt.plot(eval_tiles)
plt.show()

In [None]:
plt.plot(eval_rewards)
plt.show()

In [None]:
_ = play_game(worker_policy, True, True)

In [None]:
_ = play_game(worker_policy, True, True)