# Import

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import seaborn as sns
import statsmodels.api as sm
import scipy.stats as stats

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.cluster import AgglomerativeClustering, DBSCAN
from sklearn.mixture import GaussianMixture
from tqdm import tqdm

from statsmodels.formula.api import ols
from scipy.stats import kruskal, spearmanr, mannwhitneyu
from statsmodels.stats.multitest import multipletests
from sklearn.metrics import silhouette_score, adjusted_rand_score
from statsmodels.discrete.discrete_model import MNLogit
from statsmodels.tools import add_constant
from statsmodels.stats.outliers_influence import variance_inflation_factor

import umap.umap_ as umap  # this sometimes doesn't load, try different ways

# Model and data

### Configurations

In [None]:
# CONFIGURATION
IMAGE_DIR = 'PREPROCESSED_IMAGES'
CSV_PATH = 'main.csv'
BATCH_SIZE = 32
NUM_CLASSES = 3
EPOCHS = 50
LR = 1e-4
PATIENCE = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DATASET
class ADNIdataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None):
        self.data = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform

        self.label_map = {label: idx for idx, label in enumerate(self.data['Group'].unique())} # map labels (AD, CN, MCI) to integers
        self.education = self.data['PTEDUCAT'].values.astype(float).reshape(-1, 1)
        self.genotype = pd.get_dummies(self.data['GENOTYPE']).values.astype(float)
        self.age = self.data['Age'].values.astype(float).reshape(-1, 1)
        self.sex = pd.get_dummies(self.data['Sex'], drop_first=True).values.astype(float)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_name = row['ImageID'] + '.png'

        image_path = None
        for subdir in os.listdir(self.image_dir):
            full_path = os.path.join(self.image_dir, subdir, image_name)
            if os.path.exists(full_path):
                image_path = full_path
                break

        if image_path is None:
            raise FileNotFoundError(f"Image {image_name} not found.")

        image = Image.open(image_path).convert('L')
        if self.transform:
            image = self.transform(image)

        label = torch.tensor(self.label_map[row['Group']], dtype=torch.long)
        education_vector = torch.tensor(self.education[idx], dtype=torch.float)
        genotype_vector = torch.tensor(self.genotype[idx], dtype=torch.float)
        age_vector = torch.tensor(self.age[idx], dtype=torch.float)
        sex_vector = torch.tensor(self.sex[idx], dtype=torch.float)
        return image, label, education_vector, genotype_vector, age_vector, sex_vector

# DATA AUGMENTATION
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomAffine(degrees=3, translate=(0.01, 0.01)),  # Very minor augmentation
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


# DATASET (with stratisfied split)
full_dataset = ADNIdataset(CSV_PATH, IMAGE_DIR, transform=val_transform)
train_indices, val_indices = train_test_split(
    range(len(full_dataset)),
    test_size=0.2,
    stratify=full_dataset.data['Group'],
    random_state=42
)

# Apply different transforms
train_dataset = ADNIdataset(CSV_PATH, IMAGE_DIR, transform=train_transform)
train_dataset.data = full_dataset.data.iloc[train_indices].reset_index(drop=True)

val_dataset = ADNIdataset(CSV_PATH, IMAGE_DIR, transform=val_transform)
val_dataset.data = full_dataset.data.iloc[val_indices].reset_index(drop=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# CLASS BALANCING
# Map the groups to numeric
label_map = {'CN': 0, 'MCI': 1, 'AD': 2}
full_dataset.data['Label'] = full_dataset.data['Group'].map(label_map)

# Compute class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(full_dataset.data['Label']),
    y=full_dataset.data['Label']
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

### Training

In [None]:
# MODEL
model = models.resnet18(pretrained=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# ADD DROPOUT
model.fc = nn.Sequential(
    nn.Dropout(p=0.2),
    nn.Linear(512, NUM_CLASSES)
)
model.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = Adam(model.parameters(), lr=LR)

# TRAINING LOOP WITH EARLY STOPPING AND PLOTTING
best_val_loss = float('inf')
epochs_no_improve = 0
train_losses, val_losses, train_accs, val_accs = [], [], [], []

for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0
    correct_train = 0
    total_train = 0

    for batch in train_loader:
        images = batch[0].to(device)
        labels = batch[1].to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct_train += predicted.eq(labels).sum().item()
        total_train += labels.size(0)

    train_loss = total_train_loss / len(train_loader)
    train_acc = correct_train / total_train
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    model.eval()
    total_val_loss = 0
    correct_val = 0
    total_val = 0
    best_epoch = -1

    with torch.no_grad():
        for batch in val_loader:
            images = batch[0].to(device)
            labels = batch[1].to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_val_loss += loss.item()
            _, predicted = outputs.max(1)
            correct_val += predicted.eq(labels).sum().item()
            total_val += labels.size(0)

    val_loss = total_val_loss / len(val_loader)
    val_acc = correct_val / total_val
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f"Epoch {epoch+1}/{EPOCHS}, Train loss: {train_loss:.4f}, Train accuracy: {train_acc:.4f}, Validation loss: {val_loss:.4f}, Validation accuracy: {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        epochs_no_improve = 0
        torch.save(model.state_dict(), 'final_best_model_end.pth')  # saved path to best final model
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print("Early stopping triggered.")
            break

print(f"The best model was from epoch {best_epoch} with val_loss = {best_val_loss:.4f}")

# PLOT TRAINING/VALIDATION LOSS AND ACCURACY
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss over Epochs')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')

plt.tight_layout()
plt.savefig("training_plot.png")
plt.show()


### Load best model

In [None]:
# Recreate the model architecture (same as during training)
model = models.resnet18(pretrained=False)  # pretrained is False here, since we're loading weights
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Sequential(
    nn.Dropout(p=0.2),
    nn.Linear(512, NUM_CLASSES)
)

# Load the trained weights
model.load_state_dict(torch.load("final_best_model_end.pth", map_location=device))
model.to(device)
model.eval()  # Important for disabling dropout, etc.

### Feature extraction

In [None]:
# FEATURE EXTRACTION
model.fc = nn.Identity()  # removes the classification head
model.eval()

feature_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False)
all_features, all_labels, all_education, all_genotype, all_age, all_sex = [], [], [], [], [], []

