# Setup

In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
# ensure dependencies are installed
#!pip install polars
#!pip install pandas
#!pip install numpy
#!pip install matplotlib
#!pip install flax
#!pip install scikit-image
#!pip install git+https://github.com/wcarvalho/nicewebrl
#!pip install git+https://github.com/wcarvalho/JaxHouseMaze.git
#!pip install --upgrade git+https://github.com/wcarvalho/JaxHouseMaze.git
#!pip install --upgrade --force --no-cache-dir git+https://github.com/wcarvalho/nicewebrl

In [5]:
# import packages
import os
from os.path import join as opj
import sys

# add parent directory to current path
cdir=os.getcwd()
parent_dir = os.path.dirname(cdir)
sys.path.append(parent_dir)
print(parent_dir)

from typing import List
import numpy as np

import json
import copy
import load_data
from pprint import pprint
import matplotlib.pyplot as plt
import polars as pl
import pandas as pd
import experiment_1 as experiment

from collections import defaultdict
import jax.numpy as jnp
import jax.tree_util as jtu
import jax

from load_data import get_timestep, compute_reaction_time, EpisodeData
from load_data import render_episode, get_task_object, object_idx_to_name, success_fn, get_task_room, create_maps
from load_data import filter_episodes


/Users/hall-mcmaster/Documents/Projects/Maze/human-dyna-web


# Define functions

In [6]:
def separate_data_by_block_stage(data: List[dict]):
    """This function will group episodes by the values from get_block_stage_description

    The input i
    So for example, each episode with {'stage': "'not obvious' shortcut",
     'block': 'shortcut',
     'manipulation': 1,
     'episode_idx': 1,
     'eval': True}
     with go into its own list.
    """
    grouped_data = defaultdict(list)
    episode_idx = -1
    keys = set()
    infos = dict()
    # first group all of the data based on which (stage, block) its in
    for datum in data:
        info = load_data.get_block_stage_description(datum)
        key = load_data.dict_to_string(info)
        if not key in keys:
            episode_idx += 1
            keys.add(key)
        info['global_episode_idx'] = episode_idx
        
        updated_key = load_data.dict_to_string(info)
        grouped_data[updated_key].append(datum)
        infos[updated_key] = info
    return grouped_data, infos


In [7]:
def make_episode_data(data: List[dict], sub_ID=100):
    """This groups all of the data by block/stage information and prepares 
        (1) a list of EpisodeData objects per block/stage
        (2) a dataframe which summarizes all episode information.

    The dataframe can be used to get indices into the list of EpisodeData for further computation.
    """
    gds, gd_infos = separate_data_by_block_stage(data)

    episode_data = [None]*len(gds.keys())
    episode_info = [None]*len(gds.keys())
    for key in gds.keys():
        red = raw_episode_data = gds[key]
        
        # get actions
        actions = jnp.asarray([datum['action_idx'] for datum in red])

        # collect timesteps
        timesteps = [get_timestep(datum) for datum in red]
        
        # combine them into trajectory
        timesteps = jtu.tree_map(
                lambda *v: jnp.stack(v), *timesteps)
        
        positions = timesteps.state.agent_pos

        reaction_times = [compute_reaction_time(datum) for datum in red]
        reaction_times = jnp.asarray(reaction_times)

        episode_idx = gd_infos[key]['global_episode_idx']
        episode_data[episode_idx] = EpisodeData(
            actions=actions,
            positions=positions,
            reaction_times=reaction_times,
            timesteps=timesteps,
            
        )

        ###############################################################
        # Select variables for episode level dataframe
        ###############################################################
        datum0 = red[0] 
        # add user ID, age and sex
        info = datum0['user_data']
        # replace ID from the raw data with an anonymised subject value
        info['user_id']=sub_ID
        info.update(copy.deepcopy(gd_infos[key]))
        info.update(
            goal_object_numeric=get_task_object(timesteps),
            goal_object_string=object_idx_to_name(get_task_object(timesteps)),
            goal_object_reached=success_fn(episode_data[episode_idx].timesteps),
            actions=actions.tolist(),
            reaction_times=reaction_times.tolist(),
            room=get_task_room(timesteps),
            positions=positions.tolist(),
        )
        
        # add overlap?
        # should this be constructed at the episode level or can we do this later if we have the positions?
        # needed to assess manipulations 1 and 3, where its about whether people took the old path or a new one

        # add?
        #'timesteps', pretty complex variable, any worth having from there at the episode level
        
        episode_info[episode_idx] = info

    episode_info = pd.DataFrame(episode_info)
    return episode_data, episode_info

