In [1]:
import os
import math

import umap
import matplotlib.pyplot as plt
import matplotlib.colors as mcol
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier

import torch
import torch.nn as nn

from datasets import load_from_disk, concatenate_datasets
from brainlm_mae.modeling_brainlm import BrainLMForPretraining

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if not os.path.exists("inference_plots"):
    os.mkdir("inference_plots")

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Load Entire Dataset

In [4]:
train_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKBioBank1000_Arrow_v4/train_ukbiobank1000")
print(train_ds)
val_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKBioBank1000_Arrow_v4/val_ukbiobank1000")
print(val_ds)
test_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKBioBank1000_Arrow_v4/test_ukbiobank1000")
print(test_ds)
coords_ds = load_from_disk("/home/sr2464/palmer_scratch/datasets/UKBioBank1000_Arrow_v4/Brain_Region_Coordinates")
print(coords_ds)

Dataset({
    features: ['Raw_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 800
})
Dataset({
    features: ['Raw_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patien

In [5]:
concat_ds = concatenate_datasets([train_ds, val_ds, test_ds])
concat_ds

Dataset({
    features: ['Raw_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 1000
})

In [6]:
example0 = concat_ds[0]
print(example0['Filename'])
print(example0['Patient ID'])
print(example0['Order'])
print(example0['eid'])
print(example0['Gender'])
print(example0['Age.At.MHQ'])
print(example0['Depressed.At.Baseline'])
print(example0['Neuroticism'])
print(example0['Self.Harm.Ever'])
print(example0['Not.Worth.Living'])
print(example0['PCL.Score'])
print(example0['GAD7.Severity'])

1000011.dat

16248
1000011
1.0
76.0
0.0
3.0
0.0
0.0
10.0
3.0


## Reload CLS tokens and raw data if needed

In [42]:
all_cls_tokens = np.load("inference_plots/all_cls_tokens_200recordinglength.npy")
all_cls_tokens.shape

(1000, 256)

In [43]:
recording_col_name = "Subtract_Mean_Divide_Global_STD_Normalized_Recording"
all_recordings = np.load("inference_plots/all_{}.npy".format(recording_col_name))
all_recordings.shape

(1000, 84800)

In [44]:
raw_data_pca_reduced = np.load("inference_plots/pca_reduced_raw_data_50components.png.npy")
cls_tokens_pca_reduced = np.load("inference_plots/pca_reduced_cls_tokens_50components.png.npy")
print(raw_data_pca_reduced.shape)
print(cls_tokens_pca_reduced.shape)

(1000, 50)
(1000, 50)


In [45]:
raw_data_pca_reduced_10components = raw_data_pca_reduced[:, :10]
cls_tokens_pca_reduced_10_components = cls_tokens_pca_reduced[:, :10]
print(raw_data_pca_reduced_10components.shape)
print(cls_tokens_pca_reduced_10_components.shape)

(1000, 10)
(1000, 10)


In [46]:
print(raw_data_pca_reduced_10components[0])
print(cls_tokens_pca_reduced_10_components[0])

[  3.6791153  43.49763    -8.094716   12.79285   -18.954952   22.041622
 -36.941082  -28.353394  -10.202997    4.0116863]
[-0.4117697  -0.01193752 -0.03786373 -0.06954364 -0.15222041  0.02873909
 -0.01182826  0.08061408 -0.08338942 -0.05359326]


In [47]:
print(concat_ds["Gender"][:10])
print(concat_ds["Age.At.MHQ"][:10])
print(concat_ds["PHQ9.Severity"][:10])
print(concat_ds["Depressed.At.Baseline"][:10])
print(concat_ds["Neuroticism"][:10])
print(concat_ds["Self.Harm.Ever"][:10])
print(concat_ds["Not.Worth.Living"][:10])
print(concat_ds["PCL.Score"][:10])
print(concat_ds["GAD7.Severity"][:10])

[1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, nan, 0.0]
[76.0, 61.0, 52.0, 55.0, 53.0, 55.0, 72.0, 71.0, nan, 66.0]
[2.0, 0.0, 1.0, 13.0, 4.0, 1.0, 1.0, 0.0, nan, 2.0]
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, nan, nan]
[3.0, 4.0, 1.0, 6.0, 5.0, 0.0, nan, 1.0, nan, nan]
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, nan, 1.0]
[0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, nan, 0.0]
[10.0, 4.0, 9.0, 14.0, 5.0, 4.0, 4.0, 4.0, nan, 15.0]
[3.0, 0.0, 1.0, 5.0, 2.0, 0.0, 0.0, 0.0, nan, 4.0]


In [48]:
print("Gender:", np.unique(np.array(concat_ds["Gender"])))
print("Age.At.MHQ:", np.unique(np.array(concat_ds["Age.At.MHQ"])))
print("PHQ9.Severity:", np.unique(np.array(concat_ds["PHQ9.Severity"])))
print("Depressed.At.Baseline:", np.unique(np.array(concat_ds["Depressed.At.Baseline"])))
print("Neuroticism:", np.unique(np.array(concat_ds["Neuroticism"])))
print("Self.Harm.Ever:", np.unique(np.array(concat_ds["Self.Harm.Ever"])))
print("Not.Worth.Living:", np.unique(np.array(concat_ds["Not.Worth.Living"])))
print("PCL.Score:", np.unique(np.array(concat_ds["PCL.Score"])))
print("GAD7.Severity:", np.unique(np.array(concat_ds["GAD7.Severity"])))

Gender: [ 0.  1. nan]
Age.At.MHQ: [47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62. 63. 64.
 65. 66. 67. 68. 69. 70. 71. 72. 73. 74. 75. 76. 77. 78. 79. nan]
PHQ9.Severity: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 17. 18. 21.
 25. nan]
Depressed.At.Baseline: [ 0.  1. nan]
Neuroticism: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. nan]
Self.Harm.Ever: [ 0.  1. nan]
Not.Worth.Living: [ 0.  1. nan]
PCL.Score: [ 2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19.
 20. 25. nan]
GAD7.Severity: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 13. 15. 17. 18. 19. 21.
 nan]


