In [2]:
import numpy as np
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import nibabel as nib
import os
import re
import xml.etree.ElementTree as ET
from tqdm import tqdm
from scipy.stats import entropy
from skimage.exposure import histogram
import cv2

In [3]:
df = pd.read_csv("/kaggle/input/adni-processed/ADNI1_Complete_1Yr_1.5T_6_20_2025.csv")

In [4]:
import os

def get_all_i_folders_kaggle(root_path):
    """
    Recursively find all folders that start with 'I' and return their names.
    """
    i_folders = []
    for dirpath, dirnames, filenames in os.walk(root_path):
        for dirname in dirnames:
            if dirname.startswith('I') and dirname[1:].isdigit():
                i_folders.append(dirname)
    return i_folders

# Usage
kaggle_root = "/kaggle/input/adni-processed/ADNI1_Processed/ADNI1_Processed"
i_folders = get_all_i_folders_kaggle(kaggle_root)

print(f"Found {len(i_folders)} I folders:")

Found 459 I folders:


In [5]:
df = pd.read_csv("/kaggle/input/adni-processed/ADNI1_Complete_1Yr_1.5T_6_20_2025.csv")
folder_path = "/kaggle/input/adni-processed/ADNI1_Processed/ADNI1_Processed"
paths = []
folder_2 = "/kaggle/input/metadata"

for root_dir, dirs, files in tqdm(os.walk(folder_path), desc="Scanning files"):
    for file in files:
        if file.endswith(".nii") or file.endswith(".nii.gz"):
            final_path = os.path.join(root_dir, file)
            rel_path = os.path.relpath(final_path, folder_path)
            
            # Extract subject and image ID using regex
            match = re.search(r'_S(\d+)_I(\d+)', file)
            if match:
                s_num = match.group(1)
                i_num = match.group(2)
                new_filename = f"S{s_num}I{i_num}.xml"

                nii_dir = os.path.dirname(rel_path)
                xml_path = os.path.join(folder_2, new_filename)

                if os.path.exists(xml_path):
                    try:
                        tree = ET.parse(xml_path)
                        xml_root = tree.getroot()
                        id = xml_root[3].attrib.get('uid', None)

                        if id:
                            row = df[df['Image Data ID'].astype(str).str.strip() == str(id).strip()]
                            if not row.empty:
                                label = row.iloc[0, 2]
                                paths.append((label, final_path))
                            else:
                                print(f"[!] ID {id} not found in DataFrame")
                        else:
                            print(f"[!] UID not found in XML: {xml_path}")
                    except Exception as e:
                        print(f"[!] Failed to parse XML: {xml_path} — {e}")
                else:
                    print(f"[!] XML file missing: {xml_path}")
            else:
                print(f"[!] Failed to extract subject/image ID from: {file}")

filtered_paths = []
for path in paths:
    if path[0] in ('AD', 'CN'):
        filtered_paths.append(path)

Scanning files: 815it [00:04, 190.99it/s]


In [6]:
X = []
y = []

In [7]:
def center_crop(image, crop_size=128):
    h, w = image.shape
    if h < crop_size or w < crop_size:
        return None
    top = (h - crop_size) // 2
    left = (w - crop_size) // 2
    return image[top:top+crop_size, left:left+crop_size]

def image_entropy(img):
    hist, _ = histogram(img)
    hist = hist / np.sum(hist)
    return entropy(hist, base=2)

Taking only axial slices

In [8]:
N = 25  # number of entropy-based slices
crop_size = 256

for label, path in tqdm(filtered_paths):
    scan = nib.load(path)
    data = scan.get_fdata()
    label = 0 if label == 'AD' else 1

    slice_info = []


    for i in range(data.shape[2]):  # only axial
        slice_ = data[:, :, i]
        # Crop and skip empty ones
        cropped = center_crop(slice_, crop_size=crop_size)
        if cropped is None:
            continue

        # Compute entropy
        ent = image_entropy(cropped)
        slice_info.append((ent, cropped))

    # for axis in [0, 1, 2]:  # optionally restrict to one axis
    #     for i in range(data.shape[axis]):
    #         # Extract 2D slice along the given axis
    #         if axis == 0:
    #             slice_ = data[i, :, :]
    #         elif axis == 1:
    #             slice_ = data[:, i, :]
    #         else:
    #             slice_ = data[:, :, i]

    #         # Crop and skip empty ones
    #         cropped = center_crop(slice_, crop_size=crop_size)
    #         if cropped is None:
    #             continue

    #         # Compute entropy
    #         ent = image_entropy(cropped)
    #         slice_info.append((ent, cropped))

    # Sort slices by entropy
    slice_info.sort(reverse=True, key=lambda x: x[0])
    top_slices = slice_info[:N]

    # If not enough valid slices, skip this subject
    if len(top_slices) < N:
        print(f"[!] Skipped subject: only {len(top_slices)} slices")
        continue

    # Build per-subject volume
    subject_volume = [s[1][..., np.newaxis] for s in top_slices]  
    subject_volume = np.stack(subject_volume, axis=0)  

    X.append(subject_volume)
    y.append(label)