In [8]:
# locate relevant files
def list_files(directory):
    fnames=[]
    for root, dirs, files in os.walk(directory):
        for file in files:
            cfile=opj(root,file)
            fnames.append(cfile)
    return fnames

# Preprocess files

In [10]:
# establish input directory
cdir=os.getcwd()
project_dir=os.path.dirname(cdir)
maze_dir=os.path.dirname(project_dir)
exp_string='pilot-1-subset'
data_dir=opj(maze_dir,'human-dyna-web-raw-data', exp_string)
file_list=list_files(data_dir)
#print(file_list)

# establish input directory
exp_string_out='pilot-1'
output_dir=opj(project_dir,'data', exp_string_out)
os.makedirs(output_dir, exist_ok=True)

# open each file in the list
sub_count=0
for file in file_list:
    with open(file, 'r') as f:
        data_dicts = json.load(f)

        # remove all blocks that are "practice"
        data_dicts = [row for row in data_dicts if not 'practice' in row['metadata']['block_metadata']['desc']]
    
        # extract data
        # NOTE: manipulation room is __always__ room 0, even if task objects differ
        sub_count+=1
        print('Extracting data from sub' + str(sub_count))
        all_episode_data, all_episode_info = make_episode_data(data_dicts, sub_ID=sub_count)

        # minor changes to dataframe order and naming
        all_episode_info = all_episode_info.rename(columns={"user_id": "ID"})
        column_order = [
            "ID", 
            "age", 
            "sex", 
            "block", 
            "manipulation", 
            "stage", 
            "eval", 
            "episode_idx", 
            "global_episode_idx", 
            "goal_object_string", 
            "goal_object_numeric",
            "goal_object_reached",
            "actions",
            "reaction_times",
            "room",
            "positions"
            
        ]
        all_episode_info = all_episode_info[column_order]

        # save output file
        filename='sub-' + str(sub_count).zfill(2) + '.csv'
        full_path = opj(output_dir, filename)
        all_episode_info.to_csv(full_path, index=False)

Extracting data from sub1
Extracting data from sub2


KeyboardInterrupt: 

In [25]:
all_episode_info.head()

Unnamed: 0,ID,age,sex,block,manipulation,stage,eval,episode_idx,global_episode_idx,goal_object_string,goal_object_numeric,goal_object_reached,actions,reaction_times,room,positions
0,1,38,Male,probing for planning near goal,4,training,False,1,0,orange,29,True,"[1, 2, 1, 2, 2, 1, 2, 0, 3, 0, 0, 1, 1, 1, 0, ...","[1116.0, 283.0, 251.0, 249.0, 243.0, 332.0, 37...",0,"[[0, 11], [1, 11], [1, 10], [2, 10], [2, 9], [..."
1,1,38,Male,probing for planning near goal,4,training,False,2,1,orange,29,True,"[3, 3, 2, 2, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 2, ...","[113.0, 191.0, 766.0, 177.0, 208.0, 172.0, 194...",0,"[[0, 12], [0, 12], [0, 12], [0, 11], [0, 10], ..."
2,1,38,Male,probing for planning near goal,4,training,False,3,2,lettuce,46,True,"[1, 1, 2, 2, 2, 3, 3, 3, 2, 3, 2, 3, 3, 0]","[1379.0, 237.0, 192.0, 184.0, 205.0, 230.0, 21...",2,"[[3, 7], [4, 7], [5, 7], [5, 6], [5, 5], [5, 4..."
3,1,38,Male,probing for planning near goal,4,training,False,4,3,orange,29,True,"[2, 3, 2, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, ...","[1796.0, 418.0, 192.0, 172.0, 265.0, 202.0, 22...",0,"[[8, 11], [8, 10], [7, 10], [7, 9], [7, 8], [8..."
4,1,38,Male,probing for planning near goal,4,training,False,5,4,lettuce,46,True,"[1, 1, 2, 2, 2, 3, 3, 2, 3, 2, 3, 3, 3]","[1266.0, 202.0, 162.0, 187.0, 222.0, 250.0, 18...",2,"[[3, 7], [4, 7], [5, 7], [5, 6], [5, 5], [5, 4..."
