In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, confusion_matrix
from torch.utils.data import  TensorDataset, random_split
from torch_geometric.loader import DataLoader 
import numpy as np
from tqdm import tqdm
import os
import argparse
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import scipy
from torch_geometric.data import Data


In [2]:
def read_adni_data_normalized():
    argss = {
        'dataset_dir': '../../data/ADNI/'
    }
    args = argparse.Namespace(**argss)
    fMRI_path = args.dataset_dir + "fmri_signal.mat"
    ICV_path = args.dataset_dir + "ICV.mat"
    AGE_path = args.dataset_dir + "AGE.mat"
    DX_path = args.dataset_dir + "DX.mat"
    gender_path = args.dataset_dir + "gender.mat"
    fMRI_data_path = args.dataset_dir + "fMRIdata_ADNI2_ADNI3.csv"
    # participants_path = r'./data/ADNI/participants.tsv'

    # read fMRI_path
    fmri_data = scipy.io.loadmat(fMRI_path)['fmri_signal']
    fMRI_data = [fmri_data[i][0] for i in range(len(fmri_data))]

    # read ICV_path
    icv_data = scipy.io.loadmat(ICV_path)['ICV']
    ICV_data = pd.DataFrame([icv_data[i][0] for i in range(len(icv_data))])

    # read AGE_path
    age_data = scipy.io.loadmat(AGE_path)['AGE']
    AGE_data = pd.DataFrame([age_data[i][0] for i in range(len(age_data))])

    # read gender_path
    gender_data = scipy.io.loadmat(gender_path)['gender']
    gender_data = pd.DataFrame([gender_data[i][0] for i in range(len(gender_data))])

    # read DX_path
    dx_data = scipy.io.loadmat(DX_path)['DX']
    DX_data = pd.DataFrame([dx_data[i][0] for i in range(len(dx_data))])

    # for all above variable, add a df.insert(0, 'Image_ID', range(1, 1 + len(fMRI_data))) to add Image_ID column
    for df in [ICV_data, AGE_data, gender_data, DX_data]:
        df.insert(0, 'Image_ID', range(1, 1 + len(fMRI_data)))

    # give their column names, EstimatedTotalIntraCranialVol, Age, Gender, Diagnosis
    ICV_data.columns = ['Image_ID', 'EstimatedTotalIntraCranialVol']
    AGE_data.columns = ['Image_ID', 'Age']
    gender_data.columns = ['Image_ID', 'Gender']
    DX_data.columns = ['Image_ID', 'Diagnosis']
    Image_ID = ICV_data['Image_ID']

    data_dict = {
        'fMRI_data': fMRI_data,
        'ICV_data': ICV_data,
        'AGE_data': AGE_data,
        'gender_data': gender_data,
        'DX_data': DX_data
    }
    return data_dict

In [3]:
adni_data = read_adni_data_normalized()

In [4]:
import torch
import numpy as np
import pandas as pd

labels_file = '../../data/ADNI/y.csv'
labels_df = pd.read_csv(labels_file)
fMRI_data = adni_data['fMRI_data']
ICV_data = adni_data['ICV_data']
AGE_data = adni_data['AGE_data']
gender_data = adni_data['gender_data']
DX_data = adni_data['DX_data']

# only keep healthy control and AD. namely 2 and 0
labels_df = labels_df[labels_df['Diagnosis'].isin([2, 0])].reset_index(drop=True)
# change all 2 to 1
labels_df['Diagnosis'] = labels_df['Diagnosis'].replace({2: 1})

X = []
ys = []

# Traverse the labels_df by index i
for i in range(len(labels_df)):
    IID = labels_df['IID'][i]
    y = labels_df['Diagnosis'][i]
    
    # Get the fMRI data for the subject
    subject_data = fMRI_data[IID]
    
    # Replace zeros with ones
    subject_data[subject_data == 0] = 1
    
    # Z-score normalization for each column of each subject
    subject_data = (subject_data - np.mean(subject_data, axis=0)) / np.std(subject_data, axis=0)

    # each get first 140 time points
    subject_data = subject_data[:50, :]
    
    # Convert the subject data to a tensor and append it to the list X
    X.append(torch.tensor(subject_data, dtype=torch.float))
    
    # Append the label to the ys list
    ys.append(y)

