In [1]:
import os
import scipy.io
import numpy as np
import random
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import dense_to_sparse
import torch.nn.functional as func
import torch.optim as optim
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn import GATConv, ChebConv
import torch.nn as nn
from collections import Counter
import os.path as osp
import csv
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit

In [2]:
seed=89
atlas_name= "AAL"
dataset_name = "MDDvHC"


if atlas_name == "AAL":
    start = 0
    end = 116
elif atlas_name == "Craddock":
    start = 228
    end = 428
elif atlas_name == "Dosenbach":
    start = 1408
    end = 1568
else:
    exit()

In [None]:
# Function to set seed
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

# Environment variables for reproducibility
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Set your seed
set_seed(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)


In [4]:
from sklearn.preprocessing import StandardScaler

def normalize(matrix):
    scaler = StandardScaler()
    normalized_matrix = scaler.fit_transform(matrix)
    return normalized_matrix

In [None]:
import numpy as np
from scipy.sparse import coo_matrix
import torch
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import InMemoryDataset, Data

def fisher_z_transform(correlation_matrix, epsilon=1e-5):
    return 0.5 * np.log((1 + correlation_matrix) / (1 - correlation_matrix + epsilon))

def to_tensor(X_featgraph, X_adjgraph, Y):
    datalist = []
    
    for i in range(len(Y)):
        ty = Y[i]

        y = torch.tensor([ty]).long()

        adjacency = X_adjgraph[i]
        feature = X_featgraph[i]

        x = torch.from_numpy(feature).float()
        adj= adjacency
        adj = torch.from_numpy(adj).float()
        edge_index, edge_attr = dense_to_sparse(adj)
        
        datalist.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))

    return datalist


def compute_KNN_graph(matrix, k_degree=10):
    """ Calculate the adjacency matrix from the connectivity matrix."""

    matrix = np.abs(matrix)
    idx = np.argsort(-matrix)[:, 0:k_degree]
    matrix.sort()
    matrix = matrix[:, ::-1]
    matrix = matrix[:, 0:k_degree]

    A = adjacency(matrix, idx).astype(np.float32)

    return A


def adjacency(dist, idx):

    m, k = dist.shape
    assert m, k == idx.shape
    assert dist.min() >= 0

    # Weight matrix.
    I = np.arange(0, m).repeat(k)
    J = idx.reshape(m * k)
    V = dist.reshape(m * k)
    W = coo_matrix((V, (I, J)), shape=(m, m))

    # No self-connections.
    W.setdiag(0)

    # Non-directed graph.
    bigger = W.T > W
    W = W - W.multiply(bigger) + W.T.multiply(bigger)

    return W.todense()

In [None]:
#sliding window
def create_graph_sliding_window_demographics(X, D, Y, start, end, region=True):
    S = 30 # Sliding Step
    T = 60 # Window Size

    X_adjgraph=[]
    X_featgraph = []
    Y_list = []
    num_samples_per_subject = []

    for i in range(len(Y)):
        bold_matrix = X[i]
        n = bold_matrix.shape[0]
        demog = D[i]
        demog_expanded = np.expand_dims(demog, axis=0) 
        demog_expanded = np.repeat(demog_expanded, n, axis=0)
        
        temp_y = Y[i]

        num_rows, num_cols = bold_matrix.shape
        num_samples = 0
        for start_idx in range(0, num_cols - T + 1, S):
            end_idx = start_idx + T
            if end_idx <= num_cols:
                
                if region == True:
                    window_data =  bold_matrix[:, start_idx:end_idx]    #RxR
                else:
                    window_data =  np.transpose(bold_matrix[:, start_idx:end_idx])  #TxT

                window_data1 = np.corrcoef(window_data)
                correlation_matrix_fisher = fisher_z_transform(window_data1)
                correlation_matrix_fisher = np.around(correlation_matrix_fisher, 8)     #upto 8 decimal points

                result_matrix = np.concatenate((correlation_matrix_fisher, demog_expanded), axis=1)

                
                knn_graph = compute_KNN_graph(correlation_matrix_fisher)

                if region == True:
                    X_featgraph.append(result_matrix)
                else:
                    X_featgraph.append(window_data)
                    
                X_adjgraph.append(knn_graph)
                Y_list.append(temp_y)
                num_samples = num_samples+1
            
        num_samples_per_subject.append(num_samples)


    return X_featgraph, X_adjgraph, Y_list, num_samples_per_subject

