In [279]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
    
import os, sys, glob
import json
import re
import numpy as np
import pandas as pd
from natsort import natsorted
import tqdm
from manual_spellchecker import spell_checker

sys.path.append('/dartfs/rc/lab/F/FinnLab/tommy/isc_asynchrony_behavior/code/utils/')

from config import *
import dataset_utils as utils

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
EXPERIMENT_NAME = 'next-word-prediction'
EXPERIMENT_VERSION = 'final-multimodal-01'
TASK = 'wheretheressmoke'

gentle_dir = os.path.join(BASE_DIR, 'stimuli', 'gentle')
results_dir = os.path.join(BASE_DIR, 'experiments',  EXPERIMENT_NAME, 'results', EXPERIMENT_VERSION)
preproc_dir = os.path.join(BASE_DIR, 'stimuli', 'preprocessed')

# set the directories we need
models_dir = os.path.join(BASE_DIR, 'derivatives/model-predictions')


## Functions for loading data

In [4]:
# importing shutil module  
import shutil
from pathlib import Path

def check_used_files(experiment_name, experiment_version, task, clean_errors=False, max_missing_responses=5):
    '''
    Grabs used files based on the meta file and results directory
    Returns list of subjects that were used and ones that had errors
    '''
    
    checker = {
        'complete': [],
        'incomplete': [],
        'error': [],
        'missing': [],
    }

    meta_dir = os.path.join('/dartfs/rc/lab/F/FinnLab/tommy/jspsych_experiments/utils/experiment_meta/', experiment_name)
    meta_file = pd.read_csv(os.path.join(meta_dir, f'{experiment_version}-{task}.csv'))
    
    source_dir = os.path.join(BASE_DIR, 'stimuli',  'presentation_orders', experiment_version, task, 'jspsych')
    results_dir = os.path.join(BASE_DIR, 'experiments',  experiment_name, 'results', experiment_version)
    
    # grab the used files
    used_fns = meta_file[meta_file['used'].fillna(0).astype(bool)]
    
    approve_ids = []
    
    # go through each used file
    for i, fn in used_fns.iterrows():
        # grab info regarding subject name and modality
        curr_path = Path(fn['subject_fns'])
        sub = curr_path.stem.split('_')[0]
        modality = curr_path.parents[0].stem
        
        # find the corresponding parameter file for the current subject
        parameter_fn = glob.glob(os.path.join(source_dir, f'{sub}*.json'))
        assert (len(parameter_fn) == 1)
        
        # load the parameter file to compare to the subject's results
        df_parameters = pd.read_json(parameter_fn[0], orient='records')
        df_parameters = df_parameters.dropna()
        df_parameters['word_index'] = df_parameters['word_index'].dropna().astype(int)
        
        # then grab the subject results
        sub_results_dir = os.path.join(results_dir, task, modality, sub)
        
        # load results from the completed experiment
        try:
            current_id, demographics, experience, responses = load_participant_results(sub_results_dir, sub)
            
            # append if approving
            approve_ids.append(current_id)
        except:
#             if os.path.exists(sub_results_dir):
            checker['error'].append((i, modality, sub))
            continue
            
        # check that all indices of trials match and all responses are there
        all_trials_complete = np.all(responses['word_index'] == df_parameters['word_index'])
        missing_response_threshold = sum(pd.isnull(responses['response'])) <= max_missing_responses
        
#         all_responses_complete = np.all(~pd.isnull(responses['response']))
        
        # also ensure that we have the right amount of demographics/experience questions
        all_checks_complete = np.all([
            all_trials_complete, 
            missing_response_threshold, 
            len(demographics)==4,
            len(experience)==2,
        ])
        
        if all_trials_complete and missing_response_threshold:
            # add to list of people completed
            checker['complete'].append((i, modality, sub, current_id))
        else:
            checker['incomplete'].append((i, modality, sub, current_id))
            
        del current_id
        
    if clean_errors:
        clean_meta_errors(checker, experiment_name, experiment_version, task)
        
        # run again and return from here now that its updated
        return check_used_files(experiment_name, experiment_version, task, clean_errors=False)
    else:
        return checker, approve_ids

