In [None]:
%matplotlib inline

In [None]:
%config InlineBackend.figure_format = "retina"

In [None]:
%cd ~/projects/ip-is-all-you-need

In [None]:
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns

from ip_is_all_you_need.plots import get_phase_transition_data

In [None]:
df_large = pl.read_parquet("./results_large/results.parquet").filter(c("m") < 145)
df_small = pl.read_parquet("./results_small/results.parquet")

In [None]:
def plot_phase_transition(df, algorithm, normalize_axes=False):
    c = pl.col
    n = df["n"][0]
    df_pt = (
        df.filter((c("algorithm") == algorithm) & (c("iter") == c("sparsity") - 1))
        .with_columns(
            (c("mse_x") / c("norm_x") < 1e-15).alias("success"),
        )
        .groupby("experiment_number")
        .agg(
            c("m").first(),
            c("n").first(),
            c("measurement_rate").first(),
            c("sparsity").first(),
            c("noise_std").first(),
            c("success").mean().alias("success_rate"),
        )
        .with_columns(
            (c("m") / c("n")).alias("measurement_rate"),
            (c("sparsity") / c("m")).alias("sparsity_rate"),
        )
    )
    if normalize_axes:
        tbl = (
            df_pt.sort(
                by=["measurement_rate", "sparsity_rate"], descending=[True, False]
            )
            .pivot(
                values="success_rate",
                index="measurement_rate",
                columns="sparsity_rate",
                aggregate_function="first",
            )
            .to_pandas()
        )
        tbl = tbl.set_index("measurement_rate", drop=True)
        sns.heatmap(tbl)
        plt.xlabel("s / m")
        plt.ylabel("m / n")
    else:
        tbl = (
            df_pt.sort(by=["m", "sparsity"], descending=[True, False])
            .pivot(
                values="success_rate",
                index="m",
                columns="sparsity",
                aggregate_function="first",
            )
            .to_pandas()
        )
        tbl = tbl.set_index("m", drop=True)
        sns.heatmap(tbl)
        plt.xlabel("s")

    plt.title(f"Phase Transition for {algorithm.upper()} (n={n})")


In [None]:
plot_phase_transition(df_small, "omp")

In [None]:
plot_phase_transition(df_small, "ip")

In [None]:
plot_phase_transition(df_large, "omp")

In [None]:
plot_phase_transition(df_large, "ip")