In [1]:
%load_ext autoreload
%autoreload 2

In [32]:
import os
import json
import glob
import config
import re
import jax.numpy as jnp
import nicewebrl
from typing import List
from collections import defaultdict
from flax import serialization
import polars as pl

import jax
from experiment_structure import jax_web_env, env_params

In [3]:
import glob

pattern = '2547624190'
metadata_files = sorted(glob.glob(f"{config.DATA_DIR}/*{pattern}*.json"))
data_files = sorted(glob.glob(f"{config.DATA_DIR}/*{pattern}.msgpack"))
# data_files
assert len(data_files)==len(metadata_files), "either data or metadata is missing for a user"

In [17]:
rng = jax.random.PRNGKey(0)
example_timestep = jax_web_env.reset(rng, env_params)

In [40]:
def get_user_id(filepath):
    pattern = r'\d+'
    match = re.search(pattern, file_path)
    if match:
        return match.group()


def seperate_data_into_episodes(data: List[dict]):
  """This function will group episodes by the values in the datum dictionaries
  """
  key_to_episodes_unprocessed = defaultdict(list)
  episode_idx = -1
  keys = set()
  all_episode_information = dict()
  # first group all of the data based on which (stage, block) its in
  for datum in data:
    # This function will be used to group datapoints of an individual episode
    stage_episode_idx=datum["metadata"]["nepisodes"]
    stage_idx = datum['stage_idx']
    key = f"stage={stage_idx}_episode={stage_episode_idx}"
    if key not in keys:
      episode_idx += 1
      keys.add(key)
    info = dict(
        episode_idx=episode_idx,
        stage_episode_idx=stage_episode_idx,
        stage_idx=stage_idx)

    key_to_episodes_unprocessed[key].append(datum)
    all_episode_information[key] = info
  return key_to_episodes_unprocessed, all_episode_information

def deserialize_timestep(datum):
  timestep = datum["data"]["timestep"]
  timestep = serialization.from_bytes(example_timestep, timestep)

  return timestep

df_data = []
for metadatafile, data_file in zip(metadata_files, data_files):

    # extract relevant metadata
    with open(metadatafile, 'r') as f:
        metadata = json.load(f)
    user_data = dict(
        user_id = metadata['user_storage']['seed'],
        model_chosen = metadata['user_storage']['selected_model'],
    )
    
    data = nicewebrl.load_data(data_file)

    
    # feed form data from user
    feedback_data = data[-1]['data']
    df_data.append(dict(**user_data, stage_type='feedback', **feedback_data))
    

    # get episode information
    key_to_episodes_unprocessed, episode_information = seperate_data_into_episodes(data[:-1])
    episode_data = [None] * len(key_to_episodes_unprocessed.keys())

    for key in key_to_episodes_unprocessed.keys():
        episode_idx = episode_information[key]["episode_idx"]
        raw_episode_trials_data = key_to_episodes_unprocessed[key]

        actions = jnp.asarray([datum["data"]["action_idx"] for datum in raw_episode_trials_data])

        # Compute reaction times (T)
        reaction_times = [nicewebrl.compute_reaction_time(datum['data']) for datum in raw_episode_trials_data]
        reaction_times = jnp.asarray(reaction_times)

        # Compute episode data (T)
        timesteps = [
          deserialize_timestep(datum) for datum in raw_episode_trials_data
        ]
        timesteps = jax.tree_map(lambda *v: jnp.stack(v), *timesteps)

        df_data.append(dict(
            total_reward=timesteps.reward.sum(),
            success=(timesteps.reward > .5).any(-1),  
            reaction_times=str(reaction_times),
            stage_type='interaction',
            **user_data
        ))
            

dataframe = pl.DataFrame(df_data)
dataframe

user_id,model_chosen,stage_type,How helpful was the AI?,How human-like was the AI?,total_reward,success,reaction_times
i64,str,str,i64,i64,f64,f64,str
2547624190,"""claude""","""feedback""",2.0,4.0,,,
2547624190,"""claude""","""interaction""",,,0.0,0.0,"""[820. 499. 86. 84. 82. 83.…"
2547624190,"""claude""","""interaction""",,,0.0,0.0,"""[59. 81. 86. 84. 80. 86. 79. 8…"
2547624190,"""claude""","""interaction""",,,0.0,0.0,"""[42. 86. 85. 81. 85. 86. 89. 7…"


In [21]:
# len(data)
data[-1]['data']

{'How helpful was the AI?': 2, 'How human-like was the AI?': 4}

In [24]:
data[-2]['data'].keys()

dict_keys(['image_seen_time', 'action_taken_time', 'computer_interaction', 'action_name', 'action_idx', 'timelimit', 'timestep'])

In [11]:
data[-1]['name']

'Feedback'

In [20]:
data[-2].keys()

dict_keys(['stage_idx', 'session_id', 'data', 'user_data', 'metadata', 'name', 'body'])