In [1]:
import os
import math
import statistics

import umap
import matplotlib.pyplot as plt
import matplotlib.colors as mcol
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn import svm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score

import torch
import torch.nn as nn

from datasets import load_from_disk, concatenate_datasets
from brainlm_mae.modeling_brainlm import BrainLMForPretraining

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if not os.path.exists("inference_plots"):
    os.mkdir("inference_plots")

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Load Entire Dataset

In [4]:
train_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/train_ukbiobank")
print(train_ds)
val_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/val_ukbiobank")
print(val_ds)
test_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/test_ukbiobank")
print(test_ds)
coords_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKBioBank1000_Arrow_v4/Brain_Region_Coordinates")
print(coords_ds)

Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 61038
})
Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 7629
})
Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 7628
})
Dataset({
    features: ['Index', 'X', 'Y', 'Z'],
    num_rows: 424
})


In [5]:
concat_ds = concatenate_datasets([train_ds, val_ds, test_ds])
concat_ds

Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 76295
})

In [6]:
example0 = concat_ds[10]
print(example0['Filename'])
print(example0['Patient ID'])
print(example0['Order'])
print(example0['eid'])
print(example0['Gender'])
print(example0['Age.At.MHQ'])
print(example0['Depressed.At.Baseline'])
print(example0['Neuroticism'])
print(example0['Self.Harm.Ever'])
print(example0['Not.Worth.Living'])
print(example0['PCL.Score'])
print(example0['GAD7.Severity'])

1191089.dat_tf
_tf
18416.0
1191089.0
nan
nan
nan
nan
nan
nan
nan
nan


## Reload PCA components of CLS tokens and raw data

In [7]:
# Best BrainLM model so far: /home/mr2238/BrainLM/inference_plots/dataset_v3/2023-07-17-19_00_00_ckpt-500/
#  all_cls_200recordinglength.npy
#  pca_reduced_cls_tokens_200components.npy
# Raw recordings:
#  recordings normalized: in concat_ds, column 
#  PCA of raw recordings: /home/mr2238/BrainLM/inference_plots/dataset_v3/pca_reduced_raw_data_200length_200components.npy
all_cls_tokens = np.load("/home/mr2238/BrainLM/inference_plots/dataset_v3/2023-07-17-19_00_00_ckpt-500/all_cls_200recordinglength.npy")
all_cls_tokens.shape

(76295, 512)

In [8]:
cls_token_pca_components = np.load("/home/mr2238/BrainLM/inference_plots/dataset_v3/2023-07-17-19_00_00_ckpt-500/pca_reduced_cls_tokens_200components.npy")
cls_token_pca_components.shape

(76295, 200)

In [9]:
total_num_ex = cls_token_pca_components.shape[0]
cls_token_pca_components_list = [cls_token_pca_components[idx] for idx in range(total_num_ex)]
concat_ds = concat_ds.add_column(name="cls_token_pca_components", column=cls_token_pca_components_list)
concat_ds

Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity', 'cls_token_pca_components'],
    num_rows: 76295
})

In [10]:
# Add whole CLS token to ds as well
total_num_ex = all_cls_tokens.shape[0]
all_cls_tokens_list = [all_cls_tokens[idx] for idx in range(total_num_ex)]
concat_ds = concat_ds.add_column(name="whole_cls_token", column=all_cls_tokens_list)
concat_ds

Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity', 'cls_token_pca_components', 'whole_cls_token'],
    num_rows: 76295
})

In [11]:
# recording_col_name = "Subtract_Mean_Divide_Global_STD_Normalized_Recording"
# all_recordings = np.load("inference_plots/all_{}_490len.npy".format(recording_col_name))
# all_recordings.shape

In [12]:
recording_pca_components = np.load("/home/mr2238/BrainLM/inference_plots/dataset_v3/pca_reduced_raw_data_200length_200components.npy")
recording_pca_components.shape

(76295, 200)

In [13]:
recording_pca_components_list = [recording_pca_components[idx] for idx in range(total_num_ex)]
concat_ds = concat_ds.add_column(name="recording_pca_components", column=recording_pca_components_list)
concat_ds

Dataset({
    features: ['Raw_Recording', 'Voxelwise_RobustScaler_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity', 'cls_token_pca_components', 'whole_cls_token', 'recording_pca_components'],
    num_rows: 76295
})

