In [30]:
import itertools
import numpy as np
import torch
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as geom_data_loader
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import os
import warnings
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
import torch_geometric.nn as geom_nn

CHECKPOINT_PATH = "./saved_models"
warnings.filterwarnings('ignore')
pl.seed_everything(42)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")

Global seed set to 42


In [31]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [32]:
gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GraphConv": geom_nn.GraphConv
}

In [33]:
class GNNModel(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, layer_name="GCN", dp_rate=0.1, **kwargs):
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]

        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                gnn_layer(in_channels=in_channels, out_channels=out_channels, **kwargs),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels, out_channels=c_out, **kwargs)]
        self.layers = nn.ModuleList(layers)

    def forward(self, x, edge_index):
        for l in self.layers:
            if isinstance(l, geom_nn.MessagePassing):
                x = l(x, edge_index)
            else:
                x = l(x)
        return x

In [34]:
class GraphGNNModel(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.5, **kwargs):
        super().__init__()
        self.GNN = GNNModel(c_in=c_in, c_hidden=c_hidden, c_out=c_hidden, **kwargs)
        self.head = nn.Sequential(
            nn.Dropout(dp_rate_linear),
            nn.Linear(c_hidden, c_out)
        )

    def forward(self, x, edge_index, batch_idx):
        x = self.GNN(x, edge_index)
        x = geom_nn.global_mean_pool(x, batch_idx) # Average pooling
        x = self.head(x)
        return x

In [35]:
class NodeLevelGNN(pl.LightningModule):
    def __init__(self, c_in, c_hidden, c_out, **model_kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.model = GraphGNNModel(c_in=c_in, c_hidden=c_hidden, c_out=c_out, **model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if c_out == 1 else nn.CrossEntropyLoss()

    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)

        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            data.y = data.y.float()
        else:
            preds = x.argmax(dim=-1)
        
        loss = self.loss_module(x, data.y)
        acc = (preds == data.y).sum().float() / preds.shape[0]
        return loss, acc

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc = self.forward(batch, mode="train")
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="val")
        self.log('val_acc', acc)

    def test_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="test")
        self.log('test_acc', acc)

In [36]:
def get_loader(df):
    sorted_df = df.sort_values(by="Subject_ID")
    node_features = sorted_df.drop(columns=['Subject_ID', 'Diagnosis', 'MMSE'])
    pd.set_option('mode.chained_assignment', None)
    x = node_features.to_numpy()
    labels = sorted_df[["Diagnosis"]]
    y = labels.to_numpy().squeeze()
    df["Subject_ID"] = df.reset_index().index
    diagnosis = df["MMSE"].unique()
    all_edges = np.array([], dtype=np.int32).reshape((0, 2))
    for diag in diagnosis:
        class_df = df[df["MMSE"] == diag]
        patients = class_df["Subject_ID"].values
        permutations = list(itertools.combinations(patients, 2))
        edges_source = [e[0] for e in permutations]
        edges_target = [e[1] for e in permutations]
        diag_edges = np.column_stack([edges_source, edges_target])
        all_edges = np.vstack([all_edges, diag_edges])
    edge_index = torch.tensor(all_edges.transpose(), dtype=torch.long)
    x_tensor = torch.from_numpy(x).float()
    y_tensor = torch.tensor(y, dtype=torch.long)
    batch = torch.zeros(x_tensor.shape[0], dtype=torch.long)  # Add batch index
    data = Data(x=x_tensor, edge_index=edge_index, y=y_tensor, batch=batch)
    data.num_node_features = x_tensor.shape[1]
    
    # Create a list of Data objects for each sample to enable proper batching
    data_list = [Data(x=torch.unsqueeze(x_tensor[i], 0), edge_index=edge_index, y=y_tensor[i].unsqueeze(0), batch=batch[i].unsqueeze(0)) for i in range(len(x_tensor))]
    loader = geom_data_loader(data_list, batch_size=32, shuffle=True)
    return loader, data.num_node_features, len(np.unique(y_tensor))