# Extract feature vectors for all images 
with torch.no_grad():
    for images, labels, education, genotype, age, sex in feature_loader:
        images = images.to(device)
        features = model(images).cpu().numpy()  # outputs the learned image features (not predictions) and converts from torch tensor to numpy for clustering or plotting later
        all_features.append(features)
        all_labels.extend(labels.numpy())  # save class labels
        all_education.extend(education.numpy())  # save education values, etc.
        all_genotype.extend(genotype.numpy())
        all_age.extend(age.numpy())
        all_sex.extend(sex.numpy())

all_features = np.vstack(all_features)
all_labels = np.array(all_labels)
all_education = np.array(all_education)
all_genotype = np.array(all_genotype)
all_age = np.array(all_age)
all_sex = np.array(all_sex)


### Genotype values

Unique genotype values:
'2/2', '2/3', '3/3', '2/4', '3/4', '4/4'

In [None]:
# CREATE GENOTYPE LABELS TO USE FOR STATISTICAL TESTS
genotype_labels = np.array(['2/2', '2/3', '3/3', '2/4', '3/4', '4/4'])[np.argmax(all_genotype, axis=1)]

# Convert genotype_labels (strings) into integer codes (nominal category → ordinal int)
genotype_labels_series = pd.Series(genotype_labels)
genotype_cat = genotype_labels_series.astype('category').cat.codes.values  # e.g. "2/2" → 0

# Here is the mapping
label_mapping = dict(enumerate(genotype_labels_series.astype('category').cat.categories))
print("Genotype label mapping:", label_mapping)


Genotype label mapping: {0: '2/2', 1: '2/3', 2: '2/4', 3: '3/3', 4: '3/4', 5: '4/4'}


# Sensitivity checks

Clustering algorithm, number of clusters, PCs and varince, UMAP hyperparameters (not fully cleared).

### Finding number of PCs

In [None]:
# PCA components

# Input data
X = all_features

label_map_reverse = {v: k for k, v in full_dataset.label_map.items()}
diagnosis_names = [label_map_reverse[label] for label in all_labels]
diagnosis_labels = diagnosis_names

# Range of PCA components to test
pca_components = [10, 15, 20, 25, 30, 35, 40, 50, 60, 80, 100]
sil_scores = []
ari_scores = []

for n_comp in pca_components:
    # PCA
    pca = PCA(n_components=n_comp)
    pca_result = pca.fit_transform(X)

    # K-Means
    #kmeans = KMeans(n_clusters=3, random_state=42)
    #cluster_labels = kmeans.fit_predict(pca_result)

    # GMM
    gmm = GaussianMixture(n_components=3, random_state=42)
    cluster_labels = gmm.fit_predict(pca_result)

    # Silhouette scores
    sil = silhouette_score(pca_result, cluster_labels)
    sil_scores.append(sil)

    if diagnosis_labels is not None:
        ari = adjusted_rand_score(diagnosis_labels, cluster_labels)
        ari_scores.append(ari)

# Plotting
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(pca_components, sil_scores, marker='o')
plt.title('Silhouette Score vs PCA Components')
plt.xlabel('Number of PCA Components')
plt.ylabel('Silhouette Score')
plt.grid(True)

if diagnosis_labels is not None:
    plt.subplot(1, 2, 2)
    plt.plot(pca_components, ari_scores, marker='o', color='darkorange')
    plt.title('ARI vs PCA Components (Compared to Diagnosis)')
    plt.xlabel('Number of PCA Components')
    plt.ylabel('Adjusted Rand Index (ARI)')
    plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
# Check PCA variance

pca = PCA(n_components=100)
pca.fit(all_features)
explained_var = np.cumsum(pca.explained_variance_ratio_)

plt.figure(figsize=(8, 4))
plt.plot(range(1, 101), explained_var, marker='o')
plt.xlabel("Number of PCA Components")
plt.ylabel("Cumulative Explained Variance")
plt.title("Explained Variance vs PCA Components")
plt.grid(True)
plt.axhline(y=0.90, color='r', linestyle='--', label='90% Variance')
plt.axhline(y=0.85, color='r', linestyle='--', label='85% Variance')
plt.xticks(np.arange(0, 101, 10))  # Show every 10th tick
plt.legend()
plt.show()

print(f"Cumulative explained variance for 30 components: {explained_var[29]*100:.2f}%")


### Finding UMAP hyperparameters

In [None]:
# UMAP hyperparameters

X = all_features
diagnosis_labels = diagnosis_names

# Fixed PCA
pca = PCA(n_components=30)  # 30 according to previous test
pca_result = pca.fit_transform(X)

# UMAP parameter ranges
n_neighbors_list = [5, 10, 15, 30, 50]
min_dist_list = [0.0, 0.001, 0.01, 0.1, 0.3, 0.5]

# Store results
results = []

for n_neighbors in n_neighbors_list:
    for min_dist in min_dist_list:
        umap_model = umap.UMAP(
            n_neighbors=n_neighbors,
            min_dist=min_dist,
            n_components=2,
            random_state=42
        )
        umap_result = umap_model.fit_transform(pca_result)

        # K-means clustering on UMAP output
        kmeans = KMeans(n_clusters=3, random_state=42)
        cluster_labels = kmeans.fit_predict(umap_result)

        sil = silhouette_score(umap_result, cluster_labels)
        ari = adjusted_rand_score(diagnosis_labels, cluster_labels)

        results.append({
            "n_neighbors": n_neighbors,
            "min_dist": min_dist,
            "silhouette": sil,
            "ari": ari
        })

df = pd.DataFrame(results)

# Plot silhouette Score
plt.figure(figsize=(7, 5))

for n in n_neighbors_list:
    subset = df[df["n_neighbors"] == n]
    plt.plot(subset["min_dist"], subset["silhouette"], marker='o', label=f'n={n}')

