In [1]:
import numpy as np
import json
from pathlib import Path
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [3]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

In [50]:
def to_label(action, obs):
    strs = action.split(' ')
    unit_id = strs[1]
    if strs[0] == 'm':
        label = {'c': None, 'n': 0, 's': 1, 'w': 2, 'e': 3}[strs[2]]
    elif strs[0] == 'bcity':
        label = 4
    else:
        label = None
    
    unit_pos = (0,0)
    
    width, height = obs['width'], obs['height']
    x_shift = (32 - width) // 2
    y_shift = (32 - height) // 2
    for update in obs["updates"]:
        strs = update.split(" ")
        if strs[0] == "u" and strs[3] == unit_id:
            unit_pos = (int(strs[4]) + x_shift, int(strs[5]) + y_shift)
    return unit_id, label, unit_pos

def depleted_resources(obs):
    for u in obs['updates']:
        if u.split(' ')[0] == 'r':
            return False
    return True

def create_dataset_from_json(main_dir, team_name='Toad Brigade'): 
    obses = {}
    samples = []
    append = samples.append
   # ep_dir = [path for path in os.listdir(episode_dir)][:1]
    for episode_dir in os.listdir(main_dir):
        episode_path = os.path.join(main_dir, episode_dir)
        if os.path.isdir(episode_path):
            #episodes = [path for path in Path(dir).glob('*.json') if 'output' not in path.name]
            filepath = os.path.join(episode_path, f"{episode_dir}.json")
            #print(game_file_pattern)
            #episodes = episodes[:len(episodes) // 3]
            #for filepath in tqdm(game_file_pattern): 
            with open(filepath) as f:
                json_load = json.load(f)
            ep_id = json_load['info']['EpisodeId']
            index = np.argmax([r or 0 for r in json_load['rewards']])
            if json_load['info']['TeamNames'][index] != team_name:
                continue

            for i in range(len(json_load['steps'])-1):
                if json_load['steps'][i][index]['status'] == 'ACTIVE':
                    actions = json_load['steps'][i+1][index]['action']
                    obs = json_load['steps'][i][0]['observation']
                    
                    if depleted_resources(obs):
                        break
                    
                    obs['player'] = index
                    obs = dict([
                        (k,v) for k,v in obs.items() 
                        if k in ['step', 'updates', 'player', 'width', 'height']
                    ])
                    
                    obs_id = f'{ep_id}_{i}'
                    obses[obs_id] = obs
                    
                    action_map = np.zeros((5,32,32))
                    mask = np.zeros((5,32,32))

                    for action in actions:
                        unit_id, label, unit_pos = to_label(action, obs)
                        if label is not None:
                            action_map[label, unit_pos[0], unit_pos[1]] = 1
                            mask[:, unit_pos[0], unit_pos[1]] = 1
                            
                    mask = mask.astype('bool')
                    action_map = action_map.astype('bool')
                    #if len(samples) < 210_000:
                    append((obs_id, action_map,mask))

    return obses, samples

In [51]:
episode_dir = '23281649'
obses, samples = create_dataset_from_json(episode_dir)
obses = obses
print('obses:', len(obses), 'samples:', len(samples))

obses: 128499 samples: 128499
