# Linear Discriminant Analysis on Heart Failure Clinical Records
## Medical Diagnostic Modeling: Mortality Prediction

**Dataset Overview:**
- 299 patients with heart failure
- 13 clinical features
- Binary classification: Death event (survival prediction)
- Features: Age, anemia, creatinine, ejection fraction, platelets, serum sodium, etc.

**Focus:** Binary classification for medical diagnosis

In [None]:
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    f1_score,
    matthews_corrcoef
)
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 8)

print("Libraries imported successfully!")

## 1. Data Loading

**Note:** This notebook uses the Heart Failure Clinical Records dataset.
You can download it from:
- UCI ML Repository or Kaggle
- File: `heart_failure_clinical_records_dataset.csv`

For this demo, we'll create a synthetic version.

In [None]:
try:
    df = pd.read_csv('heart_failure_clinical_records_dataset.csv')
    print("Real dataset loaded!")
except FileNotFoundError:
    print("Creating synthetic heart failure dataset...")
    np.random.seed(42)
    n = 299
    
    age = np.random.normal(60, 12, n).astype(int)
    age = np.clip(age, 40, 95)
    
    anaemia = np.random.choice([0, 1], n, p=[0.57, 0.43])
    creatinine_phosphokinase = np.random.lognormal(5.5, 1.2, n).astype(int)
    diabetes = np.random.choice([0, 1], n, p=[0.58, 0.42])
    ejection_fraction = np.random.normal(38, 11, n).astype(int)
    ejection_fraction = np.clip(ejection_fraction, 14, 80)
    high_blood_pressure = np.random.choice([0, 1], n, p=[0.65, 0.35])
    platelets = np.random.normal(263000, 97000, n)
    serum_creatinine = np.random.lognormal(0.17, 0.6, n)
    serum_sodium = np.random.normal(137, 4, n)
    sex = np.random.choice([0, 1], n, p=[0.35, 0.65])
    smoking = np.random.choice([0, 1], n, p=[0.68, 0.32])
    time = np.random.randint(4, 285, n)
    
    # Create death event with dependencies
    risk_score = (
        (age > 65) * 0.15 +
        (ejection_fraction < 30) * 0.25 +
        (serum_creatinine > 1.5) * 0.20 +
        (serum_sodium < 135) * 0.15 +
        (time < 100) * 0.10 +
        anaemia * 0.05 +
        high_blood_pressure * 0.05 +
        diabetes * 0.05
    )
    death_event = (np.random.random(n) < risk_score).astype(int)
    
    df = pd.DataFrame({
        'age': age,
        'anaemia': anaemia,
        'creatinine_phosphokinase': creatinine_phosphokinase,
        'diabetes': diabetes,
        'ejection_fraction': ejection_fraction,
        'high_blood_pressure': high_blood_pressure,
        'platelets': platelets,
        'serum_creatinine': serum_creatinine,
        'serum_sodium': serum_sodium,
        'sex': sex,
        'smoking': smoking,
        'time': time,
        'DEATH_EVENT': death_event
    })
    print("Synthetic dataset created!")

print(f"\nDataset shape: {df.shape}")
print(f"\nTarget distribution:")
print(df['DEATH_EVENT'].value_counts())
print(f"\nMortality rate: {df['DEATH_EVENT'].mean()*100:.1f}%")

In [None]:
print('Dataset Information:')
print(df.info())
print('\nStatistical Summary:')
display(df.describe())
print('\nMissing Values:')
print(df.isnull().sum())

## 2. Medical Feature Analysis

In [None]:
# Clinical features by outcome
fig, axes = plt.subplots(3, 4, figsize=(20, 15))
axes = axes.ravel()

features = df.columns[:-1]

for idx, feature in enumerate(features):
    for outcome in [0, 1]:
        data = df[df['DEATH_EVENT'] == outcome][feature]
        if feature in ['anaemia', 'diabetes', 'high_blood_pressure', 'sex', 'smoking']:
            # Categorical
            continue
        axes[idx].hist(data, alpha=0.6, label=['Survived', 'Died'][outcome], bins=20)
    axes[idx].set_title(f'{feature}')
    axes[idx].legend()
    axes[idx].set_xlabel(feature)
    axes[idx].set_ylabel('Frequency')

plt.tight_layout()
plt.show()

In [None]:
# Correlation heatmap
plt.figure(figsize=(12, 10))
corr = df.corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', center=0, square=True, fmt='.2f')
plt.title('Correlation Matrix - Heart Failure Clinical Features')
plt.tight_layout()
plt.show()

In [None]:
# Key clinical indicators
key_features = ['age', 'ejection_fraction', 'serum_creatinine', 'serum_sodium']

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.ravel()

for idx, feature in enumerate(key_features):
    df.boxplot(column=feature, by='DEATH_EVENT', ax=axes[idx])
    axes[idx].set_title(f'{feature} by Outcome')
    axes[idx].set_xlabel('Death Event (0=Survived, 1=Died)')
    axes[idx].set_ylabel(feature)

