In [4]:
import os
import sys
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# add local directories to sys.path to allow module imports
sys.path.append('code')
sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab3-group6/code")

from BERT.data import TextDataset
from finetune_bert_utils import (
    get_sliding_window_embeddings,
    aggregate_embeddings,
    downsample_word_vectors_torch,
    load_fmri_data,
    get_fmri_data
)

# define the base path for data access
data_path = '/ocean/projects/mth240012p/shared/data'

In [5]:
# load raw text data from shared directory
with open(f'{data_path}/raw_text.pkl', 'rb') as file:
    wordseqs = pickle.load(file)

# get list of story names (removing .pkl extension)
stories = [i[:-4] for i in os.listdir(f'{data_path}/subject2')]

# split stories into training, validation, and test sets
train_stories, temp_stories = train_test_split(stories, train_size=0.6, random_state=214)
val_stories, test_stories = train_test_split(temp_stories, train_size=0.5, random_state=214)

# map story names to integer indices
story_name_to_idx = {story: i for i, story in enumerate(stories)}

In [38]:
def read_values(score, prefix, subj):
    """
    Load model interpretation scores (e.g., LIME or SHAP) for a subject's story.
    
    Args:
        score (str): Score type ("lime" or "shap").
        prefix (str): Story name prefix.
        subj (int): Subject ID.

    Returns:
        np.ndarray: Loaded scores.
    """
    path = f"{score}_{prefix}_subj{subj}.pkl"
    with open(path, "rb") as f:
        values = pickle.load(f)
    return values

In [55]:
def story_to_df(story):
    """
    Convert a story's word chunks into a DataFrame of word-level information.
    
    Args:
        story (str): Story name.

    Returns:
        pd.DataFrame: DataFrame with chunk ID, word ID, and word text.
    """
    chunks = wordseqs[story].chunks()
    valid_chunks = chunks[5:-10] # trim beginning and end for correct dimensions

    chunk_ids = []
    word_ids = []
    words = []
    
    for t, chunk in enumerate(valid_chunks):
        i = 0
        if chunk.size == 0:
            words.append(None)
            word_ids.append(i+1)
            chunk_ids.append(t+1)
        else:
            for word in chunk:
                word = word if word else None
                words.append(word)
                word_ids.append(i+1)
                chunk_ids.append(t+1)
                i += 1

    df = pd.DataFrame({"chunk_id": chunk_ids, "word_id": word_ids, "word": words})
    return df

In [56]:
def values_to_df(vals, story, av=True):
    """
    Convert interpretation scores and story text to a merged DataFrame.

    Args:
        vals (np.ndarray): Interpretation score values.
        story (str): Story name.
        av (bool): Whether to take absolute values before averaging across features.

    Returns:
        pd.DataFrame: Merged DataFrame of scores and story words.
    """
    vals = np.abs(vals) if av else vals
    mean_vals = np.mean(vals, axis=1) # average across features (embedding dimension)
    
    columns = [f"v_{i}" for i in range(1, mean_vals.shape[1]+1)]
    df_values = pd.DataFrame(mean_vals, columns=columns)
    df_values["chunk_id"] = df_values.index + 1

    df_story = story_to_df(story)

    # ensure alignment of chunks between story text and interpretation values
    assert df_values["chunk_id"].max() == df_story["chunk_id"].max(), f"chunk_id mismatch: {df_values["chunk_id"].max()} vs. {df_story["chunk_id"].max()}"
    
    df = pd.merge(df_values, df_story, on='chunk_id', how='left')
    return df

In [61]:
# process and save score data for selected test stories and subjects
select_stories = test_stories[:2]
for story in select_stories:
    prefix = story[:4]
    for subj in [2,3]:
        score = "lime" # "shap" or "lime"
        vals = read_scores(score, prefix, subj)
        df = scores_to_df(vals, story)
        path_df = f"{score}_{prefix}_subj{subj}.csv"
        df.to_csv(path_df, index=False)
        print(f"Saved {path_df}")

Saved lime_buck_subj2.csv
Saved lime_buck_subj3.csv
Saved lime_laws_subj2.csv
Saved lime_laws_subj3.csv