# Stack all the tensors in X into a single tensor
X = torch.stack(X)  # Shape will be (num_subjects, num_features, ...)

# Convert ys to a tensor
ys = torch.tensor(ys, dtype=torch.long)

# X and ys are now tensors
print("X shape:", X.shape)
print("ys shape:", ys.shape)


X shape: torch.Size([579, 50, 100])
ys shape: torch.Size([579])


In [5]:

class TimeSeriesEncoder(nn.Module):
    def __init__(self, num_rois, time_steps, embedding_size):
        super(TimeSeriesEncoder, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=num_rois, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=embedding_size, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Permute to (batch_size, time_steps, num_rois)
        # print('x 1 step', x.size())  # Debugging output
        x = F.relu(self.conv1(x))  # Convolution expects 'in_channels = num_rois'
        # print('x 2 step', x.size())  # Debugging output
        x = self.pool(x)
        # print('x 3 step', x.size())  # Debugging output
        x = F.relu(self.conv2(x))  # Convolution reduces channels to 'embedding_size'
        # print('x 4 step', x.size())  # Debugging output
        x = self.pool(x)
        # print('x 5 step', x.size())  # Debugging output
        x = x.permute(0, 2, 1)  # Permute back to (batch_size, reduced_num_rois, embedding_size)
        # print('x 6 step', x.size())  # Debugging output
        return x


# Graph Generator
class GraphGenerator(nn.Module):
    def __init__(self, embedding_size, num_rois):
        super(GraphGenerator, self).__init__()
        self.fc = nn.Linear(embedding_size, num_rois)

    def forward(self, x):
        hA = F.softmax(self.fc(x), dim=-1)
        A = torch.bmm(hA, hA.transpose(1, 2))
        return A

# Graph Predictor using GCN
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_channels, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.fc = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return x

# End-to-end model
class FBNetGen(nn.Module):
    def __init__(self, num_rois, time_steps, embedding_size, hidden_channels, num_classes):
        super(FBNetGen, self).__init__()
        self.encoder = TimeSeriesEncoder(num_rois, time_steps, embedding_size)
        self.graph_generator = GraphGenerator(embedding_size, num_rois)
        self.gcn = GCN(embedding_size, hidden_channels, num_classes)

    def forward(self, x):
        # print('x 1', x.shape)
        x = self.encoder(x)
        # print('x 2', x.shape)
        A = self.graph_generator(x)
        # print('A', A.shape)
        node_features = x.reshape(-1, x.size(2))  # Flatten to (num_samples * num_rois, embedding_size)
        num_nodes = node_features.size(0)
        edge_index = torch.randint(0, num_nodes, (2, 500)).to(x.device)  # Ensure edge_index matches the number of nodes

        # Adjust the batch tensor to match the flattened node_features
        batch = torch.repeat_interleave(torch.arange(x.size(0)), x.size(1)).to(x.device)

        output = self.gcn(node_features, edge_index, batch)
        return output


In [6]:
# from utils import *
# X,y = generate_complex_patterned_data(1000, 10, 120)
# dataset = TensorDataset(X, y)
# batch_size = 32
# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
dataset = TensorDataset(X, ys)

In [10]:

# Hyperparameters
# num_samples = 1000
num_rois = 100
time_steps = 50
embedding_size = 16
hidden_channels = 64
num_classes = 2
batch_size = 32
epochs = 200
learning_rate = 0.0001

# Generate patterned data

train_data, test_data = train_test_split(dataset, test_size=0.2)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)


# Model, Loss, Optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FBNetGen(num_rois, time_steps, embedding_size, hidden_channels, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop with tqdm
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    # for X_batch, y_batch in train_loader:
    # with tqdm(train_loader, unit="batch") as tepoch:
    #     for X_batch, y_batch in tepoch:
            # tepoch.set_description(f"Epoch {epoch+1}/{epochs}")
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        # tepoch.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(train_loader)
    # print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}') # with flush
    if (epoch+1) % 20 == 0:
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

