In [1]:
# imports
import os
import sys
import pickle
import time
import random
import copy
import itertools

import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from transformers import BertTokenizerFast, BertModel
from peft import LoraConfig, get_peft_model, TaskType
import matplotlib.pyplot as plt

# add local paths for module imports
sys.path.append('code')
sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab3-group6/code")

# import project-specific utilities
from BERT.data import TextDataset
from finetune_bert_utils import (
    get_sliding_window_embeddings,
    aggregate_embeddings,
    downsample_word_vectors_torch,
    load_fmri_data,
    get_fmri_data
)

# set computation precision and device
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

# set base path for data access
data_path = '/ocean/projects/mth240012p/shared/data'

# interpretatiion-specific packages
import shap
import lime
from lime.lime_tabular import LimeTabularExplainer
import pandas as pd
from scipy.stats import pearsonr

In [2]:
# load word sequences for each story
with open(f'{data_path}/raw_text.pkl', 'rb') as file:
    wordseqs = pickle.load(file)  # {story_id: WordSequenceObject}

# extract story IDs and split them into train/val/test sets
stories = [i[:-4] for i in os.listdir(f'{data_path}/subject2')]
train_stories, temp_stories = train_test_split(stories, train_size=0.6, random_state=214)
val_stories, test_stories = train_test_split(temp_stories, train_size=0.5, random_state=214)

# mapping from story names to index
story_name_to_idx = {story: i for i, story in enumerate(stories)}

In [3]:
# tokenizer and base BERT model
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
base_model = BertModel.from_pretrained(model_name)

In [4]:
# prepare dataset for training
train_text = [" ".join(wordseqs[i].data).strip() for i in train_stories]
train_dataset = TextDataset(train_text, tokenizer, max_len=sys.maxsize)

In [5]:
# trimming indices to exclude padding
trim_range = (5, -10)

# initialize tokenizer (is this used?)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
embeddings = {}

# tokenize full text of all stories
texts = []
for story in stories:
    words = wordseqs[story].data
    texts.append(" ".join(words).strip())
    tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
    token_per_word = [len(i) for i in tokens]
    
tokenlized_stories = tokenizer(texts, add_special_tokens=False, padding="longest", truncation=False, max_length=sys.maxsize,
                               return_token_type_ids=False, return_tensors="pt")
input_ids = tokenlized_stories["input_ids"].to(device)
attention_mask = tokenlized_stories["attention_mask"].to(device)

In [6]:
# embedding and feature extraction
def forward_pass(current_stories, base_model):
    """
    Extract downsampled and aggregated word embeddings for a set of stories.

    Args:
        current_stories (list of str): List of story IDs.
        base_model (transformers.BertModel): Pretrained BERT model.

    Returns:
        torch.Tensor: Aggregated feature representations.
    """
    idx = torch.tensor([story_name_to_idx[story] for story in current_stories], device=input_ids.device)
    selected_input_ids = input_ids[idx].to(input_ids.device)
    selected_attention_mask = attention_mask[idx].to(attention_mask.device)
    
    #idx = [story_name_to_idx[story] for story in current_stories]
    #embeddings = get_sliding_window_embeddings(base_model, input_ids[idx], attention_mask[idx])

    print(base_model.device)
    print(selected_input_ids.device)
    print(selected_attention_mask.device)

    embeddings = get_sliding_window_embeddings(base_model, selected_input_ids, selected_attention_mask)

    features = {}
    for i, story in enumerate(current_stories):
        words = wordseqs[story].data
        tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
        token_per_word = [len(i) for i in tokens]
        story_embeddings = embeddings[i]
        word_embeddings = []
        start = 0
        for i in token_per_word:
            end = start + i
            if i != 0:
                word_embedding = story_embeddings[start:end].mean(dim=0)
            else:
                word_embedding = torch.zeros(story_embeddings.size(1), device=device)
            word_embeddings.append(word_embedding)
            start = end
        
        features[story] = torch.stack(word_embeddings)#.cpu().numpy()

    features = downsample_word_vectors_torch(current_stories, features, wordseqs)
    for story in current_stories:
        features[story] = features[story][trim_range[0]:trim_range[1]]

    aggregated_features = aggregate_embeddings(features, current_stories)
    return aggregated_features

In [7]:
# load fMRI data
fmri_data = load_fmri_data(test_stories, data_path)