In [14]:
np.array(concat_ds[0]["cls_token_pca_components"], dtype=np.float32).shape

(200,)

In [15]:
print(concat_ds[0]["cls_token_pca_components"][:5])
print(concat_ds[0]["recording_pca_components"][:5])

[-0.5865330696105957, -0.2947590947151184, 0.1324397772550583, -1.0456269979476929, 0.8647788166999817]
[-27.499038696289062, 10.373053550720215, -30.58742332458496, -25.046045303344727, 29.680246353149414]


In [16]:
print(concat_ds["Gender"][:10])
print(concat_ds["Age.At.MHQ"][:10])
print(concat_ds["PHQ9.Severity"][:10])
print(concat_ds["Depressed.At.Baseline"][:10])
print(concat_ds["Neuroticism"][:10])
print(concat_ds["Self.Harm.Ever"][:10])
print(concat_ds["Not.Worth.Living"][:10])
print(concat_ds["PCL.Score"][:10])
print(concat_ds["GAD7.Severity"][:10])

[nan, nan, nan, 0.0, 1.0, 0.0, nan, nan, 0.0, 0.0]
[nan, nan, nan, 52.0, 72.0, 75.0, nan, nan, 72.0, 63.0]
[nan, nan, nan, 3.0, 1.0, 0.0, nan, nan, 0.0, 10.0]
[nan, nan, nan, 0.0, 0.0, 1.0, nan, nan, 0.0, 1.0]
[nan, nan, nan, 5.0, nan, 1.0, nan, nan, 4.0, nan]
[nan, nan, nan, 0.0, 0.0, 0.0, nan, nan, 0.0, 0.0]
[nan, nan, nan, 0.0, 1.0, 0.0, nan, nan, 0.0, 0.0]
[nan, nan, nan, 7.0, 11.0, 7.0, nan, nan, 4.0, 14.0]
[nan, nan, nan, 2.0, 3.0, 4.0, nan, nan, 0.0, 10.0]


In [17]:
print("Gender:", np.unique(np.array(concat_ds["Gender"])))
print("Age.At.MHQ:", np.unique(np.array(concat_ds["Age.At.MHQ"])))
print("PHQ9.Severity:", np.unique(np.array(concat_ds["PHQ9.Severity"])))
print("Depressed.At.Baseline:", np.unique(np.array(concat_ds["Depressed.At.Baseline"])))
print("Neuroticism:", np.unique(np.array(concat_ds["Neuroticism"])))
print("Self.Harm.Ever:", np.unique(np.array(concat_ds["Self.Harm.Ever"])))
print("Not.Worth.Living:", np.unique(np.array(concat_ds["Not.Worth.Living"])))
print("PCL.Score:", np.unique(np.array(concat_ds["PCL.Score"])))
print("GAD7.Severity:", np.unique(np.array(concat_ds["GAD7.Severity"])))

Gender: [ 0.  1. nan]
Age.At.MHQ: [47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62. 63. 64.
 65. 66. 67. 68. 69. 70. 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. nan]
PHQ9.Severity: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. nan]
Depressed.At.Baseline: [ 0.  1. nan]
Neuroticism: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. nan]
Self.Harm.Ever: [ 0.  1. nan]
Not.Worth.Living: [ 0.  1. nan]
PCL.Score: [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. nan]
GAD7.Severity: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. nan]


In [18]:
np.unique(np.array(concat_ds["Not.Worth.Living"]), return_counts=True)

(array([ 0.,  1., nan]), array([25965, 11191, 39139]))

## Run SVM Regression on continuous metadata variables

In [29]:
def run_svm_regression(variable_of_interest):
    assert variable_of_interest in ["Age.At.MHQ", "PHQ9.Severity", "Neuroticism", "PCL.Score", "GAD7.Severity"], \
        "Please specify a metadata variable with a range of continuous values."
    results = {}
    
    # Select out rows of concat_ds where we have a value for desired metadata variable
    full_label_list = concat_ds[variable_of_interest]
    non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]
    non_nan_ds = concat_ds.select(non_nan_indices)
    
    # Shuffle dataset reproducibly for train/test split, in case patients had some ordering in data
    non_nan_ds = non_nan_ds.shuffle(seed=42)
    
    # Select a fifth of samples after shuffling