In [49]:
np.unique(np.array(concat_ds["Not.Worth.Living"]), return_counts=True)

(array([ 0.,  1., nan]), array([503, 224, 273]))

## Run KNN Classifier on Gender

In [79]:
variable_of_interest = "Gender"

In [80]:
# Select rows of concat_dataset where label is not nan
full_label_list = concat_ds[variable_of_interest]
non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]
# print(non_nan_indices[:10])
# print(full_label_list[:10])
non_nan_ds = concat_ds.select(non_nan_indices)
non_nan_ds

Dataset({
    features: ['Raw_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 727
})

In [81]:
raw_data_pca_nonnan = raw_data_pca_reduced_10components[np.array(non_nan_indices)]
raw_data_pca_nonnan.shape

(727, 10)

In [82]:
cls_tokens_pca_nonnan = cls_tokens_pca_reduced_10_components[np.array(non_nan_indices)]
cls_tokens_pca_nonnan.shape

(727, 10)

In [83]:
labels = non_nan_ds[variable_of_interest]
labels = [int(num) for num in labels]
labels[:10]

[1, 1, 0, 0, 0, 1, 0, 1, 0, 0]

In [86]:
# Run on raw data
split_idx = int(raw_data_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = raw_data_pca_nonnan[:split_idx]
test_X_raw_data = raw_data_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(581, 10)
(146, 10)


In [87]:
train_X_raw_data[0]

array([  3.6791153,  43.49763  ,  -8.094716 ,  12.79285  , -18.954952 ,
        22.041622 , -36.941082 , -28.353394 , -10.202997 ,   4.0116863],
      dtype=float32)

In [94]:
num_neighbors = 3
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)

In [95]:
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

56.164

In [96]:
# Run on CLS tokens
split_idx = int(cls_tokens_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = cls_tokens_pca_nonnan[:split_idx]
test_X_raw_data = cls_tokens_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(581, 10)
(146, 10)


In [97]:
train_X_raw_data[0]

array([-0.4117697 , -0.01193752, -0.03786373, -0.06954364, -0.15222041,
        0.02873909, -0.01182826,  0.08061408, -0.08338942, -0.05359326],
      dtype=float32)

In [101]:
num_neighbors = 20
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

50.685

## Run KNN Classifier on Depression_at_Baseline

In [15]:
variable_of_interest = "Depressed.At.Baseline"

In [16]:
# Select rows of concat_dataset where label is not nan
full_label_list = concat_ds[variable_of_interest]
non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]
# print(non_nan_indices[:10])
# print(full_label_list[:10])
non_nan_ds = concat_ds.select(non_nan_indices)
non_nan_ds

Dataset({
    features: ['Raw_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 702
})

In [17]:
raw_data_pca_nonnan = raw_data_pca_reduced_10components[np.array(non_nan_indices)]
print(raw_data_pca_nonnan.shape)
cls_tokens_pca_nonnan = cls_tokens_pca_reduced_10_components[np.array(non_nan_indices)]
print(cls_tokens_pca_nonnan.shape)

(702, 10)
(702, 10)


In [18]:
labels = non_nan_ds[variable_of_interest]
labels = [int(num) for num in labels]
labels[:10]

[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]

In [19]:
# Run on raw data
split_idx = int(raw_data_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = raw_data_pca_nonnan[:split_idx]
test_X_raw_data = raw_data_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(561, 10)
(141, 10)


In [21]:
train_X_raw_data[0]

array([  3.6791153,  43.49763  ,  -8.094716 ,  12.79285  , -18.954952 ,
        22.041622 , -36.941082 , -28.353394 , -10.202997 ,   4.0116863],
      dtype=float32)

In [25]:
num_neighbors = 20
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

82.27

In [26]:
# Run on CLS tokens
split_idx = int(cls_tokens_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = cls_tokens_pca_nonnan[:split_idx]
test_X_raw_data = cls_tokens_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(561, 10)
(141, 10)


In [27]:
train_X_raw_data[0]

array([-0.4117697 , -0.01193752, -0.03786373, -0.06954364, -0.15222041,
        0.02873909, -0.01182826,  0.08061408, -0.08338942, -0.05359326],
      dtype=float32)

In [31]:
num_neighbors = 20
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

82.27

## Run KNN on Self.Harm.Ever

In [50]:
variable_of_interest = "Self.Harm.Ever"

In [51]:
# Select rows of concat_dataset where label is not nan
full_label_list = concat_ds[variable_of_interest]
non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]
# print(non_nan_indices[:10])
# print(full_label_list[:10])
non_nan_ds = concat_ds.select(non_nan_indices)
non_nan_ds

Dataset({
    features: ['Raw_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 725
})

In [52]:
raw_data_pca_nonnan = raw_data_pca_reduced_10components[np.array(non_nan_indices)]
print(raw_data_pca_nonnan.shape)
cls_tokens_pca_nonnan = cls_tokens_pca_reduced_10_components[np.array(non_nan_indices)]
print(cls_tokens_pca_nonnan.shape)

(725, 10)
(725, 10)


In [53]:
labels = non_nan_ds[variable_of_interest]
labels = [int(num) for num in labels]
labels[:10]

[0, 0, 0, 1, 0, 0, 0, 0, 1, 0]

In [54]:
# Run on CLS tokens
split_idx = int(cls_tokens_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = cls_tokens_pca_nonnan[:split_idx]
test_X_raw_data = cls_tokens_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(580, 10)
(145, 10)


In [55]:
train_X_raw_data[0]

array([-0.4117697 , -0.01193752, -0.03786373, -0.06954364, -0.15222041,
        0.02873909, -0.01182826,  0.08061408, -0.08338942, -0.05359326],
      dtype=float32)

In [56]:
num_neighbors = 20
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

95.862

In [57]:
# Run on raw data
split_idx = int(raw_data_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = raw_data_pca_nonnan[:split_idx]
test_X_raw_data = raw_data_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(580, 10)
(145, 10)


In [59]:
train_X_raw_data[0]

array([  3.6791153,  43.49763  ,  -8.094716 ,  12.79285  , -18.954952 ,
        22.041622 , -36.941082 , -28.353394 , -10.202997 ,   4.0116863],
      dtype=float32)

In [62]:
num_neighbors = 20
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

95.862

## Run KNN on Not.Worth.Living

In [63]:
variable_of_interest = "Not.Worth.Living"

In [64]:
# Select rows of concat_dataset where label is not nan
full_label_list = concat_ds[variable_of_interest]
non_nan_indices = [idx for idx in range(len(full_label_list)) if not math.isnan(full_label_list[idx])]
# print(non_nan_indices[:10])
# print(full_label_list[:10])
non_nan_ds = concat_ds.select(non_nan_indices)
non_nan_ds

Dataset({
    features: ['Raw_Recording', 'All_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_All_Voxel_Normalized_Recording', 'Per_Patient_Per_Voxel_Normalized_Recording', 'Per_Voxel_All_Patient_Normalized_Recording', 'Subtract_Mean_Normalized_Recording', 'Subtract_Mean_Divide_Global_STD_Normalized_Recording', 'Subtract_Mean_Divide_Global_99thPercent_Normalized_Recording', 'Filename', 'Patient ID', 'Order', 'eid', 'Gender', 'Age.At.MHQ', 'PHQ9.Severity', 'Depressed.At.Baseline', 'Neuroticism', 'Self.Harm.Ever', 'Not.Worth.Living', 'PCL.Score', 'GAD7.Severity'],
    num_rows: 727
})

In [65]:
raw_data_pca_nonnan = raw_data_pca_reduced_10components[np.array(non_nan_indices)]
print(raw_data_pca_nonnan.shape)
cls_tokens_pca_nonnan = cls_tokens_pca_reduced_10_components[np.array(non_nan_indices)]
print(cls_tokens_pca_nonnan.shape)

(727, 10)
(727, 10)


In [66]:
labels = non_nan_ds[variable_of_interest]
labels = [int(num) for num in labels]
labels[:10]

[0, 0, 0, 1, 1, 0, 1, 0, 0, 0]

In [67]:
# Run on CLS tokens
split_idx = int(cls_tokens_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = cls_tokens_pca_nonnan[:split_idx]
test_X_raw_data = cls_tokens_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(581, 10)
(146, 10)


In [68]:
train_X_raw_data[0]

array([-0.4117697 , -0.01193752, -0.03786373, -0.06954364, -0.15222041,
        0.02873909, -0.01182826,  0.08061408, -0.08338942, -0.05359326],
      dtype=float32)

In [73]:
num_neighbors = 20
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

73.288

In [74]:
# Run on raw data
split_idx = int(raw_data_pca_nonnan.shape[0] * 0.8)
train_X_raw_data = raw_data_pca_nonnan[:split_idx]
test_X_raw_data = raw_data_pca_nonnan[split_idx:]
print(train_X_raw_data.shape)
print(test_X_raw_data.shape)
train_y_raw_data = labels[:split_idx]
test_y_raw_data = labels[split_idx:]

(581, 10)
(146, 10)


In [78]:
num_neighbors = 20
neigh = KNeighborsClassifier(n_neighbors=num_neighbors)
neigh.fit(train_X_raw_data, train_y_raw_data)
round(neigh.score(test_X_raw_data, test_y_raw_data) * 100, 3)

71.233