In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam,SGD
from tensorflow.keras.applications import VGG19,ResNet152
from tensorflow.keras.layers import Dense,Dropout,BatchNormalization,Input,Resizing,Flatten,Concatenate,TimeDistributed,Softmax,Multiply,Lambda,GlobalAveragePooling2D
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.utils import Sequence,to_categorical
from tensorflow.keras import Model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import nibabel as nib
import os
import re
import xml.etree.ElementTree as ET
from tqdm import tqdm
from scipy.stats import entropy
from skimage.exposure import histogram
import cv2
from tensorflow.keras.utils import plot_model

In [4]:
df = pd.read_csv("/kaggle/input/adni-processed/ADNI1_Complete_1Yr_1.5T_6_20_2025.csv")
folder_path = "/kaggle/input/adni-processed/ADNI1_Processed/ADNI1_Processed"
paths = []
folder_2 = "/kaggle/input/cleaded-metadata"

for root_dir, dirs, files in tqdm(os.walk(folder_path), desc="Scanning files"):
    for file in files:
        if file.endswith(".nii") or file.endswith(".nii.gz"):
            final_path = os.path.join(root_dir, file)
            rel_path = os.path.relpath(final_path, folder_path)
            
            # Extract subject and image ID using regex
            match = re.search(r'_S(\d+)_I(\d+)', file)
            if match:
                s_num = match.group(1)
                i_num = match.group(2)
                new_filename = f"S{s_num}I{i_num}.xml"

                nii_dir = os.path.dirname(rel_path)
                xml_path = os.path.join(folder_2, new_filename)

                if os.path.exists(xml_path):
                    try:
                        tree = ET.parse(xml_path)
                        xml_root = tree.getroot()
                        id = xml_root[3].attrib.get('uid', None)

                        if id:
                            row = df[df['Image Data ID'].astype(str).str.strip() == str(id).strip()]
                            if not row.empty:
                                label = row.iloc[0, 2]
                                paths.append((label, final_path))
                            else:
                                print(f"[!] ID {id} not found in DataFrame")
                        else:
                            print(f"[!] UID not found in XML: {xml_path}")
                    except Exception as e:
                        print(f"[!] Failed to parse XML: {xml_path} — {e}")
                else:
                    print(f"[!] XML file missing: {xml_path}")
            else:
                print(f"[!] Failed to extract subject/image ID from: {file}")

filtered_paths = []
for path in paths:
    if path[0] in ('AD', 'CN'):
        filtered_paths.append(path)

Scanning files: 815it [00:08, 96.69it/s] 


In [5]:
X = []
y = []

In [6]:
def center_crop(image, crop_size=128):
    h, w = image.shape
    if h < crop_size or w < crop_size:
        return None
    top = (h - crop_size) // 2
    left = (w - crop_size) // 2
    return image[top:top+crop_size, left:left+crop_size]

def image_entropy(img):
    hist, _ = histogram(img)
    hist = hist / np.sum(hist)
    return entropy(hist, base=2)

In [7]:
N = 100  # number of entropy-based slices (top-k slices selection)
crop_size = 128

for label, path in tqdm(filtered_paths):
    scan = nib.load(path)
    data = scan.get_fdata()
    label = 0 if label == 'AD' else 1

    slice_info = []

    for axis in [0, 1, 2]:  
        for i in range(data.shape[axis]):
            # Extract 2D slice along the given axis
            if axis == 0:
                slice_ = data[i, :, :]
            elif axis == 1:
                slice_ = data[:, i, :]
            else:
                slice_ = data[:, :, i]

            # Crop and skip empty ones
            cropped = center_crop(slice_, crop_size=crop_size)
            if cropped is None:
                continue

            # Compute entropy
            ent = image_entropy(cropped)
            slice_info.append((ent, cropped))

    # Sort slices by entropy
    slice_info.sort(reverse=True, key=lambda x: x[0])
    top_slices = slice_info[:N]

    # If not enough valid slices, skip this subject
    if len(top_slices) < N:
        print(f"[!] Skipped subject: only {len(top_slices)} slices")
        continue

    # Build per-subject volume
    subject_volume = [s[1][..., np.newaxis] for s in top_slices]  # (128, 128, 1)
    subject_volume = np.stack(subject_volume, axis=0)  # (N, 128, 128, 1)

    X.append(subject_volume)
    y.append(label)

