In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report,
    roc_auc_score, roc_curve, auc
)
from sklearn.preprocessing import label_binarize

# 1️⃣ Load dataset
digits = datasets.load_digits()
X, y = digits.data, digits.target  # Features and labels

# Binarize labels for ROC-AUC (One-vs-Rest)
y_bin = label_binarize(y, classes=np.unique(y))

# 2️⃣ Train-test split
X_train, X_test, y_train, y_test, y_bin_train, y_bin_test = train_test_split(
    X, y, y_bin, test_size=0.2, random_state=42, stratify=y
)

# 3️⃣ Define parameter grid
param_grid = [
    {"solver": ["svd"]},  # No shrinkage allowed for 'svd'
    {"solver": ["lsqr", "eigen"], "shrinkage": ["auto"]}  # Shrinkage only for 'lsqr' and 'eigen'
]

# 4️⃣ Hyperparameter tuning
grid_search = GridSearchCV(LinearDiscriminantAnalysis(), param_grid, cv=5, scoring="accuracy", n_jobs=-1)
grid_search.fit(X_train, y_train)

# 5️⃣ Train best LDA model
best_lda = grid_search.best_estimator_
best_lda.fit(X_train, y_train)

# 6️⃣ Predictions
y_pred = best_lda.predict(X_test)
y_prob = best_lda.predict_proba(X_test)

# 7️⃣ Evaluation Metrics
acc = accuracy_score(y_test, y_pred)
bal_acc = balanced_accuracy_score(y_test, y_pred)
roc_auc = roc_auc_score(y_bin_test, y_prob, multi_class="ovr")  # Fixed sample mismatch

print(f"Best Parameters: {grid_search.best_params_}")
print(f"Accuracy: {acc:.4f}")
print(f"Balanced Accuracy: {bal_acc:.4f}")
print(f"ROC-AUC Score: {roc_auc:.4f}")

# 8️⃣ Confusion Matrix
plt.figure(figsize=(6,5))
sns.heatmap(confusion_matrix(y_test, y_pred), annot=True, cmap="Blues", fmt="d")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()

# 9️⃣ ROC Curve
plt.figure(figsize=(8,6))
for i in range(len(digits.target_names)):
    fpr, tpr, _ = roc_curve(y_bin_test[:, i], y_prob[:, i])
    plt.plot(fpr, tpr, label=f"Digit {i} (AUC = {auc(fpr, tpr):.2f})")

plt.plot([0, 1], [0, 1], "k--")  # Random chance line
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Multiclass ROC Curve")
plt.legend(loc="lower right")
plt.show()
