In [None]:
%matplotlib inline

In [None]:
%config InlineBackend.figure_format = "retina"

In [None]:
%cd ~/projects/ip-is-all-you-need

In [None]:
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from ip_is_all_you_need.plots import get_phase_transition_data

c = pl.col

In [None]:
df_large = pl.read_parquet("./results_large/results.parquet").filter(c("m") < 120)
df_small = pl.read_parquet("./results_small/results.parquet")

In [None]:
def get_phase_transition_data(df, algorithm):
    df_pt = (
        df.filter((c("algorithm") == algorithm) & (c("iter") == c("sparsity") - 1))
        .with_columns(
            (c("mse_x") / c("norm_x") < 1e-14).alias("success"),
        )
        .groupby("experiment_number")
        .agg(
            c("m").first(),
            c("n").first(),
            c("measurement_rate").first(),
            c("sparsity").first(),
            c("noise_std").first(),
            c("success").mean().alias("success_rate"),
        )
        .with_columns(
            (c("m") / c("n")).alias("measurement_rate"),
            (c("sparsity") / c("m")).alias("sparsity_rate"),
        )
    )
    return df_pt

def plot_phase_transition(df, algorithm, normalize_axes=False):
    n = df["n"][0]
    df_pt = get_phase_transition_data(df, algorithm)
    if normalize_axes:
        # FIXME
        tbl = (
            df_pt.sort(
                by=["measurement_rate", "sparsity_rate"], descending=[True, False]
            )
            .pivot(
                values="success_rate",
                index="measurement_rate",
                columns="sparsity_rate",
                aggregate_function="first",
            )
            .to_pandas()
        )
        sns.heatmap(tbl.iloc[:, 1:])
        plt.xlabel("s / m")
        plt.ylabel("m / n")
    else:
        tbl = (
            df_pt.sort(by=["m", "sparsity"], descending=[True, False])
            .pivot(
                values="success_rate",
                index="m",
                columns="sparsity",
                aggregate_function="first",
            )
            .to_pandas()
        )
        tbl = tbl.set_index("m", drop=True)
        sns.heatmap(tbl)
        plt.xlabel("Sparsity $s$")
        plt.ylabel("Number of measurements $m$")

    plt.title(f"Phase Transition for {algorithm.upper()} (n={n})")


In [None]:
plot_phase_transition(df_small, "omp")

In [None]:
plot_phase_transition(df_small, "ip")

In [None]:
plot_phase_transition(df_large, "omp")

In [None]:
plot_phase_transition(df_large, "ip")

In [None]:
def plot_probability_curves(df_small, df_large, save_file=None):
    fig, axs = plt.subplots(1, 2, figsize=(13.0, 4.8), sharey=True)
    for k, df in enumerate([df_small ,df_large]):
        ax = axs[k]
        df_pt_omp = get_phase_transition_data(df, "omp")
        df_pt_ip = get_phase_transition_data(df, "ip")
        n = df_pt_omp["n"][0]

        labels = []
        lines = []
        for s in sorted(df_pt_omp["sparsity"].unique()):
            labels.append(f"$s$={s} (OMP)")
            labels.append(f"$s$={s} (IP)")
            df_pt_at_s_omp = df_pt_omp.filter(c("sparsity") == s).sort("m")
            df_pt_at_s_ip = df_pt_ip.filter(c("sparsity") == s).sort("m")
            cur_lines = ax.plot(df_pt_at_s_omp["m"], df_pt_at_s_omp["success_rate"])
            lines.append(cur_lines[0])
            cur_lines = ax.plot(df_pt_at_s_ip["m"], df_pt_at_s_ip["success_rate"], "o", fillstyle="none", color=cur_lines[0].get_color())
            lines.append(cur_lines[0])
            ax.set_xlabel("Number of measurements $m$")
            if k == 0:
                ax.set_ylabel("Probability of exact recovery")
            ax.set_title(f"$n$={n}")
            ax.grid("on")

    fig = plt.gcf()
    fig.suptitle("Recovery probabilities for IP and OMP")
    fig.legend(lines, labels, loc="upper center", bbox_to_anchor=(0.0, -.1, 1., .102), ncol=len(lines) / 2)
    if save_file:
        plt.savefig(save_file)

In [None]:
plot_probability_curves(df_small, df_large, save_file="./recovery_probability.eps")