In [1]:
import numpy as np

In [21]:
import os

from pathlib import Path

from common import GameReplay, Entity, GridView, Controls, to_entities

from typing import List
from numpy.typing import NDArray

vels: set = set()

def get_player(entities: List[Entity], index) -> Entity:
  for e in entities:
    if e.type != 'archer':
      continue
    if e['playerIndex'] != index:
      continue
    return e
  raise Exception('Player not present: ()'.format(index))

def extend_data(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:
  for k, v in entry.items():
    if k in data and data[k][-1].shape != v.shape:
      # raise Exception('Shape mismatch. Expected: {}. Actual: {}'.format(data[k][-1].shape, v.shape))
      print('Shape mismatch. Expected: {}. Actual: {}'.format(data[k][-1].shape, v.shape))
      return False
    
  for k, v in entry.items():
    if k not in data:
      data[k] = [v]
    else:
      data[k].append(v)
      
  return True

def process_replay(filepath, inputs: dict[str, List[NDArray]], outputs: dict[str, List[NDArray]]):
  replay = GameReplay()
  replay.load(filepath)
  index: int = replay.state_init['index']
  gv = GridView(1)
  gv.set_scenario(replay.state_scenario)
  
  control = Controls()
  for i, (state, commands) in enumerate(zip(replay.state_update, replay.actions)):
    if i >= len(replay.actions)-3:
      break
    # print(commands)
    control.parse_command(commands)
    entities = to_entities(state['entities'])
    me = get_player(entities, index)
    gv.update(entities, me)
    m, n = me.s.tupleint()
    sight_margin = 5
    sight = (m//2 + sight_margin, n//2 + sight_margin)
    immediate_wall = gv.view(sight)
    # left_wall = immediate_wall[:sight_margin, sight_margin:-sight_margin]
    # right_wall = immediate_wall[-sight_margin:, sight_margin:-sight_margin]
    # bot_wall = immediate_wall[:, :sight_margin]
    # top_wall = immediate_wall[:, -sight_margin]

    if i > 0:
      vels.add(me.v)
      pos_curr = me.p.copy()
      pos_curr.sub(pos_prev)
      output_state = {
        'dpos': pos_curr.array(),
        'vel': me.v.array()
      }
      if extend_data(inputs, input_state):
        extend_data(outputs, output_state)

    input_state = {
      'wall': immediate_wall,
      'vel': me.v.array(),
      'dir': control.direction(),
      # 'jump': control.jump_state(),
      # 'dash': control.dash_state()
    }
    pos_prev = me.p.copy()


inputs = {}
outputs = {}

experiment_name = 'only_side_moves'

replay_dir = os.path.join('replays', experiment_name)
for f in os.listdir(replay_dir):
  print('Process replay:', f)
  process_replay(os.path.join(replay_dir, f), inputs, outputs)
  break

def save_data(name: str, type: str, data: dict[str, List[NDArray]]):
  for k, v in data.items():
    dir_path = Path(os.path.join('data', name, type))
    dir_path.mkdir(parents=True, exist_ok=True)
    np.save(dir_path.joinpath(k), np.array(v))
    
save_data(experiment_name, 'inputs', inputs)
save_data(experiment_name, 'outputs', outputs)
# print(inputs)


Process replay: 16784799700.json


AttributeError: 'NoneType' object has no attribute 'shape'