def clean_meta_errors(checker, experiment_name, experiment_version, task):
    
    meta_dir = os.path.join('/dartfs/rc/lab/F/FinnLab/tommy/jspsych_experiments/utils/experiment_meta/', experiment_name)
    meta_fn = os.path.join(meta_dir, f'{experiment_version}-{task}.csv')
    
    meta_file = pd.read_csv(meta_fn)
    
    results_dir = os.path.join(BASE_DIR, 'experiments',  experiment_name, 'results', experiment_version)
    
    errors = checker['error']
    modalities = ['text', 'audio']
    
    if any(checker['error']):
        
        remove_idxs, _, _ = zip(*checker['error']) 
        remove_idxs = list(remove_idxs)
        
        for modality in modalities:

            # get errors for the current modality
            modality_errors = [error for error in errors if error[1] == modality]
            errors_dir = os.path.join(results_dir, task, modality, 'error')

            # get new errors dir if previous has files in it
            batch_errors = sorted(glob.glob(os.path.join(errors_dir, '*')))

            if any(batch_errors):
                last_error_dir = Path(batch_errors[-1]).stem
            else:
                last_error_dir = 'batch_1'

            if any(glob.glob(os.path.join(errors_dir, last_error_dir, '*'))):
                curr_batch_num = int(last_error_dir.split('_')[-1]) + 1
                curr_error_dir = os.path.join(errors_dir, f'batch_{curr_batch_num}')
                os.makedirs(curr_error_dir)
            else:
                curr_error_dir = os.path.join(errors_dir, last_error_dir)

            for item in modality_errors:
                file_idx, modality, sub = item

                # then grab the subject results
                sub_results_dir = os.path.join(results_dir, task, modality, sub)

                if os.path.exists(sub_results_dir):
                    print ('Here')
                    shutil.move(sub_results_dir, curr_error_dir)


        print (f'Cleaned meta file!')
        meta_file.loc[list(remove_idxs), 'used'] = None
        meta_file['used'] = meta_file['used'].astype('Int64')
        meta_file.to_csv(meta_fn, index=False)


In [5]:
def load_participant_results(sub_dir, sub):
    
    # load and filter down to response trials
    df_results = pd.read_csv(os.path.join(sub_dir, f'{sub}_next-word-prediction.csv')).fillna(False)
    df_results['word_index'] = df_results['word_index'].astype(int)
    
    # grab the prolific id
    prolific_id = list(set(df_results['prolific_id']))   

    # filter down demographics
    demographics = df_results[df_results['experiment_phase'].str.contains('demographics').fillna(False)]
    demographics = demographics[['experiment_phase', 'response']].reset_index(drop=True)
    
    # age, race, ethnicity, gender
    assert (len(demographics) == 4)
    
    # filter down to questinos about moth/story experience
    experience = df_results[df_results['experiment_phase'].str.contains('experience').fillna(False)]
    experience = experience[['experiment_phase', 'response']].reset_index(drop=True)
    
    # moth experience + story experience
    assert (len(experience) == 2)
    
    # filter down to get the responses
    responses = df_results[df_results['experiment_phase'] == 'test']
    responses.loc[:,'response'] = responses['response'].str.lower()
    responses = responses[['critical_word', 'word_index', 'entropy_group', 'accuracy_group', 'response']].reset_index(drop=True)
    
    return prolific_id[0], demographics, experience, responses

def add_word_response(dict, key, value):
    
    if key in dict:
        dict[key].append(value)
    else:
        dict[key] = [value]
        
    return dict

def aggregate_participant_responses(results_dir, task, sub_mod_list):
    
#     MODALITIES = ['audio', 'text']
    
    df_results = pd.DataFrame(columns=['prolific_id', 'modality', 'subject',  'word_index', 'response', 'ground_truth', 'entropy_group', 'accuracy_group'])
    
    all_ids = []
    
    for sub, mod in sub_mod_list: 
        # go through each task and get participant data
        sub_dir = os.path.join(results_dir, task, mod, sub)
        print (sub, mod)
        if os.path.exists(sub_dir):
            current_id, demographics, experience, responses = load_participant_results(sub_dir, sub)
            responses['response'] = responses['response'].fillna('')

            # for right now only focus on responses
            for index, response, critical_word, entropy_group, accuracy_group in responses[['word_index', 'response', 'critical_word', 'entropy_group', 'accuracy_group']].values:

                df_results.loc[len(df_results)] = {
                    'prolific_id': current_id,
                    'modality': mod,
                    'subject': sub,
                    'word_index': index,
                    'response': response,
                    'ground_truth': critical_word.lower(),
                    'entropy_group': entropy_group, 
                    'accuracy_group': accuracy_group
                }
            
            all_ids.append(current_id)
        else:
            print (f'File not exists: {mod}, {sub}')
            
    return df_results

