# BrainLM Tutorial Notebook
BrainLM is a foundation model for brain activity recordings. The GitHub repo can be found here: https://github.com/vandijklab/BrainLM

**Using this tutorial notebook, you will learn how to apply BrainLM to your own fMRI data!**

**Capabilities include:**
> - **Easy-to-use Preprocessing functions** to format your fMRI dataset for BrainLM input
> - **Finetuned BrainLM models** for clinical metadata prediction (Age, Anxiety, PTSD, Neuroticism)
> - **Pretrained BrainLM models** for Zero-shot (clinical) Regression and downstream finetuning
> - "Putting your dataset in perspective" via joint-embedding with 40k+ UKBiobank subjects



# 0. Before You Begin 
- Your fMRI data must be minimally preprocessed using standard fMRI preprocessing procedures. Here are the preprocessing [scripts](https://www.fmrib.ox.ac.uk/ukbiobank/fbp/) and [documentation](https://biobank.ctsu.ox.ac.uk/crystal/crystal/docs/brain_mri.pdf) for the UKBiobank Imaging Pipeline.
- **Conda Environment:** Make sure you're using the correct conda environment. 
>- For now: `/vast/palmer/home.mccleary/sr2464/.conda/envs/cell_lm_flash_attn`
>-  Eventually: The default BrainLM conda environment on GitHub
>-  Before camera ready: we need to merge these. Eg nilearn and other stuff

In [1]:
#Importing Libraries
import os
import math
from random import randint, seed
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle
import argparse

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_from_disk, concatenate_datasets
from transformers import ViTImageProcessor, ViTMAEConfig
# from brainlm_mae.modeling_brainlm import BrainLMForPretraining
# from utils.brainlm_trainer import BrainLMTrainer
# from brainlm_mae.configuration_brainlm import BrainLMConfig
# from brainlm_mae.brainlm_finetuning_mlp_pred_head import BrainLMForFinetuning
sys.path.append("../")
from brainlm_mae.vit_image_finetuning_mlp_pred_head import ViTMAEForFinetuning
from brainlm_mae.vit_image_finetune_config import ViTMAEFinetuneConfig
from brainlm_mae.modeling_vit_mae_with_padding import ViTMAEForPreTraining 
from brainlm_mae.replace_vitmae_attn_with_flash_attn import replace_vitmae_attn_with_flash_attn
from utils.brainlm_trainer_log_wandb_only import BrainLMTrainer
from BrainLM_Toolkit import convert_fMRIvols_to_A424, convert_fMRI_dat_files_to_arrow_dataset
from trial import *

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'brainlm_mae.vit_image_finetuning_mlp_pred_head'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed(42)

# 1. Convert data to Arrow Format

In [3]:
args = {
    "uk_biobank_dir": "/home/jo548/palmer_scratch/EMERGE/matrices",     # "Path to directory containing dat files, A424 coordinates file, and A424 excel sheet.",
    "arrow_dataset_save_directory": "/home/mt2286/palmer_scratch/test",     # "The directory where you want to save the output arrow datasets."
    "dataset_name": "EMERGE",
}
convert_to_arrow_datasets(args, args["arrow_dataset_save_directory"])


UKBioBank FMRI Data Arrow Conversion Starting...
There's no A24 Coordinates dat file


Getting normalization stats: 100%|██████████| 1513/1513 [00:34<00:00, 43.92it/s]
Normalizing Data:   0%|          | 0/1513 [00:00<?, ?it/s]


Print data array:  (645, 424)


ValueError: operands could not be broadcast together with shapes (645,424) (424,1) 

# 3. Clinical Variable Prediction using Finetuned BrainLM models

**Current Capabilities:** Age, PTSD (PCL-5), Anxiety (GAD-7), Neuroticism

- ask Syed/Antonio for which notebooks are most up. todate.
- my best guess is inference_01_finetuned_model_ViT_image.ipynb on branch antonio_dev_syed_working

In [4]:
#Loading the model
modality = "Age.At.MHQ" # Choose from ["Age.At.MHQ", "PHQ9.Severity", "PCL.Score", "GAD7.Severity", "Neuroticism", "Depressed.At.Baseline", "Self.Harm.Ever", "Not.Worth.Living", "Gender"]
ft_model = f"/gpfs/gibbs/pi/dijk/BrainLM_mihir_files/BrainLM_fine_{modality}_mihir_run/111M/"
ft_modelname="finetune_age_FromScratch_Nov24th_1225_111M"
# config = ViTMAEFinetuneConfig.from_pretrained(ft_model)

In [5]:
config = ViTMAEFinetuneConfig.from_pretrained("vandijklab/brainlm_finetuned", subfolder="age")
model = ViTMAEForFinetuning.from_pretrained("vandijklab/brainlm_finetuned", config=config, subfolder="age")

In [7]:
#Loading the dataset
# coords_ds = load_from_disk("/home/mt2286/project/BrainLM-dev/BrainLM_Toolkit/output/Arrow_Datasets2/Test_data_arrow_norm/Brain_Region_Coordinates")
# train_ds = load_from_disk("/home/mt2286/project/BrainLM-dev/BrainLM_Toolkit/output/Arrow_Datasets2/Test_data_arrow_norm/train")
# hcp_ds = load_from_disk("/gpfs/gibbs/pi/dijk/HCP_Arrow_WithRegression/test_hcp")
coords_ds = load_from_disk("/home/mt2286/palmer_scratch/EMBARC_prep/Brain_Region_Coordinates")
# train_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/val_ukbiobank")
# val_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/val_ukbiobank")
train_ds = load_from_disk("/home/mt2286/palmer_scratch/EMBARC_prep/train_ukbiobank1000")
dataset_v = "embarc"
split = "train"

In [8]:
task_to_i = {"rs": 0, "tf": 1}
task_type = pd.Series(train_ds['Filename']).str.split("_", expand=True)[:][1].map(task_to_i)
train_ds = train_ds.add_column(name="Task_Type", column=task_type)

# task_type = pd.Series(val_ds['Filename']).str.split("_", expand=True)[:][1].map(task_to_i)
# val_ds = val_ds.add_column(name="Task_Type", column=task_type)

# task_type = pd.Series(test_ds['Filename']).str.split("_", expand=True)[:][1].map(task_to_i)
# test_ds = test_ds.add_column(name="Task_Type", column=task_type)
concat_ds = train_ds
# concat_ds = concatenate_datasets([train_ds, val_ds, test_ds])
index = pd.Series(np.arange(concat_ds.num_rows))
concat_ds = concat_ds.add_column(name="Index", column=index)
# index = pd.Series(np.arange(test_ds.num_rows))
# test_ds = test_ds.add_column(name="Index", column=index)

In [9]:
#Saving directory
sav_dir = "/gpfs/gibbs/pi/dijk/embarc_inf"
if not os.path.exists(sav_dir):
    os.makedirs(sav_dir)

# dataset_split = {"train": train_ds, "val": val_ds, "test": test_ds, "concat": concat_ds}
dataset_split = {"train": concat_ds}
ds_used = dataset_split[split]

In [75]:
#No need to run this for inference.

variable_of_interest_col_name = modality
# recording_col_name = "Voxelwise_RobustScaler_Normalized_Recording"
length = 200

# Processing for metadata variable
full_label_list = ds_used[variable_of_interest_col_name]
non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]

