# NeuroSmriti - Model Testing & Evaluation

This notebook tests the trained Alzheimer's detection models and provides comprehensive evaluation metrics.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import json
import os
from datetime import datetime
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_auc_score, roc_curve,
    precision_recall_curve, average_precision_score
)
import warnings
warnings.filterwarnings('ignore')

print("Libraries loaded!")

## 1. Load Models and Test Data

In [None]:
# Load models
models = {}
model_files = [
    ('Random Forest', '../models/random_forest_model.pkl'),
    ('Gradient Boosting', '../models/gradient_boosting_model.pkl'),
    ('Neural Network', '../models/neural_network_model.pkl'),
    ('Ensemble', '../models/ensemble_model.pkl')
]

for name, path in model_files:
    if os.path.exists(path):
        models[name] = joblib.load(path)
        print(f"Loaded: {name}")
    else:
        print(f"Not found: {path}")

# Load preprocessors
scaler = joblib.load('../models/scaler.pkl')
label_encoder = joblib.load('../models/label_encoder.pkl')

with open('../models/feature_list.json', 'r') as f:
    feature_list = json.load(f)

print(f"\nLoaded {len(models)} models")
print(f"Features: {len(feature_list)}")
print(f"Classes: {label_encoder.classes_}")

In [None]:
# Load test data
df = pd.read_csv('../data/alzheimers_420k_dataset.csv')

# Prepare features (same as training)
binary_cols = ['has_apoe4', 'family_history_ad', 'amyloid_positive', 'tau_positive',
               'hypertension', 'diabetes', 'depression']
for col in binary_cols:
    if col in df.columns:
        df[col] = df[col].astype(int)

df['gender_female'] = (df['gender'] == 'Female').astype(int)

# Create test set (last 20%)
test_size = int(len(df) * 0.2)
test_df = df.tail(test_size)

X_test = test_df[feature_list].fillna(0).values
y_test = test_df['diagnosis_stage'].values

X_test_scaled = scaler.transform(X_test)
y_test_encoded = label_encoder.transform(y_test)

print(f"Test set size: {len(X_test):,} samples")

## 2. Model Evaluation

In [None]:
# Evaluate each model
evaluation_results = {}

print("="*70)
print("MODEL EVALUATION RESULTS")
print("="*70)

for name, model in models.items():
    print(f"\n--- {name} ---")
    
    # Predictions
    y_pred = model.predict(X_test_scaled)
    y_pred_proba = model.predict_proba(X_test_scaled) if hasattr(model, 'predict_proba') else None
    
    # Metrics
    accuracy = accuracy_score(y_test_encoded, y_pred)
    precision = precision_score(y_test_encoded, y_pred, average='weighted')
    recall = recall_score(y_test_encoded, y_pred, average='weighted')
    f1 = f1_score(y_test_encoded, y_pred, average='weighted')
    
    # Per-class metrics
    precision_per_class = precision_score(y_test_encoded, y_pred, average=None)
    recall_per_class = recall_score(y_test_encoded, y_pred, average=None)
    f1_per_class = f1_score(y_test_encoded, y_pred, average=None)
    
    # ROC AUC
    roc_auc = None
    if y_pred_proba is not None:
        try:
            roc_auc = roc_auc_score(y_test_encoded, y_pred_proba, multi_class='ovr', average='weighted')
        except:
            pass
    
    evaluation_results[name] = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'roc_auc': roc_auc,
        'precision_per_class': precision_per_class,
        'recall_per_class': recall_per_class,
        'f1_per_class': f1_per_class,
        'y_pred': y_pred,
        'y_pred_proba': y_pred_proba
    }
    
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    if roc_auc:
        print(f"ROC AUC:   {roc_auc:.4f}")

In [None]:
# Detailed classification report for best model
best_model_name = max(evaluation_results.keys(), key=lambda k: evaluation_results[k]['f1_score'])
best_results = evaluation_results[best_model_name]

print(f"\n{'='*70}")
print(f"DETAILED REPORT: {best_model_name}")
print(f"{'='*70}\n")