In [8]:
# LoRA config and classifier loading
weight_decay = 1e-2
lora_model = True
lora_rank = 8

if lora_model:
    config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank * 2,
    target_modules=['query', 'value'],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION
    )
    ckpt = torch.load(f'/jet/home/azhang19/stat 214/stat-214-lab3-group6/code/classifier_ckpt/best_lora_wd{weight_decay}_r{lora_rank}.pth', weights_only=False)
    # was /ocean/projects/mth240012p/azhang19/lab3/classifier_ckpts
    classifiers = {i: ckpt[i]['classifier_module'] for i in ckpt.keys()}
    lora_weights = {i: ckpt[i]['lora_state_dict'] for i in ckpt.keys()}
else:
    classifiers = torch.load(f'/jet/home/azhang19/stat 214/stat-214-lab3-group6/code/classifier_ckpt/best_classifiers{weight_decay}.pth', weights_only=False)
    # was /ocean/projects/mth240012p/azhang19/lab3/classifier_ckpts

In [9]:
# make predictions on fMRI data
def make_prediction(pred_stories):
    """
    Generate fMRI predictions and return ground-truth values.

    Args:
        pred_stories (list of str): List of story IDs to evaluate.

    Returns:
        tuple of dict: (predicted_fmri, true_fmri) keyed by subject ID.
    """
    with torch.inference_mode():
        pred_fmri = {}
        true_fmri = {}
        for subj in fmri_data.keys():
            base_model = BertModel.from_pretrained(model_name).to(device).eval()
            if lora_model:
                base_model = get_peft_model(base_model, config).to(device)
                base_model.load_state_dict(lora_weights[subj])
                base_model.eval()
        
            features = forward_pass(pred_stories, base_model)
            # added base_model as parameter to avoid cpu/cuda device mismatch
            pred_fmri[subj] = classifiers[subj](features)
            true_fmri[subj] = get_fmri_data(pred_stories, fmri_data)[subj]
        #print(pred_fmri['subject2'].shape)
        #print(pred_fmri['subject3'].shape)
        #print(true_fmri['subject2'].shape)
        #print(true_fmri['subject3'].shape)

    return pred_fmri, true_fmri

In [None]:
# test make_predictions()
test_pred_fmri, test_true_fmri = make_prediction(test_stories)

In [10]:
# identify top correlated voxels
def top_voxels(pred_stories, top_perc=1):
    """
    Select top-percentile voxels based on prediction correlation.

    Args:
        pred_stories (list of str): Stories to evaluate.
        top_perc (float): Top percentile of voxels to keep.

    Returns:
        dict: Mapping from story → subject → voxel indices.
    """
    result = {}

    def voxelwise_corr(y_pred, y_true):
        y_pred = y_pred.detach().cpu().numpy()

        corr = np.array([
            pearsonr(y_pred[:, v], y_true[:, v])[0]
            for v in range(y_true.shape[1])
        ])
        return corr
    
    for story in pred_stories:
        story_dict = {}
        test_pred_fmri, test_true_fmri = make_prediction([story])
        for subj in fmri_data.keys():
            corr = voxelwise_corr(test_pred_fmri[subj], test_true_fmri[subj])
            thresh = np.percentile(corr, 100 - top_perc)
            print(f"{story} ({subj}): {thresh}")
            indices = np.where(corr >= thresh)[0]
            story_dict[subj] = indices

        result[story] = story_dict
    return result

In [None]:
topv = top_voxels(test_stories[:2], top_perc=0.5)

In [16]:
# save top-voxels to pickle
with open('topv.pkl', 'wb') as f:
    pickle.dump(topv, f)

In [11]:
# load top-voxels from pickle
with open('topv.pkl', 'rb') as f:
    topv = pickle.load(f)

In [None]:
# prepare features and classifiers for interpretation
subj = "subject3" # "subject2" or "subject3"
base_model = BertModel.from_pretrained(model_name).to(device).eval()
if lora_model:
    base_model = get_peft_model(base_model, config).to(device)
    base_model.load_state_dict(lora_weights[subj])
    base_model.eval()

test_story = test_stories[1] # adjust index
test_features = forward_pass([test_story], base_model)
train_features = forward_pass(train_stories[:5], base_model)
selected_voxels = topv[test_story][subj]
classifier = classifiers[subj]

