<a href="https://colab.research.google.com/github/udayameister/sc-DGNN/blob/main/sc_DGNN_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#import packages
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data

In [None]:
# Define the SchizophreniaDataset class to load and preprocess the data into PyG's Data object
class SchizophreniaDataset(torch.utils.data.Dataset):
    def __init__(self, ...):  # Implement your data loading logic here
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass


In [None]:
# Define Edge Processing module
class EdgeProcessingModule(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(EdgeProcessingModule, self).__init__(aggr='mean')
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j, edge_index):
        return self.lin(x_j - x_i)

In [None]:
# Define Node Processing module
class NodeProcessingModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(NodeProcessingModule, self).__init__()
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        return self.lin(x)

In [None]:
# Define Enhanced Graph Processing module
class EnhancedGraphProcessingModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EnhancedGraphProcessingModule, self).__init__()
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        return self.lin(x)


In [None]:
# Define Pooling Strategy module
class PoolingModule(nn.Module):
    def __init__(self):
        super(PoolingModule, self).__init__()

    def forward(self, x, batch):
        return torch.mean(x, dim=0)

In [None]:
# Define Readout module
class ReadoutModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ReadoutModule, self).__init__()
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        return self.lin(x)

In [None]:
# Define the Deep Graph Neural Network model
class SchizophreniaPredictionGNN(nn.Module):
    def __init__(self, in_channels, edge_hidden_channels, node_hidden_channels,
                 graph_hidden_channels, pooling_channels, num_classes):
        super(SchizophreniaPredictionGNN, self).__init__()
        self.edge_module = EdgeProcessingModule(in_channels, edge_hidden_channels)
        self.node_module = NodeProcessingModule(edge_hidden_channels, node_hidden_channels)
        self.graph_module = EnhancedGraphProcessingModule(node_hidden_channels, graph_hidden_channels)
        self.pooling_module = PoolingModule()
        self.readout_module = ReadoutModule(graph_hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Edge Processing
        edge_feats = self.edge_module(x, edge_index)

        # Node Processing
        node_feats = self.node_module(edge_feats)

        # Enhanced Graph Processing
        graph_feats = self.graph_module(node_feats)

        # Pooling Strategy
        pooled_feats = self.pooling_module(graph_feats, batch)

        # Readout
        output = self.readout_module(pooled_feats)
        return output


In [None]:
# Load your data into PyG's Data object
# data = SchizophreniaDataset(...)

# Split the data into training and test sets
# train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# Instantiate the model
model = SchizophreniaPredictionGNN(in_channels=your_input_feature_dim,
                                    edge_hidden_channels=64,
                                    node_hidden_channels=64,
                                    graph_hidden_channels=64,
                                    pooling_channels=64,
                                    num_classes=2)  # 2 classes: Schizophrenia vs Non-Schizophrenia


In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    output = model(data)
    labels = data.y
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()


In [None]:
# Evaluation
model.eval()
with torch.no_grad():
    output = model(data)
    predicted_labels = torch.argmax(output, dim=1)
    true_labels = data.y
    accuracy = accuracy_score(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels, average='binary')  # Assuming binary classification
    auc = roc_auc_score(true_labels, predicted_labels)

print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, AUC: {auc:.4f}')