non_nan_ds = ds_used.select(non_nan_indices)  # select samples which have non-nan values for metadata variable
non_nan_ds = non_nan_ds.shuffle(seed=42)  # shuffle reproducibly to remove any ordering samples may have had
print(f"After selecting non-nan metadata, have {non_nan_ds.num_rows} samples to finetune on.\n\n")
labels_nonnan = non_nan_ds[variable_of_interest_col_name]  # get labels

#Normalize variables: log(variable + 1) / max_val
if variable_of_interest_col_name in ["PCL.Score", "GAD7.Severity"]:
    labels_nonnan_log1p = np.log(np.array(labels_nonnan, dtype=np.float32) + 1.0)  # log base_e(value + 1)
    labels_nonnan_log1p_divmax = np.divide(labels_nonnan_log1p, labels_nonnan_log1p.max())  # bring into range [0, 1]
    labels_normalized = labels_nonnan_log1p_divmax.tolist()
    pos_weight = None
elif variable_of_interest_col_name == "Age.At.MHQ":
    z_score_transform = StandardScaler()
    labels_normalized_np = z_score_transform.fit_transform(np.expand_dims(np.array(labels_nonnan), axis=1))
    labels_normalized_np = np.squeeze(labels_normalized_np, axis=1)
    labels_normalized = labels_normalized_np.tolist()
    np.save('mean.npy', z_score_transform.mean_)
    np.save('std.npy', z_score_transform.scale_)
    pos_weight = None
elif variable_of_interest_col_name == "Neuroticism":
    labels_nonnan_log1p = np.array(labels_nonnan, dtype=np.float32)  # if is Neuroticism, then no log transform, distribution is already good
    labels_nonnan_log1p_divmax = np.divide(labels_nonnan_log1p, labels_nonnan_log1p.max())  # bring into range [0, 1]
    labels_normalized = labels_nonnan_log1p_divmax.tolist()
    pos_weight = None
elif variable_of_interest_col_name == "PHQ9.Severity":
    labels_normalized = [1 if num > 4.0 else 0 for num in labels_nonnan]
    sum_ones = sum(labels_normalized)
    sum_zeros = len(labels_normalized) - sum_ones
    # pos_weight = sum_zeros / sum_ones
    pos_weight = None
elif variable_of_interest_col_name in ["Depressed.At.Baseline", "Self.Harm.Ever", "Not.Worth.Living"]:
    labels_normalized = labels_nonnan
    sum_ones = sum(labels_normalized)
    sum_zeros = len(labels_normalized) - sum_ones
    # pos_weight = sum_zeros / sum_ones
    pos_weight = None
elif variable_of_interest_col_name == "Gender":
    labels_normalized = labels_nonnan
    pos_weight = pos_weight = torch.tensor([2027 / 1687], dtype=torch.float32)