# wrap models for SHAP and LIME
def wrapped_shap_model(X_numpy):
    """
    SHAP-compatible model wrapper.

    Args:
        X_numpy (np.ndarray): Input features.

    Returns:
        np.ndarray: Model predictions for selected voxels.
    """
    X_tensor = torch.tensor(X_numpy, dtype=torch.float32).to(device)
    with torch.no_grad():
        output = classifier(X_tensor)
        return output[:, selected_voxels].cpu().numpy()

def wrapped_lime_model(X_numpy):
    """
    LIME-compatible model wrapper.

    Args:
        X_numpy (np.ndarray): Input features.

    Returns:
        np.ndarray: Model predictions for selected voxels.
    """
    batch_size = 128
    X_tensor = torch.tensor(X_numpy, dtype=torch.float32).to(device)
    
    preds = []
    with torch.no_grad():
        for i in range(0, X_tensor.size(0), batch_size):
            batch = X_tensor[i:i+batch_size].unsqueeze(1)
            output = classifier(batch)
            preds.append(output[:, 0, selected_voxels].cpu())
    return torch.cat(preds, dim=0).numpy()


In [13]:
# feature matrices for explanation
X_test = test_features.detach().cpu().numpy()
X_train = train_features.detach().cpu().numpy()
background = X_train.mean(axis=0, keepdims=True) # test value
#background_all = np.vstack([
#    train_features[story].detach().cpu().numpy() for story in train_stories[:5]
#])

In [None]:
# initialize SHAP KernelExplainer using a wrapped prediction model and background dataset
# the model should return outputs shaped (n_samples, n_voxels)
shap_explainer = shap.KernelExplainer(wrapped_shap_model, background)

# compute SHAP values for each sample in the test set
# output is a list of arrays, one per voxel: each with shape (n_samples, n_features)
shap_values = shap_explainer.shap_values(X_test)

In [14]:
# initialize LIME TabularExplainer with training background
# we use regression mode since the model returns continuous values
lime_explainer = LimeTabularExplainer(
    training_data = background,
    mode = "regression",
    feature_names = [i for i in range(X_test.shape[1])],
    discretize_continuous=False # keep features continuous instead of binning
)

In [None]:
def compute_lime(chunk_i):
    """
    Computes LIME explanations for a single instance in the test set across all output voxels.

    Parameters:
    -----------
    chunk_i : int
        Index of the test sample in X_test to explain.

    Returns:
    --------
    lime_values_chunk : np.ndarray
        A (num_features, num_voxels) matrix where each [feature, voxel] entry is the LIME weight
        for that feature when predicting that voxel for the given test sample.
    """
    lime_values_chunk = np.zeros((num_features, num_voxels))
    for voxel_i in range(num_voxels):
        def voxel_predict_fn(x):
            # define a wrapper to extract predictions for a specific voxel
            return wrapped_lime_model(x)[:, voxel_i]

        # run LIME for this voxel and test sample (top 10 most important features only)
        explanation = lime_explainer.explain_instance(data_row=X_test[chunk_i], predict_fn=voxel_predict_fn, num_features=10)

        # store weights for returned feature indices
        for feature, weight in explanation.as_list():
            lime_values_chunk[feature, voxel_i] = weight

    print(chunk_i, end=" ") # progress indicator
    return lime_values_chunk

# get shape information
num_chunks = X_test.shape[0]       # number of test samples
num_features = X_test.shape[1]     # number of input features
num_voxels = len(selected_voxels)  # number of model outputs (voxels)

# compute LIME values for each test sample
lime_values_list = [
    compute_lime(chunk_i)
    for chunk_i in range(num_chunks)
]

In [16]:
# stack results into a 3D array: (n_samples, n_features, n_voxels)
lime_values = np.zeros((num_chunks, num_features, num_voxels))
for i, chunk_i in enumerate(range(num_chunks)):
    lime_values[chunk_i, :, :] = lime_values_list[i]

In [17]:
# save LIME values to pickle
with open('lime.pkl', 'wb') as f:
    pickle.dump(lime_values, f)

In [17]:
# save SHAP values to pickle
with open('shap.pkl', 'wb') as f:
    pickle.dump(shap_values, f)

In [2]:
# load LIME values from pickle
with open('lime.pkl', 'rb') as f:
    lime_values = pickle.load(f)

In [23]:
# load SHAP values from pickle
with open('shap.pkl', 'rb') as f:
    shap_values = pickle.load(f)