In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
import pickle
import time
import random
import copy
import itertools

import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from transformers import BertTokenizerFast, BertModel
from peft import LoraConfig, get_peft_model, TaskType
import matplotlib.pyplot as plt

sys.path.append('code')
sys.path.append("/jet/home/azhang19/stat 214/stat-214-lab3-group6/code")

from BERT.data import TextDataset
from finetune_bert_utils import get_sliding_window_embeddings, aggregate_embeddings, downsample_word_vectors_torch, load_fmri_data, get_fmri_data

torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the base path for data access
data_path = '/ocean/projects/mth240012p/shared/data'

In [2]:
# %% Load preprocessed word sequences (likely includes words and their timings)
with open(f'{data_path}/raw_text.pkl', 'rb') as file:
    wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}

# %% Get list of story identifiers and split into training and testing sets
# Assumes story data for 'subject2' exists and filenames are story IDs + '.npy'
stories = [i[:-4] for i in os.listdir(f'{data_path}/subject2')] # Extract story IDs from filenames
# Split stories into train and test sets with a fixed random state for reproducibility


# First, use 60% for training and 40% for the remaining data.
train_stories, temp_stories = train_test_split(stories, train_size=0.6, random_state=214)
# Then split the remaining 40% equally to get 20% validation and 20% test.
val_stories, test_stories = train_test_split(temp_stories, train_size=0.5, random_state=214)

story_name_to_idx = {story: i for i, story in enumerate(stories)}

  wordseqs = pickle.load(file) # wordseqs is expected to be a dictionary: {story_id: WordSequenceObject}


In [3]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
original_base_model = BertModel.from_pretrained(model_name)

In [4]:
lora_rank = 8
lora_alpha = lora_rank * 2
lora_dropout = 0.1

target_modules_bert = [
    "query", "value",
    # "key",
    # "dense"
]

config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    target_modules=target_modules_bert,
    lora_dropout=lora_dropout,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION
)
base_model = get_peft_model(original_base_model, config).to(device)

In [5]:
train_text = [" ".join(wordseqs[i].data).strip() for i in train_stories]
train_dataset = TextDataset(train_text, tokenizer, max_len=sys.maxsize)

In [6]:
trim_range = (5, -10)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

texts = []

for story in stories:
    words = wordseqs[story].data
    texts.append(" ".join(words).strip())
    tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
    token_per_word = [len(i) for i in tokens]
tokenlized_stories = tokenizer(texts, add_special_tokens=False, padding="longest", truncation=False, max_length=sys.maxsize,
                               return_token_type_ids=False, return_tensors="pt")
input_ids = tokenlized_stories["input_ids"].to(device)
attention_mask = tokenlized_stories["attention_mask"].to(device)

In [7]:
def forward_pass(current_stories):
    idx = [story_name_to_idx[story] for story in current_stories]
    embeddings = get_sliding_window_embeddings(base_model, input_ids[idx], attention_mask[idx])

    features = {}
    for i, story in enumerate(current_stories):
        words = wordseqs[story].data
        tokens = tokenizer(words, add_special_tokens=False, truncation=False, max_length=sys.maxsize)['input_ids']
        token_per_word = [len(i) for i in tokens]
        story_embeddings = embeddings[i]
        word_embeddings = []
        start = 0
        for i in token_per_word:
            end = start + i
            if i != 0:
                word_embedding = story_embeddings[start:end].mean(dim=0)
            else:
                word_embedding = torch.zeros(story_embeddings.size(1), device=device)
            word_embeddings.append(word_embedding)
            start = end
        
        features[story] = torch.stack(word_embeddings)#.cpu().numpy()

    features = downsample_word_vectors_torch(current_stories, features, wordseqs)
    for story in current_stories:
        features[story] = features[story][trim_range[0]:trim_range[1]]

    aggregated_features = aggregate_embeddings(features, current_stories)
    return aggregated_features

In [8]:
fmri_data = load_fmri_data(stories, data_path)

In [9]:
# No matter training or validation
def get_loss(classifiers, sample_stories):
    features = forward_pass(sample_stories)
    pred_fmri = {}
    loss = []
    current_fmri = get_fmri_data(sample_stories, fmri_data)
    for subj in fmri_data.keys():
        pred_fmri[subj] = classifiers[subj](features)
        obj = torch.from_numpy(current_fmri[subj]).float().to(device)
        # Handle NaN values in obj
        obj = torch.nan_to_num(obj, nan=0.0)
        loss.append(nn.functional.mse_loss(pred_fmri[subj], obj))
    return loss

In [10]:
def train_step(classifiers, sample_stories):
    base_model.train()
    loss = get_loss(classifiers, sample_stories)
    loss_for_backprop = loss[0] + loss[1]

    optim.zero_grad(set_to_none=True)
    loss_for_backprop.backward()
    optim.step()
    return loss_for_backprop.item(), loss[0].item(), loss[1].item()

In [11]:
epochs = 2
loss_record = np.zeros((epochs, 2, 2)) # [epoch, subject_idx (0 for S2, 1 for S3), metric (0 for train, 1 for val)]

best_val_loss_subject2 = float('inf')
best_val_loss_subject3 = float('inf')

best_classifier_s2 = None
best_classifier_s3 = None
best_lora_weights_for_s2 = None # LoRA weights when S2 validation was best
best_lora_weights_for_s3 = None # LoRA weights when S3 validation was best