100%|██████████| 234/234 [03:37<00:00,  1.08it/s]


In [8]:
X = np.stack(X)  # shape: (num_subjects, N, 128, 128, 1)
y = np.array(y)  # shape: (num_subjects,)

In [9]:
y = to_categorical(y)

In [10]:
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)

In [15]:
class DiagnosisXAINetwork:
    def __init__(self, num_slices=100, input_shape=(128, 128, 3), f_dim=2048, att_dim=128, num_classes=2, lr=1e-3):
        self.num_slices = num_slices
        self.input_shape = input_shape
        self.f_dim = f_dim
        self.att_dim = att_dim
        self.num_classes = num_classes
        self.lr = lr
        
        # Feature extractor
        base_model = ResNet152(include_top=False, weights='imagenet', input_shape=input_shape)
        base_model.trainable = False
        self.feature_extractor = tf.keras.Model(inputs=base_model.input, outputs=tf.keras.layers.GlobalAveragePooling2D()(base_model.output))
        
        # Learnable weights
        self.W1 = np.random.randn(att_dim, f_dim) * 0.01
        self.W2 = np.random.randn(f_dim, f_dim) * 0.01
        self.W3 = np.random.randn(num_classes, f_dim) * 0.01

    def softmax(self, x):
        e_x = np.exp(x - np.max(x))
        return e_x / np.sum(e_x)

    def forward_single(self, X_patient):
        slice_features = []

        for j in range(self.num_slices):
            slice_ = X_patient[j, :, :, 0]
            slice_rgb = np.repeat(slice_[:, :, np.newaxis], 3, axis=-1)
            slice_rgb = preprocess_input(slice_rgb)
            slice_rgb = np.expand_dims(slice_rgb, axis=0)

            feature = self.feature_extractor(slice_rgb).numpy().reshape(-1)
            slice_features.append(feature)

        slice_features = np.stack(slice_features)  # (N, f_dim)

        # Attention
        j_scores = np.mean(np.dot(self.W1, slice_features.T), axis=0)  # (N,)
        attention_weights = self.softmax(j_scores)  # (N,)

        weighted_feature = np.sum(slice_features * attention_weights[:, np.newaxis], axis=0)  # (f_dim,)
        fused_feature = np.dot(self.W2, weighted_feature)  # (f_dim,)
        logits = np.dot(self.W3, fused_feature)  # (num_classes,)
        probs = self.softmax(logits)  # (num_classes,)

        cache = {
            "slice_features": slice_features,
            "attention_weights": attention_weights,
            "weighted_feature": weighted_feature,
            "fused_feature": fused_feature,
            "probs": probs,
            "logits": logits
        }

        return probs, cache

    def compute_loss(self, probs, y_true):
        return -np.sum(y_true * np.log(probs + 1e-8))  # cross-entropy

    def backward_single(self, X_patient, y_true, cache):
        # Unpack
        probs = cache["probs"]
        fused_feature = cache["fused_feature"]
        weighted_feature = cache["weighted_feature"]
        attention_weights = cache["attention_weights"]
        slice_features = cache["slice_features"]

        # Gradients
        dL_dlogits = probs - y_true  # (num_classes,)
        dL_dW3 = np.outer(dL_dlogits, fused_feature)  # (num_classes, f_dim)

        dL_dfused = np.dot(self.W3.T, dL_dlogits)  # (f_dim,)
        dL_dW2 = np.outer(dL_dfused, weighted_feature)  # (f_dim, f_dim)

        dL_dweighted = np.dot(self.W2.T, dL_dfused)  # (f_dim,)

        # Now propagate into attention mechanism (simplified)
        # dL/da_i = dL_dweighted * f_i
        dL_datt = np.sum(dL_dweighted * slice_features, axis=1)  # (N,)

        # Now softmax gradient (for attention weights)
        a = attention_weights.reshape(-1, 1)
        softmax_grad = np.diagflat(a) - np.dot(a, a.T)  # (N, N)
        dL_dj = np.dot(softmax_grad, dL_datt)  # (N,)

        # Propagate into W1 (this is crude; treating j_i = mean(W1 @ f_i))
        grad_W1 = np.zeros_like(self.W1)
        for i in range(self.num_slices):
            grad_W1 += (dL_dj[i] / self.att_dim) * np.outer(np.ones(self.att_dim), slice_features[i])

        # Update weights
        self.W3 -= self.lr * dL_dW3
        self.W2 -= self.lr * dL_dW2
        self.W1 -= self.lr * grad_W1

    def train_on_batch(self, X_batch, y_batch):
        total_loss = 0
        for i in range(X_batch.shape[0]):
            probs, cache = self.forward_single(X_batch[i])
            loss = self.compute_loss(probs, y_batch[i])
            self.backward_single(X_batch[i], y_batch[i], cache)
            total_loss += loss
        return total_loss / X_batch.shape[0]