else:
    raise NotImplementedError("Unknown variable of interest specified.")

# Replace labels in concat_ds with normalized labels
non_nan_ds = non_nan_ds.remove_columns(variable_of_interest_col_name)
non_nan_ds = non_nan_ds.add_column(name=variable_of_interest_col_name, column=labels_nonnan)
ds_used = non_nan_ds

After selecting non-nan metadata, have 3728 samples to finetune on.




Flattening the indices:   0%|          | 0/3728 [00:00<?, ? examples/s]

Flattening the indices: 100%|██████████| 3728/3728 [04:23<00:00, 14.15 examples/s]


In [10]:
ds_used


Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Task_Type', 'Index'],
    num_rows: 1865
})

In [11]:
subset = ds_used.select(range(10))

# Now, extract the 'All_Patient_All_Voxel_Normalized_Recording' column from this subset
first_five_instances = subset['Voxelwise_RobustScaler_Normalized_Recording']
print(len(first_five_instances))
print(len(first_five_instances[0]))
print(len(first_five_instances[1][1]))

# labels = subset["Age.At.MHQ"]
# print(labels)

10
424
350


In [14]:
subset = embarc_ds.select(range(5))

# Now, extract the 'All_Patient_All_Voxel_Normalized_Recording' column from this subset
first_five_instances = subset['All_Patient_All_Voxel_Normalized_Recording']
print(len(first_five_instances))
print(len(first_five_instances[0]))
print(len(first_five_instances[1][1]))


5
424
350
True


In [12]:
fmri_recs = ds_used['Voxelwise_RobustScaler_Normalized_Recording']
data_indices = [idx for idx in tqdm(range(len(ds_used)), desc="Filtering indices") if len(fmri_recs[idx][0]) >= 200]
embarc_ds = ds_used.select(data_indices)  
print(embarc_ds)

Filtering indices: 100%|██████████| 1865/1865 [00:00<00:00, 988547.57it/s]

Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Task_Type', 'Index'],
    num_rows: 1819
})





In [16]:
#Preprocessing fMRI images
voxel_x_coords_list = coords_ds["Y"]
reorder_idxs_by_x_coord = sorted(range(len(voxel_x_coords_list)), key=lambda k: voxel_x_coords_list[k])
reorder_idxs_by_x_coord = np.array(reorder_idxs_by_x_coord)
max_val_to_scale = 5.6430855
num_timepoints_per_voxel = 200 #model.config.num_timepoints_per_voxel
image_column_name = "Voxelwise_RobustScaler_Normalized_Recording"
# image_column_name = "Voxelwise_RobustScaler_Normalized_Recording"

def preprocess_images(examples):
    """Preprocess a batch of images by applying transforms."""
    fmri_images_list = []
    # label_list = []
    # print('examples.keys(): ',examples.keys())
    # print('len(examples[image_column_name][0]): ',len(examples[image_column_name][0]))
    # print('examples[variable_of_interest_col_name][0]: ',examples[variable_of_interest_col_name][0])
    for idx in range(len(examples[image_column_name])):
        signal_window = torch.tensor(examples[image_column_name][idx], dtype=torch.float32).t()
        # label = examples[variable_of_interest_col_name][0] # Original 
        # label = examples[variable_of_interest_col_name][idx]
        # label = torch.tensor(label, dtype=torch.float32)
        # print('[in preprocess_images] label: ',label)
        # print(aaa)

        # Choose random starting index, take window of moving_window_len points for each region
        start_idx = randint(0, signal_window.shape[0] - num_timepoints_per_voxel)
        end_idx = start_idx + num_timepoints_per_voxel
        signal_window = signal_window[start_idx: end_idx, :]
        signal_window = torch.movedim(signal_window, 0, 1)  # --> [num_voxels, moving_window_len]

        # reorder voxels according to x-coordinate
        signal_window = signal_window[reorder_idxs_by_x_coord, :]
        signal_window = signal_window / max_val_to_scale

        # Repeat tensor for 3 channels (R,G,B)
        signal_window = signal_window.unsqueeze(0).repeat(3, 1, 1)

        fmri_images_list.append(signal_window)
        # label_list.append(label)
    
    # examples["pixel_values"] = [transforms(image) for image in fmri_images_list]
    examples["pixel_values"] = fmri_images_list  # No transformation or resizing; model will do padding
    # examples["label"] = label_list
    return examples

# def collate_fn(examples):
#     pixel_values = torch.stack([example["pixel_values"] for example in examples])
#     labels = torch.stack([example["label"] for example in examples])
#     # print('labels: ',labels)
#     # print(aaa)
#     # labels = torch.tensor([1 for _ in range(len(pixel_values))])
    
#     return {
#         "pixel_values": pixel_values,
#         "input_ids": pixel_values,
#         "labels": labels
#     }
embarc_ds.set_transform(preprocess_images)

