In [None]:
%matplotlib inline

In [None]:
%config InlineBackend.figure_format = "retina"

In [None]:
%cd ~/projects/ip-is-all-you-need

In [None]:
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm

from ip_is_all_you_need.plots import get_phase_transition_data

c = pl.col
sns.set()

In [None]:
df_large = pl.read_parquet("./results_large/results.parquet").filter(c("m") < 120)
df_small = pl.read_parquet("./results_small/results.parquet")

In [None]:
def get_phase_transition_data(df, algorithm):
    df_pt = (
        # filter to only the last iteration
        df.filter((c("algorithm") == algorithm) & (c("iter") == c("sparsity") - 1))
        # define success as relative reconstruction error < eps
        .with_columns(
            (c("mse_x") / (c("norm_x") ** 2) < 1e-14).alias("success"),
        )
        # for each experiment
        .groupby("experiment_number")
        # record the settings, success rate, and iou statistics
        .agg(
            c("m").first(),
            c("n").first(),
            c("measurement_rate").first(),
            c("sparsity").first(),
            c("noise_std").first(),
            c("iou").mean(),
            c("iou").quantile(0.05).alias("iou_lo"),
            c("iou").quantile(0.95).alias("iou_hi"),
            c("success").mean().alias("success_rate"),
        )
        .with_columns(
            (c("m") / c("n")).alias("measurement_rate"),
            (c("sparsity") / c("m")).alias("sparsity_rate"),
        )
    )
    return df_pt

def plot_phase_transition(df, algorithm, normalize_axes=False):
    n = df["n"][0]
    df_pt = get_phase_transition_data(df, algorithm)
    if normalize_axes:
        # FIXME
        tbl = (
            df_pt.sort(
                by=["measurement_rate", "sparsity_rate"], descending=[True, False]
            )
            .pivot(
                values="success_rate",
                index="measurement_rate",
                columns="sparsity_rate",
                aggregate_function="first",
            )
            .to_pandas()
        )
        sns.heatmap(tbl.iloc[:, 1:])
        plt.xlabel("s / m")
        plt.ylabel("m / n")
    else:
        tbl = (
            df_pt.sort(by=["m", "sparsity"], descending=[True, False])
            .pivot(
                values="success_rate",
                index="m",
                columns="sparsity",
                aggregate_function="first",
            )
            .to_pandas()
        )
        tbl = tbl.set_index("m", drop=True)
        sns.heatmap(tbl)
        plt.xlabel("Sparsity $s$")
        plt.ylabel("Number of measurements $m$")

    plt.title(f"Phase Transition for {algorithm.upper()} (n={n})")


In [None]:
plot_phase_transition(df_small, "omp")

In [None]:
plot_phase_transition(df_small, "ip")

In [None]:
plot_phase_transition(df_large, "omp")

In [None]:
plot_phase_transition(df_large, "ip")

In [None]:
def plot_probability_curves(df_small, df_large, save_file=None):
    fig, axs = plt.subplots(1, 2, figsize=(13.0, 4.8), sharey=True)
    for k, df in enumerate([df_small ,df_large]):
        ax = axs[k]
        df_pt_omp = get_phase_transition_data(df, "omp")
        df_pt_ip = get_phase_transition_data(df, "ip")
        n = df_pt_omp["n"][0]

        labels = []
        lines = []
        for s in sorted(df_pt_omp["sparsity"].unique()):
            labels.append(f"$s$={s} (OMP)")
            labels.append(f"$s$={s} (IP)")
            df_pt_at_s_omp = df_pt_omp.filter(c("sparsity") == s).sort("m")
            df_pt_at_s_ip = df_pt_ip.filter(c("sparsity") == s).sort("m")
            cur_lines = ax.plot(df_pt_at_s_omp["m"], df_pt_at_s_omp["success_rate"])
            lines.append(cur_lines[0])
            cur_lines = ax.plot(df_pt_at_s_ip["m"], df_pt_at_s_ip["success_rate"], "o", fillstyle="none", color=cur_lines[0].get_color())
            lines.append(cur_lines[0])
            ax.set_xlabel("Number of measurements $m$")
            if k == 0:
                ax.set_ylabel("Probability of exact recovery")
            ax.set_title(f"$n$={n}")
            ax.grid("on")

    fig = plt.gcf()
    fig.legend(lines, labels, loc="upper center", bbox_to_anchor=(0.0, -.1, 1., .102), ncol=len(lines) / 2)
    if save_file:
        plt.savefig(save_file, bbox_inches="tight")