In [None]:
def create_graph_demographics(X, D, Y, start, end, region=True):
    X_adjgraph=[]
    X_featgraph = []

    for i in range(len(Y)):
        if region == True:
            bold_matrix = X[i]  #RxR
        else:
            bold_matrix = np.transpose(X[i]) #TxT

        n= bold_matrix.shape[0]
       
        demog_expanded = np.expand_dims(D[i], axis=0)
        demog_expanded = np.repeat(demog_expanded, n, axis=0) 
        
        window_data1 = np.corrcoef(bold_matrix)
        correlation_matrix_fisher = fisher_z_transform(window_data1)
        correlation_matrix_fisher = np.around(correlation_matrix_fisher, 8)
        knn_graph = compute_KNN_graph(correlation_matrix_fisher)
        result_matrix = np.concatenate((correlation_matrix_fisher, demog_expanded), axis=1)

        if region == True:
            X_featgraph.append(result_matrix)
        else:
            X_featgraph.append(bold_matrix)
            
        X_adjgraph.append(knn_graph)

    return X_featgraph, X_adjgraph, Y

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import ChebConv, global_mean_pool,GATConv



class SkipConnModel(nn.Module):
    def __init__(self, num_features_R, num_classes, k_order, dropout_prob=0.5):
        super(SkipConnModel, self).__init__()
        self.dropout_prob = dropout_prob
        
        self.num_layers = 6
        self.convs = nn.ModuleList() 
        self.bns = nn.ModuleList()   
        
        self.convs.append(ChebConv(num_features_R, 128, K=3, normalization='sym'))
        self.bns.append(nn.BatchNorm1d(128))
        
        self.convs.append(ChebConv(128, 128, K=3, normalization='sym'))
        self.bns.append(nn.BatchNorm1d(128))

        self.convs.append(ChebConv(128, 128, K=3, normalization='sym'))
        self.bns.append(nn.BatchNorm1d(128))
        
        
        self.out_fc = nn.Linear(128, num_classes)
        self.weights = torch.nn.Parameter(torch.randn(len(self.convs)))
        

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        self.out_fc.reset_parameters()
        torch.nn.init.normal_(self.weights)

    def forward(self, data_R):
        x1, edge_index1, edge_attr1 = data_R.x, data_R.edge_index, data_R.edge_attr
        batch1 = data_R.batch
        
        layer_out1 = []  
        x1 = self.convs[0](x1, edge_index1, edge_attr1)
        x1 = self.bns[0](x1)
        x1 = F.relu(x1, inplace=True)
        layer_out1.append(x1)
        x1 = F.dropout(x1, p=self.dropout_prob, training=self.training)
        
                          
        x1 = self.convs[1](x1, edge_index1, edge_attr1)
        x1 = self.bns[1](x1)
        x1 = F.relu(x1, inplace=True)
        x1 = x1 + 0.8 * layer_out1[0]
        layer_out1.append(x1)
        x1 = F.dropout(x1, p=self.dropout_prob, training=self.training) 
        
                          
        x1 = self.convs[2](x1, edge_index1, edge_attr1)
        x1 = self.bns[2](x1)
        x1 = F.relu(x1, inplace=True)
        x1 = x1 + 0.8 * layer_out1[1]
        layer_out1.append(x1)
       
        weight = F.softmax(self.weights, dim=0)
        weighted_outs = [layer_out1[i] * weight[i] for i in range(len(layer_out1))]
        emb = sum(weighted_outs)
        pooled_emb = global_mean_pool(emb, batch1)
        x = self.out_fc(pooled_emb)
        
        return x, pooled_emb


