In [1]:
from sklearn import metrics
import torch
import numpy as np
from scipy.io import loadmat
from utils.Preprocessing import Preprocessing
from utils.metrics import compute_fc_matrix_regular
import matplotlib.pyplot as plt
from networks.Autoencoder import Autoencoder
from scipy.stats import zscore
import utils.utils as utils
import seaborn as sns
from skimage.metrics import structural_similarity as ssim
import pandas as pd
from scipy import stats

In [2]:
# Load the classifier
classifier_original = torch.load(f'./models/DT_original_data.pt')
classifier_reconstructed = torch.load(f'./models/DT_reconstructed_data.pt')
classifier_latent = torch.load(f'./models/DT_latent_data.pt')

In [3]:
preprocessor = Preprocessing()
mat_data = loadmat('./data/laufs_sleep.mat')
assert mat_data['TS_N1'].shape == mat_data['TS_N2'].shape == mat_data['TS_N3'].shape == mat_data['TS_W'].shape
wake_data = mat_data['TS_W'][0][1:]
n3_data = mat_data['TS_N3'][0][1:]

wake_data_shortened, _ = preprocessor.shorten_data(wake_data, final_length=200)
n3_data_shortened, _ = preprocessor.shorten_data(n3_data, final_length=200)

concatenated_data = np.concatenate([wake_data_shortened, n3_data_shortened])
concatenated_data = zscore(concatenated_data)

wake_data_norm = concatenated_data[:wake_data_shortened.shape[0]]
n3_data_norm = concatenated_data[wake_data_shortened.shape[0]:]

wake_data_reshaped = wake_data_norm.reshape(-1, 80)
n3_data_reshaped = n3_data_norm.reshape(-1, 80)

num_time_points = 200
num_brain_nodes = 80

wake_labels = torch.zeros((wake_data_shortened.shape[0], 1))
n3_labels = torch.ones((n3_data_shortened.shape[0], 1))

valid_labels = torch.cat((wake_labels, n3_labels), 0)

wake_num_participants = wake_data_reshaped.shape[0] // num_time_points
n3_num_participants = n3_data_reshaped.shape[0] // num_time_points

wake_data_tensor = torch.tensor(wake_data_reshaped, dtype=torch.float32).reshape(wake_num_participants, num_time_points, num_brain_nodes)
n3_data_tensor = torch.tensor(n3_data_reshaped, dtype=torch.float32).reshape(n3_num_participants, num_time_points, num_brain_nodes)

assert wake_data_tensor.shape[0] == wake_data_shortened.shape[0]
assert n3_data_tensor.shape[0] == n3_data_shortened.shape[0]

latent_dimension = 18
model = Autoencoder(latent_dim=latent_dimension)
model.load_state_dict(torch.load(f'./models/doubleDropout/AE_results_lat{latent_dimension}.pt'))
model.eval()

Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=80, out_features=256, bias=True)
    (1): Dropout(p=0.2, inplace=False)
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): Dropout(p=0.1, inplace=False)
    (5): ReLU()
    (6): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Linear(in_features=128, out_features=18, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=18, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): ReLU()
    (6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.2, inplace=False)
    (8): Linear(in_features=256, out_features=80, bias=True)
 

### PERMUTATION TEST FOR CLASSIFIERS

In [24]:
from sklearn.model_selection import permutation_test_score

with torch.no_grad():
    num_wake = wake_data_tensor.shape[0]
    reconstructed_wake, latent_wake = model(wake_data_tensor.view(-1, num_brain_nodes))
    latent_wake_data = latent_wake.view(num_wake, num_time_points, latent_dimension)
    reconstructed_wake = reconstructed_wake.view(num_wake, num_time_points, num_brain_nodes)

    num_n3 = n3_data_tensor.shape[0]
    reconstructed_n3, latent_n3 = model(n3_data_tensor.view(-1, num_brain_nodes))
    latent_n3_data = latent_n3.view(num_n3, num_time_points, latent_dimension)
    reconstructed_n3 = reconstructed_n3.view(num_n3, num_time_points, num_brain_nodes)

num_participants = num_wake + num_n3

original_correlation_matrices = torch.empty(num_participants, 80, 80)
latent_correlation_matrices = torch.empty(num_participants, latent_dimension, latent_dimension)
reconstructed_correlation_matrices = torch.empty(num_participants, 80, 80)

data_tensor = torch.cat([wake_data_tensor, n3_data_tensor], dim=0)
latent_data = torch.cat([latent_wake_data, latent_n3_data], dim=0)
reconstructed_data = torch.cat([reconstructed_wake, reconstructed_n3], dim=0)

for i in range(num_participants):
    # Original Correlation Matrix
    participant_data = data_tensor[i]
    participant_data_tensor = participant_data.clone().detach()
    original_correlation_matrix = compute_fc_matrix_regular(participant_data_tensor)

    original_correlation_matrices[i] = original_correlation_matrix

    # Latent Correlation Matrix
    participant_latent_data = latent_data[i]
    participant_latent_data_tensor = participant_latent_data.clone().detach()
    latent_correlation_matrix = compute_fc_matrix_regular(participant_latent_data_tensor, num_brain_nodes=latent_dimension)

    latent_correlation_matrices[i] = latent_correlation_matrix

    # Reconstructed Correlation Matrix
    participant_reconstructed_data = reconstructed_data[i]
    participant_reconstructed_data_tensor = participant_reconstructed_data.clone().detach()
    reconstructed_correlation_matrix = compute_fc_matrix_regular(participant_reconstructed_data_tensor)

    reconstructed_correlation_matrices[i] = reconstructed_correlation_matrix

original_lower_triangles = torch.stack([torch.tensor(utils.extract_lower_triangle(mat.numpy())) for mat in original_correlation_matrices])
latent_lower_triangles = torch.stack([torch.tensor(utils.extract_lower_triangle(mat.numpy())) for mat in latent_correlation_matrices])
reconstructed_lower_triangles = torch.stack([torch.tensor(utils.extract_lower_triangle(mat.numpy())) for mat in reconstructed_correlation_matrices])

score_original, permutation_scores_original, p_value_original = permutation_test_score(
    classifier_original, original_lower_triangles, valid_labels, n_permutations=1000, n_jobs=-1, random_state=42,
)

score_reconstructed, permutation_scores_reconstructed, p_value_reconstructed = permutation_test_score(
    classifier_reconstructed, reconstructed_lower_triangles, valid_labels, n_permutations=1000, n_jobs=-1, random_state=42,
)

score_latent, permutation_scores_latent, p_value_latent = permutation_test_score(
    classifier_latent, latent_lower_triangles, valid_labels, n_permutations=1000, n_jobs=-1, random_state=42
)

print (f'Original data: (score) {score_original:.3f}, (p-value) {p_value_original:.3f}')
print (f'Reconstructed data: (score) {score_reconstructed:.3f}, (p-value) {p_value_reconstructed:.3f}')
print (f'Latent data: (score) {score_latent:.3f}, (p-value) {p_value_latent:.3f}')

Original data: (score) 0.860, (p-value) 0.003
Reconstructed data: (score) 0.893, (p-value) 0.001
Latent data: (score) 0.827, (p-value) 0.007
