In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
import pickle
import time
import random
import copy
import itertools

import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from transformers import BertTokenizerFast, BertModel
import matplotlib.pyplot as plt

sys.path.append('code')
sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab3-group6/code")

from BERT.data import TextDataset
from finetune_bert_utils import get_sliding_window_embeddings, aggregate_embeddings, downsample_word_vectors_torch, load_fmri_data, get_fmri_data

torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the base path for data access
data_path = '/ocean/projects/mth240012p/shared/data'

In [2]:
# %% Load preprocessed word sequences (likely includes words and their timings)
with open(f'{data_path}/raw_text.pkl', 'rb') as file:
    wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}

# %% Get list of story identifiers and split into training and testing sets
# Assumes story data for 'subject2' exists and filenames are story IDs + '.npy'
stories = [i[:-4] for i in os.listdir(f'{data_path}/subject2')] # Extract story IDs from filenames
# Split stories into train and test sets with a fixed random state for reproducibility


# First, use 60% for training and 40% for the remaining data.
train_stories, temp_stories = train_test_split(stories, train_size=0.6, random_state=214)
# Then split the remaining 40% equally to get 20% validation and 20% test.
val_stories, test_stories = train_test_split(temp_stories, train_size=0.5, random_state=214)

story_name_to_idx = {story: i for i, story in enumerate(stories)}

  wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}


In [3]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
base_model = BertModel.from_pretrained(model_name)

In [4]:
train_text = [" ".join(wordseqs[i].data).strip() for i in train_stories]
train_dataset = TextDataset(train_text, tokenizer, max_len=sys.maxsize)

In [5]:
trim_range = (5, -10)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
embeddings = {}

texts = []

for story in stories:
    words = wordseqs[story].data
    texts.append(" ".join(words).strip())
    tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
    token_per_word = [len(i) for i in tokens]
tokenlized_stories = tokenizer(texts, add_special_tokens=False, padding="longest", truncation=False, max_length=sys.maxsize,
                               return_token_type_ids=False, return_tensors="pt")
input_ids = tokenlized_stories["input_ids"].to(device)
attention_mask = tokenlized_stories["attention_mask"].to(device)

In [6]:
def forward_pass(current_stories):
    idx = [story_name_to_idx[story] for story in current_stories]
    embeddings = get_sliding_window_embeddings(base_model, input_ids[idx], attention_mask[idx])

    features = {}
    for i, story in enumerate(current_stories):
        words = wordseqs[story].data
        tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
        token_per_word = [len(i) for i in tokens]
        story_embeddings = embeddings[i]
        word_embeddings = []
        start = 0
        for i in token_per_word:
            end = start + i
            if i != 0:
                word_embedding = story_embeddings[start:end].mean(dim=0)
            else:
                word_embedding = torch.zeros(story_embeddings.size(1), device=device)
            word_embeddings.append(word_embedding)
            start = end
        
        features[story] = torch.stack(word_embeddings)#.cpu().numpy()

    features = downsample_word_vectors_torch(current_stories, features, wordseqs)
    for story in current_stories:
        features[story] = features[story][trim_range[0]:trim_range[1]]

    aggregated_features = aggregate_embeddings(features, current_stories)
    return aggregated_features

In [7]:
fmri_data = load_fmri_data(stories, data_path)

In [10]:
weight_decay = 1e-2
classifiers = torch.load(f'/ocean/projects/mth240012p/azhang19/lab3/classifier_ckpts/best_classifiers{weight_decay}.pth', weights_only=False)

In [11]:
with torch.inference_mode():
    features = forward_pass(train_stories)
    pred_fmri = {}
    true_fmri = {}
    for subj in fmri_data.keys():
        pred_fmri[subj] = classifiers[subj](features)
        true_fmri[subj] = get_fmri_data(train_stories, fmri_data)[subj]
    print(pred_fmri['subject2'].shape)
    print(pred_fmri['subject3'].shape)
    print(true_fmri['subject2'].shape)
    print(true_fmri['subject3'].shape)


torch.Size([20263, 94251])
torch.Size([20263, 95556])
(20263, 94251)
(20263, 95556)