def get_human_probs(responses):
    
    unique, counts = np.unique(responses, return_counts=True)
    probs = counts / sum(counts)
    
    return probs, unique

def strip_punctuation(text):
    
    full_text = re.sub('[^A-Za-z0-9]+', '', text)
    
    return full_text

## Functions for cleaning data

In [None]:
import enchant, string, time
from IPython.display import clear_output

def clean_participant_responses(df_results, df_transcript):
    
    # grab indices of responses --> used to index back in
    response_indices = df_results['experiment_phase'] == 'test'
    response_indices = np.where(response_indices)[0]

    checked_indices = []
    
    # filter down to get the responses
    df_responses = df_results.iloc[response_indices, :].reset_index(drop=True)
    
    ##############################
    ###### Run spell-check #######
    ##############################

    print (f'##########################\n' +
           f'### Running spellcheck ###\n' +
           f'##########################\n\n')

    time.sleep(5)

    enc_dict = enchant.Dict("en_US")
    
    for index, df in df_responses.iterrows():

        response = df['response']
        
        # tokens = df['response'].split()
        if not enc_dict.check(response) or response in string.punctuation and index not in checked_indices:
            df = prompt_correct_response(df, df_transcript, enc_dict, prompt_correction=False)
            df_responses.iloc[df.name] = df
            checked_indices.append(index)
            

    ##############################
    ######## Find phrases ########
    ##############################

    clear_output(wait=False)
    print (f'##########################\n' +
           f'####### Find phrases #####\n' +
           f'##########################\n\n')

    time.sleep(5)

    # go through each row
    for index, df in df_responses.iterrows():
        response = df['response'].split()

        if len(response) > 1 and index not in checked_indices:
            df = prompt_correct_response(df, df_transcript, enc_dict, prompt_correction=False)
            df_responses.iloc[df.name] = df
            checked_indices.append(index)
        else:
            continue

    ##############################
    ######## Final check #########
    ##############################

    clear_output(wait=False)
    print (f'##########################\n' +
           f'####### Final check ######\n' +
           f'##########################\n\n')

    time.sleep(5)

    # go through each row
    for index, df in df_responses.iterrows():

        if index not in checked_indices:
            df = prompt_correct_response(df, df_transcript, enc_dict, prompt_correction=True)
            df_responses.iloc[df.name] = df
            checked_indices.append(index)
        
    df_results.iloc[response_indices] = df_responses

    return df_results

def prompt_correct_response(df_response, df_transcript, enc_dict, range_display=7, prompt_correction=False):
    
    word_index = df_response['word_index']
    response = df_response['response']
    ground_truth = df_response['critical_word']
    
    start_index = (word_index - range_display) if (word_index - range_display) >= 0 else 0 
    end_index = (word_index + range_display) if (word_index + range_display) - len(df_transcript) <= 0 else None
    
    # grab the context    
    start_context = df_transcript['Word_Written'].iloc[start_index:word_index]
    end_context = df_transcript['Word_Written'].iloc[word_index + 1:end_index]
    
    # display the word
    string_to_print = ""
    
    if start_index > 0:
        string_to_print+= ".... "
    
    string_to_print+= " ".join(start_context) + " " + "\033[43;30m" + response + "\033[m" + " " + " ".join(end_context)
    
    if end_index < len(df_transcript):
        string_to_print+= " ...."

    clear_output(wait=False)

    print("\n\nCurrent Word: " + string_to_print)
    print ("Ground Truth: ", ground_truth)
    
    # suggestions = enc_dict.suggest(misspelled_word)
    if prompt_correction:
        prompt_correction = input('\nNeeds correction? [y/n]: ')
    
    if prompt_correction == 'y' or prompt_correction == False:
        suggestions = enc_dict.suggest(response)
        print("\Suggestions: ", suggestions)
        correct_word = input("\nCorrect Version: ")
        
        if correct_word == "-999":
            break_flag = True
            sys.exit(0)
        elif correct_word.isdigit() and int(correct_word) < len(suggestions): # User wants to use suggestion
            df_response['response'] = df_response['response'].replace(response, suggestions[int(correct_word)-1])
        elif len(correct_word) == 0: # User wants to Skip
            return df_response
        elif correct_word == "''" or correct_word == '""': # User wants to remove the word
            df_response['response'] = df_response['response'].replace(response, "")
        else:
            df_response['response'] = df_response['response'].replace(response, correct_word)

    return df_response

## Load and clean data files

