<a href="https://colab.research.google.com/github/sudarshan-360/Machine-Learning/blob/main/GNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio
!pip install torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-$(python3 -c "import torch; print(torch.__version__.split('+')[0])").html
!pip install nilearn scikit-learn


Looking in links: https://data.pyg.org/whl/torch-2.8.0.html
Collecting torch-geometric
  Using cached torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting torch-scatter
  Using cached torch_scatter-2.1.2.tar.gz (108 kB)
  Installing build dependencies ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Getting requirements to build wheel ... [?25l[?25herror
[1;31merror[0m: [1msubprocess-exited-with-error[0m

[31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
[31m│[0m exit code: [1;36m1[0m
[31m╰─>[0m See above for output.

[1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.


In [None]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.8.0+cpu.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.8.0+cpu.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.8.0+cpu.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.8.0+cpu.html
!pip install torch-geometric


Looking in links: https://data.pyg.org/whl/torch-2.8.0+cpu.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcpu/torch_scatter-2.1.2%2Bpt28cpu-cp312-cp312-linux_x86_64.whl (645 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m645.6/645.6 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt28cpu
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cpu.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.8.0%2Bcpu/torch_sparse-0.6.18%2Bpt28cpu-cp312-cp312-linux_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt28cpu
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cpu.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from nilearn import datasets, input_data
from torch_geometric.utils import dense_to_sparse

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# 1. Fetch OASIS VBM dataset
oasis = datasets.fetch_oasis_vbm(n_subjects=150)

# 2. Fetch atlas (AAL)
aal = datasets.fetch_atlas_aal() # This atlas divides the brain into 116 anatomical regions
masker = input_data.NiftiLabelsMasker(labels_img=aal.maps, standardize=False) # Create a masker to extract region-wise features from MRI images
masker.fit() #Learns the mapping of voxels → regions

# 3. Extract region-wise features

node_features_list = []
for img_path in oasis.gray_matter_maps:
    # Extract average gray matter values for each of the 116 brain regions
    # Each region contains multiple voxels (3D pixels)

    # masker.transform averages the voxel values in a region → 1 value per region
    features = masker.transform(img_path).flatten() # flatten() converts the result into a 1D vector of size 116
    node_features_list.append(features)

print("Extracted features for", len(node_features_list), "subjects")
print("Each subject has", node_features_list[0].shape[0], "nodes")

# 4. Labels (CDR>=1 → positive)

# if CDR >= 1 → 1 (dementia), else 0 (healthy)
# CDR stands for Clinical Dementia Rating.
labels = np.array([1 if cdr >= 1 else 0 for cdr in oasis.ext_vars['cdr']])
print("Labels distribution:", np.bincount(labels))

# 5. Build PyG Data objects

data_list = []
for features, label in zip(node_features_list, labels):
    x = torch.tensor(features[:, None], dtype=torch.float)  # each region is a node
    n_nodes = x.shape[0]
    diff = features[:, None] - features[None, :]
    edge_weight_matrix = np.exp(-np.abs(diff)) #Converts Diff -> Similarity: Small difference → similarity ≈ 1, Large difference → similarity → closer to 0
    dense = torch.tensor(edge_weight_matrix, dtype=torch.float)
    # make sparse graph (needed for GCN) -  Adjacency matrix
    edge_index, edge_weight = dense_to_sparse(dense)
    y = torch.tensor([label], dtype=torch.long)
    data_list.append(Data(x=x, edge_index=edge_index, edge_weight=edge_weight, y=y))

print("Graphs built:", len(data_list))

# 6. Define GCN model
class GCN(nn.Module):
    def __init__(self, in_dim=1, hidden=64, dropout=0.5):
        super().__init__()
        # First GCN layer: converts input node features to 'hidden' dimension
        self.conv1 = GCNConv(in_dim, hidden)
        # Second GCN layer: keeps hidden dimension same, learns more graph patterns
        self.conv2 = GCNConv(hidden, hidden)
        # Fully connected layer to output 2 classes (CDR<1 or CDR>=1)
        self.lin   = nn.Linear(hidden, 2)
        # Dropout probability to prevent overfitting
        self.dropout = dropout

    def forward(self, x, edge_index, batch, edge_weight=None):
        x = F.relu(self.conv1(x, edge_index, edge_weight=edge_weight))
        #Regularisation to avoid overfitting
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv2(x, edge_index, edge_weight=edge_weight))
        # Pool node features into a single graph feature vector (mean of all nodes)
        x = global_mean_pool(x, batch)
        return self.lin(x)

# 7. Training loop function
def run_one_epoch(model, loader, optimizer=None, class_weights=None):
    is_train = optimizer is not None
    model.train() if is_train else model.eval()
    all_preds, all_probs, all_labels = [], [], []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device) if class_weights is not None else None)

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch.x, batch.edge_index, batch.batch, edge_weight=batch.edge_weight)
        loss = criterion(logits, batch.y.view(-1))

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss += float(loss.item()) * batch.num_graphs

        with torch.no_grad():
            probs = F.softmax(logits, dim=1)[:, 1].cpu().numpy()
            preds = logits.argmax(dim=1).cpu().numpy()
            labels_batch = batch.y.view(-1).cpu().numpy()
            all_probs.extend(probs.tolist())
            all_preds.extend(preds.tolist())
            all_labels.extend(labels_batch.tolist())

    avg_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(all_labels, all_preds)
    try:
        auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        auc = float('nan')
    return avg_loss, acc, auc

# 8. Class weights

# Convert labels to numpy array for calculation
labels_np = np.array(labels)

# Count how many positive (CDR>=1) and negative (CDR<1) samples
pos, neg = (labels_np == 1).sum(), (labels_np == 0).sum()

# Calculate class weights to balance training:
# If there are fewer negatives, give them more weight and vice versa

w0 = 1.0 if neg == 0 else (pos + neg) / (2.0 * neg)
w1 = 1.0 if pos == 0 else (pos + neg) / (2.0 * pos)
class_weights = torch.tensor([w0, w1], dtype=torch.float)
print(f"Class balance -> neg: {neg}, pos: {pos}, weights: {w0:.2f}, {w1:.2f}")

# 9. Cross-validation
def cross_validate_k(data_list, labels_np, k, epochs=200, patience=20):
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
    fold_metrics = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(np.arange(len(labels_np)), labels_np), start=1):
        train_list = [data_list[i] for i in train_idx]
        val_list   = [data_list[i] for i in val_idx]

        train_loader = DataLoader(train_list, batch_size=8, shuffle=True)
        val_loader   = DataLoader(val_list, batch_size=8, shuffle=False)

        in_dim = data_list[0].x.shape[1]
        model = GCN(in_dim=in_dim, hidden=64, dropout=0.5).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

        best_val_auc, patience_cnt = -1.0, 0
        for epoch in range(1, epochs + 1):
            tr_loss, tr_acc, tr_auc = run_one_epoch(model, train_loader, optimizer, class_weights)
            val_loss, val_acc, val_auc = run_one_epoch(model, val_loader, optimizer=None, class_weights=class_weights)

            if val_auc > best_val_auc:
                best_val_auc = val_auc
                patience_cnt = 0
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            else:
                patience_cnt += 1
            if patience_cnt >= patience:
                break

        model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
        val_loss, val_acc, val_auc = run_one_epoch(model, val_loader, optimizer=None, class_weights=class_weights)
        fold_metrics.append((val_acc, val_auc))

    accs = [m[0] for m in fold_metrics]
    aucs = [m[1] for m in fold_metrics]
    return np.nanmean(accs), np.nanstd(accs), np.nanmean(aucs), np.nanstd(aucs)

# 10. Evaluate with different k
candidate_ks = [2, 3, 5]
for k in candidate_ks:
    mean_acc, std_acc, mean_auc, std_auc = cross_validate_k(data_list, labels_np, k)
    print(f"k={k} -> ACC {mean_acc:.3f} ± {std_acc:.3f}, AUC {mean_auc:.3f} ± {std_auc:.3f}")


  from nilearn import datasets, input_data


Device: cpu


  aal = datasets.fetch_atlas_aal()


Extracted features for 150 subjects
Each subject has 116 nodes
Labels distribution: [140  10]
Graphs built: 150
Class balance -> neg: 140, pos: 10, weights: 0.54, 7.50




k=2 -> ACC 0.500 ± 0.433, AUC 0.667 ± 0.033




k=3 -> ACC 0.640 ± 0.410, AUC 0.514 ± 0.172




k=5 -> ACC 0.760 ± 0.347, AUC 0.525 ± 0.222
