In [None]:
import numpy as np

In [None]:
import os

from pathlib import Path

from common import GameReplay, Entity, GridView, Controls, to_entities

from typing import List
from numpy.typing import NDArray

def get_player(entities: List[Entity], index) -> Entity:
  for e in entities:
    if e.type != 'archer':
      continue
    if e['playerIndex'] != index:
      continue
    return e
  raise Exception('Player not present: ()'.format(index))

def match_shape(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:
  if len(data) == 0:
    return True
  
  for k, v in entry.items():
    if k not in data:
      raise Exception(k, 'not present in data')
    if data[k][-1].shape != v.shape:
      # raise Exception('Shape mismatch. Expected: {}. Actual: {}'.format(data[k][-1].shape, v.shape))
      print('Shape mismatch. Expected: {}. Actual: {}'.format(data[k][-1].shape, v.shape))
      return False
  return True
  
def should_extend_input(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:
  if not match_shape(data, entry):
    return False
  
  if abs(np.max(entry['vel'])) > 30:
    return False
  
  # if np.max(entry['jump']) > 0:
  #   return False
  
  # if np.max(entry['wall']) > 0:
  #   return False
  return True

def should_extend_output(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:
  if not match_shape(data, entry):
    return False
  
  if np.linalg.norm(entry['dpos']) > 50:
    return False
  return True

def is_same_as_last(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:
  for k, v in entry.items():
    if k not in data:
      return False
    if not np.array_equal(data[k][-1], v):
      return False 
  return True
  
def extend_data(data: dict[str, List[NDArray]], entry: dict[str, NDArray]) -> bool:    
  for k, v in entry.items():
    if k not in data:
      data[k] = [v]
    else:
      data[k].append(v)
      
  return True

def process_replay(filepath, inputs: dict[str, List[NDArray]], outputs: dict[str, List[NDArray]]):
  replay = GameReplay()
  replay.load(filepath)
  print('actions:', len(replay.actions))
  index: int = replay.state_init['index']
  gv = GridView(1)
  gv.set_scenario(replay.state_scenario)
  
  control = Controls()
  for i, (state, commands) in enumerate(zip(replay.state_update, replay.actions)):
    if i >= len(replay.actions)-3:
      break
    # print(commands)
    control.parse_command(commands)
    entities = to_entities(state['entities'])
    me: Entity = get_player(entities, index)
    gv.update(entities, me)
    # m, n = me.s.tupleint()
    # sight_margin = 5
    # sight = (m//2 + sight_margin, n//2 + sight_margin)
    # immediate_wall = gv.view(sight)
    # left_wall = immediate_wall[:sight_margin, sight_margin:-sight_margin]
    # right_wall = immediate_wall[-sight_margin:, sight_margin:-sight_margin]
    # bot_wall = immediate_wall[:, :sight_margin]
    # top_wall = immediate_wall[:, -sight_margin]

    if i > 0:
      output_state = {
        'dpos': me.p.array() - pos_prev,
        'vel': me.v.array()
      }
      # if (should_extend_input(inputs, input_state) and should_extend_output(outputs, output_state) and 
      #     (not is_same_as_last(inputs, input_state) or not is_same_as_last(outputs, output_state))):
      # if (should_extend_input(inputs, input_state) and should_extend_output(outputs, output_state)):
      extend_data(inputs, input_state)
      extend_data(outputs, output_state)

    input_state = {
      # 'i': i,
      'wall': immediate_wall,
      'vel': me.v.array(),
      'dir': control.direction(),
      'jump': control.jump_state(),
      'onledge': 1 if me['state'] == 'ledgeGrab' else 0
      # 'dash': control.dash_state()
    }
    pos_prev = me.p.array()


inputs = {}
outputs = {}

replay_name = 'only_side_modes'
data_name = 'side_moves_jump'

replay_dir = os.path.join('replays', replay_name)
for f in os.listdir(replay_dir):
  print('Process replay:', f)
  process_replay(os.path.join(replay_dir, f), inputs, outputs)
  break

import shutil

def save_data(name: str, type: str, data: dict[str, List[NDArray]]):
  dir_path = os.path.join('data', name, type)
  if os.path.exists(dir_path):
    shutil.rmtree(dir_path)
    
  os.makedirs(dir_path)
  for k, v in data.items():
    np.save(os.path.join(dir_path, k), np.array(v))

len_inputs = len(next(iter(inputs.values())))
len_outputs = len(next(iter(outputs.values())))
print('inputs', len_inputs)
print('outputs', len_outputs)
assert len_inputs == len_outputs
save_data(data_name, 'input', inputs)
save_data(data_name, 'output', outputs)
# print(inputs)


In [None]:
np.linalg.norm(np.array([1, 318]))