print(classification_report(y_test_encoded, best_results['y_pred'], target_names=label_encoder.classes_))

## 3. Confusion Matrix Analysis

In [None]:
# Detailed confusion matrix for best model
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
cm = confusion_matrix(y_test_encoded, best_results['y_pred'])
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_)
axes[0].set_title(f'{best_model_name} - Confusion Matrix (Counts)')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')

# Normalized
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', ax=axes[1],
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_)
axes[1].set_title(f'{best_model_name} - Confusion Matrix (Normalized)')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')

plt.tight_layout()
plt.savefig('../data/test_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Misclassification analysis
print("Misclassification Analysis:")
print("-" * 50)

for i, class_name in enumerate(label_encoder.classes_):
    class_mask = y_test_encoded == i
    class_total = class_mask.sum()
    correct = (best_results['y_pred'][class_mask] == i).sum()
    incorrect = class_total - correct
    
    print(f"\n{class_name.capitalize()}:")
    print(f"  Total: {class_total:,}")
    print(f"  Correct: {correct:,} ({correct/class_total*100:.1f}%)")
    print(f"  Misclassified: {incorrect:,} ({incorrect/class_total*100:.1f}%)")
    
    # Most common misclassifications
    if incorrect > 0:
        misclass = best_results['y_pred'][class_mask & (best_results['y_pred'] != i)]
        if len(misclass) > 0:
            most_common = np.bincount(misclass).argmax()
            print(f"  Most often confused with: {label_encoder.classes_[most_common]}")

## 4. Per-Class Performance

In [None]:
# Per-class metrics comparison
metrics_df = pd.DataFrame({
    'Class': label_encoder.classes_,
    'Precision': best_results['precision_per_class'],
    'Recall': best_results['recall_per_class'],
    'F1 Score': best_results['f1_per_class']
})

print("Per-Class Metrics:")
print(metrics_df.to_string(index=False))

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(label_encoder.classes_))
width = 0.25

ax.bar(x - width, metrics_df['Precision'], width, label='Precision', color='#3498db')
ax.bar(x, metrics_df['Recall'], width, label='Recall', color='#2ecc71')
ax.bar(x + width, metrics_df['F1 Score'], width, label='F1 Score', color='#e74c3c')

ax.set_xlabel('Disease Stage')
ax.set_ylabel('Score')
ax.set_title('Per-Class Performance Metrics')
ax.set_xticks(x)
ax.set_xticklabels([c.capitalize() for c in label_encoder.classes_])
ax.legend()
ax.set_ylim(0.8, 1.0)