In [37]:
def train_node_classifier(df, **model_kwargs):
    pl.seed_everything(42)
    node_data_loader, num_node_features, num_classes = get_loader(df)
    root_dir = os.path.join(CHECKPOINT_PATH, "NodeLevelGNN")
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         accelerator="gpu" if torch.cuda.is_available() else "mps",
                         devices=1,
                         max_epochs=500,
                         enable_progress_bar=False)
    trainer.logger._default_hp_metric = None
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "NodeLevelGNN.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = NodeLevelGNN.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything()
        model = NodeLevelGNN(c_in=num_node_features, c_hidden=model_kwargs.pop('c_hidden'), c_out=num_classes, **model_kwargs)
        trainer.fit(model, node_data_loader, node_data_loader)
        model = NodeLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    test_result = trainer.test(model, node_data_loader, verbose=False)
    batch = next(iter(node_data_loader))
    batch = batch.to(model.device)
    _, train_acc = model.forward(batch, mode="train")
    _, val_acc = model.forward(batch, mode="val")
    result = {"train": train_acc.item(),
              "val": val_acc.item(),
              "test": test_result[0]['test_acc']}
    return model, result

In [38]:
df = pd.read_csv("./data/combined_data_subset_selected.csv")
df['MMSE'] = pd.cut(df['MMSE'], bins=[0, 0.25, 0.5, 0.75, 1], labels=['very low', 'low', 'medium', 'high'])

In [39]:
df

Unnamed: 0,Subject_ID,Age,Sex,MMSE,KAAgkL3gvR34_RUV7I,9kp4KA69O4w7qes13I,fj3VaHdbTqbufvj1KE,Bioe2e40i6vCeFKBFQ,xXp1ae5nin_1Xtf6YI,WddX7Uru50c_dQOwqo,...,TquC6ANSpC.booLjl4,oop4pqiNEK0ojYhSXg,WF_ziugieiGlHiZxeQ,oLUKIIIpcl4lAiIsak,0pKeK4gjT6L8oPop5I,NuSUe8JTu9QpP8mqKo,KrSBbx7g4uCNeihXqU,cs6LoBn7Cn0Uo1F3js,Diagnosis,intercept
0,DCR00025,78,0,medium,0.562187,0.331808,0.442700,0.166206,0.445339,0.580395,...,0.272891,0.914376,0.500827,0.463428,0.269195,0.706103,0.224217,0.412062,0,1
1,DCR00028,76,0,medium,0.364965,0.296833,0.293662,0.261484,0.179618,0.000000,...,0.532119,0.823368,0.207573,0.635436,0.250871,0.783603,0.085135,0.530617,0,1
2,DCR00031,76,0,medium,0.146130,0.125352,0.178385,0.353928,0.430409,0.468388,...,0.426785,0.873960,0.439801,0.789685,0.168997,0.908566,0.255001,0.556696,0,1
3,DCR00032,76,1,medium,0.309759,0.320664,0.313261,0.440292,0.525652,0.174752,...,0.408014,0.905325,0.165924,0.444634,0.214378,0.793530,0.296840,0.400357,0,1
4,DCR00037,76,1,medium,0.151492,0.345653,0.153025,0.376373,0.383213,0.339325,...,0.405914,0.822862,0.307349,0.778732,0.294955,0.702799,0.327312,0.447230,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
309,THSMCI033,77,1,high,0.219104,0.065413,0.150580,0.333030,0.421250,0.301582,...,0.307607,0.880877,0.285851,0.555279,0.364249,0.787017,0.420947,0.509086,2,1
310,THSMCI060,77,1,medium,0.587870,0.654465,0.626304,0.801285,0.727783,0.582114,...,0.763878,0.357892,0.702414,0.102165,0.769098,0.122849,0.624014,0.685961,2,1
311,THSMCI061,77,1,high,0.524451,0.739365,0.682732,0.820718,0.880425,0.711450,...,0.679939,0.489043,0.884783,0.346317,0.607627,0.376494,0.955124,0.867451,2,1
312,THSMCI064,77,0,high,0.468280,0.526385,0.560796,0.912711,0.933397,0.865602,...,0.693156,0.288844,0.829169,0.425817,0.530701,0.470798,0.970048,0.672451,2,1


In [None]:
model_kwargs = {
    "c_hidden": 64,
    "num_layers": 4,
    "layer_name": "GCN",
    "dp_rate": 0.1
}

model, result = train_node_classifier(df, **model_kwargs)

In [43]:
print(result)

{'train': 94.3057, 'val': 90.4321, 'test': 0.8913456}