plt.title("Silhouette Score vs min_dist (by n_neighbors)")
plt.xlabel("min_dist")
plt.ylabel("Silhouette Score")
plt.legend(title="n_neighbors")
plt.grid(True)
plt.tight_layout()
plt.show()


### Clustering algorithms

In [None]:
# Testing clustering algorithms

# INPUT
X = all_features  # CNN features
true_labels = diagnosis_names  # ['CN', 'AD', 'MCI']

# Dimensionality rreduction (PCA)
pca = PCA(n_components=30)
X_pca = pca.fit_transform(X)

# Clustering methods
clustering_methods = {
    "KMeans (k=3)": KMeans(n_clusters=3, random_state=42),
    "Agglomerative (k=3)": AgglomerativeClustering(n_clusters=3),
    "DBSCAN": DBSCAN(eps=3, min_samples=5),
    "GMM (k=3)": GaussianMixture(n_components=3, random_state=42)
}

# Evaluation
sil_scores = []
ari_scores = []
labels_used = []

# Loop through clustering methods
for name, method in clustering_methods.items():
    try:
        if "GMM" in name:
            cluster_labels = method.fit_predict(X_pca)
        else:
            cluster_labels = method.fit_predict(X_pca)

        if len(np.unique(cluster_labels)) < 2:
            sil = np.nan
            ari = np.nan
        else:
            sil = silhouette_score(X_pca, cluster_labels)
            ari = adjusted_rand_score(true_labels, cluster_labels)

        sil_scores.append(sil)
        ari_scores.append(ari)
        labels_used.append(name)

        print(f"{name}")
        print(f"Silhouette Score: {sil:.4f}")
        print(f"Adjusted Rand Index vs Diagnosis: {ari:.4f}\n")

    except Exception as e:
        print(f"Error with {name}: {e}")
        sil_scores.append(np.nan)
        ari_scores.append(np.nan)
        labels_used.append(name)

# Plot
x = np.arange(len(labels_used))
width = 0.35

plt.figure(figsize=(12, 6))
plt.bar(x - width/2, sil_scores, width, label='Silhouette Score')
plt.bar(x + width/2, ari_scores, width, label='ARI vs Diagnosis')

plt.xticks(x, labels_used, rotation=15)
plt.ylabel("Score")
plt.title("Clustering Method Comparison (Silhouette score and ARI)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


### Final: optimal number of K for K-means

In [None]:
# Optimal number of K for K-means

X = all_features

wcss = []  # store within-cluster sum of squares
K_range = range(2, 11)  # test from 2 to 10 clusters

for k in K_range:
    kmeans = KMeans(n_clusters=k, random_state=42)
    kmeans.fit(X)
    wcss.append(kmeans.inertia_)  # inertia_ = WCSS

# Plot WCSS (for elbow)
plt.figure(figsize=(8, 4))
plt.plot(K_range, wcss, marker='o', linestyle='-')
plt.title("Optimal number of clusters")
plt.xlabel("Number of Clusters (K)")
plt.ylabel("Within-Cluster Sum of Squares (WCSS)")
plt.xticks(K_range)  # show integer ticks for K
plt.grid(True)
plt.show()


# Dimensionality reduction, clustering and visualisation

### PCA AND K-MEANS

In [None]:
# PCA + K-means
pca = PCA(n_components=30)  # new n components
pca_result = pca.fit_transform(all_features)

# K-means
kmeans = KMeans(n_clusters=3, random_state=42)
cluster_labels = kmeans.fit_predict(pca_result)
cluster_centers_pca = kmeans.cluster_centers_

# UMAP
# Combine PCA features with cluster centers for visualisation
combined = np.vstack([pca_result, cluster_centers_pca])

# Fit UMAP on the combined data
umap_model = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.0, random_state=42)  # new UMAP after sensitivity check
umap_result_combined = umap_model.fit_transform(combined)

# Separate points and cluster centers again
umap_data = umap_result_combined[:-3]
umap_centers = umap_result_combined[-3:]

### PLOT USING UMAP

In [None]:
# PLOT EVERYTHING (UMAP)

# Plot clusters
plt.figure(figsize=(10, 8))
cluster_colors = ['sandybrown', 'teal', 'mediumpurple']

# Plot each cluster with a fixed color
num_clusters = len(np.unique(cluster_labels))
for i in range(num_clusters):
    indices = np.where(cluster_labels == i)[0]
    plt.scatter(umap_data[indices, 0], umap_data[indices, 1],
                c=cluster_colors[i], label=f'Cluster {i}', alpha=0.6, s=40)
    
plt.scatter(umap_centers[:, 0], umap_centers[:, 1], c='black', s=200, marker='X', label='Cluster Centers')

# Annotate each cluster center with its cluster number
for i, (x, y) in enumerate(umap_centers):
    plt.text(x + 0.3, y, f'Cluster {i}', fontsize=12, fontweight='bold', color='black')
    
plt.title('UMAP of CNN activations coloured by Cluster')
plt.grid(True)
plt.legend()
plt.show()

# Map integer labels back to original diagnosis group names
label_map_reverse = {v: k for k, v in full_dataset.label_map.items()}
diagnosis_names = [label_map_reverse[label] for label in all_labels]
color_map = {
    'AD': 'indianred',
    'MCI': 'mediumseagreen',
    'CN': 'cornflowerblue'
}

# Plot diagnosis
plt.figure(figsize=(10, 8))
for group in np.unique(diagnosis_names):
    indices = [i for i, x in enumerate(diagnosis_names) if x == group]
    plt.scatter(umap_data[indices, 0], umap_data[indices, 1], label=group, alpha=0.6, s=40, c=color_map[group])
plt.scatter(umap_centers[:, 0], umap_centers[:, 1], c='black', s=150, marker='X', label='Cluster Centers')
plt.legend()
plt.title('UMAP of CNN activations coloured by Diagnosis')
plt.grid(True)
plt.show()

