In [1]:
# !mkdir ./episodes
# !cp ../input/lux-ai-top-episodes/* ./episodes
# !rm ./episodes/*_info.json
# !ls -1 ./episodes| wc -l

!ls -1 ../input/lux-ai-top-episodes | wc -l

663


In [2]:
import numpy as np
import json
from pathlib import Path
import os
import random
import pandas
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [3]:
def to_label(action):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm' or strs[0] == 'bcity':
        if strs[0] == 'm':
            label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
        elif strs[0] == 'bcity':
            label = 4
        return unit_id, label
    if strs[0] == 'r' or strs[0] == 'bw':
        if strs[0] == 'r':
            label = 5
        elif strs[0] == 'bw':
            label = 6
        return (int(strs[1]),int(strs[2])),label
    else:
        return None,None
    


def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True


def create_dataset_from_json(episode_dir, team_name=['Toad Brigade']): 
    obses = []
    obses_index = 0
    samples = []
    city_samples = []
    append = samples.append
    episodes = []
    for dir in episode_dir:
        episodes.extend([path for path in Path(dir).glob('*.json') if 'output' not in path.name])
    for filepath in tqdm(episodes): 
        with open(filepath) as f:
            json_load = json.load(f)

        ep_id = json_load['info']['EpisodeId']
        index = np.argmax([r or 0 for r in json_load['rewards']])
        if json_load['info']['TeamNames'][index] not in team_name:
            continue

        for i in range(len(json_load['steps'])-1):
            if json_load['steps'][i][index]['status'] == 'ACTIVE':
                actions = json_load['steps'][i+1][index]['action']
                obs = json_load['steps'][i][0]['observation']
                
                if depleted_resources(obs):
                    break
                
                obs['player'] = index
#                 obs = dict([
#                     (k,np.array(v)) for k,v in obs.items() 
#                     if k in ['step', 'updates', 'player', 'width', 'height']
#                 ])
                obs = [
                    obs['step'],
                    np.array(obs['updates']),
                    obs['player'],
                    obs['width'],
                    obs['height'],
                ]
                #[step,updates,player,width,height]
                obs_id = obses_index
                obses_index += 1
                obses.append(obs)
                                
                for action in actions:
                    info, label = to_label(action)
                    if label is not None:
                        if label<=4:
                            append((obs_id, info, label))
                        else:
                            city_samples.append((obs_id, info, label-5))

    return np.array(obses), np.array(samples), np.array(city_samples)

In [4]:
episode_dir = ['../input/lux-ai-top-episodes']
obses, samples, city_samples = create_dataset_from_json(episode_dir)
print('obses:', len(obses), 'samples:', len(samples),'city samples:', len(city_samples))

  0%|          | 0/663 [00:00<?, ?it/s]



obses: 201186 samples: 1467183 city samples: 131962


In [5]:
print(samples.shape)

(1467183, 3)


In [6]:
labels = [sample[-1] for sample in samples]
actions = ['north', 'south', 'west', 'east', 'bcity']
for value, count in zip(*np.unique(labels, return_counts=True)):
    print(f'{actions[int(value)]:^5}: {count:>3}')

north: 363929
south: 355338
west : 338863
east : 319100
bcity: 89953


In [7]:
city_labels = [sample[-1] for sample in city_samples]
actions = ['r', 'bw']
for value, count in zip(*np.unique(city_labels, return_counts=True)):
    print(f'{actions[int(value)]:^5}: {count:>3}')

  r  : 95121
 bw  : 36841


In [8]:
np.save('samples.npy',samples)
np.save('obses.npy',obses)
train, val = train_test_split(samples, test_size=0.1, random_state=42, stratify=labels)
np.save('train.npy',train)
np.save('val.npy',val)
train, val = train_test_split(city_samples, test_size=0.1, random_state=42, stratify=city_labels)
np.save('city_train.npy',train)
np.save('city_val.npy',val)