100%|██████████| 234/234 [03:47<00:00,  1.03it/s]


In [9]:
len(X)

234

In [10]:
X[0].shape

(25, 256, 256, 1)

In [11]:
import numpy as np
from sklearn.model_selection import train_test_split

X_flat = []
y_flat = []


for i in range(len(X)):
    volume = X[i]  
    label = y[i]
    for j in range(25):
        slice_2d = volume[j, :, :, 0]     
        X_flat.append(slice_2d)
        y_flat.append(label)

# Convert to NumPy arrays
X_flat = np.array(X_flat) 
y_flat = np.array(y_flat)  


In [12]:
X_flat.shape

(5850, 256, 256)

To produce graphs for SGCNN

In [15]:
# --- CONFIG ---
GRAPH_DIR = "brain_graphs"

# --- 1. GRAPH GENERATION FROM FLAT SLICE LIST ---
import os
import numpy as np
import torch
from tqdm import tqdm
from skimage.measure import regionprops
from skimage.exposure import histogram
from scipy.stats import entropy
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj
from sklearn.cluster import KMeans


CROP_SIZE = 128
os.makedirs(GRAPH_DIR, exist_ok=True)

def segment_slice_kmeans(slice_img, n_regions=5):
    flat = slice_img.reshape(-1, 1)
    km = KMeans(n_clusters=n_regions, n_init='auto').fit(flat)
    return km.labels_.reshape(slice_img.shape)

def slice_to_graph(slice_img, mask, label):
    num_nodes = int(mask.max()) + 1
    x = []
    props = regionprops(mask, intensity_image=slice_img)
    for region in props:
        x.append([
            region.mean_intensity,
            region.area,
            region.eccentricity,
            region.solidity,
            region.perimeter,
            region.extent
        ])
    x = torch.tensor(x, dtype=torch.float)
    adj = np.zeros((num_nodes, num_nodes))
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            touching = (
                np.any((mask == i) & np.roll(mask == j, 1, axis=0)) or
                np.any((mask == i) & np.roll(mask == j, 1, axis=1))
            )
            if touching:
                adj[i, j] = adj[j, i] = 1
    edge_index = torch.tensor(np.array(np.nonzero(adj)), dtype=torch.long)
    return Data(x=x, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))

for idx, (slice_img, label) in enumerate(tqdm(zip(X_flat, y_flat), total=len(X_flat))):
    img = slice_img.squeeze()  # remove channel if present
    mask = segment_slice_kmeans(img)
    g = slice_to_graph(img, mask, label)
    torch.save(g, os.path.join(GRAPH_DIR, f"slice_{idx}.pt"))


100%|██████████| 5850/5850 [05:40<00:00, 17.19it/s]


In [None]:

GRAPH_DIR = "brain_graphs"


import os
import numpy as np
import torch
from tqdm import tqdm
from skimage.measure import regionprops
from skimage.exposure import histogram
from scipy.stats import entropy
from torch_geometric.data import Data, Dataset
from torch_geometric.utils import to_dense_adj
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch.utils.data import DataLoader as ImageDataLoader, Dataset as TorchDataset

# --- Hybrid Model Definition ---
class CNNBranch(nn.Module):
    def __init__(self):
        super(CNNBranch, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.adapt_pool = nn.AdaptiveAvgPool2d((16, 16))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 16 * 16, 128)

    def forward(self, x):
        x = F.relu(self.pool(self.conv1(x)))
        x = F.relu(self.pool(self.conv2(x)))
        x = F.relu(self.pool(self.conv3(x)))
        x = self.adapt_pool(x)
        x = self.flatten(x)
        return self.fc(x)

class SGCNNBranch(nn.Module):
    def __init__(self, in_features=6):
        super(SGCNNBranch, self).__init__()
        self.conv1 = GCNConv(in_features, 32)
        self.conv2 = GCNConv(32, 64)
        self.pool = global_mean_pool
        self.fc = nn.Linear(64, 128)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.pool(x, batch)
        return self.fc(x)

class HybridClassifier(nn.Module):
    def __init__(self):
        super(HybridClassifier, self).__init__()
        self.cnn_branch = CNNBranch()
        self.sgcnn_branch = SGCNNBranch()
        self.classifier = nn.Linear(256, 2)

    def forward(self, image, graph):
        cnn_feat = self.cnn_branch(image)
        sgcnn_feat = self.sgcnn_branch(graph)
        combined = torch.cat([cnn_feat, sgcnn_feat], dim=1)
        return self.classifier(combined)

#different datasets classes for sgcnn and cnn model
class GraphOnlyDataset(Dataset):
    def __init__(self, graph_dir, indices):
        super().__init__()
        self.graph_paths = sorted([os.path.join(graph_dir, f) for f in os.listdir(graph_dir) if f.endswith(".pt")])
        self.graph_paths = [self.graph_paths[i] for i in indices]

    def len(self):
        return len(self.graph_paths)

    def get(self, idx):
        return torch.load(self.graph_paths[idx], weights_only=False)

class ImageOnlyDataset(TorchDataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]