In [1]:
from pathlib import Path
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import partipy as pt

benchmark_path = Path(".")

result_df = pd.read_csv("results.csv")
result_df["time_log10"] = np.log10(result_df["time"])

In [3]:
result_df

Unnamed: 0,time,rss,varexpl,seed,l2_dist,l2_dist_norm,conv,n_iter,n_samples,n_features,n_archetypes,noise_std,init_alg,optim_alg,time_log10
0,0.178123,0.042563,0.999936,0,0.188833,0.034639,True,103,100,10,3,0.00,uniform,projected_gradients,-0.749280
1,0.010759,0.262285,0.999412,1,0.363272,0.069697,True,39,100,10,3,0.00,uniform,projected_gradients,-1.968214
2,0.021202,0.055419,0.999894,2,0.235944,0.043035,True,53,100,10,3,0.00,uniform,projected_gradients,-1.673626
3,0.014075,0.364150,0.999549,3,0.377265,0.060455,True,42,100,10,3,0.00,uniform,projected_gradients,-1.851550
4,0.014810,0.334923,0.999428,4,0.359623,0.067436,True,44,100,10,3,0.00,uniform,projected_gradients,-1.829449
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
475,0.076448,31.721642,0.962149,0,1.910686,0.280721,True,40,200,20,7,0.05,plus_plus,frank_wolfe,-1.116634
476,0.095566,35.779911,0.967970,1,2.092418,0.302452,True,47,200,20,7,0.05,plus_plus,frank_wolfe,-1.019698
477,0.095207,50.995167,0.960343,2,2.481209,0.351253,True,55,200,20,7,0.05,plus_plus,frank_wolfe,-1.021330
478,0.083730,44.376160,0.948239,3,2.249726,0.315640,True,48,200,20,7,0.05,plus_plus,frank_wolfe,-1.077118


In [4]:
all_features = ["time_log10", "varexpl", "l2_dist_norm"]
noise_std_list = result_df["noise_std"].sort_values().unique()

for noise_std in noise_std_list:

    for feature in all_features:

        unique_features = (result_df
                        .loc[result_df["noise_std"]==noise_std, :]["n_features"]
                        .sort_values().unique()
                        )
        unique_archetypes = (result_df
                            .loc[result_df["noise_std"]==noise_std, :]["n_archetypes"]
                            .sort_values().unique()
                            )

        fig, axes = plt.subplots(
            nrows=len(unique_features), 
            ncols=len(unique_archetypes), 
            figsize=(3 * len(unique_archetypes)+3, 3 * len(unique_features)),
            squeeze=False,
        )

        # use this to store handles/labels just once
        legend_handles = None
        legend_labels = None

        rng = np.random.default_rng(seed=42)
        for row_idx, n_features in enumerate(unique_features):
            for col_idx, n_archetypes in enumerate(unique_archetypes):
                ax = axes[row_idx, col_idx]
                selection_vec = (
                    (result_df["noise_std"] == noise_std) &
                    (result_df["n_features"] == n_features) &
                    (result_df["n_archetypes"] == n_archetypes)
                )
                result_df_subset = result_df.loc[selection_vec, :].copy()

                # summarize accross the different seeds
                result_df_subset = (
                    result_df_subset
                    .groupby(["n_samples", "n_features", "n_archetypes", "optim_alg", "init_alg"], as_index=False)
                    [feature]
                    .agg("mean")  # or use .agg(["mean", "std"]) for multiple statistics
                )

                # adding some jitter
                result_df_subset["n_samples"] = result_df_subset["n_samples"] + \
                    result_df_subset["n_samples"] * rng.normal(loc=0, scale=0.1, size=len(result_df_subset))
                result_df_subset["n_samples_log10"] = np.log10(result_df_subset["n_samples"])

                # draw plot and capture handles/labels only once
                plot = sns.scatterplot(
                    data=result_df_subset, x="n_samples_log10", y=feature, 
                    hue="optim_alg", style="init_alg", 
                    ax=ax, alpha=0.9
                )
                if legend_handles is None:
                    legend_handles, legend_labels = ax.get_legend_handles_labels()
                
                ax.legend_.remove()  # remove local legend
                ax.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7)  # <-- Add this line

        # put global legend outside the figure
        fig.legend(
            handles=legend_handles, labels=legend_labels,
            loc='center right', bbox_to_anchor=(1.1, 0.5), frameon=False
        )

        # Add row labels (n_features)
        for row_idx, n_features in enumerate(unique_features):
            fig.text(
                x=0.01, 
                y=0.5 - (row_idx - len(unique_features)/2 + 0.5) / len(unique_features),
                s=f"{n_features} features", 
                va='center', ha='left', fontsize=12, rotation=90
            )

        # Add column labels (n_archetypes)
        for col_idx, n_archetypes in enumerate(unique_archetypes):
            fig.text(
                x=(col_idx + 0.5) / len(unique_archetypes), 
                y=0.95, 
                s=f"{n_archetypes} archetypes", 
                va='bottom', ha='center', fontsize=12
            )

        fig.suptitle(f"Comparison at {noise_std} Noise Level | {feature}", fontsize=16, y=1.02)
        plt.tight_layout(rect=[0.02, 0.00, 0.85, 0.98])  # Adjust for both legend and title
        fig.savefig(benchmark_path / f"{feature}_{noise_std:.2f}.png", bbox_inches="tight")
        plt.close()
