In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("..")

import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np

from multiprocessing import Pool
import traceback
from functional import seq

from agg_utils.path_conf import figure_root_dir

In [3]:
ref_dict = {
    str(p): (255 - 255 * plt.imread(p)).astype(np.float32)
    for p in tqdm(list((figure_root_dir / "plotly_preselect").glob("reference_*.png")))
}

files = seq((figure_root_dir / "plotly_preselect").glob("*.png")).filter(lambda x: 'reference' not in x.name).to_list()

  0%|          | 0/18 [00:00<?, ?it/s]

In [10]:
from __future__ import annotations
from pathlib import Path
from agg_utils.metrics import get_png_path, _get_mse_series, _get_dssim_series, _get_or_conv_mask

def compute_dssim_plotly(
    agg_path: str | Path, ref_dict, mse: bool, **ssim_kwargs
) -> pd.Series:
    """Compute the DSSIM, MSE, and MAE for a Plotly figure.

    More specifically, this function computes the DSSIM for a reference figure with
    a line width of 1 and the same line width as the original figure.

    .. NOTE::
        This method can also be used for the `Bokeh` toolit

    parameters
    ----------
    idx_r : pd.Series
        a row of the aggregation dataframe
    ref_dict : dict
        a dictionary of reference images, with the path as key and the image as value
    ssim_kwargs : dict
        keyword arguments for the skimage.metrics.structural_similarity function

    returns
    -------
    pd.Series
        a row of the aggregation dataframe with the MSSIM and DSSIM values added
        More specifically, this function adds the following columns:
        - DSSIM_same_lw
        - DSSIM_masked_same_lw
        - SSIM_same_lw
        - SSIM_masked_same_lw

    """
    agg_path: Path = Path(agg_path)
    toolkit: str = agg_path.parent.name
    splits = (agg_path.name).split(".")[0].split("_")
    aggregator, data, n, n_out = splits[:4]
    ls, lw = splits[-2:]
    factor = None
    if len(splits) == 7:
        factor = splits[-3]
        factor = int(factor[7:])

    ls = ls[3:]
    lw = lw[3:]

    # fmt: off
    reference_path_lw_same = get_png_path(toolkit, data, n, n_out, "reference", lw, ls)

    dim = 1
    win_size = 11

    # read the images
    agg = (255 - 255 * plt.imread(agg_path)).astype(np.float32)
    ref_same_lw = ref_dict.get(str(reference_path_lw_same), None)

    # fmt: off
    or_conv_same_lw = _get_or_conv_mask(agg[:, :, dim], ref_same_lw[:, :, dim], win_size)
    mse_list = (
        [
            _get_mse_series(agg, ref_same_lw, dim, or_conv_same_lw).add_suffix("_same_lw"),
        ]
        if mse
        else []
    )

    ssim_dict_g = dict(ssim_kwargs)
    ssim_dict_g["gaussian_weights"] = True

    ssim_dict_no_g = dict(ssim_kwargs)
    ssim_dict_no_g["gaussian_weights"] = False

    # fmt: off
    return pd.concat(
        [
            pd.Series(
                index=["toolkit", "data", "aggregator", "n", "n_out", "ls", "lw", "factor"],
                data=[toolkit, data, aggregator, n, n_out, ls, lw, factor],
            ),
            _get_dssim_series( agg, ref_same_lw, dim, or_conv_same_lw, **ssim_dict_no_g).add_suffix("_same_lw"),
            _get_dssim_series( agg, ref_same_lw, dim, or_conv_same_lw, **ssim_dict_g).add_suffix("_gaussian_same_lw"),
            *mse_list,
        ],
    )

In [11]:
def wrap_compute_dssim_plotly(agg_path):
    return compute_dssim_plotly(agg_path, mse=True, ref_dict=ref_dict)

out = []
with Pool(processes=8) as pool:
    results = pool.imap_unordered(wrap_compute_dssim_plotly, files[:])
    results = tqdm(results, total=len(files[:]))
    try:
        out = [f for f in results]
        del results
    except:
        traceback.print_exc()
        pool.terminate()
    finally:
        pool.close()
        pool.join()


df_out = pd.concat(out, axis=1).T
# df_out.to_csv("../loc_data/plotly_metrics_v4.csv", index=False)

cat_cols = ['toolkit', 'data', 'aggregator', 'ls']
int_cols = ['n', 'lw', 'n_out']
for c in cat_cols:
    df_out[c] = df_out[c].astype('category')
for c in int_cols:
    df_out[c] = df_out[c].astype('int')


for c in set(df_out.columns).difference(cat_cols + int_cols):
    df_out[c] = df_out[c].astype('float')
df_out.to_parquet("../loc_data/plotly_metrics_preselect.parquet")

  0%|          | 0/44694 [00:00<?, ?it/s]