In [1]:
import os
import math
from random import randint, seed
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_from_disk, concatenate_datasets
from transformers import ViTImageProcessor, ViTMAEConfig
from brainlm_mae.modeling_vit_mae_with_padding import ViTMAEForPreTraining 
from brainlm_mae.replace_vitmae_attn_with_flash_attn import replace_vitmae_attn_with_flash_attn
from toolkit.BrainLM_Toolkit import convert_fMRIvols_to_A424, convert_to_arrow_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed(42)

In [3]:
raw_data_dir = "/home/mt2286/project/BrainLM/toolkit/sample_dataset/raw_fMRI_data"
save_data_dir = "/home/mt2286/project/BrainLM/toolkit/sample_dataset/a424_fMRI_data" #Make sure this directory exists.
args = {
    "uk_biobank_dir": "/home/mt2286/project/BrainLM/toolkit/sample_dataset/a424_fMRI_data",     # "Path to directory containing dat files, A424 coordinates file, and A424 excel sheet.",
    "arrow_dataset_save_directory": os.path.join(save_data_dir,"arrow_form"),     # "The directory where you want to save the output arrow datasets."
    "dataset_name": "Test_data_arrow_norm",
}

# convert_fMRIvols_to_A424(data_path=raw_data_dir, output_path=save_data_dir)
convert_to_arrow_datasets(args, args["arrow_dataset_save_directory"])

FMRI Data Arrow Conversion Starting...
There's no A24 Coordinates dat file


100%|██████████| 1/1 [00:00<00:00, 12.44it/s]


Not processing 0 files due to insufficient fMRI data


100%|██████████| 1/1 [00:00<00:00, 12.23it/s]
Getting normalization stats: 100%|██████████| 1/1 [00:00<00:00, 12.47it/s]
Normalizing Data:   0%|          | 0/1 [00:00<?, ?it/s]

Normalizing Data: 100%|██████████| 1/1 [00:00<00:00,  9.79it/s]


(424, 645)
Print data array:  (424, 645)


