In [1]:
from PIL import Image
import torchvision.transforms as T
import os
import numpy as np

In [2]:
run_id = "1653389330"
root_path = "../data/"
in_path = f"{root_path}{run_id}/"

processed_path = f"{root_path}processed/{run_id}/"
states_path, frames_path = f"{in_path}states", f"{in_path}images"

try:
    os.mkdir(processed_path)
except OSError as error:
    pass

In [4]:
def calculate_distance(px, py, dx, dy):
    xdist = dx - px
    ydist = dy - py
    
    return xdist, ydist

def process_state(did, ids):
    states = np.load(os.path.join(states_path, f"{did}.npy"))
    new_states = np.zeros((len(ids), 4))  # len(ids) = 400 step, [dgx, dgy, vx, vy]
    
    for new_state_id, old_state_id in enumerate(ids):
        # states -> [timesteps, people, 7]
        # 1 state: [px, py, vx, vy, dx, dy, tau]
        dgx, dgy = calculate_distance(states[old_state_id, 0, 0], states[old_state_id, 0, 1], 
                                     states[old_state_id, 0, 4], states[old_state_id, 0, 5])
        vx, vy = states[old_state_id, 0, 2], states[old_state_id, 0, 3]
        
        new_states[new_state_id, :] = np.array([dgx, dgy, vx, vy])
        
    np.save(f'{processed_path}{did}/states.npy', new_states)

def process_frame(did, ids):
    image = Image.open(os.path.join(frames_path, f"{did}.gif"))
    
    transform = T.Compose([T.CenterCrop(256), T.Resize(128), T.Grayscale()])
    
    for new_frame_id, old_frame_id in enumerate(ids):
        image.seek(old_frame_id)
        frame = transform(image)
        frame.save(f'{processed_path}{did}/{new_frame_id}.jpg')
    
    image.close()

In [5]:
for filename in os.listdir(frames_path):
    if filename.endswith(".gif"):
        demonstration_id = int(filename.strip(".gif"))
        
        image = Image.open(os.path.join(frames_path, filename))
        nof_steps = image.n_frames
        image.close()
        
        ids = np.linspace(1, nof_steps-1, num=400, dtype=int)  # 400 steps on each trajectory
        
        try:
            os.mkdir(f'{processed_path}{demonstration_id}')
        except OSError as error:
            pass
        
        process_state(demonstration_id, ids)
        process_frame(demonstration_id, ids)