In [None]:
def DECOV(embeddings):
    embeddings_t = embeddings.T
    C = torch.cov(embeddings_t)
    C_fro_norm = torch.norm(C, p='fro')
    diag_elements = torch.diag(C,0)
    C2_l2norm_diag = torch.norm(diag_elements)
    L_DECOV = (C_fro_norm ** 2) - (C2_l2norm_diag ** 2)

    return L_DECOV


In [None]:
def GCN_train(loader):
    model.train()

    pred = []
    label = []
    loss_all = 0
    alpha= 1
    beta = 1e-8

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, pooled = model(data)

        pooled = pooled.to('cpu')
        loss_decov = DECOV(pooled)
        loss_decov = loss_decov.to(device)
        pooled = pooled.to(device)

        loss_ce = func.cross_entropy(output, data.y)
        loss = alpha * loss_ce + beta *loss_decov
        loss.backward()
        loss_all += data.num_graphs * loss.item()
        optimizer.step()

        pred.append(func.softmax(output, dim=1).max(dim=1)[1])
        label.append(data.y)

    y_pred = torch.cat(pred, dim=0).cpu().detach().numpy()
    y_true = torch.cat(label, dim=0).cpu().detach().numpy()
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    epoch_sen = tp / (tp + fn)
    epoch_spe = tn / (tn + fp)
    epoch_acc = (tn + tp) / (tn + tp + fn + fp)
    f1 = f1_score(y_true, y_pred)

    return epoch_sen, epoch_spe, epoch_acc, f1, loss_all/len(loader)


def GCN_test(loader):
    model.eval()

    pred = []
    scores = []
    label = []
    loss_all = 0
    for data in loader:
        data = data.to(device)
        output , pooled= model(data)

        loss_ce = func.cross_entropy(output, data.y)
        loss = loss_ce
        loss_all += data.num_graphs * loss.item()

        softmax_output = func.softmax(output, dim=1)
        scores.append(softmax_output[:, 1])
        pred.append(softmax_output.max(dim=1)[1])
        label.append(data.y)

    y_pred = torch.cat(pred, dim=0).cpu().detach().numpy()
    y_scores = torch.cat(scores, dim=0).cpu().detach().numpy()
    y_true = torch.cat(label, dim=0).cpu().detach().numpy()

    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    epoch_sen = tp / (tp + fn)
    epoch_spe = tn / (tn + fp)
    epoch_acc = (tn + tp) / (tn + tp + fn + fp)
    epoch_f1 = f1_score(y_true, y_pred)
    epoch_auc = roc_auc_score(y_true, y_scores)

    return epoch_sen, epoch_spe, epoch_acc, epoch_f1, epoch_auc,loss_all / len(loader)

In [None]:
X_new =  np.load(f'./{dataset_name}/{atlas_name}/X.npz')
X_loaded = [X_new[key] for key in X_new.files]
X_loaded = [normalize(matrix) for matrix in X_loaded]
print(X_loaded[0].shape)
Y_loaded = np.load(f'./{dataset_name}/{atlas_name}/Y.npy')
print(Y_loaded)
demographic_data = pd.read_csv(f'./{dataset_name}/{atlas_name}/demographics_data.csv')
demographics = demographic_data[['Age', 'Edu', 'Sex']].values


In [None]:
X_male, Y_male = [], []
X_female, Y_female = [], []
demographics_male, demographics_female = [], []

for i, sex in enumerate(demographic_data['Sex']):
    if sex == 0:  
        X_male.append(X_loaded[i])
        Y_male.append(Y_loaded[i])
        demographics_male.append(demographics[i])
    elif sex == 1:  
        X_female.append(X_loaded[i])
        Y_female.append(Y_loaded[i])
        demographics_female.append(demographics[i])

Y_male = np.array(Y_male)
Y_female = np.array(Y_female)
demographics_male = np.array(demographics_male)
demographics_female = np.array(demographics_female)