# Plot education
plt.figure(figsize=(10, 8))
scatter = plt.scatter(umap_data[:, 0], umap_data[:, 1], c=all_education, cmap='viridis', alpha=0.7)
plt.scatter(umap_centers[:, 0], umap_centers[:, 1], c='black', s=200, marker='X', label='Cluster Centers')
cbar = plt.colorbar(scatter)
cbar.set_label('Years of Education')
plt.title('UMAP of CNN activations coloured by Education')
plt.grid(True)
plt.legend()
plt.show()

# Plot genotype
unique_labels, genotype_numeric = np.unique(genotype_labels, return_inverse=True)  # encode labels as categorical integer codes (0 to n_categories-1)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(umap_data[:, 0], umap_data[:, 1], c=genotype_numeric, cmap='plasma', alpha=0.7)
plt.scatter(umap_centers[:, 0], umap_centers[:, 1], c='black', s=200, marker='X', label='Cluster Centers')
# Create colourbar
cbar = plt.colorbar(scatter, ticks=np.arange(len(unique_labels)))
cbar.ax.set_yticklabels(unique_labels)
cbar.set_label('Genotype APOE')
plt.title('UMAP of CNN activations coloured by Genotype')
plt.grid(True)
plt.legend()
plt.show()


# Plot age
plt.figure(figsize=(10, 8))
scatter = plt.scatter(umap_data[:, 0], umap_data[:, 1], c=all_age, cmap='cividis', alpha=0.7)
plt.scatter(umap_centers[:, 0], umap_centers[:, 1], c='black', s=200, marker='X', label='Cluster Centers')
cbar = plt.colorbar(scatter)
cbar.set_label('Age')
plt.title('UMAP of CNN activations coloured by Age')
plt.grid(True)
plt.legend()
plt.show()

# Plot sex
sex_labels = np.where(all_sex == 1, 'M', 'F')  # convert to string labels for plotting
plt.figure(figsize=(10, 8))
# Plot M in blue
indices_m = np.where(sex_labels == 'M')[0]
plt.scatter(umap_data[indices_m, 0], umap_data[indices_m, 1],
            label='M', alpha=0.6, s=40, c='mediumseagreen')
# Plot F in red
indices_f = np.where(sex_labels == 'F')[0]
plt.scatter(umap_data[indices_f, 0], umap_data[indices_f, 1],
            label='F', alpha=0.6, s=40, c='orchid')
# Plot cluster centers
plt.scatter(umap_centers[:, 0], umap_centers[:, 1],
            c='black', s=150, marker='X', label='Cluster Centers')
plt.legend()
plt.title('UMAP of CNN activations coloured by Sex')
plt.grid(True)
plt.show()


### Distribution plots by diagnosis

In [None]:
# Map integer labels back to original diagnosis group names
label_map_reverse = {v: k for k, v in full_dataset.label_map.items()}
diagnosis_names = [label_map_reverse[label] for label in all_labels]

In [None]:
# PLOT KDE DISTRUBUTIONS (FOR CONTINOUS VARIABLES EDUCATION AND AGE)

# EDUCATION
df_plot = pd.DataFrame({
    'Education': np.array(all_education).flatten(),
    'Diagnosis': diagnosis_names
})

# Plot
plt.figure(figsize=(10, 6))
sns.kdeplot(
    data=df_plot,
    x="Education",
    hue="Diagnosis",
    fill=True,
    common_norm=False, # normalise
    alpha=0.4,
    linewidth=2,
    palette={"CN": "cornflowerblue", "MCI": "mediumseagreen", "AD": "red"}
)

plt.xlabel("Years of Education")
plt.ylabel("Density")
plt.title("Distribution of Education by Diagnosis group")
plt.grid(True)
plt.tight_layout()
plt.show()


# AGE
df_plot = pd.DataFrame({
    'Age': np.array(all_age).flatten(),
    'Diagnosis': diagnosis_names
})

# Plot
plt.figure(figsize=(10, 6))
sns.kdeplot(
    data=df_plot,
    x="Age",
    hue="Diagnosis",
    fill=True,
    common_norm=False,  # normalise
    alpha=0.4,
    linewidth=2,
    palette={"CN": "cornflowerblue", "MCI": "mediumseagreen", "AD": "red"}
)

plt.xlabel("Age")
plt.ylabel("Density")
plt.title("Distribution of Age by Diagnosis group")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# BAR PLOTS DISTRIBUTIONS (FOR CATEGORICAL VARIABLES GENOTYPE AND SEX)

# SEX
df_plot_sex = pd.DataFrame({
    'Sex': np.array(all_sex).flatten(),  # 0 = one sex, 1 = the other
    'Diagnosis': diagnosis_names
})
sex_map = {0: 'F', 1: 'M'}
df_plot_sex['Sex'] = df_plot_sex['Sex'].map(sex_map)

# Count and normalise
df_plot_sex['count'] = 1
sex_counts = df_plot_sex.groupby(['Diagnosis', 'Sex'])['count'].sum().reset_index()
total_per_diagnosis = sex_counts.groupby('Diagnosis')['count'].transform('sum')
sex_counts['proportion'] = sex_counts['count'] / total_per_diagnosis

# Plot
plt.figure(figsize=(8, 5))
sns.barplot(
    data=sex_counts,
    x='Sex',
    y='proportion',
    hue='Diagnosis',
    palette={"CN": "cornflowerblue", "MCI": "mediumseagreen", "AD": "red"}
)
plt.ylabel("Proportion")
plt.title("Distribution of Sex by Diagnosis Group")
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()


# GENOTYPE
ordered_genotypes = ['2/2', '2/3', '3/3', '2/4', '3/4', '4/4']
genotype_cat = pd.Categorical(genotype_labels, categories=ordered_genotypes, ordered=True)
genotype_numeric = genotype_cat.codes  # these are 0–5, matching the order above

df_plot_genotype = pd.DataFrame({
    'Genotype': genotype_cat,
    'Diagnosis': diagnosis_names
})

# Count and normalise
df_plot_genotype['count'] = 1
geno_counts = df_plot_genotype.groupby(['Diagnosis', 'Genotype'])['count'].sum().reset_index()
total_per_diag = geno_counts.groupby('Diagnosis')['count'].transform('sum')
geno_counts['proportion'] = geno_counts['count'] / total_per_diag

