## 1.Load MLflow Metrics

import mlflow
import pandas as pd

client = mlflow.tracking.MlflowClient(mlflow.get_tracking_uri())
runs = client.search_runs(experiment_ids=["<your_experiment_id>"], 
                          filter_string="",  # or filter by tags
                          output_format="pandas")
# Now `runs` is a DataFrame with columns: params.ssa.n_imfs, params.feature_extraction.ds, 
# params.cnn.embedding_dim, metrics.avg_test_accuracy_mat, ...


## 2.Accuracy Heatmap

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Suppose you pivot on n_imfs and embedding_dim (for MAT dataset)
pivot = runs.pivot_table(
    index="params.ssa.n_imfs",
    columns="params.cnn.embedding_dim",
    values="metrics.avg_test_accuracy_mat"
)
plt.figure(figsize=(6, 4))
sns.heatmap(pivot, annot=True, fmt=".3f", cmap="viridis")
plt.title("MAT Accuracy by (n_imfs × embedding_dim)")
plt.ylabel("n_imfs")
plt.xlabel("embedding_dim")
plt.tight_layout()
plt.savefig("results/figures/accuracy_heatmap_mat.png", dpi=300)


## 3.Bar Chart of Nested‐CV vs. Single‐Split

In [None]:
# Filter runs by a tag or param
train_test_runs = runs[runs["params.evaluation.strategy"] == "train_test"]
nested_cv_runs = runs[runs["params.evaluation.strategy"] == "nested_cv"]

acc_tt = train_test_runs["metrics.avg_test_accuracy_mat"].astype(float)
acc_nc = nested_cv_runs["metrics.avg_test_accuracy_mat"].astype(float)

data = pd.DataFrame({
    "Strategy": ["train_test"] * len(acc_tt) + ["nested_cv"] * len(acc_nc),
    "Accuracy": np.concatenate([acc_tt.values, acc_nc.values])
})

plt.figure(figsize=(5, 4))
sns.boxplot(x="Strategy", y="Accuracy", data=data)
sns.stripplot(x="Strategy", y="Accuracy", data=data, color="black", alpha=0.5)
plt.title("Test Accuracy: Train/Test vs. Nested CV (MAT)")
plt.tight_layout()
plt.savefig("results/figures/acc_comparison_boxplot.png", dpi=300)


## 4.Confusion Matrix Montage

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import glob

fig, axes = plt.subplots(1, 5, figsize=(15, 3))  # assume 5 folds
cm_paths = sorted(glob.glob("results/figures/mat/outer_fold_*/confusion_matrix.png"))
for ax, img_path in zip(axes, cm_paths):
    img = mpimg.imread(img_path)
    ax.imshow(img)
    ax.axis("off")
    ax.set_title(os.path.basename(os.path.dirname(img_path)))
plt.tight_layout()
plt.savefig("results/figures/cm_montage_mat.png", dpi=300)


## 5.SHAP Summary Montage