In [14]:
#Only if modality is Age
def reverse_z_score(normalized_data, mean_path='mean.npy', std_path='std.npy'):
    # Load saved mean and standard deviation
    mean = np.load("/home/mt2286/project/BrainLM-dev/BrainLM_Toolkit/mean.npy")
    std = np.load("/home/mt2286/project/BrainLM-dev/BrainLM_Toolkit/std.npy")

    if isinstance(normalized_data, list):
        normalized_data_np = np.array(normalized_data)
    else:
        normalized_data_np = normalized_data

    # Reshape if the input is a 1D array
    if normalized_data_np.ndim == 1:
        normalized_data_np = normalized_data_np.reshape(-1, 1)
    
    # Manually reverse the z-score normalization
    reversed_np = (normalized_data_np * std) + mean

    # If the output is expected to be 1D, squeeze the array
    if reversed_np.shape[1] == 1:
        reversed_np = np.squeeze(reversed_np, axis=1)
    
    reversed_data = reversed_np.tolist()
    
    return reversed_data


In [19]:
if model.config.patch_size == 16:
    resize_target_size = 432
elif model.config.patch_size == 14:
    resize_target_size = 434
else:
    raise RuntimeError("New patch size encountered, check init() of VitMAEForPreTraining")

res_arr = []
i =0
model = model.to(device)
with torch.no_grad():  # No gradient tracking for inference
    for batch in tqdm(embarc_ds, desc="Running inference"):
        pixel_values = batch["pixel_values"].unsqueeze(0)
        batch_size, channels, height, width = pixel_values.shape
        
        height_pad_total = resize_target_size - height
        height_pad_total_half = height_pad_total // 2

        width_pad_total = resize_target_size - width
        width_pad_total_half = width_pad_total // 2

        # Padding for the whole batch
        pixel_values_padded = F.pad(pixel_values, (width_pad_total_half, width_pad_total_half, height_pad_total_half, height_pad_total_half), "constant", -1)
        
        encoder_output = model(
            pixel_values_padded.to(device),
            output_attentions=False,
            output_hidden_states=True,
            is_train=False
        )
        
        reversed_data = [reverse_z_score(output.cpu())[0] for output in encoder_output]
        res_arr.extend(reversed_data)
        if i == 10:
            break
        i += 1

Running inference:   0%|          | 0/1819 [00:00<?, ?it/s]

Running inference:   1%|          | 10/1819 [00:06<20:10,  1.49it/s]


In [36]:
res_arr[:10]

[58.777088008524586,
 67.4934525273299,
 55.103782586868576,
 67.39009272091751,
 59.04440160839704,
 70.8334976345175,
 54.453155665135796,
 61.39124787648072,
 72.17388075858094,
 66.18419450638217]

In [20]:
res_arr[:10]

[70.69828823933477,
 70.59515816939594,
 70.58661955696951,
 70.59520931993517,
 70.56480032436855,
 70.64159227613803,
 70.49952870869973,
 70.66545047506102,
 70.64024825162471,
 70.67109467249226]

In [59]:
preds_name = os.path.join(sav_dir, f'{modality}_pred.npy')
print("Saving inference results to: ", preds_name)
np.save(preds_name, res_arr)

# 4. Zero-shot Clinical Regression using BrainLM mean-pool Tokens

In [23]:
replace_vitmae_attn_with_flash_attn()  # Flash Attention
seed(42)
dataset_v = "embarc"
split = "test"

In [24]:
# Load model and specify params
params = "650M" #Choose between 650M and 111M
model_name = f"brainlm_vitmae_{params}_100pc_data_mr75"
model_path = f"/gpfs/gibbs/pi/dijk/BrainLM_runs/huggingface_best_models/{model_name}"
config = ViTMAEConfig.from_pretrained(model_path)
config.update({
    "mask_ratio": 0.75,
    "timepoint_patching_size": 20,
    "num_timepoints_per_voxel": 200,
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "output_attentions": True,
})

model = ViTMAEForPreTraining.from_pretrained(
        model_path,
        from_tf=bool(".ckpt" in model_path),
        config=config,
        # cache_dir=model_args.cache_dir,
        # revision=model_args.model_revision,
        # use_auth_token=True if model_args.use_auth_token else None,
    ).to(device)

model = model.half()
model.eval()
# print(model.dtype)
# print(model.config.mask_ratio)
# print(model.vit.embeddings.config.mask_ratio)

do_r2 = True
do_inference = True
aggregation_mode = "cls" # 'cls', 'mean', or 'max'

variable_of_interest_col_name = "Index"
image_column_name = "All_Patient_All_Voxel_Normalized_Recording"
length = 200
num_voxels = 424

# need this if running on matteo's branch, due to multiple train modes (auto-encoder, causal attention, predict last, etc)
try:
    print(model.config.train_mode)
except AttributeError:
    model.config.train_mode = "auto_encode"

In [14]:
# Dataset

coords_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/Brain_Region_Coordinates")
train_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/train_ukbiobank")

val_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/val_ukbiobank")

test_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/test_ukbiobank")


print(train_ds[0]['Filename'])