# Evaluation with tqdm
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    with tqdm(test_loader, unit="batch") as ttest:
        for X_batch, y_batch in ttest:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            _, preds = torch.max(outputs, 1)
            all_preds.append(preds.cpu())
            all_labels.append(y_batch)

all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

# Metrics Calculation
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
auroc = roc_auc_score(all_labels, all_preds)
tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)

print(f'Accuracy: {accuracy:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'AUROC: {auroc:.4f}')
print(f'Sensitivity: {sensitivity:.4f}')
print(f'Specificity: {specificity:.4f}')


Epoch 20/200, Loss: 0.5413
Epoch 40/200, Loss: 0.5355
Epoch 60/200, Loss: 0.5244
Epoch 80/200, Loss: 0.5283
Epoch 100/200, Loss: 0.5047
Epoch 120/200, Loss: 0.4390
Epoch 140/200, Loss: 0.3906
Epoch 160/200, Loss: 0.3188
Epoch 180/200, Loss: 0.2883
Epoch 200/200, Loss: 0.3005


100%|██████████| 4/4 [00:00<00:00, 71.13batch/s]

Accuracy: 0.6207
F1 Score: 0.7556
AUROC: 0.4702
Sensitivity: 0.8193
Specificity: 0.1212





In [None]:

# Training loop with tqdm
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    with tqdm(train_loader, unit="batch") as tepoch:
        # for X_batch, y_batch in tepoch:
        for i in tepoch:
            print(i)
            X_batch, y_batch, batch, ptr = i
            # tepoch.set_description(f"Epoch {epoch+1}/{epochs}")
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            tepoch.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

# Evaluation with tqdm
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    with tqdm(test_loader, unit="batch") as ttest:
        for X_batch, y_batch in ttest:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            _, preds = torch.max(outputs, 1)
            all_preds.append(preds.cpu())
            all_labels.append(y_batch)

all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)

# Metrics Calculation
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds)
auroc = roc_auc_score(all_labels, all_preds)
tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)

print(f'Accuracy: {accuracy:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'AUROC: {auroc:.4f}')
print(f'Sensitivity: {sensitivity:.4f}')
print(f'Specificity: {specificity:.4f}')


  0%|          | 0/15 [00:00<?, ?batch/s]


[tensor([[[-1.0790, -1.6929, -1.1463,  ...,  0.4162, -0.6378,  0.1779],
         [-0.1744, -0.0126, -0.4268,  ...,  1.0236, -0.1680,  0.6776],
         [ 0.5021,  0.7596,  0.4689,  ...,  0.7910,  0.5980,  0.7352],
         ...,
         [ 0.3452, -0.1843, -1.2420,  ...,  0.5222, -1.0048,  1.0349],
         [-0.2712,  0.1095, -0.3691,  ...,  0.8000,  0.1274,  1.2421],
         [-0.6005,  0.2217,  0.5570,  ...,  1.0520,  2.0189,  1.3398]],

        [[ 0.6050, -0.0868,  0.6044,  ...,  0.2025, -0.1426,  0.3498],
         [-2.0480, -0.4314, -1.0616,  ...,  0.1661, -0.0222,  0.1519],
         [-2.0345, -0.5252, -1.0195,  ...,  0.0949, -0.2922,  0.0774],
         ...,
         [ 0.4958, -0.4533, -0.8365,  ..., -1.0103, -0.3615, -0.8292],
         [ 1.5971,  0.2116, -0.5327,  ..., -1.2389,  0.6867, -0.5208],
         [ 0.5321, -0.0843, -0.6049,  ..., -1.4190,  0.9359, -0.6442]],

        [[ 1.2948, -2.7925, -1.1588,  ..., -2.0441, -0.6034, -2.4449],
         [-1.5527,  1.7589,  0.5280,  ...,  

ValueError: not enough values to unpack (expected 4, got 2)