In [16]:
X_dummy = X_train
y_dummy = y_train

In [None]:
model = DiagnosisXAINetwork()

for epoch in tqdm(range(50)):
    total_loss = 0
    correct = 0

    for i in tqdm(range(X_dummy.shape[0])):
        # Forward pass
        probs, cache = model.forward_single(X_dummy[i])
        loss = model.compute_loss(probs, y_dummy[i])
        model.backward_single(X_dummy[i], y_dummy[i], cache)

        total_loss += loss

        # Accuracy
        pred_class = np.argmax(probs)
        true_class = np.argmax(y_dummy[i])
        if pred_class == true_class:
            correct += 1

    acc = correct / X_dummy.shape[0]
    avg_loss = total_loss / X_dummy.shape[0]
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, Accuracy = {acc:.2%}")

  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/187 [00:00<?, ?it/s][A
  1%|          | 1/187 [00:58<2:59:51, 58.02s/it][A
  1%|          | 2/187 [01:55<2:58:29, 57.89s/it][A
  2%|▏         | 3/187 [02:53<2:57:38, 57.92s/it][A
  2%|▏         | 4/187 [03:51<2:56:19, 57.81s/it][A
  3%|▎         | 5/187 [04:49<2:56:06, 58.06s/it][A
  3%|▎         | 6/187 [05:48<2:55:22, 58.13s/it][A
  4%|▎         | 7/187 [06:46<2:54:17, 58.10s/it][A
  4%|▍         | 8/187 [07:44<2:53:18, 58.09s/it][A
  5%|▍         | 9/187 [08:42<2:52:14, 58.06s/it][A
  5%|▌         | 10/187 [09:40<2:51:36, 58.18s/it][A
  6%|▌         | 11/187 [10:38<2:50:34, 58.15s/it][A
  6%|▋         | 12/187 [11:36<2:49:31, 58.12s/it][A
  7%|▋         | 13/187 [12:34<2:48:27, 58.09s/it][A
  7%|▋         | 14/187 [13:32<2:47:23, 58.05s/it][A
  8%|▊         | 15/187 [14:30<2:46:28, 58.07s/it][A
  9%|▊         | 16/187 [15:29<2:45:30, 58.07s/it][A
  9%|▉         | 17/187 [16:27<2:44:31, 58.07s/it][A
 10%|▉   

Epoch 1: Loss = 0.7008, Accuracy = 65.78%



  0%|          | 0/187 [00:00<?, ?it/s][A
  1%|          | 1/187 [00:57<2:58:53, 57.71s/it][A
  1%|          | 2/187 [01:55<2:57:41, 57.63s/it][A
  2%|▏         | 3/187 [02:52<2:56:29, 57.55s/it][A
  2%|▏         | 4/187 [03:49<2:55:09, 57.43s/it][A
  3%|▎         | 5/187 [04:47<2:54:07, 57.40s/it][A
  3%|▎         | 6/187 [05:44<2:53:09, 57.40s/it][A
  4%|▎         | 7/187 [06:42<2:52:16, 57.43s/it][A
  4%|▍         | 8/187 [07:39<2:51:07, 57.36s/it][A
  5%|▍         | 9/187 [08:36<2:50:15, 57.39s/it][A
  5%|▌         | 10/187 [09:34<2:49:11, 57.36s/it][A
  6%|▌         | 11/187 [10:31<2:48:12, 57.35s/it][A
  6%|▋         | 12/187 [11:28<2:47:21, 57.38s/it][A
  7%|▋         | 13/187 [12:26<2:46:48, 57.52s/it][A
  7%|▋         | 14/187 [13:24<2:45:52, 57.53s/it][A
  8%|▊         | 15/187 [14:21<2:44:48, 57.49s/it][A
  9%|▊         | 16/187 [15:19<2:43:42, 57.44s/it][A
  9%|▉         | 17/187 [16:16<2:42:48, 57.46s/it][A
 10%|▉         | 18/187 [17:14<2:41:52, 57.47s/

Epoch 2: Loss = 0.5931, Accuracy = 69.52%



  0%|          | 0/187 [00:00<?, ?it/s][A
  1%|          | 1/187 [00:57<2:57:49, 57.36s/it][A
  1%|          | 2/187 [01:54<2:56:51, 57.36s/it][A
  2%|▏         | 3/187 [02:52<2:56:08, 57.44s/it][A
  2%|▏         | 4/187 [03:49<2:55:18, 57.48s/it][A
  3%|▎         | 5/187 [04:47<2:54:22, 57.49s/it][A
  3%|▎         | 6/187 [05:44<2:53:14, 57.43s/it][A
  4%|▎         | 7/187 [06:42<2:52:17, 57.43s/it][A
  4%|▍         | 8/187 [07:39<2:51:36, 57.52s/it][A
  5%|▍         | 9/187 [08:37<2:50:41, 57.54s/it][A
  5%|▌         | 10/187 [09:34<2:49:36, 57.50s/it][A
  6%|▌         | 11/187 [10:32<2:48:36, 57.48s/it][A
  6%|▋         | 12/187 [11:29<2:47:32, 57.44s/it][A
  7%|▋         | 13/187 [12:27<2:46:37, 57.46s/it][A
  7%|▋         | 14/187 [13:24<2:45:40, 57.46s/it][A
  8%|▊         | 15/187 [14:21<2:44:45, 57.47s/it][A
  9%|▊         | 16/187 [15:19<2:43:56, 57.52s/it][A
  9%|▉         | 17/187 [16:17<2:42:58, 57.52s/it][A
 10%|▉         | 18/187 [17:14<2:41:47, 57.44s/

Epoch 3: Loss = 0.5455, Accuracy = 75.40%



  0%|          | 0/187 [00:00<?, ?it/s][A
  1%|          | 1/187 [00:57<2:58:13, 57.49s/it][A
  1%|          | 2/187 [01:54<2:57:03, 57.42s/it][A
  2%|▏         | 3/187 [02:52<2:56:30, 57.56s/it][A
  2%|▏         | 4/187 [03:49<2:55:21, 57.50s/it][A
  3%|▎         | 5/187 [04:47<2:54:21, 57.48s/it][A
  3%|▎         | 6/187 [05:44<2:53:25, 57.49s/it][A
  4%|▎         | 7/187 [06:42<2:52:20, 57.44s/it][A
  4%|▍         | 8/187 [07:39<2:51:23, 57.45s/it][A
  5%|▍         | 9/187 [08:37<2:50:15, 57.39s/it][A
  5%|▌         | 10/187 [09:34<2:49:12, 57.36s/it][A
  6%|▌         | 11/187 [10:31<2:48:14, 57.35s/it][A
  6%|▋         | 12/187 [11:28<2:47:15, 57.34s/it][A
  7%|▋         | 13/187 [12:26<2:46:12, 57.31s/it][A
  7%|▋         | 14/187 [13:23<2:45:12, 57.30s/it][A
  8%|▊         | 15/187 [14:20<2:44:14, 57.29s/it][A
  9%|▊         | 16/187 [15:18<2:43:31, 57.38s/it][A
  9%|▉         | 17/187 [16:15<2:42:40, 57.42s/it][A
 10%|▉         | 18/187 [17:13<2:41:34, 57.36s/