# Make Resting state vs Task column
task_to_i = {"rs": 0, "tf": 1}
# code is a bit wordy but it is the fastest
task_type = pd.Series(train_ds['Filename']).str.split("_", expand=True)[:][1].map(task_to_i)
train_ds = train_ds.add_column(name="Task_Type", column=task_type)

task_type = pd.Series(val_ds['Filename']).str.split("_", expand=True)[:][1].map(task_to_i)
val_ds = val_ds.add_column(name="Task_Type", column=task_type)

task_type = pd.Series(test_ds['Filename']).str.split("_", expand=True)[:][1].map(task_to_i)
test_ds = test_ds.add_column(name="Task_Type", column=task_type)

concat_ds = concatenate_datasets([train_ds, val_ds, test_ds])
concat_ds

# make order label
index = pd.Series(np.arange(concat_ds.num_rows))
concat_ds = concat_ds.add_column(name="Index", column=index)

index = pd.Series(np.arange(test_ds.num_rows))
test_ds = test_ds.add_column(name="Index", column=index)

# index = pd.Series(np.arange(HCP_ds.num_rows))
# HCP_ds = HCP_ds.add_column(name="Index", column=index)

example0 = concat_ds[0]
print(example0['Filename'])
print(example0['Patient ID'])
print(example0['Task_Type'])
print(example0['Index'])

5446541.dat_rs
5446541.dat_rs
_rs
0
0


In [15]:
# Saving directory
dir_name = f"/gpfs/gibbs/pi/dijk/BrainLM_zero_inf/{model_name}/dataset_{dataset_v}/"
if not os.path.exists(dir_name) and do_inference:
    os.makedirs(dir_name)
dataset_split = {"train": train_ds, "val": val_ds, "test": test_ds, "concat": concat_ds}
ds_used = dataset_split[split]
print(ds_used)

Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity', 'Task_Type', 'Index'],
    num_rows: 7628
})


In [25]:
#Forward pass through model, passing whole fMRI recording
image_processor = ViTImageProcessor(size={"height": model.config.image_size[0], "width": model.config.image_size[1]})
if "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
else:
    size = (image_processor.size["height"], image_processor.size["width"])
voxel_x_coords_list = coords_ds["Y"]
reorder_idxs_by_x_coord = sorted(range(len(voxel_x_coords_list)), key=lambda k: voxel_x_coords_list[k])
reorder_idxs_by_x_coord = np.array(reorder_idxs_by_x_coord)
max_val_to_scale = 5.6430855

In [27]:
def preprocess_images(examples):
    """Preprocess a batch of images by applying transforms."""
    fmri_images_list = []
    for idx in range(len(examples[image_column_name])):
        signal_window = torch.tensor(examples[image_column_name][idx], dtype=torch.float32).t()

        # Choose random starting index, take window of moving_window_len points for each region
        start_idx = randint(0, signal_window.shape[0] - length)
        end_idx = start_idx + length
        signal_window = signal_window[start_idx: end_idx, :]
        signal_window = torch.movedim(signal_window, 0, 1)  # --> [num_voxels, moving_window_len]

        # reorder voxels according to x-coordinate
        signal_window = signal_window[reorder_idxs_by_x_coord, :]
        signal_window = signal_window / max_val_to_scale

        # Repeat tensor for 3 channels (R,G,B)
        signal_window = signal_window.unsqueeze(0).repeat(3, 1, 1)

        fmri_images_list.append(signal_window) 


    examples["pixel_values"] = fmri_images_list  # No transformation or resizing; model will do padding
    return examples

# def collate_fn(examples):
#     pixel_values = torch.stack([example["pixel_values"] for example in examples])
#     labels = torch.tensor([1 for _ in range(len(pixel_values))])
#     return {
#         "pixel_values": pixel_values,
#         "input_ids": pixel_values,
#         "labels": labels
#     }
def get_attention_cls_token(attn_probs):
    attn_probs_heads = attn_probs[31].squeeze(0) 
    attn_probs_avg = attn_probs_heads.mean(dim=0, keepdim=True)
    cls_attn = attn_probs_avg[:, 0, :].cpu().numpy()
    return cls_attn
embarc_ds.set_transform(preprocess_images)

In [21]:
model