#sample_stories = train_stories # Using all train stories per epoch
def minibatch_iterator(story_list, batch_size):
    stories_to_process = random.sample(story_list, len(story_list))

    num_stories = len(stories_to_process)
    for i in range(0, num_stories, batch_size):
        yield stories_to_process[i : min(i + batch_size, num_stories)]

In [12]:
weight_decay = 1e-2
#classifiers = {'subject2': nn.Linear(768, 94251, device=device), 'subject3': nn.Linear(768, 95556, device=device)}
classifiers = torch.load(f'/ocean/projects/mth240012p/azhang19/lab3/classifier_ckpts/best_classifiers{weight_decay}.pth', weights_only=False)
params_to_optimize = itertools.chain(
    *[i.parameters() for i in classifiers.values()],
    base_model.parameters()
)
optim = torch.optim.AdamW(params_to_optimize, lr=2e-3 / (len(train_stories) / 15), weight_decay=weight_decay, fused=True)

In [13]:
for epoch in range(epochs):
    loss_subject2_train = 0
    loss_subject3_train = 0
    for sampled_stories in minibatch_iterator(train_stories, 15):
        # train_step already sets base_model_peft.train()
        _, loss_subject2_train_batch, loss_subject3_train_batch = train_step(classifiers, sampled_stories)
        loss_subject2_train += loss_subject2_train_batch
        loss_subject3_train += loss_subject3_train_batch
    loss_subject2_train /= (len(train_stories) / 15)
    loss_subject3_train /= (len(train_stories) / 15)
        
    with torch.no_grad():
        base_model.eval()
        val_losses = get_loss(classifiers, val_stories)
        current_loss_subject2_val = val_losses[0].item()
        current_loss_subject3_val = val_losses[1].item()
    
    print(f"Epoch {epoch+1}/{epochs}, Train S2: {loss_subject2_train:.4f}, Train S3: {loss_subject3_train:.4f}, Val S2: {current_loss_subject2_val:.4f}, Val S3: {current_loss_subject3_val:.4f}")
    
    loss_record[epoch, 0, 0] = loss_subject2_train
    loss_record[epoch, 0, 1] = current_loss_subject2_val
    loss_record[epoch, 1, 0] = loss_subject3_train
    loss_record[epoch, 1, 1] = current_loss_subject3_val

    # Check for Subject 2
    if current_loss_subject2_val < best_val_loss_subject2:
        best_val_loss_subject2 = current_loss_subject2_val
        best_classifier_s2 = copy.deepcopy(classifiers['subject2'])
        # Save the current LoRA weights that led to this best S2 val loss
        best_lora_weights_for_s2 = copy.deepcopy(base_model.state_dict())
        print(f"  New best Val Loss for Subject 2: {best_val_loss_subject2:.4f}. Saved S2 classifier and current LoRA weights.")

    # Check for Subject 3
    if current_loss_subject3_val < best_val_loss_subject3:
        best_val_loss_subject3 = current_loss_subject3_val
        best_classifier_s3 = copy.deepcopy(classifiers['subject3'])
        # Save the current LoRA weights that led to this best S3 val loss
        best_lora_weights_for_s3 = copy.deepcopy(base_model.state_dict())
        print(f"  New best Val Loss for Subject 3: {best_val_loss_subject3:.4f}. Saved S3 classifier and current LoRA weights.")

Epoch 1/2, Train S2: 0.9946, Train S3: 0.9925, Val S2: 1.0494, Val S3: 1.0488
  New best Val Loss for Subject 2: 1.0494. Saved S2 classifier and current LoRA weights.
  New best Val Loss for Subject 3: 1.0488. Saved S3 classifier and current LoRA weights.
Epoch 2/2, Train S2: 0.9715, Train S3: 0.9691, Val S2: 1.0066, Val S3: 1.0045
  New best Val Loss for Subject 2: 1.0066. Saved S2 classifier and current LoRA weights.
  New best Val Loss for Subject 3: 1.0045. Saved S3 classifier and current LoRA weights.


In [14]:
save_obj_s2 = {
    'classifier_module': best_classifier_s2,
    'lora_state_dict': best_lora_weights_for_s2,
    'lora_config_params': { # Save LoRA config parameters for easier reloading
        'r': config.r,
        'lora_alpha': config.lora_alpha,
        'target_modules': config.target_modules,
        'lora_dropout': config.lora_dropout,
        'bias': config.bias,
        'task_type': str(config.task_type)
    },
    'base_model_name': model_name
}

save_obj_s3 = {
    'classifier_module': best_classifier_s3,
    'lora_state_dict': best_lora_weights_for_s3,
    'val_loss': best_val_loss_subject3,
    'lora_config_params': {
        'r': config.r,
        'lora_alpha': config.lora_alpha,
        'target_modules': config.target_modules,
        'lora_dropout': config.lora_dropout,
        'bias': config.bias,
        'task_type': str(config.task_type)
    },
    'base_model_name': model_name
}
filename = f'/ocean/projects/mth240012p/azhang19/lab3/classifier_ckpts/best_lora_wd{weight_decay}_r{lora_rank}.pth'
torch.save({'subject2': save_obj_s2, 'subject3': save_obj_s3}, filename)