In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
import pickle
import time
import random
import copy
import itertools

import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from transformers import BertTokenizerFast, BertModel
import matplotlib.pyplot as plt

sys.path.append('code')
sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab3-group6/code")

from BERT.data import TextDataset
from finetune_bert_utils import get_sliding_window_embeddings, aggregate_embeddings, downsample_word_vectors_torch, load_fmri_data, get_fmri_data

torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the base path for data access
data_path = '/ocean/projects/mth240012p/shared/data'

In [2]:
# %% Load preprocessed word sequences (likely includes words and their timings)
with open(f'{data_path}/raw_text.pkl', 'rb') as file:
    wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}

# %% Get list of story identifiers and split into training and testing sets
# Assumes story data for 'subject2' exists and filenames are story IDs + '.npy'
stories = [i[:-4] for i in os.listdir(f'{data_path}/subject2')] # Extract story IDs from filenames
# Split stories into train and test sets with a fixed random state for reproducibility


# First, use 60% for training and 40% for the remaining data.
train_stories, temp_stories = train_test_split(stories, train_size=0.6, random_state=214)
# Then split the remaining 40% equally to get 20% validation and 20% test.
val_stories, test_stories = train_test_split(temp_stories, train_size=0.5, random_state=214)

story_name_to_idx = {story: i for i, story in enumerate(stories)}

  wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}


In [3]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
base_model = BertModel.from_pretrained(model_name)

In [4]:
train_text = [" ".join(wordseqs[i].data).strip() for i in train_stories]
train_dataset = TextDataset(train_text, tokenizer, max_len=sys.maxsize)

In [5]:
trim_range = (5, -10)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
embeddings = {}

texts = []

for story in stories:
    words = wordseqs[story].data
    texts.append(" ".join(words).strip())
    tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
    token_per_word = [len(i) for i in tokens]
tokenlized_stories = tokenizer(texts, add_special_tokens=False, padding="longest", truncation=False, max_length=sys.maxsize,
                               return_token_type_ids=False, return_tensors="pt")
input_ids = tokenlized_stories["input_ids"].to(device)
attention_mask = tokenlized_stories["attention_mask"].to(device)

In [6]:
def forward_pass(current_stories):
    idx = [story_name_to_idx[story] for story in current_stories]
    embeddings = get_sliding_window_embeddings(base_model, input_ids[idx], attention_mask[idx])

    features = {}
    for i, story in enumerate(current_stories):
        words = wordseqs[story].data
        tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
        token_per_word = [len(i) for i in tokens]
        story_embeddings = embeddings[i]
        word_embeddings = []
        start = 0
        for i in token_per_word:
            end = start + i
            if i != 0:
                word_embedding = story_embeddings[start:end].mean(dim=0)
            else:
                word_embedding = torch.zeros(story_embeddings.size(1), device=device)
            word_embeddings.append(word_embedding)
            start = end
        
        features[story] = torch.stack(word_embeddings)#.cpu().numpy()

    features = downsample_word_vectors_torch(current_stories, features, wordseqs)
    for story in current_stories:
        features[story] = features[story][trim_range[0]:trim_range[1]]

    aggregated_features = aggregate_embeddings(features, current_stories)
    return aggregated_features

In [7]:
fmri_data = load_fmri_data(stories, data_path)

In [8]:
def get_loss(classifiers, sample_stories):
    with torch.inference_mode():
        features = forward_pass(sample_stories)
    features = features.clone()
    pred_fmri = {}
    loss = []
    current_fmri = get_fmri_data(sample_stories, fmri_data)
    for subj in fmri_data.keys():
        pred_fmri[subj] = classifiers[subj](features)
        obj = torch.from_numpy(current_fmri[subj]).float().to(device)
        # Handle NaN values in obj
        obj = torch.nan_to_num(obj, nan=0.0)
        loss.append(nn.functional.mse_loss(pred_fmri[subj], obj))
    return loss