ViTMAEForPreTraining(
  (vit): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-31): 32 x ViTMAELayer(
          (attention): ViTMAEAttention(
            (attention): ViTMAESelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=1280, out_feature

In [29]:
#Save CLS Tokens + All Embeddings
model_type="pad"
list_cls_tokens = []
list_attn_cls_tokens = []
all_embeddings = []
all_index = []
with torch.no_grad():
    for recording in tqdm(embarc_ds, desc="Getting CLS tokens"):

        pixel_values = recording["pixel_values"].unsqueeze(0).half().to(device)
        if model_type == "pad":
            # pixel_values is [batch, channels=3, 424, 200]. Pad to [batch, channels=3, 432, 432]
            height_pad_total = model.config.image_size[0] - pixel_values.shape[2]
            height_pad_total_half = height_pad_total // 2
            width_pad_total = model.config.image_size[1] - pixel_values.shape[3]
            width_pad_total_half = width_pad_total // 2
            pixel_values = F.pad(pixel_values, (width_pad_total_half, width_pad_total_half, height_pad_total_half, height_pad_total_half), "constant", -1)

        encoder_output = model.vit(
            pixel_values=pixel_values,
            output_hidden_states=True
        )

        cls_token = encoder_output.last_hidden_state[:,0,:]  # torch.Size([1, 256])
        embedding = encoder_output.last_hidden_state[:,1:,:]
        all_embeddings.append(embedding.detach().cpu().numpy())
        list_cls_tokens.append(cls_token.detach().cpu().numpy())
        # all_index.append(recording["labels"].detach().numpy())
        attn_cls_token = get_attention_cls_token(encoder_output.attentions)
        list_attn_cls_tokens.append(attn_cls_token)
print(all_embeddings[0].shape)

Getting CLS tokens: 100%|██████████| 1819/1819 [12:06<00:00,  2.50it/s]

(1, 240, 1280)





In [31]:
preds_name = os.path.join(sav_dir, f'{params}_cls_token.npy')
print("Saving inference results to: ", preds_name)
np.save(preds_name, list_attn_cls_tokens)

Saving inference results to:  /gpfs/gibbs/pi/dijk/embarc_inf/650M_cls_token.npy


In [None]:
if aggregation_mode == "cls":
    print("cls aggregation")
    all_embeds = np.concatenate(list_cls_tokens, axis=0)
elif aggregation_mode == "mean":
    print("mean pool aggregation")
    all_mean_embeddings = [e.mean(axis=1) for e in all_embeddings]
    all_embeds = np.concatenate(all_mean_embeddings, axis=0)
elif aggregation_mode == "max":
    print("max pool aggregation")
    all_sum_embeddings = [e.max(axis=1) for e in all_embeddings]
    all_embeds = np.concatenate(all_sum_embeddings, axis=0)   


print(all_embeds.shape)
all_index = np.concatenate(all_index, axis=0)

# check if dataloader messed up order in any way
if not np.all(all_index[:-1] <= all_index[1:]):
    # reorder everything
    print("reordering")
    all_embeds = all_embeds[all_index, :]
    
np.save(f"{dir_name}{split}_{aggregation_mode}_all_{length}recordinglength.npy", all_embeds)


# extract patch tokens with data
if aggregation_mode != "cls":
    print(length)
    hor_img_start_idx = (model.config.image_size[1] - length) // 2
    hor_token_start = hor_img_start_idx // model.config.patch_size
    hor_token_end = hor_token_start + np.ceil(length / model.config.patch_size).astype(int)

    vert_img_start_idx = (model.config.image_size[0] - num_voxels) // 2
    vert_token_start = vert_img_start_idx // model.config.patch_size
    vert_token_end = vert_token_start + np.ceil(num_voxels / model.config.patch_size).astype(int)
    print(hor_token_start, hor_token_end, vert_token_start, vert_token_end)

    for i, e in enumerate(all_embeddings):
        e = e.reshape(e.shape[0], int(np.sqrt(e.shape[1])), int(np.sqrt(e.shape[1])), -1)
        e = e[:, vert_token_start:vert_token_end, hor_token_start:hor_token_end]
        all_embeddings[i] = e.reshape(e.shape[0], -1, e.shape[-1])
    
    print(all_embeddings[0].shape)

    if aggregation_mode == "mean":
        print("mean pool aggregation")
        all_mean_embeddings = [e.mean(axis=1) for e in all_embeddings]
        all_embeds = np.concatenate(all_mean_embeddings, axis=0)
    if aggregation_mode == "max":
        print("max pool aggregation")
        all_sum_embeddings = [e.max(axis=1) for e in all_embeddings]
        all_embeds = np.concatenate(all_sum_embeddings, axis=0)  

    np.save(f"{dir_name}{split}_{aggregation_mode}_only_data_{length}recordinglength.npy", all_embeds)

# Save raw recordings as well
all_recordings = []
for idx, batch in enumerate(tqdm(dataloader_batched)):
    signal = batch["pixel_values"]
    recording = signal.flatten(start_dim = 1)
    recording = np.array(recording, dtype=np.float32)
    all_recordings.append(recording)
all_recordings = np.vstack(all_recordings)
all_recordings.shape

np.save(f"{dir_name}all_recordings_{length}length.npy", all_recordings)

In [None]:
#Plotting CLS Tokens and Raw Recordings With PCA
aggregation_mode = "cls" # 'cls', 'mean', or 'max'
extraction_mode = "all" 
load_path = f"{dir_name}{split}_{aggregation_mode}_{extraction_mode}_200recordinglength.npy"
all_cls_tokens = np.load(load_path)
all_cls_tokens.shape

In [None]:
from sklearn.decomposition import PCA
n_components = 200
filename = f"{dir_name}pca_obj_cls_tokens_{length}length_{n_components}components.pkl"

pca = PCA(n_components=n_components)
pca.fit(all_cls_tokens)
print("pca.n_components_:", pca.n_components_)
print("pca.n_features_in_:", pca.n_features_in_)
print("pca.components_.shape:", pca.components_.shape)
print("pca.explained_variance_.shape:", pca.explained_variance_.shape)
print("pca.explained_variance_ratio_.shape:", pca.explained_variance_ratio_.shape)
print("pca.singular_values_.shape:", pca.singular_values_.shape)

In [None]:
cls_tokens_pca_reduced = pca.transform(all_cls_tokens)
print(cls_tokens_pca_reduced.shape)
np.save(f"{dir_name}{split}_pca_reduced_cls_tokens_200components.npy", cls_tokens_pca_reduced)
with open(filename, 'wb') as pickle_file:
        pickle.dump(pca, pickle_file)

In [None]:
load_path = f"{dir_name}{split}_{aggregation_mode}_{extraction_mode}_200recordinglength.npy"
all_embeds = np.load(load_path)
all_embeds.shape

In [None]:
total_num_ex = cls_tokens_pca_reduced.shape[0]
embed_pca_components_list = [cls_tokens_pca_reduced[idx] for idx in range(total_num_ex)]
test_ds = test_ds.add_column(name="embed_pca_components", column=embed_pca_components_list)
total_num_ex = all_embeds.shape[0]
all_embeds_list = [all_embeds[idx] for idx in range(total_num_ex)]
test_ds = test_ds.add_column(name="whole_embed", column=all_embeds_list)

In [None]:
def target_variable_normalization(labels_nonnan, variable_of_interest_col_name):
    print("Normalizing variable:", variable_of_interest_col_name)
    if variable_of_interest_col_name in ["PCL.Score", "GAD7.Severity"]:
        labels_nonnan_log1p = np.log(np.array(labels_nonnan, dtype=np.float32) + 1.0)  # log base_e(value + 1)
        labels_nonnan_log1p_divmax = np.divide(labels_nonnan_log1p, labels_nonnan_log1p.max())  # bring into range [0, 1]
        labels_normalized = labels_nonnan_log1p_divmax.tolist()
    elif variable_of_interest_col_name == "Age.At.MHQ":
        z_score_transform = StandardScaler()
        labels_normalized_np = z_score_transform.fit_transform(np.expand_dims(np.array(labels_nonnan), axis=1))
        labels_normalized_np = np.squeeze(labels_normalized_np, axis=1)
        labels_normalized = labels_normalized_np.tolist()
    elif variable_of_interest_col_name == "Neuroticism":
        labels_nonnan_log1p = np.array(labels_nonnan, dtype=np.float32)  # if is Neuroticism, then no log transform, distribution is already good
        labels_nonnan_log1p_divmax = np.divide(labels_nonnan_log1p, labels_nonnan_log1p.max())  # bring into range [0, 1]
        labels_normalized = labels_nonnan_log1p_divmax.tolist()
    elif variable_of_interest_col_name == "PHQ9.Severity":
        labels_normalized = [1 if num > 4.0 else 0 for num in labels_nonnan]
        sum_ones = sum(labels_normalized)
        sum_zeros = len(labels_normalized) - sum_ones
    elif variable_of_interest_col_name in ["Depressed.At.Baseline", "Self.Harm.Ever", "Not.Worth.Living"]:
        labels_normalized = labels_nonnan
        sum_ones = sum(labels_normalized)
        sum_zeros = len(labels_normalized) - sum_ones
    elif variable_of_interest_col_name in ['Gender', "Task_Type"]:
        labels_normalized = labels_nonnan
    else:
        raise NotImplementedError("Unknown variable of interest specified.")
    
    return np.array(labels_normalized)

from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_validate
from sklearn.model_selection import train_test_split

def run_svm_regression(variable_of_interest, use_whole=False):
    assert variable_of_interest in ["Age.At.MHQ", "PHQ9.Severity", "Neuroticism", "PCL.Score", "GAD7.Severity"], \
        "Please specify a metadata variable with a range of continuous values."

    # Define custom scoring metrics
    scoring = ["r2","neg_mean_squared_error"]
    
    # Select non-nan samples for metadata variable
    full_label_list = test_ds[variable_of_interest]
    non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]

    non_nan_ds = test_ds.select(non_nan_indices)

    # Shuffle dataset reproducibly
    non_nan_ds = non_nan_ds.shuffle(seed=42)

    # Get PCA components for raw data and CLS tokens after shuffling
    if use_whole:
        embed_pca_nonnan = np.array(non_nan_ds["whole_embed"], dtype=np.float32)
    else:
        embed_pca_nonnan = np.array(non_nan_ds["embed_pca_components"], dtype=np.float32)


    # Get labels
    labels = non_nan_ds[variable_of_interest]
    labels = [int(num) for num in labels]

    # Normalize target variable for regression
    labels_normalized_np = target_variable_normalization(
        labels_nonnan=labels, 
        variable_of_interest_col_name=variable_of_interest
    )
    labels_normalized = labels_normalized_np.tolist()
    print("Max and min:", labels_normalized_np.max(), labels_normalized_np.min())
    
    
    
    #--- Fit MLP on CLS Token PCA Components ---#
    regr = svm.LinearSVR()
#     scores = cross_val_score(regr, cls_token_pca_nonnan, labels_normalized, cv=5, 
#                              scoring="neg_mean_squared_error")
    results_dict = cross_validate(regr, embed_pca_nonnan, labels_normalized, cv=5, 
                             scoring=scoring)
    r2_scores = results_dict["test_r2"]
    mse_scores = results_dict["test_neg_mean_squared_error"]
    mse_scores = [-1 * num for num in mse_scores]
    print(f"CLS Token PCA Component MSE: {statistics.mean(mse_scores):.3f} +/- {statistics.stdev(mse_scores):.3f}")
    print(f"CLS Token PCA Component R2: {statistics.mean(r2_scores):.3f} +/- {statistics.stdev(r2_scores):.3f}")

def run_svm_classification(variable_of_interest, use_whole=False):
    assert variable_of_interest in ["Gender","PHQ9.Severity","Task_Type"], \
        "Please specify a metadata variable with a binary value."

    # Define custom scoring metrics
    scoring = ["accuracy","balanced_accuracy","roc_auc","f1"]
    
    # Select non-nan samples for metadata variable
    full_label_list = test_ds[variable_of_interest]
    non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]

    non_nan_ds = test_ds.select(non_nan_indices)

    # Shuffle dataset reproducibly
    non_nan_ds = non_nan_ds.shuffle(seed=42)

    # Get PCA components for raw data and CLS tokens after shuffling
    if use_whole:
        embed_pca_nonnan = np.array(non_nan_ds["whole_embed"], dtype=np.float32)
    else:
        embed_pca_nonnan = np.array(non_nan_ds["embed_pca_components"], dtype=np.float32)


    # Get labels
    labels = non_nan_ds[variable_of_interest]
    labels = [int(num) for num in labels]

    # Normalize target variable for regression
    labels_normalized_np = target_variable_normalization(
        labels_nonnan=labels, 
        variable_of_interest_col_name=variable_of_interest
    )
    labels_normalized = labels_normalized_np.tolist()
    print("Max and min:", labels_normalized_np.max(), labels_normalized_np.min())
    
    
    
    #--- Fit MLP on CLS Token PCA Components ---#
    regr = svm.SVC()