Saving the dataset (1/1 shards): 100%|██████████| 1/1 [00:00<00:00, 55.68 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 424/424 [00:02<00:00, 188.53 examples/s]

Done.





### Zero-shot Inference - Saving CLS Token

In [None]:
replace_vitmae_attn_with_flash_attn()
params = "650M" #Choose between 650M and 111M
config = ViTMAEConfig.from_pretrained("vandijklab/brainlm", subfolder=f"vitmae_{params}")
config.update({
    "mask_ratio": 0.75,
    "timepoint_patching_size": 20,
    "num_timepoints_per_voxel": 200,
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "output_attentions": True,
})

model = ViTMAEForPreTraining.from_pretrained(
        "vandijklab/brainlm",
        config=config,
        subfolder=f"vitmae_{params}",
    ).to(device)

model = model.half()
model.eval()
# print(model.dtype)
# print(model.config.mask_ratio)
# print(model.vit.embeddings.config.mask_ratio)

do_r2 = True
do_inference = True
aggregation_mode = "cls" # 'cls', 'mean', or 'max'

variable_of_interest_col_name = "Index"
image_column_name = "All_Patient_All_Voxel_Normalized_Recording"
length = 200
num_voxels = 424

# need this if running on matteo's branch, due to multiple train modes (auto-encoder, causal attention, predict last, etc)
try:
    print(model.config.train_mode)
except AttributeError:
    model.config.train_mode = "auto_encode"

In [None]:
coords_ds = load_from_disk("/home/mt2286/project/BrainLM/toolkit/sample_dataset/a424_fMRI_data/Arrow_Datasets2/Brain_Region_Coordinates")
train_ds = load_from_disk("/home/mt2286/project/BrainLM/toolkit/sample_dataset/a424_fMRI_data/Arrow_Datasets2/train/")
# val_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/val_ukbiobank")
# test_ds = load_from_disk("/gpfs/gibbs/pi/dijk/BrainLM_Datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/test_ukbiobank")


In [None]:
dir_name = f"/gpfs/gibbs/pi/dijk/BrainLM_zero_inf/{model_name}/dataset_{dataset_v}/"
if not os.path.exists(dir_name) and do_inference:
    os.makedirs(dir_name)
dataset_split = {"train": train_ds, "val": val_ds, "test": test_ds, "concat": concat_ds}
ds_used = dataset_split[split]
print(ds_used)

In [None]:
image_processor = ViTImageProcessor(size={"height": model.config.image_size[0], "width": model.config.image_size[1]})
if "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
else:
    size = (image_processor.size["height"], image_processor.size["width"])
voxel_x_coords_list = coords_ds["Y"]
reorder_idxs_by_x_coord = sorted(range(len(voxel_x_coords_list)), key=lambda k: voxel_x_coords_list[k])
reorder_idxs_by_x_coord = np.array(reorder_idxs_by_x_coord)
max_val_to_scale = 5.6430855

In [None]:
def preprocess_images(examples):
    """Preprocess a batch of images by applying transforms."""
    fmri_images_list = []
    for idx in range(len(examples[image_column_name])):
        signal_window = torch.tensor(examples[image_column_name][idx], dtype=torch.float32).t()

        # Choose random starting index, take window of moving_window_len points for each region
        start_idx = randint(0, signal_window.shape[0] - length)
        end_idx = start_idx + length
        signal_window = signal_window[start_idx: end_idx, :]
        signal_window = torch.movedim(signal_window, 0, 1)  # --> [num_voxels, moving_window_len]

        # reorder voxels according to x-coordinate
        signal_window = signal_window[reorder_idxs_by_x_coord, :]
        signal_window = signal_window / max_val_to_scale

        # Repeat tensor for 3 channels (R,G,B)
        signal_window = signal_window.unsqueeze(0).repeat(3, 1, 1)

        fmri_images_list.append(signal_window) 


    examples["pixel_values"] = fmri_images_list  # No transformation or resizing; model will do padding
    return examples


def get_attention_cls_token(attn_probs):
    attn_probs_heads = attn_probs[31].squeeze(0) 
    attn_probs_avg = attn_probs_heads.mean(dim=0, keepdim=True)
    cls_attn = attn_probs_avg[:, 0, :].cpu().numpy()
    return cls_attn
embarc_ds.set_transform(preprocess_images)

In [None]:
model_type="pad"
list_cls_tokens = []
list_attn_cls_tokens = []
all_embeddings = []
all_index = []
with torch.no_grad():
    for recording in tqdm(embarc_ds, desc="Getting CLS tokens"):

        pixel_values = recording["pixel_values"].unsqueeze(0).half().to(device)
        if model_type == "pad":
            # pixel_values is [batch, channels=3, 424, 200]. Pad to [batch, channels=3, 432, 432]
            height_pad_total = model.config.image_size[0] - pixel_values.shape[2]
            height_pad_total_half = height_pad_total // 2
            width_pad_total = model.config.image_size[1] - pixel_values.shape[3]
            width_pad_total_half = width_pad_total // 2
            pixel_values = F.pad(pixel_values, (width_pad_total_half, width_pad_total_half, height_pad_total_half, height_pad_total_half), "constant", -1)

        encoder_output = model.vit(
            pixel_values=pixel_values,
            output_hidden_states=True
        )

        cls_token = encoder_output.last_hidden_state[:,0,:]  # torch.Size([1, 256])
        embedding = encoder_output.last_hidden_state[:,1:,:]
        all_embeddings.append(embedding.detach().cpu().numpy())
        list_cls_tokens.append(cls_token.detach().cpu().numpy())
        # all_index.append(recording["labels"].detach().numpy())
        attn_cls_token = get_attention_cls_token(encoder_output.attentions)
        list_attn_cls_tokens.append(attn_cls_token)
print(all_embeddings[0].shape)

In [None]:
preds_name = os.path.join(sav_dir, f'{params}_cls_token.npy')
print("Saving inference results to: ", preds_name)
np.save(preds_name, list_attn_cls_tokens)