In [9]:
def train_step(classifiers, sample_stories):
    with torch.inference_mode():
        features = forward_pass(sample_stories)
    features = features.clone()
    loss = get_loss(classifiers, sample_stories)
    loss_for_backprop = loss[0] + loss[1]

    optim.zero_grad(set_to_none=True)
    loss_for_backprop.backward()
    optim.step()
    return loss_for_backprop.item(), loss[0].item(), loss[1].item()

In [11]:
epochs = 100
loss_record = np.zeros((epochs, 2, 2))
best_loss = np.zeros(2) + 1e9
best_classifiers = {'subject2': None, 'subject3': None}

#sample_stories = random.sample(stories, 3)
sample_stories = train_stories
#features = forward_pass(sample_stories)

In [20]:
weight_decay = 1e-2
classifiers = {'subject2': nn.Linear(768, 94251, device=device), 'subject3': nn.Linear(768, 95556, device=device)}
optim = torch.optim.AdamW(itertools.chain(*[i.parameters() for i in classifiers.values()]), lr=2e-3, weight_decay=weight_decay, fused=True)

In [12]:
for epoch in range(epochs):
    _, loss_subject2_train, loss_subject3_train = train_step(classifiers, sample_stories)
    with torch.no_grad():
        loss_subject2_val, loss_subject3_val = get_loss(classifiers, val_stories)
        loss_subject2_val, loss_subject3_val = loss_subject2_val.item(), loss_subject3_val.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss_subject2_train:.4f}, {loss_subject3_train:.4f}, Val Loss: {loss_subject2_val:.4f}, {loss_subject3_val:.4f}")
    loss_record[epoch, 0, 0] = loss_subject2_train
    loss_record[epoch, 0, 1] = loss_subject2_val
    loss_record[epoch, 1, 0] = loss_subject3_train
    loss_record[epoch, 1, 1] = loss_subject3_val
    if loss_subject2_val < best_loss[0]:
        print(f"New best validation loss for subject2: {loss_subject2_val:.4f}")
        best_loss[0] = loss_subject2_val
        best_classifiers['subject2'] = copy.deepcopy(classifiers['subject2'])
    if loss_subject3_val < best_loss[1]:
        print(f"New best validation loss for subject3: {loss_subject3_val:.4f}")
        best_loss[1] = loss_subject3_val
        best_classifiers['subject3'] = copy.deepcopy(classifiers['subject3'])

Epoch 1/100, Loss: 2.9351, 2.9298, Val Loss: 3.0720, 3.0702
New best validation loss for subject2: 3.0720
New best validation loss for subject3: 3.0702
Epoch 2/100, Loss: 3.1365, 3.1355, Val Loss: 2.3179, 2.3153
New best validation loss for subject2: 2.3179
New best validation loss for subject3: 2.3153
Epoch 3/100, Loss: 2.3588, 2.3569, Val Loss: 2.2215, 2.2183
New best validation loss for subject2: 2.2215
New best validation loss for subject3: 2.2183
Epoch 4/100, Loss: 2.2589, 2.2562, Val Loss: 2.3439, 2.3405
Epoch 5/100, Loss: 2.3842, 2.3811, Val Loss: 2.0447, 2.0420
New best validation loss for subject2: 2.0447
New best validation loss for subject3: 2.0420
Epoch 6/100, Loss: 2.0787, 2.0762, Val Loss: 1.7070, 1.7059
New best validation loss for subject2: 1.7070
New best validation loss for subject3: 1.7059
Epoch 7/100, Loss: 1.7325, 1.7314, Val Loss: 1.6964, 1.6963
New best validation loss for subject2: 1.6964
New best validation loss for subject3: 1.6963
Epoch 8/100, Loss: 1.7186, 1

KeyboardInterrupt: 

In [22]:
torch.save(best_classifiers, f'/ocean/projects/mth240012p/azhang19/lab3/classifier_ckpts/best_classifiers{weight_decay}.pth')