# Plot
plt.figure(figsize=(10, 5))
sns.barplot(
    data=geno_counts,
    x='Genotype',
    y='proportion',
    hue='Diagnosis',
    palette={"CN": "cornflowerblue", "MCI": "mediumseagreen", "AD": "red"}
)
plt.ylabel("Proportion")
plt.title("Distribution of Genotype by Diagnosis Group")
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
# Information about distribution of genotype and diagnosis

df_plot = pd.DataFrame({
    'GenotypeNumeric': genotype_numeric,
    'GenotypeLabel': genotype_labels,
    'Diagnosis': diagnosis_names
})

df_plot.groupby(['Diagnosis', 'GenotypeLabel']).size()

# Statistical tests and evaluation analysis

### Silhouette score, Kruskal-Wallis, clusterwise summaries

In [None]:
# STATISTICAL ANALYSIS

# SILHOUETTE SCORE
sil_score = silhouette_score(pca_result, cluster_labels)
print(f"Silhouette Score: {sil_score:.4f}")

# KRUSKAL–WALLIS TEST
# diagnosis differences between clusters
diagnosis_by_cluster = [all_labels[np.array(cluster_labels) == i].ravel() for i in np.unique(cluster_labels)]
kruskal_stat_d, kruskal_p_d = kruskal(*diagnosis_by_cluster)
print(f"Kruskal–Wallis Test (Diagnosis): H={float(kruskal_stat_d):.4f}, p={float(kruskal_p_d):.4e}")

# education differences between clusters
education_by_cluster = [all_education[np.array(cluster_labels) == i].ravel() for i in np.unique(cluster_labels)]
kruskal_stat_ed, kruskal_p_ed = kruskal(*education_by_cluster)
print(f"Kruskal–Wallis Test (Education): H={float(kruskal_stat_ed):.4f}, p={float(kruskal_p_ed):.4e}")

# genotype differences between clusters as NOMINAL
genotype_by_cluster = [genotype_cat[np.array(cluster_labels) == i] for i in np.unique(cluster_labels)]  # group the coded genotypes by cluster
kruskal_stat_g, kruskal_p_g = kruskal(*genotype_by_cluster)
print(f"Kruskal–Wallis Test (Genotype as Nnminal): H={float(kruskal_stat_g):.4f}, p={float(kruskal_p_g):.4e}")

# age differences between clusters
age_by_cluster = [all_age[np.array(cluster_labels) == i].ravel() for i in np.unique(cluster_labels)]
kruskal_stat_age, kruskal_p_age = kruskal(*age_by_cluster)
print(f"Kruskal–Wallis Test (Age): H={float(kruskal_stat_age):.4f}, p={float(kruskal_p_age):.4e}")

# sex differences between clusters
sex_by_cluster = [all_sex[np.array(cluster_labels) == i].ravel() for i in np.unique(cluster_labels)]
kruskal_stat_sex, kruskal_p_sex = kruskal(*sex_by_cluster)
print(f"Kruskal–Wallis Test (Sex): H={float(kruskal_stat_sex):.4f}, p={float(kruskal_p_sex):.4e}")

# CLUSTERWISE SUMMARY STATISTICS
df = pd.DataFrame({
    'Cluster': cluster_labels.ravel(),
    'Diagnosis': all_labels.ravel(),
    'Education': all_education.ravel(),
    'GenotypeIndex': genotype_cat.ravel(),
    'GenotypeNumeric': genotype_numeric.ravel(),
    'Age': all_age.ravel(),
    'Sex': all_sex.ravel()
})

# Summary diagnosis
summary_d = df.groupby('Cluster')['Diagnosis'].agg(['count', 'mean', 'median', 'std'])
print("\Diagnosis Stats per Cluster:")
print(summary_d)
# Summary education
summary_ed = df.groupby('Cluster')['Education'].agg(['count', 'mean', 'median', 'std'])
print("\nEducation Stats per Cluster:")
print(summary_ed) 
# Summary genotype
summary_g = df.groupby('Cluster')['GenotypeNumeric'].agg(['count', 'mean', 'median', 'std'])
print("Genotype Stats per Cluster:")
print(summary_g)
# Summary age
summary_age = df.groupby('Cluster')['Age'].agg(['count', 'mean', 'median', 'std'])
print("\Age Stats per Cluster:")
print(summary_age) 
# Summary age
summary_sex = df.groupby('Cluster')['Sex'].agg(['count', 'mean', 'median', 'std'])
print("\Sex Stats per Cluster:")
print(summary_sex)


### Post-hoc Mann-Whitney U test

In [None]:
# Post-hoc pairwise test
# Specifically if there are any differences between any 2 pairs of clusters
# Since we proved that all except age show differences, these are tested

# ANALYSIS OF DIAGNOSIS
unique_clusters = np.unique(cluster_labels)
comparisons = []
p_values = []

# Pairwise Mann–Whitney U tests
for i in range(len(unique_clusters)):
    for j in range(i + 1, len(unique_clusters)):
        c1 = unique_clusters[i]
        c2 = unique_clusters[j]
        
        group1 = all_labels[np.array(cluster_labels) == c1]
        group2 = all_labels[np.array(cluster_labels) == c2]
        
        stat, p = mannwhitneyu(group1, group2, alternative='two-sided')
        comparisons.append(f"{c1} vs {c2}")
        p_values.append(p)

# Apply Bonferroni correction
p_values = np.array(p_values).flatten()
reject, pvals_corrected, _, _ = multipletests(p_values, method='bonferroni')

# Print
print("\nPost-hoc Mann–Whitney U Tests for Diagnosis (Bonferroni corrected):\n")
print(f"{'Comparison':<10}, {'Raw p-value':<12}, {'Bonferroni p':<14}, {'Significant'}")
print("-" * 60)

for i in range(len(comparisons)):
    print(f"{comparisons[i]:<10} | {p_values[i]:<12.4e} | {pvals_corrected[i]:<14.4e} | {reject[i]}")


# ANALYSIS OF EDUCATION
unique_clusters = np.unique(cluster_labels)
comparisons = []
p_values = []