#     total_samples = non_nan_ds.num_rows
#     fifth_samples = total_samples // 5
#     start_idx = fifth_samples * split
#     end_idx = start_idx + fifth_samples
#     split_indices = list(range(start_idx, end_idx))
#     non_nan_ds = non_nan_ds.select(split_indices)
    
    # Get PCA components for raw data and CLS tokens after shuffling
    raw_data_pca_nonnan = np.array(non_nan_ds["recording_pca_components"], dtype=np.float32)
    cls_token_pca_nonnan = np.array(non_nan_ds["cls_token_pca_components"], dtype=np.float32)
    whole_cls_token = np.array(non_nan_ds["whole_cls_token"], dtype=np.float32)
    
    # Get labels
    labels = non_nan_ds[variable_of_interest]
    labels = [int(num) for num in labels]
    
    # Z-score normalize labels for regression
    z_score_transform = StandardScaler()
    labels_normalized_np = z_score_transform.fit_transform(np.expand_dims(np.array(labels), axis=1))
    labels_normalized_np = np.squeeze(labels_normalized_np, axis=1)
    labels_normalized = labels_normalized_np.tolist()
    
    #--- Fit MLP on Raw Data PCA Components ---#
    regr = svm.SVR()
    scores = cross_val_score(regr, raw_data_pca_nonnan, labels_normalized, cv=5, 
                             scoring="neg_mean_squared_error")
    scores = [-1 * num for num in scores]
    results["Raw_Data_PCA"] = scores
    
    print(f"Raw Data PCA Components MSE: {statistics.mean(scores):.3f} +/- {statistics.stdev(scores):.3f}")
    
    #--- Fit MLP on CLS Token PCA Components ---#
    regr = svm.SVR()
    scores = cross_val_score(regr, cls_token_pca_nonnan, labels_normalized, cv=5, 
                             scoring="neg_mean_squared_error")
    scores = [-1 * num for num in scores]
    results["CLS_Token_PCA"] = scores
    print(f"CLS Token PCA Component MSE: {statistics.mean(scores):.3f} +/- {statistics.stdev(scores):.3f}")
    
    #--- Fit MLP on Whole CLS Tokens ---#
    regr = svm.SVR()
    scores = cross_val_score(regr, whole_cls_token, labels_normalized, cv=5, 
                             scoring="neg_mean_squared_error")
    scores = [-1 * num for num in scores]
    results["Whole_CLS_Token"] = scores
    print(f"Whole CLS Token MSE: {statistics.mean(scores):.3f} +/- {statistics.stdev(scores):.3f}")
    
    return results

In [30]:
metadata_variable = "Age.At.MHQ"
phq9_results = run_svm_regression(metadata_variable)

Loading cached shuffled indices for dataset at /home/sr2464/palmer_scratch/datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/train_ukbiobank/cache-d8ca8d5320239fd9.arrow


Raw Data PCA Components MSE: 0.797 +/- 0.011
CLS Token PCA Component MSE: 0.776 +/- 0.013
Whole CLS Token MSE: 0.812 +/- 0.013


In [31]:
metadata_variable = "PHQ9.Severity"
phq9_results = run_svm_regression(metadata_variable)

Loading cached shuffled indices for dataset at /home/sr2464/palmer_scratch/datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/train_ukbiobank/cache-d8ca8d5320239fd9.arrow


Raw Data PCA Components MSE: 1.075 +/- 0.049
CLS Token PCA Component MSE: 1.096 +/- 0.050
Whole CLS Token MSE: 1.092 +/- 0.051


In [None]:
metadata_variable = "PCL.Score"
phq9_results = run_svm_regression(metadata_variable)

Loading cached shuffled indices for dataset at /home/sr2464/palmer_scratch/datasets/UKB_Large_rsfMRI_and_tffMRI_Arrow_WithRegression_v3_with_metadata/train_ukbiobank/cache-d8ca8d5320239fd9.arrow


Raw Data PCA Components MSE: 1.110 +/- 0.040
CLS Token PCA Component MSE: 1.204 +/- 0.040


In [None]:
metadata_variable = "GAD7.Severity"
phq9_results = run_svm_regression(metadata_variable)