In [None]:
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
device = torch.device('cpu')
eval_metrics2 = np.zeros((skf.n_splits, 5))
eval_metrics3 = np.zeros((skf.n_splits, 5))

dataset = X_loaded
labels = Y_loaded

In [None]:
dataset[0].shape

In [None]:
for n_fold, (train_val, test) in enumerate(skf.split(dataset, labels)):

    model = SkipConnModel((end - start + 3), 2,3).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

    train_val_dataset = [dataset[i] for i in train_val]
    test_dataset = [dataset[i] for i in test]
    train_val_labels = labels[train_val]
    test_labels = labels[test]
    train_val_demographics = demographics[train_val]
    test_demographics = demographics[test]

    train_val_index = np.arange(len(train_val_dataset))
    train_idx, val_idx, _, _ = train_test_split(
        train_val_index,
        train_val_labels,
        test_size=0.1,
        shuffle=True,
        stratify=train_val_labels, 
        random_state=seed
    )

    train_dataset = [train_val_dataset[i] for i in train_idx]
    val_dataset = [train_val_dataset[i] for i in val_idx]
    train_labels = [train_val_labels[i] for i in train_idx]
    val_labels = [train_val_labels[i] for i in val_idx]
    train_demographics = [train_val_demographics[i] for i in train_idx]
    val_demographics = [train_val_demographics[i] for i in val_idx]

    X_train_featgraph, X_train_adjgraph, Y_train, _ = create_graph_sliding_window_demographics(train_dataset, train_demographics, train_labels, start, end, region=True)
    X_val_featgraph, X_val_adjgraph, Y_val = create_graph_demographics(val_dataset, val_demographics, val_labels, start, end, region=True)
    X_test_featgraph, X_test_adjgraph, Y_test = create_graph_demographics(test_dataset, test_demographics, test_labels, start, end, region=True)

    X_train_datalist = to_tensor(X_train_featgraph, X_train_adjgraph, Y_train)
    X_val_datalist = to_tensor(X_val_featgraph, X_val_adjgraph, Y_val)
    X_test_datalist = to_tensor(X_test_featgraph, X_test_adjgraph, Y_test)
    
           
    train_loader = DataLoader(X_train_datalist, batch_size=32, shuffle=True, num_workers=0, worker_init_fn=seed_worker, generator=g)
    val_loader = DataLoader(X_val_datalist, batch_size=32, shuffle=True, num_workers=0, worker_init_fn=seed_worker, generator=g)
    test_loader = DataLoader(X_test_datalist, batch_size=32, shuffle=True, num_workers=0, worker_init_fn=seed_worker, generator=g)

    best_test_acc2 = 0
    best_test_f12 = 0
    best_test_sen2 = 0
    best_test_spe2 = 0
    best_test_auc2 = 0

    best_val_acc3 = 0
    best_test_acc3 = 0
    best_test_f13 = 0
    best_test_sen3 = 0
    best_test_spe3 = 0
    best_test_auc3 = 0

    for epoch in range(50):
        _, _, _, _, t_loss = GCN_train(train_loader)
        val_sen, val_spe, val_acc,val_f1, val_auc, v_loss = GCN_test(val_loader)
        test_sen, test_spe, test_acc,test_f1, test_auc, _ = GCN_test(test_loader)
        
        
        if test_acc > best_test_acc2: 
            best_test_acc2 = test_acc
            best_test_f12 = test_f1
            best_test_sen2, best_test_spe2,best_test_auc2 = test_sen, test_spe, test_auc
        
        if val_acc > best_val_acc3:
            best_val_acc3 = val_acc
            best_test_f13 = test_f1
            best_test_sen3, best_test_spe3, best_test_acc3, best_test_auc3 = test_sen, test_spe, test_acc, test_auc
            
        
        print('CV: {:03d}, Epoch: {:03d}, Val Loss: {:.5f}, Val ACC: {:.5f},Val AUC: {:.5f}, Test ACC: {:.5f}, Test F1: {:.5f}, TEST SPE: {:.5f}, '
                  'TEST SEN: {:.5f}, TEST AUC: {:.5f}'.format(n_fold +1, epoch + 1, v_loss, val_acc, val_auc, test_acc,test_f1,
                                        test_spe,test_sen, test_auc))
    
   

    eval_metrics2[n_fold, 0] = best_test_sen2
    eval_metrics2[n_fold, 1] = best_test_spe2
    eval_metrics2[n_fold, 2] = best_test_acc2
    eval_metrics2[n_fold, 3] = best_test_f12
    eval_metrics2[n_fold, 4] = best_test_auc2

    eval_metrics3[n_fold, 0] = best_test_sen3
    eval_metrics3[n_fold, 1] = best_test_spe3
    eval_metrics3[n_fold, 2] = best_test_acc3
    eval_metrics3[n_fold, 3] = best_test_f13
    eval_metrics3[n_fold, 4] = best_test_auc3
        