plt.tight_layout()
plt.savefig('../data/per_class_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. ROC Curves

In [None]:
# ROC curves for multi-class
if best_results['y_pred_proba'] is not None:
    fig, ax = plt.subplots(figsize=(10, 8))
    
    colors = plt.cm.viridis(np.linspace(0, 0.8, len(label_encoder.classes_)))
    
    for i, (class_name, color) in enumerate(zip(label_encoder.classes_, colors)):
        # Binary labels for this class
        y_binary = (y_test_encoded == i).astype(int)
        y_score = best_results['y_pred_proba'][:, i]
        
        fpr, tpr, _ = roc_curve(y_binary, y_score)
        auc = roc_auc_score(y_binary, y_score)
        
        ax.plot(fpr, tpr, color=color, lw=2, label=f'{class_name.capitalize()} (AUC = {auc:.3f})')
    
    ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title(f'ROC Curves - {best_model_name}')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../data/roc_curves.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("Probability predictions not available for ROC curves")

## 6. Prediction Confidence Analysis

In [None]:
# Analyze prediction confidence
if best_results['y_pred_proba'] is not None:
    confidences = np.max(best_results['y_pred_proba'], axis=1)
    
    # Correct vs incorrect predictions
    correct_mask = best_results['y_pred'] == y_test_encoded
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Confidence distribution
    axes[0].hist(confidences[correct_mask], bins=50, alpha=0.7, label='Correct', color='#2ecc71')
    axes[0].hist(confidences[~correct_mask], bins=50, alpha=0.7, label='Incorrect', color='#e74c3c')
    axes[0].set_xlabel('Prediction Confidence')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Prediction Confidence Distribution')
    axes[0].legend()
    
    # Accuracy by confidence threshold
    thresholds = np.linspace(0.5, 0.99, 20)
    accuracies = []
    coverages = []
    
    for thresh in thresholds:
        mask = confidences >= thresh
        if mask.sum() > 0:
            acc = accuracy_score(y_test_encoded[mask], best_results['y_pred'][mask])
            cov = mask.sum() / len(mask)
            accuracies.append(acc)
            coverages.append(cov)
    
    ax2 = axes[1]
    ax2.plot(thresholds[:len(accuracies)], accuracies, 'b-', label='Accuracy', linewidth=2)
    ax2.set_xlabel('Confidence Threshold')
    ax2.set_ylabel('Accuracy', color='blue')
    ax2.tick_params(axis='y', labelcolor='blue')
    
    ax3 = ax2.twinx()
    ax3.plot(thresholds[:len(coverages)], coverages, 'r--', label='Coverage', linewidth=2)
    ax3.set_ylabel('Coverage', color='red')
    ax3.tick_params(axis='y', labelcolor='red')
    
    axes[1].set_title('Accuracy vs Coverage Trade-off')
    
    plt.tight_layout()
    plt.savefig('../data/confidence_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nConfidence Statistics:")
    print(f"  Mean confidence: {confidences.mean():.3f}")
    print(f"  Median confidence: {np.median(confidences):.3f}")
    print(f"  Correct predictions avg confidence: {confidences[correct_mask].mean():.3f}")
    print(f"  Incorrect predictions avg confidence: {confidences[~correct_mask].mean():.3f}")

## 7. Sample Predictions

In [None]:
# Show sample predictions
print("Sample Predictions:")
print("="*80)

sample_indices = np.random.choice(len(y_test), size=10, replace=False)

for idx in sample_indices:
    actual = label_encoder.classes_[y_test_encoded[idx]]
    predicted = label_encoder.classes_[best_results['y_pred'][idx]]
    correct = "YES" if actual == predicted else "NO"
    
    if best_results['y_pred_proba'] is not None:
        confidence = best_results['y_pred_proba'][idx].max()
        print(f"Patient {idx}: Actual={actual:<10} Predicted={predicted:<10} Correct={correct:<4} Confidence={confidence:.2%}")
    else:
        print(f"Patient {idx}: Actual={actual:<10} Predicted={predicted:<10} Correct={correct}")

## 8. Save Test Results

In [None]:
# Save comprehensive test report
test_report = {
    "test_date": datetime.now().isoformat(),
    "test_size": len(y_test),
    "models_tested": list(models.keys()),
    "best_model": best_model_name,
    "results": {}
}

for name, results in evaluation_results.items():
    test_report["results"][name] = {
        "accuracy": float(results['accuracy']),
        "precision": float(results['precision']),
        "recall": float(results['recall']),
        "f1_score": float(results['f1_score']),
        "roc_auc": float(results['roc_auc']) if results['roc_auc'] else None
    }

with open('../models/test_results.json', 'w') as f:
    json.dump(test_report, f, indent=2)

print("Test results saved to ../models/test_results.json")

In [None]:
# Final summary
print("\n" + "="*70)
print("FINAL TEST SUMMARY")
print("="*70)
print(f"\nTest Dataset Size: {len(y_test):,} samples")
print(f"Best Model: {best_model_name}")
print(f"\nPerformance Metrics:")
print(f"  Accuracy:  {best_results['accuracy']:.4f} ({best_results['accuracy']*100:.2f}%)")
print(f"  Precision: {best_results['precision']:.4f}")
print(f"  Recall:    {best_results['recall']:.4f}")
print(f"  F1 Score:  {best_results['f1_score']:.4f}")
if best_results['roc_auc']:
    print(f"  ROC AUC:   {best_results['roc_auc']:.4f}")
print("\n" + "="*70)