# Pairwise Mann–Whitney U tests
for i in range(len(unique_clusters)):
    for j in range(i + 1, len(unique_clusters)):
        c1 = unique_clusters[i]
        c2 = unique_clusters[j]
        
        group1 = all_education[np.array(cluster_labels) == c1]
        group2 = all_education[np.array(cluster_labels) == c2]
        
        stat, p = mannwhitneyu(group1, group2, alternative='two-sided')
        comparisons.append(f"{c1} vs {c2}")
        p_values.append(p)

# Apply Bonferroni correction
p_values = np.array(p_values).flatten()
reject, pvals_corrected, _, _ = multipletests(p_values, method='bonferroni')

# Print
print("\nPost-hoc Mann–Whitney U Tests for Education (Bonferroni corrected):\n")
print(f"{'Comparison':<10}, {'Raw p-value':<12}, {'Bonferroni p':<14}, {'Significant'}")
print("-" * 60)

for i in range(len(comparisons)):
    print(f"{comparisons[i]:<10} | {p_values[i]:<12.4e} | {pvals_corrected[i]:<14.4e} | {reject[i]}")


# ANALYSIS OF GENOTYPE
unique_clusters = np.unique(cluster_labels)
comparisons = []
p_values = []

# Pairwise Mann–Whitney U tests
for i in range(len(unique_clusters)):
    for j in range(i + 1, len(unique_clusters)):
        c1 = unique_clusters[i]
        c2 = unique_clusters[j]
        
        group1 = genotype_cat[np.array(cluster_labels) == c1]
        group2 = genotype_cat[np.array(cluster_labels) == c2]
        
        stat, p = mannwhitneyu(group1, group2, alternative='two-sided')
        comparisons.append(f"{c1} vs {c2}")
        p_values.append(p)

# Apply Bonferroni correction
p_values = np.array(p_values).flatten()
reject, pvals_corrected, _, _ = multipletests(p_values, method='bonferroni')

# Print
print("\nPost-hoc Mann–Whitney U Tests for Genotype (Bonferroni corrected):\n")
print(f"{'Comparison':<10}, {'Raw p-value':<12}, {'Bonferroni p':<14}, {'Significant'}")
print("-" * 60)

for i in range(len(comparisons)):
    print(f"{comparisons[i]:<10} | {p_values[i]:<12.4e} | {pvals_corrected[i]:<14.4e} | {reject[i]}")    


# ANALYSIS OF SEX
unique_clusters = np.unique(cluster_labels)
comparisons = []
p_values = []

# Pairwise Mann–Whitney U tests
for i in range(len(unique_clusters)):
    for j in range(i + 1, len(unique_clusters)):
        c1 = unique_clusters[i]
        c2 = unique_clusters[j]
        
        group1 = all_sex[np.array(cluster_labels) == c1]
        group2 = all_sex[np.array(cluster_labels) == c2]
        
        stat, p = mannwhitneyu(group1, group2, alternative='two-sided')
        comparisons.append(f"{c1} vs {c2}")
        p_values.append(p)

# Apply Bonferroni correction
p_values = np.array(p_values).flatten()
reject, pvals_corrected, _, _ = multipletests(p_values, method='bonferroni')

# Print
print("\nPost-hoc Mann–Whitney U Tests for Sex (Bonferroni corrected):\n")
print(f"{'Comparison':<10}, {'Raw p-value':<12}, {'Bonferroni p':<14}, {'Significant'}")
print("-" * 60)

for i in range(len(comparisons)):
    print(f"{comparisons[i]:<10} | {p_values[i]:<12.4e} | {pvals_corrected[i]:<14.4e} | {reject[i]}")    

### ARI

In [None]:
# Adjusted Rand Index (ARI) between clusters and diagnosis
ari = adjusted_rand_score(all_labels, cluster_labels)
print(f"Adjusted Rand Index (Clusters vs Diagnosis): {ari:.4f}")

# Just for curiosity:
ari = adjusted_rand_score(all_labels, all_education.ravel())
print(f"Adjusted Rand Index (Clusters vs Education): {ari:.4f}")

ari = adjusted_rand_score(all_labels, genotype_cat)
print(f"Adjusted Rand Index (Clusters vs Genotype): {ari:.4f}")

ari = adjusted_rand_score(all_labels, all_age.ravel())
print(f"Adjusted Rand Index (Clusters vs Age): {ari:.4f}")

ari = adjusted_rand_score(all_labels, all_sex.ravel())
print(f"Adjusted Rand Index (Clusters vs Sex): {ari:.4f}")

For analysing the above:
Most Common Diagnosis per Cluster:
Cluster 0 → CN
Cluster 1 → AD
Cluster 2 → MCI

### MLR

In [None]:
# Multinomial Logistic Regression for predicting Clusters

df_multi = pd.DataFrame({
    'Cluster': cluster_labels,
    'Diagnosis': all_labels.ravel(),
    'Education': all_education.ravel(),
    'GenotypeIndex': genotype_cat.ravel(),
    'Age': all_age.ravel(),
    'Sex': all_sex.ravel()
})

# Convert to categorical for regression
df_multi['GenotypeIndex'] = pd.Categorical(genotype_cat)
df_multi['Diagnosis'] = pd.Categorical(df_multi['Diagnosis'])

# Cluster 0 (CN) as baseline
logit_model = sm.MNLogit(df_multi['Cluster'], sm.add_constant(df_multi[['Diagnosis', 'Education', 'GenotypeIndex', 'Age', 'Sex']]))
result = logit_model.fit(disp=False)
print("\nMultinomial Logistic Regression (Cluster ~ Diagnosis + Education + Genotype + Age + Sex):")
print(result.summary())

In [None]:
df_multi['Cluster'] = pd.Categorical(df_multi['Cluster'], categories=[1, 0, 2])  # now Cluster 1 is baseline

# Refit MLR with new baseline
logit_model = sm.MNLogit(df_multi['Cluster'], sm.add_constant(df_multi[['Diagnosis', 'Education', 'GenotypeIndex', 'Age', 'Sex']]))
result = logit_model.fit(disp=False)

