In [None]:
!pip uninstall basicpy -y
!pip install -e ../

In [None]:
!pip install pandas

In [None]:
from basicpy import BaSiC
from basicpy import datasets as bdata
import numpy as np
from basicpy.tools.dct_tools import JaxDCT
from matplotlib import pyplot as plt
import pandas as pd

In [None]:
data = bdata.wsi_brain()

In [None]:
basic_approximates = [
    BaSiC(
        fitting_mode="approximate",
        working_size=None,
        get_darkfield=d,
        max_reweight_iterations=1,
    )
    for d in [False, True]
]

basic_ladmaps = [
    BaSiC(
        fitting_mode="ladmap",
        working_size=None,
        get_darkfield=d,
        smoothness_flatfield=100 / 80000,
        smoothness_darkfield=0.2,
        sparse_cost_darkfield=0.2,
        max_reweight_iterations=1,
    )
    for d in [False, True]
]

result_df = []
for data_key in bdata.RESCALED_TEST_DATA_PROPS.keys():
    data = bdata.fetch(data_key)
    for b in basic_approximates + basic_ladmaps:
        b.fit(data)
        res = (
            data
            - b.baseline[:, np.newaxis, np.newaxis] * b.flatfield[np.newaxis]
            - b.darkfield
        )
        assert np.allclose(res, b._residual, atol=100000, rtol=1e-2)
        print(
            [
                b._smoothness_flatfield,
                b._smoothness_darkfield,
                b._sparse_cost_darkfield,
            ]
        )
        lagrangians = [
            np.sum(np.abs(res)),
            b._smoothness_flatfield * np.sum(np.abs(JaxDCT.dct2d(b._S))),
            b._smoothness_darkfield * np.sum(np.abs(JaxDCT.dct2d(b._D_R))),
            b._sparse_cost_darkfield * np.sum(np.abs(b._D_R)),
        ]
        print(lagrangians)
        print(sum(lagrangians) / 1e9)
        result_df.append(
            {
                "data_key": data_key,
                "lagrangian_value": np.sum(lagrangians),
                "fitting_mode": b.fitting_mode,
                "get_darkfield": b.get_darkfield,
                "smoothness_flatfield": b._smoothness_flatfield,
                "smoothness_darkfield": b._smoothness_darkfield,
                "sparse_cost_darkfield": b._sparse_cost_darkfield,
            }
        )
result_df = pd.DataFrame.from_records(result_df)

In [None]:
result_df["label"] = (
    result_df["fitting_mode"]
    + " "
    + result_df["get_darkfield"].apply(lambda x: "with " if x else "without ")
    + "darkfield "
)
result_df = result_df.sort_values(["data_key", "get_darkfield"])
result_df

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(10, 3), gridspec_kw=dict(wspace=0.5))
for ax, (image_key, grp) in zip(axes, result_df.groupby("data_key")):
    ax.bar(
        grp["label"],
        grp["lagrangian_value"],
    )
    ax.xaxis.set_tick_params(rotation=90)

In [None]:
summarized_df = []
for (image_key, get_darkfield), grp in result_df.groupby(["data_key", "get_darkfield"]):
    row_a = grp[grp["fitting_mode"] == "approximate"]
    row_l = grp[grp["fitting_mode"] == "ladmap"]
    summarized_df.append(
        {
            "data_key": image_key,
            "get_darkfield": get_darkfield,
            "ratio": row_l["lagrangian_value"].values[0]
            / row_a["lagrangian_value"].values[0],
        }
    )
summarized_df = pd.DataFrame.from_records(summarized_df)

In [None]:
summarized_df

In [None]:
grp

In [None]:
width = 0.35
for get_darkfield, grp in summarized_df.groupby("get_darkfield"):
    xs = np.arange(len(grp)) + (width / 2 * (get_darkfield - 0.5) * 2)
    plt.bar(
        xs,
        grp["ratio"],
        width,
        label=("with " if get_darkfield else "without ") + "darkfield",
    )
    if get_darkfield:
        plt.xticks(np.arange(len(grp)), grp["data_key"], rotation=10, ha="right")
plt.hlines(1.0, -1, 5, ls="--", color="k")
plt.ylabel("cost function ratio")
plt.xlim(-0.5, 4.5)
plt.legend(loc="lower right")
plt.xlabel("dataset")