In [None]:
import numpy as np
import pandas as pd
import json
from tqdm.notebook import tqdm
import zipfile
import os

In [None]:
def make_input(obs):
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    cities = {}
    q, mod = divmod(obs['step'], 40)
    nights = 10 * (9 - q) - max(mod - 30, 0)
    
    b = np.zeros((18, 32, 32), dtype=np.float32)
    
    for update in obs['updates']:
        strs = update.split(' ')
        input_identifier = strs[0]
        
        if input_identifier == 'u':
            # Units
            team = int(strs[2])
            x = int(strs[4]) + x_shift
            y = int(strs[5]) + y_shift
            cooldown = float(strs[6])
            wood = int(strs[7])
            coal = int(strs[8])
            uranium = int(strs[9])
            idx = (team - obs['player']) % 2 * 3
            b[idx:idx + 3, x, y] = (
                1,
                cooldown / 6,
                min(wood + coal + uranium, 100) / 100
            )
        elif input_identifier == 'ct':
            # CityTiles
            team = int(strs[1])
            city_id = strs[2]
            x = int(strs[3]) + x_shift
            y = int(strs[4]) + y_shift
            idx = 6 + (team - obs['player']) % 2 * 2
            b[idx:idx + 2, x, y] = (
                1,
                cities[city_id]
            )
        elif input_identifier == 'r':
            # Resources
            r_type = strs[1]
            x = int(strs[2]) + x_shift
            y = int(strs[3]) + y_shift
            amt = int(strs[4])
            b[{'wood': 10, 'coal': 11, 'uranium': 12}[r_type], x, y] = amt / 800
        elif input_identifier == 'rp':
            # Research Points
            team = int(strs[1])
            rp = int(strs[2])
            b[13 + (team - obs['player']) % 2, :] = min(rp, 200) / 200
        elif input_identifier == 'c':
            # Cities
            city_id = strs[2]
            fuel = float(strs[3])
            lightupkeep = float(strs[4])
            cities[city_id] = min(fuel / lightupkeep, nights) / nights
        elif input_identifier == 'ccd':
            # Roads
            break
    
    # Day/Night Cycle
    b[15, :] = obs['step'] % 40 / 40
    # Turns
    b[16, :] = obs['step'] / 360
    # Map Size
    b[17, x_shift:32 - x_shift, y_shift:32 - y_shift] = 1

    return b

In [None]:
def get_pos(obs, unit_id):
    for update in obs['updates']:
        strs = update.split(' ')
        if strs[0] == 'u' and strs[3] == unit_id:
            return int(strs[1]), int(strs[4]), int(strs[5])


def to_label(obs, actions):
    shift = (32 - obs['width']) // 2
    
    label = np.zeros((32, 32), dtype=np.int8)
    for action in actions:
        strs = action.split(' ')
        if strs[0] == 'm':
            a = {'c': None, 'n': 1, 's': 2, 'w': 3, 'e': 4}[strs[2]]
        elif strs[0] == 'bcity':
            a = 5
        else:
            a = None 
        
        if a is not None:
            u_type, x, y = get_pos(obs, strs[1])
            if u_type == 0:
                label[x + shift, y + shift] = a
                
    return label


def create_dataset_from_json(sub_ids, team_name='Toad Brigade'):
    df = pd.read_csv(
        '../input/lux-ai-meta-kaggle/EpisodeAgents.csv', 
        usecols=['EpisodeId', 'SubmissionId']
    )
    df = df[df['SubmissionId'].isin(sub_ids)]
    z = zipfile.PyZipFile('datasets.zip', mode='w')
    for episode_id in tqdm(df['EpisodeId'].unique()): 
        with open(f'../input/lux-ai-simulations-episode-scraper/{episode_id}.json') as f:
            json_load = json.load(f)

        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] != team_name:
            continue

        for i in range(len(json_load['steps'])-1):
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                obs['player'] = index
                
                f = f'{episode_id}_{i}.npz'
                np.savez_compressed(
                    f, 
                    state=make_input(obs), 
                    action=to_label(obs, actions)
                ) 
                z.write(f)
                os.remove(f)
    
    print(len(z.namelist()))
    z.close()

In [None]:
create_dataset_from_json(sub_ids=[23281649, 23297953])