print("\nMultinomial Logistic Regression (Cluster ~ Diagnosis + Education + Genotype + Age + Sex) [Cluster 1 as baseline]:")
print(result.summary())


### PCA component correlations

In [None]:
# PCA Component Correlations

pca_features = pca_result[:, :10]  # first 10 principal components

def correlate_pcs(pca_features, external_variable, variable_name):
    print(f"\nTop PCA Components Correlated with {variable_name}:")
    for i in range(pca_features.shape[1]):
        r, p = spearmanr(pca_features[:, i], external_variable.ravel())
        print(f"PC{i+1}: Spearman r={r:.4f}, p={p:.2e}")

correlate_pcs(pca_features, diagnosis_numeric, 'Diagnosis')
correlate_pcs(pca_features, all_education, 'Education')
correlate_pcs(pca_features, all_age, 'Age')
correlate_pcs(pca_features, genotype_cat, 'GenotypeIndex')
correlate_pcs(pca_features, all_sex, 'Sex')

### Feature maps

In [None]:

def plot_feature_maps(model, image_tensor, layer, num_maps=8, layer_name=""):
    activation = {}

    def hook_fn(module, input, output):
        activation['features'] = output.detach().cpu()

    # Register forward hook
    handle = layer.register_forward_hook(hook_fn)

    # Forward pass
    model.eval()
    with torch.no_grad():
        _ = model(image_tensor.to(device))

    handle.remove()

    # Get feature maps
    features = activation['features'].squeeze(0)  # shape: (C, H, W)
    num_maps = min(num_maps, features.shape[0])

    # Plot
    plt.figure(figsize=(15, 5))
    for i in range(num_maps):
        plt.subplot(1, num_maps, i + 1)
        plt.imshow(features[i], cmap='viridis')
        plt.title(f'Map {i}')
        plt.axis('off')
    plt.suptitle(f'Feature Maps (Layer {layer_name})')
    plt.tight_layout()
    plt.show()

# RUN EXAMPLE
# Load a random image
sample_loader = DataLoader(full_dataset, batch_size=1, shuffle=True)
sample_img_batch = next(iter(sample_loader))
sample_img = sample_img_batch[0].to(device)  # [0] = image, others = label + extras

# Plot from key layers of ResNet-18
layers_to_plot = [
    (model.conv1, "0"),
    (model.layer1[0].conv1, "1"),
    (model.layer2[0].conv1, "2"),
    (model.layer3[0].conv1, "3"),
    (model.layer4[0].conv1, "4"),
]

for layer, name in layers_to_plot:
    plot_feature_maps(model, sample_img, layer, num_maps=6, layer_name=name)


# Other stuff

## Overall cluster information and different sanity checks

In [None]:
# TO CHECK HOW MANY DIAGNOSIS IN EACH CLUSTER

df_cluster_diagnosis = pd.DataFrame({
    'Cluster': cluster_labels,
    'Diagnosis': diagnosis_labels  # or diagnosis_names if you have string labels
})

# Crosstab counts of diagnoses in each cluster
crosstab = pd.crosstab(df_cluster_diagnosis['Cluster'], df_cluster_diagnosis['Diagnosis'])

# Normalise to get percentages
crosstab_percent = crosstab.div(crosstab.sum(axis=1), axis=0) * 100

# Print results
print("Diagnosis Counts per Cluster:")
print(crosstab)
print("\nDiagnosis Percentages per Cluster:")
print(crosstab_percent.round(2))


In [None]:
label_map_reverse = {v: k for k, v in full_dataset.label_map.items()}
print("Label Map:", label_map_reverse)

diagnosis_numeric = np.array([label_map[label] for label in diagnosis_labels])

# LABEL MAP IS NOT THE SAME AS THE CLUSTER ORDER!!!! DO NOT MIND THE LABEL MAP.

In [None]:
df_mapping = pd.DataFrame({
    'Cluster': cluster_labels,
    'Diagnosis': diagnosis_numeric.ravel()
})

# Count how many of each diagnosis per cluster
counts = pd.crosstab(df_mapping['Cluster'], df_mapping['Diagnosis'])

# Find most common diagnosis per cluster
dominant_labels = counts.idxmax(axis=1)  # gives the diagnosis index
dominant_labels_named = dominant_labels.map(label_map)  # convert to names

label_map_reverse = {v: k for k, v in label_map.items()}

# Print
print("Most Common Diagnosis per Cluster:\n")
for cluster_id, diag_index in dominant_labels.items():
    label_name = label_map_reverse[int(diag_index)]
    print(f"Cluster {cluster_id} → {label_name}")


In [None]:
sex_map = {0: 'F', 1: 'M'}

## ANCOVA (not used anymore)

In [None]:
# ANOCOVA

# Flatten just in case
education_flat = all_education.ravel()
cluster_labels_flat = np.array(cluster_labels).ravel()
diagnosis_flat = np.array(all_labels).ravel()
age_flat = np.array(all_age).ravel()
sex_flat = np.array(all_sex).ravel()

df_ancova = pd.DataFrame({
    'Education': education_flat,
    'Cluster': cluster_labels_flat.astype(str),
    'Diagnosis': diagnosis_flat.astype(str),
    'Age': age_flat,
    'GenotypeIndex': genotype_strength,
    'Sex': sex_flat
})