plt.suptitle('')
plt.tight_layout()
plt.show()

## 3. Data Preparation and Split

In [None]:
# Prepare data
X = df.drop('DEATH_EVENT', axis=1).values
y = df['DEATH_EVENT'].values
feature_names = df.columns[:-1].tolist()

# Split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=y
)

print(f'Training: {X_train.shape[0]} samples')
print(f'Test: {X_test.shape[0]} samples')
print(f'\nClass balance in training:')
print(pd.Series(y_train).value_counts(normalize=True))

In [None]:
# Standardize
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
print('Feature scaling complete!')

## 4. Linear Discriminant Analysis

In [None]:
# Train LDA
lda = LinearDiscriminantAnalysis()
lda.fit(X_train_scaled, y_train)

y_pred = lda.predict(X_test_scaled)
y_proba = lda.predict_proba(X_test_scaled)[:, 1]

print('LDA MODEL PERFORMANCE')
print('='*70)
print(f'Accuracy: {accuracy_score(y_test, y_pred):.4f}')
print(f'ROC-AUC: {roc_auc_score(y_test, y_proba):.4f}')
print(f'F1-Score: {f1_score(y_test, y_pred):.4f}')
print(f'Matthews Correlation: {matthews_corrcoef(y_test, y_pred):.4f}')

In [None]:
# Classification report
print('\nClassification Report:')
print(classification_report(y_test, y_pred, target_names=['Survived', 'Died']))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Survived', 'Died'], yticklabels=['Survived', 'Died'])
plt.title('Confusion Matrix - Heart Failure Prediction')
plt.ylabel('True')
plt.xlabel('Predicted')
plt.show()

tn, fp, fn, tp = cm.ravel()
print(f'\nSensitivity (True Positive Rate): {tp/(tp+fn):.3f}')
print(f'Specificity (True Negative Rate): {tn/(tn+fp):.3f}')
print(f'Positive Predictive Value: {tp/(tp+fp):.3f}')
print(f'Negative Predictive Value: {tn/(tn+fn):.3f}')

## 5. Clinical Feature Importance

In [None]:
# Feature coefficients
coefs = lda.coef_[0]
feature_importance = pd.DataFrame({
    'Feature': feature_names,
    'Coefficient': coefs,
    'Abs_Coefficient': np.abs(coefs)
}).sort_values('Abs_Coefficient', ascending=False)

print('Feature Importance:')
display(feature_importance)

plt.figure(figsize=(12, 8))
colors = ['red' if x < 0 else 'blue' for x in feature_importance['Coefficient'].values]
plt.barh(range(len(feature_importance)), feature_importance['Coefficient'].values[::-1], color=colors[::-1], alpha=0.7)
plt.yticks(range(len(feature_importance)), feature_importance['Feature'].values[::-1])
plt.xlabel('LDA Coefficient')
plt.title('Clinical Feature Importance for Mortality Prediction')
plt.axvline(x=0, color='black', linestyle='--')
plt.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()

## 6. ROC and Medical Decision Making

In [None]:
# ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_proba)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

axes[0].plot(fpr, tpr, linewidth=2, label=f'LDA (AUC={roc_auc_score(y_test, y_proba):.3f})')
axes[0].plot([0, 1], [0, 1], 'k--', label='Random')
axes[0].set_xlabel('False Positive Rate (1 - Specificity)')
axes[0].set_ylabel('True Positive Rate (Sensitivity)')
axes[0].set_title('ROC Curve')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Precision-Recall
precision, recall, _ = precision_recall_curve(y_test, y_proba)
axes[1].plot(recall, precision, linewidth=2)
axes[1].set_xlabel('Recall (Sensitivity)')
axes[1].set_ylabel('Precision (PPV)')
axes[1].set_title('Precision-Recall Curve')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Summary and Clinical Implications

In [None]:
print('\n' + '='*70)
print('KEY CLINICAL INSIGHTS')
print('='*70)
top_3 = feature_importance.head(3)
print('\n1. MOST IMPORTANT PREDICTORS:')
for i, row in top_3.iterrows():
    print(f'   - {row["Feature"]}: coefficient = {row["Coefficient"]:.4f}')
print('\n2. MODEL PERFORMANCE:')
print(f'   - Sensitivity: {tp/(tp+fn):.2%} (catch rate for mortality)')
print(f'   - Specificity: {tn/(tn+fp):.2%} (correct survival predictions)')
print(f'   - ROC-AUC: {roc_auc_score(y_test, y_proba):.3f}')
print('\n3. CLINICAL APPLICATION:')
print('   - Can identify high-risk patients for intensive monitoring')
print('   - Feature importance guides clinical assessment priorities')
print('   - Balance sensitivity vs specificity based on intervention costs')
print('\n' + '='*70)