#     scores = cross_val_score(regr, cls_token_pca_nonnan, labels_normalized, cv=5, 
#                              scoring="neg_mean_squared_error")
    results_dict = cross_validate(regr, embed_pca_nonnan, labels_normalized, cv=5, 
                             scoring=scoring)
    # print('results_dict: ',results_dict)
    accuracy_score = results_dict["test_accuracy"]
    balanced_accuracy_score = results_dict["test_balanced_accuracy"]
    roc_auc_score = results_dict["test_roc_auc"]
    f1_score = results_dict["test_f1"]
    if use_whole:
        print(f"CLS Token whole data accuracy: {statistics.mean(accuracy_score):.3f} +/- {statistics.stdev(accuracy_score):.3f}")
        print(f"CLS Token whole data balanced_accuracy: {statistics.mean(balanced_accuracy_score):.3f} +/- {statistics.stdev(balanced_accuracy_score):.3f}")
        print(f"CLS Token whole data roc_auc: {statistics.mean(roc_auc_score):.3f} +/- {statistics.stdev(roc_auc_score):.3f}")
        print(f"CLS Token whole data f1: {statistics.mean(f1_score):.3f} +/- {statistics.stdev(f1_score):.3f}")
    else: 
        print(f"CLS Token PCA data accuracy: {statistics.mean(accuracy_score):.3f} +/- {statistics.stdev(accuracy_score):.3f}")
        print(f"CLS Token PCA data balanced_accuracy: {statistics.mean(balanced_accuracy_score):.3f} +/- {statistics.stdev(balanced_accuracy_score):.3f}")
        print(f"CLS Token PCA data roc_auc: {statistics.mean(roc_auc_score):.3f} +/- {statistics.stdev(roc_auc_score):.3f}")
        print(f"CLS Token PCA data f1: {statistics.mean(f1_score):.3f} +/- {statistics.stdev(f1_score):.3f}")



In [None]:
metadata_variable = "Age.At.MHQ"
age_results = run_svm_regression(metadata_variable, use_whole=False)

In [None]:
metadata_variable = "GAD7.Severity"
gad7_results = run_svm_regression(metadata_variable, use_whole=False)

In [None]:
metadata_variable = "Neuroticism"
phq9_results = run_svm_regression(metadata_variable, use_whole=False)

In [None]:
metadata_variable = "PCL.Score"
pcl_results = run_svm_regression(metadata_variable, use_whole=False)

In [None]:
metadata_variable = "PHQ9.Severity"
phq_results = run_svm_classification(metadata_variable, use_whole=True)

In [None]:
metadata_variable = "Gender"
gender_results = run_svm_classification(metadata_variable, use_whole=True)

In [None]:
metadata_variable = "Task_Type"
task_results = run_svm_classification(metadata_variable, use_whole=True)