In [None]:
print("Corresponding to Best test_acc")
eval_df2 = pd.DataFrame(eval_metrics2)
eval_df2.columns = ['SEN', 'SPE', 'ACC','F1', 'AUC-ROC']
eval_df2.index = ['Fold_%02i' % (i + 1) for i in range(skf.n_splits)]
print(eval_df2)
print('Average Sensitivity: %.4f±%.4f' % (eval_metrics2[:, 0].mean(), eval_metrics2[:, 0].std()))
print('Average Specificity: %.4f±%.4f' % (eval_metrics2[:, 1].mean(), eval_metrics2[:, 1].std()))
print('Average Accuracy: %.4f±%.4f' % (eval_metrics2[:, 2].mean(), eval_metrics2[:, 2].std()))
print('Average F1: %.4f±%.4f' % (eval_metrics2[:, 3].mean(), eval_metrics2[:, 3].std()))
print('Average AUC-ROC: %.4f±%.4f' % (eval_metrics2[:, 4].mean(), eval_metrics2[:, 4].std()))

print("\nCorresponding to Maximum val_acc")
eval_df3 = pd.DataFrame(eval_metrics3)
eval_df3.columns = ['SEN', 'SPE', 'ACC','F1', 'AUC-ROC']
eval_df3.index = ['Fold_%02i' % (i + 1) for i in range(skf.n_splits)]
print(eval_df3)
print('Average Sensitivity: %.4f±%.4f' % (eval_metrics3[:, 0].mean(), eval_metrics3[:, 0].std()))
print('Average Specificity: %.4f±%.4f' % (eval_metrics3[:, 1].mean(), eval_metrics3[:, 1].std()))
print('Average Accuracy: %.4f±%.4f' % (eval_metrics3[:, 2].mean(), eval_metrics3[:, 2].std()))
print('Average F1: %.4f±%.4f' % (eval_metrics3[:, 3].mean(), eval_metrics3[:, 3].std()))
print('Average AUC-ROC: %.4f±%.4f' % (eval_metrics3[:, 4].mean(), eval_metrics3[:, 4].std()))

## Gradcam calculation

In [None]:
def calculate_grad_cam(model, data):
    data = data.to(device)
    model.zero_grad()
    data.x.requires_grad = True  

    output, _ = model(data) 

    scores = F.softmax(output, dim=1)
    print(scores.shape)
    print(scores[0])

    scores_mdd = scores[:, 1]          
    print('scores = ',scores_mdd.shape)

    scores_mdd.backward(torch.ones_like(scores_mdd))  

    gradients = data.x.grad
    print('gradients = ',gradients.shape)  

    # Calculate Grad-CAM values 
    grad_cam = gradients * data.x
    print('data.x = ',data.x.shape)  
    print('gradcam = ',grad_cam.shape)  

    contribution = torch.mean(grad_cam, dim=1)
    
    print('contribution = ',contribution.shape)  

    return scores, grad_cam, contribution