# Group 18 — Dementia Prediction Project
## Step 2: Complete EDA + Export Results

> **Prerequisites:** Run Step 1 first to generate `dementia_clean.csv`  
> **Dataset:** [Dementia Dataset on Kaggle](https://www.kaggle.com/datasets/fatemehmehrparvar/dementia)

In [None]:
# Uncomment if running in Colab:
# !pip install pandas numpy matplotlib seaborn scipy

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

print("✓ Libraries loaded successfully")

## 0. Load Data

In [None]:
# Colab users: upload both CSV files before running
df_raw = pd.read_csv('OPTIMAL_combined_3studies_6feb2020.csv')
df     = pd.read_csv('dementia_clean.csv')

colors = ['#4C72B0', '#DD8452']
labels = ['No Dementia', 'Dementia']

print(f"Raw data shape:   {df_raw.shape}")
print(f"Clean data shape: {df.shape}")
df.head()

## Figure 1 — Target Variable & Continuous Feature Distributions

KDE (density-normalised) curves are used for continuous variables to handle the class imbalance (~95% No Dementia vs ~5% Dementia). Dashed vertical lines mark group means.

In [None]:
cont_vars = [
    ('age',            'Age'),
    ('EF',             'Executive Function (EF)'),
    ('PS',             'Processing Speed (PS)'),
    ('Global',         'Global Cognition'),
    ('educationyears', 'Education (Years)'),
]

fig1, axes = plt.subplots(2, 3, figsize=(16, 10))
fig1.suptitle('Figure 1 — Target Variable & Continuous Features Distribution',
              fontsize=14, fontweight='bold', y=1.01)

# ── Class distribution bar chart ──────────────────────────────
ax = axes[0, 0]
counts = df_raw['dementia'].value_counts().sort_index()
bars = ax.bar(['No Dementia', 'Dementia'], counts.values,
              color=colors, edgecolor='white', linewidth=1.5)
for bar, val in zip(bars, counts.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 15,
            f'n={val}\n({val/sum(counts.values)*100:.1f}%)',
            ha='center', fontsize=10)
ax.set_title('Dementia Class Distribution', fontweight='bold')
ax.set_ylabel('Count')
ax.set_ylim(0, 2100)

# ── KDE plots for continuous variables ───────────────────────
ax_positions = [(0,1), (0,2), (1,0), (1,1), (1,2)]
for (var, title), (r, c) in zip(cont_vars, ax_positions):
    ax = axes[r, c]
    for i, (label, color) in enumerate(zip(labels, colors)):
        subset = df[df['dementia'] == i][var].dropna()
        subset.plot.kde(ax=ax, color=color, label=f'{label} (n={len(subset)})', linewidth=2)
        ax.axvline(subset.mean(), color=color, linestyle='--', alpha=0.6, linewidth=1)
    ax.set_title(title, fontweight='bold')
    ax.set_ylabel('Density')
    ax.legend(fontsize=8)
    ax.set_xlabel(var)

plt.tight_layout()
plt.savefig('EDA_fig1_distributions.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()
print("✓ Saved: EDA_fig1_distributions.png")

## Figure 2 — Outlier Detection

Box plots grouped by dementia status. Outlier counts (IQR method) are labelled in red below each box.

In [None]:
fig2, axes = plt.subplots(1, 5, figsize=(18, 6))
fig2.suptitle('Figure 2 — Outlier Detection (Box Plots by Dementia Status)',
              fontsize=14, fontweight='bold')

for ax, (var, title) in zip(axes, cont_vars):
    data_plot = [df[df['dementia'] == i][var].dropna() for i in [0, 1]]
    bp = ax.boxplot(data_plot, patch_artist=True,
                    medianprops=dict(color='black', linewidth=2))
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    ax.set_xticklabels(['No\nDementia', 'Dementia'])
    ax.set_title(title, fontweight='bold', fontsize=10)
    ax.set_ylabel(var)
    for i, data in enumerate(data_plot):
        q1, q3 = data.quantile(0.25), data.quantile(0.75)
        iqr = q3 - q1
        outliers = data[(data < q1 - 1.5*iqr) | (data > q3 + 1.5*iqr)]
        ax.text(i+1, ax.get_ylim()[0], f'{len(outliers)} outliers',
                ha='center', fontsize=8, color='red')

plt.tight_layout()
plt.savefig('EDA_fig2_outliers.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()
print("✓ Saved: EDA_fig2_outliers.png")

## Figure 3 — Categorical Features vs Dementia Status

Within-group percentage bar charts allow direct comparison across groups despite class imbalance. Chi-squared p-values and significance stars are shown in each subplot title.

In [None]:
cat_info = [
    ('hypertension',         {0:'No',    1:'Yes'},                   'Hypertension'),
    ('diabetes',             {0:'No',    1:'Yes'},                    'Diabetes'),
    ('hypercholesterolemia', {0:'No',    1:'Yes'},                    'Hypercholesterolaemia'),
    ('smoking',              {0:'Never', 1:'Ex',    2:'Current'},     'Smoking Status'),
    ('gender',               {0:'Male',  1:'Female'},                 'Gender'),
    ('Fazekas',              {0:'0',     1:'1',     2:'2',  3:'3'},   'Fazekas Score'),
    ('lac_count',            {0:'Zero',  1:'1-2',   2:'3-5', 3:'>5'},'Lacune Count'),
    ('CMB_count',            {0:'None',  1:'\u22651'},               'CMB Count'),
]

fig3, axes = plt.subplots(2, 4, figsize=(20, 10))
fig3.suptitle('Figure 3 — Categorical Features vs Dementia Status (%)',
              fontsize=14, fontweight='bold')

for ax, (var, tick_map, title) in zip(axes.flat, cat_info):
    ct = pd.crosstab(df[var], df['dementia'])
    ct_pct = ct.div(ct.sum(axis=1), axis=0) * 100
    chi2, p, _, _ = stats.chi2_contingency(ct)
    sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'ns'
    ct_pct.plot(kind='bar', ax=ax, color=colors, edgecolor='white', linewidth=1, legend=False)
    ax.set_title(f'{title}\n(p={p:.3f} {sig})', fontweight='bold', fontsize=10)
    ax.set_ylabel('% within group')
    ax.set_xticklabels([tick_map.get(x, x) for x in ct_pct.index], rotation=15, ha='right')
    ax.set_ylim(0, 115)
    for container in ax.containers:
        ax.bar_label(container, fmt='%.1f%%', fontsize=7, padding=2)

handles = [plt.Rectangle((0,0),1,1, color=c) for c in colors]
fig3.legend(handles, labels, loc='upper right', fontsize=10)
plt.tight_layout()
plt.savefig('EDA_fig3_categorical.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()
print("✓ Saved: EDA_fig3_categorical.png")

## Figure 4 — Feature Relationships

Three panels: (1) Age vs EF scatter with trend lines, (2) EF vs PS scatter, (3) Fazekas score vs Global Cognition boxplots — all coloured by dementia status.

In [None]:
fig4, axes = plt.subplots(1, 3, figsize=(18, 6))
fig4.suptitle('Figure 4 — Feature Relationships', fontsize=14, fontweight='bold')

# ── Panel 1: Age vs EF ────────────────────────────────────────
ax = axes[0]
for i, (label, color) in enumerate(zip(labels, colors)):
    subset = df[df['dementia'] == i]
    ax.scatter(subset['age'], subset['EF'], alpha=0.4, color=color, label=label, s=20)
for i, color in enumerate(colors):
    subset = df[df['dementia'] == i].dropna(subset=['age', 'EF'])
    z = np.polyfit(subset['age'], subset['EF'], 1)
    p_line = np.poly1d(z)
    x_line = np.linspace(subset['age'].min(), subset['age'].max(), 100)
    ax.plot(x_line, p_line(x_line), color=color, linewidth=2, linestyle='--')
ax.set_xlabel('Age')
ax.set_ylabel('Executive Function (EF)')
ax.set_title('Age vs EF by Dementia Status', fontweight='bold')
ax.legend()

# ── Panel 2: EF vs PS ─────────────────────────────────────────
ax = axes[1]
for i, (label, color) in enumerate(zip(labels, colors)):
    subset = df[df['dementia'] == i]
    ax.scatter(subset['EF'], subset['PS'], alpha=0.4, color=color, label=label, s=20)
ax.set_xlabel('Executive Function (EF)')
ax.set_ylabel('Processing Speed (PS)')
ax.set_title('EF vs PS by Dementia Status', fontweight='bold')
ax.legend()

# ── Panel 3: Fazekas vs Global Cognition ─────────────────────
ax = axes[2]
fazekas_vals = sorted(df['Fazekas'].dropna().unique())
positions_offset = [-0.2, 0.2]
for i, (label, color) in enumerate(zip(labels, colors)):
    subset = df[df['dementia'] == i]
    data_by_fazekas = [subset[subset['Fazekas'] == f]['Global'].dropna() for f in fazekas_vals]
    positions = [f + positions_offset[i] for f in fazekas_vals]
    bp = ax.boxplot(data_by_fazekas, positions=positions, widths=0.35, patch_artist=True,
                    medianprops=dict(color='black', linewidth=1.5))
    for patch in bp['boxes']:
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
ax.set_xlabel('Fazekas Score')
ax.set_ylabel('Global Cognition Score')
ax.set_title('Fazekas vs Global Cognition\nby Dementia Status', fontweight='bold')
ax.set_xticks(fazekas_vals)
ax.set_xticklabels(fazekas_vals)
handles = [plt.Rectangle((0,0),1,1, color=c, alpha=0.7) for c in colors]
ax.legend(handles, labels)

plt.tight_layout()
plt.savefig('EDA_fig4_relationships.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()
print("✓ Saved: EDA_fig4_relationships.png")

## Figure 5 — Cohort Comparison

Checks consistency across the three studies (ASPS, rundmc, scans) by comparing dementia rates, age distributions, EF distributions, and comorbidity rates.

In [None]:
# Reconstruct study label from one-hot columns
df['study_label'] = 'ASPS'
df.loc[df['study1_rundmc'] == 1, 'study_label'] = 'rundmc'
df.loc[df['study1_scans']  == 1, 'study_label'] = 'scans'
study_list   = df['study_label'].unique()
study_colors = {'ASPS': '#4C72B0', 'rundmc': '#DD8452', 'scans': '#2ca02c'}

fig5, axes = plt.subplots(2, 3, figsize=(16, 10))
fig5.suptitle('Figure 5 — Cohort Comparison (study1)\nChecking consistency across studies',
              fontsize=14, fontweight='bold')

# Dementia rate
ax = axes[0, 0]
dem_rate = df.groupby('study_label')['dementia'].mean() * 100
bars = ax.bar(dem_rate.index, dem_rate.values,
              color=[study_colors[s] for s in dem_rate.index], edgecolor='white')
for bar, val in zip(bars, dem_rate.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2,
            f'{val:.1f}%', ha='center', fontsize=10)
ax.set_title('Dementia Rate by Study', fontweight='bold')
ax.set_ylabel('Dementia Rate (%)')

# Age KDE
ax = axes[0, 1]
for study in study_list:
    df[df['study_label'] == study]['age'].plot.kde(
        ax=ax, label=study, color=study_colors[study], linewidth=2)
ax.set_title('Age Distribution by Study', fontweight='bold')
ax.set_xlabel('Age'); ax.set_ylabel('Density'); ax.legend()

# EF KDE
ax = axes[0, 2]
for study in study_list:
    df[df['study_label'] == study]['EF'].plot.kde(
        ax=ax, label=study, color=study_colors[study], linewidth=2)
ax.set_title('EF Distribution by Study', fontweight='bold')
ax.set_xlabel('EF Score'); ax.set_ylabel('Density'); ax.legend()

# Hypertension rate
ax = axes[1, 0]
hyp_rate = df.groupby('study_label')['hypertension'].mean() * 100
bars = ax.bar(hyp_rate.index, hyp_rate.values,
              color=[study_colors[s] for s in hyp_rate.index], edgecolor='white')
for bar, val in zip(bars, hyp_rate.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
            f'{val:.1f}%', ha='center', fontsize=10)
ax.set_title('Hypertension Rate by Study', fontweight='bold')
ax.set_ylabel('Hypertension Rate (%)')

# Diabetes rate
ax = axes[1, 1]
diab_rate = df.groupby('study_label')['diabetes'].mean() * 100
bars = ax.bar(diab_rate.index, diab_rate.values,
              color=[study_colors[s] for s in diab_rate.index], edgecolor='white')
for bar, val in zip(bars, diab_rate.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2,
            f'{val:.1f}%', ha='center', fontsize=10)
ax.set_title('Diabetes Rate by Study', fontweight='bold')
ax.set_ylabel('Diabetes Rate (%)')

# Sample size
ax = axes[1, 2]
n_by_study = df['study_label'].value_counts()
bars = ax.bar(n_by_study.index, n_by_study.values,
              color=[study_colors[s] for s in n_by_study.index], edgecolor='white')
for bar, val in zip(bars, n_by_study.values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5,
            f'n={val}', ha='center', fontsize=10)
ax.set_title('Sample Size by Study', fontweight='bold')
ax.set_ylabel('Count')

plt.tight_layout()
plt.savefig('EDA_fig5_cohorts.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()
print("✓ Saved: EDA_fig5_cohorts.png")

## Figure 6 — Correlation Matrix

Lower-triangle heatmap of Pearson correlations across all key features.

In [None]:
corr_vars = ['dementia', 'age', 'educationyears', 'EF', 'PS', 'Global',
             'diabetes', 'hypertension', 'hypercholesterolemia',
             'smoking', 'Fazekas', 'lac_count', 'CMB_count']

fig6, ax = plt.subplots(figsize=(12, 10))
corr_matrix = df[corr_vars].corr()
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(corr_matrix, mask=mask, ax=ax, annot=True, fmt='.2f',
            cmap='RdBu_r', center=0, vmin=-1, vmax=1,
            square=True, linewidths=0.5, cbar_kws={'shrink': 0.8})
ax.set_title('Figure 6 — Correlation Matrix (All Key Features)',
             fontsize=14, fontweight='bold', pad=15)
plt.tight_layout()
plt.savefig('EDA_fig6_correlation.png', dpi=150, bbox_inches='tight', facecolor='white')
plt.show()
print("✓ Saved: EDA_fig6_correlation.png")

## Export Statistical Results to CSV

Five CSV files are generated:

| File | Contents |
|------|----------|
| `EDA_1_numerical_stats.csv` | t-test results for continuous variables |
| `EDA_2_outlier_summary.csv` | IQR-based outlier counts per group |
| `EDA_3_categorical_stats.csv` | Chi-squared results for categorical variables |
| `EDA_4_cohort_comparison.csv` | Key metrics broken down by study |
| `EDA_5_correlation_matrix.csv` | Full correlation matrix |

In [None]:
# ── 1. Numerical statistics (t-test) ─────────────────────────
num_rows = []
for var, title in cont_vars:
    g0 = df[df['dementia'] == 0][var].dropna()
    g1 = df[df['dementia'] == 1][var].dropna()
    t_stat, p = stats.ttest_ind(g0, g1)
    sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'ns'
    num_rows.append({
        'Variable':         title,
        'No_Dementia_Mean': round(g0.mean(), 3),
        'No_Dementia_SD':   round(g0.std(),  3),
        'No_Dementia_N':    len(g0),
        'Dementia_Mean':    round(g1.mean(), 3),
        'Dementia_SD':      round(g1.std(),  3),
        'Dementia_N':       len(g1),
        'p_value':          round(p, 4),
        'Significance':     sig,
    })
df_num = pd.DataFrame(num_rows)
df_num.to_csv('EDA_1_numerical_stats.csv', index=False)
print("✓ EDA_1_numerical_stats.csv")
df_num

In [None]:
# ── 2. Outlier summary ───────────────────────────────────────
out_rows = []
for var, title in cont_vars:
    data = df[var].dropna()
    q1, q3 = data.quantile(0.25), data.quantile(0.75)
    iqr = q3 - q1
    for grp, grp_label in [(0, 'No Dementia'), (1, 'Dementia')]:
        g = df[df['dementia'] == grp][var].dropna()
        outliers = g[(g < q1 - 1.5*iqr) | (g > q3 + 1.5*iqr)]
        out_rows.append({
            'Variable':     title,
            'Group':        grp_label,
            'Q1':           round(q1,  3),
            'Q3':           round(q3,  3),
            'IQR':          round(iqr, 3),
            'N_Outliers':   len(outliers),
            'Pct_Outliers': round(len(outliers) / len(g) * 100, 1),
        })
df_out = pd.DataFrame(out_rows)
df_out.to_csv('EDA_2_outlier_summary.csv', index=False)
print("✓ EDA_2_outlier_summary.csv")
df_out

In [None]:
# ── 3. Categorical statistics (chi-squared) ──────────────────
cat_rows = []
for var, tick_map, title in cat_info:
    ct = pd.crosstab(df[var], df['dementia'])
    ct_pct = ct.div(ct.sum(axis=1), axis=0) * 100
    chi2_val, p, _, _ = stats.chi2_contingency(ct)
    sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 else 'ns'
    for cat_val in sorted(ct.index):
        cat_rows.append({
            'Variable':         title,
            'Category':         tick_map.get(cat_val, str(cat_val)),
            'N_Total':          int(ct.loc[cat_val].sum()),
            'Pct_No_Dementia':  round(ct_pct.loc[cat_val, 0.0], 1),
            'Pct_Dementia':     round(ct_pct.loc[cat_val, 1.0], 1),
            'Chi2_p_value':     round(p, 4),
            'Significance':     sig,
        })
df_cat = pd.DataFrame(cat_rows)
df_cat.to_csv('EDA_3_categorical_stats.csv', index=False)
print("✓ EDA_3_categorical_stats.csv")
df_cat

In [None]:
# ── 4. Cohort comparison ─────────────────────────────────────
cohort_rows = []
for study in ['ASPS', 'rundmc', 'scans']:
    sub = df[df['study_label'] == study]
    cohort_rows.append({
        'Study':                 study,
        'N':                     len(sub),
        'Dementia_Rate_pct':     round(sub['dementia'].mean() * 100, 1),
        'Mean_Age':              round(sub['age'].mean(), 1),
        'Hypertension_Rate_pct': round(sub['hypertension'].mean() * 100, 1),
        'Diabetes_Rate_pct':     round(sub['diabetes'].mean() * 100, 1),
        'Mean_EF':               round(sub['EF'].mean(), 2),
    })
df_cohort = pd.DataFrame(cohort_rows)
df_cohort.to_csv('EDA_4_cohort_comparison.csv', index=False)
print("✓ EDA_4_cohort_comparison.csv")
df_cohort

In [None]:
# ── 5. Correlation matrix ─────────────────────────────────────
corr_out_vars = ['dementia', 'age', 'educationyears', 'EF', 'PS', 'Global',
                 'diabetes', 'hypertension', 'hypercholesterolemia',
                 'smoking', 'Fazekas', 'lac_count', 'CMB_count']
corr_out = df[corr_out_vars].corr().round(3)
corr_out.to_csv('EDA_5_correlation_matrix.csv')
print("✓ EDA_5_correlation_matrix.csv")
corr_out

## Summary

| Output | Description |
|--------|-------------|
| `EDA_fig1_distributions.png` | Class balance + KDE plots |
| `EDA_fig2_outliers.png` | Box plots with outlier counts |
| `EDA_fig3_categorical.png` | % bar charts for categorical features |
| `EDA_fig4_relationships.png` | Scatter & box plots between features |
| `EDA_fig5_cohorts.png` | Cross-study comparison |
| `EDA_fig6_correlation.png` | Correlation heatmap |
| `EDA_1_numerical_stats.csv` | t-test results |
| `EDA_2_outlier_summary.csv` | Outlier summary |
| `EDA_3_categorical_stats.csv` | Chi-squared results |
| `EDA_4_cohort_comparison.csv` | Cohort breakdown |
| `EDA_5_correlation_matrix.csv` | Full correlation matrix |

**Next step:** Run Step 3 modelling notebook.