# ANCOVA: Does education differ by cluster, after accounting for diagnosis?
model = ols('Education ~ C(Cluster) + C(Diagnosis)', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (Education ~ Cluster + Diagnosis):")
print(anova_table.round(4))

# ANCOVA: Does genotype differ by cluster, after accounting for diagnosis?
model = ols('GenotypeIndex ~ C(Cluster) + C(Diagnosis)', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (GenotypeIndex ~ Cluster + Diagnosis):")
print(anova_table.round(4))

# ANCOVA: Does age differ by cluster, after accounting for diagnosis?
model = ols('Age ~ C(Cluster) + C(Diagnosis)', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (Age ~ Cluster + Diagnosis):")
print(anova_table.round(4))

# ANCOVA: Does sex differ by cluster, after accounting for diagnosis?
model = ols('Sex ~ C(Cluster) + C(Diagnosis)', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (Sex ~ Cluster + Diagnosis):")
print(anova_table.round(4))

# ANCOVA: Does education differ by cluster, after accounting for age?
model = ols('Education ~ C(Cluster) + Age', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (Education ~ Cluster + Age):")
print(anova_table.round(4))

# ANCOVA: Does genotype differ by cluster, after accounting for age?
model = ols('GenotypeIndex ~ C(Cluster) + Age', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (GenotypeIndex ~ Cluster + Age):")
print(anova_table.round(4))

# ANCOVA: Does education differ by cluster, after accounting for sex?
model = ols('Education ~ C(Cluster) + Sex', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (Education ~ Cluster + Sex):")
print(anova_table.round(4))

# ANCOVA: Does genotype differ by cluster, after accounting for sex?
model = ols('GenotypeIndex ~ C(Cluster) + Sex', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (GenotypeIndex ~ Cluster + Sex):")
print(anova_table.round(4))

# ANCOVA: Does education differ by cluster, after accounting for age + diagnosis + sex?
model = ols('Education ~ C(Cluster) + Age + C(Diagnosis) + Sex', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (Education ~ Cluster + Age + Diangosis + Sex):")
print(anova_table.round(4))

# ANCOVA: Does genotype differ by cluster, after accounting for age + diagnosis + sex?
model = ols('GenotypeIndex ~ C(Cluster) + Age + C(Diagnosis) + Sex', data=df_ancova).fit()
anova_table = sm.stats.anova_lm(model, typ=2)
print("ANCOVA (GenotypeIndex ~ Cluster + Age + Diangosis + Sex):")
print(anova_table.round(4))

## Assumption checks for MLR

### Linearity

In [None]:
# Check for linearity

X = pd.DataFrame({
    'Diagnosis': all_labels.ravel(),
    'Education': all_education.ravel(),
    'GenotypeIndex': genotype_cat.ravel(),
    'Age': all_age.ravel(),
    'Sex': all_sex.ravel()
})
y = pd.Series(cluster_labels, name='Cluster')

#X['GenotypeIndex'] = X['GenotypeIndex'].astype('category') # to use categorical in MLR

# Log-transform some variables
#X['Age_log'] = np.log(X['Age'] + 1)
#X['Genotype_log'] = np.log(X['GenotypeIndex'] + 1)

# Predictors to check
predictors = ['Diagnosis', 'Education', 'GenotypeIndex', 'Age', 'Sex']
X_model = sm.add_constant(X[predictors])

# Fit MLR
mlr_model = sm.MNLogit(y, X_model).fit(disp=False)

# Predict probabilities
probs = mlr_model.predict(X_model)
probs.columns = ['P(Cluster=0)', 'P(Cluster=1)', 'P(Cluster=2)']

# Predictor vs. predicted probability for each cluster
for predictor in predictors:
    plt.figure(figsize=(14, 4))
    
    for i in range(3):  # for each cluster class
        plt.subplot(1, 3, i+1)
        sns.scatterplot(x=X[predictor], y=probs.iloc[:, i], alpha=0.4)
        sns.regplot(x=X[predictor], y=probs.iloc[:, i], scatter=False, color='red', lowess=True)
        plt.title(f'{predictor} vs P(Cluster={i})')
        plt.xlabel(predictor)
        plt.ylabel(f'P(Cluster={i})')
    
    plt.tight_layout()
    plt.show()


### Residuals

In [None]:

X = pd.DataFrame({
    'Diagnosis': all_labels.ravel(),
    'Education': all_education.ravel(),
    'GenotypeIndex': genotype_cat.ravel(),
    'Age': all_age.ravel(),
    'Sex': all_sex.ravel()
})

y = cluster_labels.ravel()
X = add_constant(X)

# Fit MLR
mlr_model = MNLogit(y, X).fit()
predicted_probs = mlr_model.predict(X)

# Get residuals (raw residuals from predicted class probs)
residuals = []
for i in range(len(y)):
    true_class = y[i]
    residual = 1 - predicted_probs.iloc[i, true_class]
    residuals.append(residual)

residuals = np.array(residuals)

plt.figure(figsize=(6, 4))
sns.histplot(residuals, kde=True, color='steelblue')
plt.title("Histogram of Residuals")
plt.xlabel("Residual")
plt.ylabel("Count")
plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(6, 4))
stats.probplot(residuals, dist="norm", plot=plt)
plt.title("Q–Q Plot of Residuals")
plt.tight_layout()
plt.show()


In [None]:
# Residuals are not required, for this is not a classification task

ks_stat, ks_p = stats.kstest(residuals, 'norm', args=(np.mean(residuals), np.std(residuals)))
print(f"Kolmogorov–Smirnov Test: D = {ks_stat:.4f}, p = {ks_p:.4e}")

if ks_p > 0.05:
    print("Residuals appear to follow a normal distribution (fail to reject H0)")
else:
    print("Residuals deviate from normal distribution (reject H0)")


### No multicollinearity

In [None]:

X = pd.DataFrame({
    'Diagnosis': all_labels.ravel(),
    'Education': all_education.ravel(),
    'GenotypeIndex': genotype_cat.ravel(),
    'Age': all_age.ravel(),
    'Sex': all_sex.ravel()
})

# Correlation matrix
plt.figure(figsize=(8, 6))
sns.heatmap(X.corr(), annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Matrix of Predictors")
plt.tight_layout()
plt.show()

# Variance Inflation Factor (VIF)
X_with_const = X.copy()
X_with_const['Intercept'] = 1

# Compute VIF
vif_data = pd.DataFrame()
vif_data['Feature'] = X_with_const.columns
vif_data['VIF'] = [variance_inflation_factor(X_with_const.values, i)
                   for i in range(X_with_const.shape[1])]

# Drop intercept row from VIF table
vif_data = vif_data[vif_data['Feature'] != 'Intercept']

print("\nVariance Inflation Factor (VIF):")
print(vif_data)