def plot_iou_curves(df_small, df_large, save_file=None):
    fig, axs = plt.subplots(1, 2, figsize=(13.0, 4.8), sharey=True)
    for k, df in enumerate([df_small ,df_large]):
        ax = axs[k]
        df_pt_omp = get_phase_transition_data(df, "omp")
        df_pt_ip = get_phase_transition_data(df, "ip")
        n = df_pt_omp["n"][0]

        labels = []
        lines = []
        for s in sorted(df_pt_omp["sparsity"].unique()):
            labels.append(f"$s$={s} (OMP)")
            labels.append(f"$s$={s} (IP)")
            df_pt_at_s_omp = df_pt_omp.filter(c("sparsity") == s).sort("m")
            df_pt_at_s_ip = df_pt_ip.filter(c("sparsity") == s).sort("m")
            cur_lines = ax.plot(df_pt_at_s_omp["m"], df_pt_at_s_omp["success_rate"])
            lines.append(cur_lines[0])
            cur_lines = ax.plot(df_pt_at_s_ip["m"], df_pt_at_s_ip["success_rate"], "o", fillstyle="none", color=cur_lines[0].get_color())
            lines.append(cur_lines[0])
            ax.set_xlabel("Number of measurements $m$")
            if k == 0:
                ax.set_ylabel("Probability of exact recovery")
            ax.set_title(f"$n$={n}")
            ax.grid("on")

    fig = plt.gcf()
    fig.legend(lines, labels, loc="upper center", bbox_to_anchor=(0.0, -.1, 1., .102), ncol=len(lines) / 2)
    if save_file:
        plt.savefig(save_file, bbox_inches="tight")

In [None]:
plot_probability_curves(df_small, df_large, save_file="/Users/ryanpilgrim/projects/NeurIPS2023-IP-OMP/recovery_probability.eps")

In [None]:
df_pt_small_ip = get_phase_transition_data(df_small, "ip")

In [None]:
def rel_mse(
    args,
):
    return (
        np.mean((np.array(args["x_hat_ip"]) - np.array(args["x_hat_omp"])) ** 2)
        / args["norm_x"] ** 2
    )


def get_relative_mse_ip_omp(results_dir):
    results_dir = Path(results_dir)
    directories = [d for d in results_dir.iterdir() if d.is_dir()]
    dfs = []
    # should probably load dataframes first, then apply method chain
    # clean up later
    for directory in tqdm(directories):
        dfs.append(
            pl.read_parquet(directory / "results.parquet")
            .filter(c("iter") == c("sparsity") - 1)
            .pivot(
                [
                    "experiment_number",
                    "m",
                    "n",
                    "sparsity",
                    "x_hat",
                    "norm_x",
                ],
                index=["trial", "iter"],
                columns="algorithm",
                aggregate_function="first",
            )
            .drop(
                "experiment_number_algorithm_ip",
                "m_algorithm_ip",
                "n_algorithm_ip",
                "sparsity_algorithm_ip",
                "norm_x_algorithm_ip",
            )
            .rename(
                {
                    "experiment_number_algorithm_omp": "experiment_number",
                    "m_algorithm_omp": "m",
                    "n_algorithm_omp": "n",
                    "sparsity_algorithm_omp": "sparsity",
                    "norm_x_algorithm_omp": "norm_x",
                    "x_hat_algorithm_ip": "x_hat_ip",
                    "x_hat_algorithm_omp": "x_hat_omp",
                }
            )
            .select(
                "trial",
                "iter",
                "experiment_number",
                "m",
                "n",
                "sparsity",
                pl.struct(pl.all()).apply(rel_mse).alias("rel_mse_ip_omp"),
            )
        )

    rel_mse_ip_omp = (
        pl.concat(dfs)
        .groupby("experiment_number")
        .agg(
            c("m").first(),
            c("n").first(),
            c("sparsity").first(),
            c("rel_mse_ip_omp").mean(),
            c("rel_mse_ip_omp").quantile(0.05).alias("rel_mse_ip_omp_05p"),
            c("rel_mse_ip_omp").quantile(0.95).alias("rel_mse_ip_omp_95p"),
        )
        .sort("experiment_number")
    )

    rel_mse_ip_omp.write_parquet(results_dir / "rel_mse_ip_omp.parquet")

    return rel_mse_ip_omp

In [None]:
rel_mse_small = get_relative_mse_ip_omp("./results_small")
rel_mse_large = get_relative_mse_ip_omp("./results_large")

In [None]:
def plot_agreement(rel_mse, save_path=None):
    pd_rel_mse = rel_mse.pivot(
        values=["rel_mse_ip_omp", "rel_mse_ip_omp_05p", "rel_mse_ip_omp_95p"],
        index="m",
        columns="sparsity",
        aggregate_function="first",
    ).sort("m", descending=True).select("m", c(r"^rel_mse_ip_omp_sparsity_\d{2}$")).to_pandas()
    pd_rel_mse.set_index("m", inplace=True)
    sns.heatmap(10 * np.log10(pd_rel_mse))
    ax = plt.gca()
    ax.set_xticklabels([int(x.split("_")[-1]) for x in  pd_rel_mse.columns], rotation=0)
    plt.xlabel("Sparsity $s$")
    plt.ylabel("Number of measurements $m$")
    if save_path:
        plt.savefig(save_path)

In [None]:
plot_agreement(rel_mse_small)

In [None]:
plot_agreement(rel_mse_large)