In [6]:
# check all the files
checker, approve_ids = check_used_files(EXPERIMENT_NAME, EXPERIMENT_VERSION, TASK, clean_errors=True)

# find the subject list based on complete files
file_idx, modality, sub_list, prolific_ids = zip(*checker['complete'])
sub_mod_list = list(zip(sub_list, modality))

print (len(set(approve_ids)))
print ('\n'.join(approve_ids))

300
5df0932af9a01d0dabd9313d
658dc5ae8ed0c5190d385266
60bf4ae16dbab59c9fdd4218
6346f314592ee4d8c3d84e57
5dccabc026eb869389043084
65a3ec70fa2315c1f83c19af
62a999b8ed876efe5ffe9599
654658c94edb85e1de07f87a
62b2f74cf3794e252cc7d1a7
644963b432e6a701b58ea4ac
65a19b9691a3f0e753743321
5e1297656e8aab8e8a1b3b76
62682277645054f5802459b8
56a8930d7f2472000c93764e
659aba3aa97e20e203171594
5e3afe879d5f1e30b75b9ca6
596155a998cf77000106f8d8
599a9252bbe848000179676e
5d8e154af2858200171fdb95
591c70f8f399850001c51444
59d4c100078dbe0001951236
5e3723a2c0b2896ad554be73
5ecd4b104b4dc408fcf4eb16
5c4b4d5c2cfe4d00018485cc
5dcddf9f4d51e40a5292727d
62a212208e8395cfb1c4e42b
643e64c922e6c2f53d73975f
5967e9a831394d0001abdb3f
61502dee7a3a5468adbb2222
5bec4ebcf2dba6000166d420
5a1342b7f2e3460001edbfd2
65a68cb02b70524088c53064
59d4b084115096000190dbce
62f108ab1d119c891d45ae6a
5f9e03f3517e0d3ad1eab67b
62eac867210f0e2e8d322d44
6147c5874b61952e42e9b2bd
5ea4cebb8944a8495280db42
5c60501d767686000100150c
60a6b2ed58d0aaed52f1d

## Load and clean participants data

'/dartfs/rc/lab/F/FinnLab/tommy/isc_asynchrony_behavior/experiments/next-word-prediction/results/final-multimodal-01'

In [281]:
cleaned_results_dir = os.path.join(BASE_DIR, 'experiments',  EXPERIMENT_NAME, 'cleaned-results', EXPERIMENT_VERSION)

df_transcript = pd.read_csv(os.path.join(preproc_dir, TASK, f'{TASK}_transcript-preprocessed.csv'))

for i, (sub, modality) in enumerate(sub_mod_list):

    sub_cleaned_dir = os.path.join(cleaned_results_dir, TASK, modality, sub)
    out_fn = os.path.join(sub_cleaned_dir, f'{sub}_next-word-prediction.csv')
    
    utils.attempt_makedirs(sub_cleaned_dir)

    if os.path.exists(out_fn):
        print (f'File exists: {modality} {sub}')
        continue
    else:
        print (f'Correcting: {modality} {sub}')

    ############################
    #### Load subject data #####
    ############################
    
    sub_dir = os.path.join(results_dir, TASK, modality, sub)

    # load and filter down to response trials
    df_results = pd.read_csv(os.path.join(sub_dir, f'{sub}_next-word-prediction.csv')).fillna(False)
    df_results['word_index'] = df_results['word_index'].astype(int)

    ############################
    ###### Check responses #####
    ############################

    df_results = clean_participant_responses(df_results, df_transcript)

    ############################
    #### Save cleaned data #####
    ############################
    
    df_results.to_csv(out_fn, index=False)

    print (f'Saved file for {modality} {sub}')
    # if os.path.exists(sub_dir):
    #     current_id, demographics, experience, responses = load_participant_results(sub_dir, sub)
    #     responses['response'] = responses['response'].fillna('')



Current Word: .... as the door opens I hear the [43;30mtv[m of television come out and on ....
Ground Truth:  blare
\Suggestions:  ['TV', 't', 'v', 'ts', 'ti', 'iv', 'ta', 'av', 'tn', 'tr', 'to', 'Av']


Exception ignored in: <function Dict.__del__ at 0x2ade688c0af0>
Traceback (most recent call last):
  File "/dartfs/rc/lab/F/FinnLab/tommy/conda/envs/dark_matter/lib/python3.9/site-packages/enchant/__init__.py", line 555, in __del__
    try:
KeyboardInterrupt: 


KeyboardInterrupt: Interrupted by user