# `synthflow` Report

In [None]:
# Parameters

run_path = None
root_path = None
mode = None
is_minimal = None

In [None]:
import itertools as it
import json
import os
import pickle
import re
import string
from collections import defaultdict
from pathlib import Path
from pprint import pprint

import matplotlib.pylab as plt
import missingno as msno
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
import wandb
from IPython.display import Markdown, display
from matplotlib.colors import ListedColormap
from matplotlib_venn import venn3
from scipy import stats
from sklearn.neighbors import LocalOutlierFactor
from tqdm import trange
from wandb.sdk.internal.datastore import DataStore

from synthflow.birth import EVALUATION_COLUMN_BINS
from synthflow.evaluation import numerify
from synthflow.evaluation.faithfulness import scale_columns_by_weights
from synthflow.evaluation.utility.analysis import _bin_table
from synthflow.utils import compute_complete_counts, get_single_row_by_query, n_way_gen

In [None]:
WANDB_DIR = os.environ.get("WANDB_DIR", os.getcwd())
WANDB_PATH = Path(WANDB_DIR)
print(f"{WANDB_DIR=}")

SYNTH_PROCESSED_CMAP = ListedColormap([plt.cm.tab10(2), plt.cm.tab10(1)])

PATH_RE = re.compile(b"\\d[A-Z]([a-zA-Z]:\\|/|.)")

In [None]:
def PICKLE_LOAD(path):
    return pickle.loads(path.read_bytes())


def DF_PICKLE_LOAD(path):
    return pd.read_pickle(path)


def DF_JSON_LOAD(path):
    if mode == "online":
        path = Path(str(path) + ".table.json")
    table = json.loads(path.read_bytes())
    return pd.DataFrame(table["data"], columns=table["columns"])


ARTIFACTS_METADATA = {
    "model": None,
    "real": DF_PICKLE_LOAD,
    "processed_real": DF_PICKLE_LOAD,
    "subsampled_real": DF_PICKLE_LOAD,
    "synth": DF_PICKLE_LOAD,
    "constraints": DF_PICKLE_LOAD,
    "utility": DF_JSON_LOAD,
    "privacy": DF_JSON_LOAD,
    "faithfulness": DF_JSON_LOAD,
    "acceptance": DF_JSON_LOAD,
    "dp_acceptance": DF_JSON_LOAD,
    "column_weights": PICKLE_LOAD,
    "dp_info": PICKLE_LOAD,
    "transcript": PICKLE_LOAD,
}


def find_non(raw, base):
    return [i for i, b in enumerate(raw) if b not in base]


def clean_name(name):
    if "-" in name:
        *_, name = name.split("-")

    if ":" in name:
        name, _ = name.split(":")

    return name


def artificat_online_iter(run_path):
    ARTIFACT_PATH = Path(root_path) / "artifacts" / run_path.replace("/", "_")
    ARTIFACT_PATH.mkdir(parents=True, exist_ok=True)

    api = wandb.Api()
    run = api.run(run_path)

    for artifact in run.logged_artifacts():
        name = clean_name(artifact.name)
        artifact.download(root=ARTIFACT_PATH)
        path = ARTIFACT_PATH / name
        yield name, path


def artifact_offline_hack_iter(run_path):
    run_ds_paths = list((WANDB_PATH / "wandb").glob(f"*{run_path}*/*.wandb"))
    assert len(run_ds_paths) == 1, print(run_ds_paths)
    run_ds_path = run_ds_paths[0]

    ds = DataStore()
    ds.open_for_scan(run_ds_path)

    while True:
        data = ds.scan_data()
        if data is None:
            break
        if b"storageLayout" in data:
            name_raw, _, path_raw = data.split(b" ")

            name_raw = name_raw.rstrip(string.digits.encode("ascii"))
            name_start_index = (
                find_non(name_raw, (string.ascii_letters + "_").encode("ascii"))[-1] + 1
            )
            name = name_raw[name_start_index:].decode("ascii")
            name = clean_name(name)

            # TODO: refactor
            path_start_index = PATH_RE.search(path_raw).start() + 2
            path_raw = path_raw[path_start_index:]
            path_end_index = find_non(path_raw, string.printable.encode("ascii"))[0] - 1
            path = path_raw[:path_end_index].decode("ascii")

            yield name, Path(path)

In [None]:
artificat_iter = (
    artifact_offline_hack_iter if mode == "offline" else artificat_online_iter
)

for name, path in artificat_iter(run_path):
    if name not in ARTIFACTS_METADATA:
        print(f"[SKIPPED] {name}")
        continue

    load_fn = ARTIFACTS_METADATA[name]

    if load_fn is not None:
        df = load_fn(path)
        if load_fn in (DF_PICKLE_LOAD, DF_JSON_LOAD):
            obj_name = f"{name}_df"
        else:
            obj_name = name
        # A USEFUL DENGEROUS HACK!
        globals()[obj_name] = df

        print(obj_name)


config = {} if mode == "offline" else wandb.Api().run(run_path).config

## Datasets

In [None]:
msno.bar(real_df)

In [None]:
msno.bar(synth_df)

In [None]:
datasets = {
    "real": real_df,
    "synth": synth_df,
    "processed": processed_real_df,
    "subsampled": subsampled_real_df,
}
numerified_datasets = {
    name: numerify(dataset).astype(int) for name, dataset in datasets.items()
}
binnified_datasets = {
    name: _bin_table(dataset, EVALUATION_COLUMN_BINS)
    for name, dataset in numerified_datasets.items()
}

In [None]:
unified_df = pd.concat(
    [dataset.assign(name=name) for name, dataset in numerified_datasets.items()],
    ignore_index=True,
)
binned_unified_df = pd.concat(
    [dataset.assign(name=name) for name, dataset in binnified_datasets.items()],
    ignore_index=True,
)

columns = synth_df.columns

In [None]:
synth_processed_unifed_df = pd.concat(
    [synth_df.assign(name="synth"), processed_real_df.assign(name="processed")],
    ignore_index=True,
)
synth_processed_numerified_unifed_df = pd.concat(
    [
        numerified_datasets["synth"].assign(name="synth"),
        numerified_datasets["processed"].assign(name="processed"),
    ],
    ignore_index=True,
)

## Differntial Privacy Info

In [None]:
dp_info

## Model

In [None]:
display(Markdown(f"### ε = {config.get('epsilon')}"))
display(Markdown(f"### {config.get('model')}"))
display(Markdown(f"Hyperparameters: {config.get('hparams')}"))

## Transformations

### Generation Configurations

In [None]:
display(Markdown(f"#### {config.get('trans_id')}"))
pprint(config.get("transformations"))

### Unique per Column

In [None]:
pd.DataFrame({name: dataset.nunique() for name, dataset in datasets.items()})

## Dataset Projection 

In [None]:
config.get("dataset_projection")

## Acceptnace

In [None]:
acceptance_df

## DP Acceptnace

In [None]:
dp_acceptance_df

## Utility - Marginials

In [None]:
n_way_df = utility_df[
    utility_df["name"].str.contains("frequencies")
    & utility_df["name"].str.contains("way")
    & utility_df["name"].str.contains("max")
][["val_pr_s", "name"]]

n_way_df[["marginals", "agg", "metric", "mode", "binning"]] = n_way_df[
    "name"
].str.split("/", expand=True)
n_way_df = n_way_df.drop(["name", "metric"], axis=1)

n_way_df = n_way_df.pivot(["marginals", "binning"], "mode", "val_pr_s")
n_way_df["diff"] *= 100

n_way_df

### Utility 1-Way

In [None]:
(
    pd.concat(
        [
            dataset.agg(["median", "mean", "std"]).assign(name=name)
            for name, dataset in numerified_datasets.items()
            if name in ("synth", "processed")
        ]
    )
    .set_index("name", append=True)
    .T
)

In [None]:
for column in columns:
    display(Markdown(f"#### {column}"))

    sns.displot(data=unified_df, x=column, kind="ecdf", hue="name")
    plt.show()
    sns.displot(data=unified_df, x=column, kind="ecdf", col="name", hue="name")
    plt.show()

    crosstab = (
        pd.crosstab(
            synth_processed_unifed_df[column],
            synth_processed_unifed_df["name"],
            normalize="columns",
        )
        * 100
    )
    crosstab.plot(kind="bar", colormap=SYNTH_PROCESSED_CMAP)
    plt.show()
    display(crosstab)

    sns.boxplot(x="name", y=column, data=unified_df, saturation=0.5)
    plt.show()

    display(
        pd.DataFrame(
            {
                name: dataset[column].describe(datetime_is_numeric=True)
                for name, dataset in numerified_datasets.items()
            }
        ).round(2)
    )

    sns.catplot(x=column, hue="name", data=binned_unified_df, kind="count", aspect=1.5)
    plt.show()

## Utility 2 Way

### Frequencies

In [None]:
for first, second in it.combinations(real_df.columns, r=2):
    cmps2way = {
        name: (pd.crosstab(dataset[first], dataset[second], normalize=True))
        for name, dataset in numerified_datasets.items()
        if name in ("synth", "processed")
    }

    cmps2way = {name: 100 * cmps for name, cmps in cmps2way.items()}

    vmax = min(2 * max(cmps.max().max() for cmps in cmps2way.values()), 100)
    fig, axes = plt.subplots(1, 2, figsize=(15, 10), sharex=True, sharey=True)
    fig.suptitle(f"{first} $\\times$ {second}")

    for (name, cmps), ax in zip(cmps2way.items(), axes.flat):
        ax.set_title(name)
        g = sns.heatmap(
            cmps,
            vmin=0,
            vmax=vmax,
            annot=True,
            fmt=".2f",
            cmap="YlGnBu",
            # cmap='vlag',
            # square=True,
            ax=ax,
        )

### Correlations

In [None]:
corrs = {name: dataset.corr() for name, dataset in numerified_datasets.items()}
corr_mask = np.zeros_like(corrs["real"])
corr_mask[np.triu_indices_from(corr_mask)] = True

_, axes = plt.subplots(2, 2, figsize=(15, 13))
for (name, corr), ax in zip(corrs.items(), axes.flat):
    ax.set_title(name)
    sns.heatmap(
        corr,
        mask=corr_mask,
        annot=True,
        fmt=".2f",
        vmin=-1,
        vmax=1,
        cmap="vlag",
        square=True,
        ax=ax,
    )

#### Centrals

In [None]:
column_pairs = (
    utility_df[~utility_df["by"].isna()][["target", "by"]]
    .drop_duplicates()
    .to_dict(orient="list")
)


by_target = defaultdict(list)
for target, by in zip(column_pairs["target"], column_pairs["by"]):
    by_target[by].append(target)


def plot_centrals(datasets, agg):
    assert agg in ("mean", "median")

    binned_choosen_df = binned_unified_df[binned_unified_df["name"].isin(datasets)]
    choosen_df = unified_df[binned_unified_df["name"].isin(datasets)]

    for by, targets in by_target.items():
        _, axes = plt.subplots(
            1, len(targets), figsize=(7 * len(targets), 5), squeeze=False
        )
        for target, ax in zip(targets, axes[0]):
            g = sns.pointplot(
                x=binned_choosen_df[by],
                y=choosen_df[target],
                hue=choosen_df["name"],
                linestyles=["solid", "dashed", "dashdot", "dotted"],
                errwidth=1,
                estimator=np.median if agg == "median" else np.mean,
                ax=ax,
            )
            max_ = choosen_df[target].groupby(binned_choosen_df[by]).median().max()
            ax.set_ylim(unified_df[target].min() * 0.9, max_ * 1.1)
            ax.tick_params(axis="x", rotation=45)
            plt.setp(g.collections, alpha=0.7)
            plt.setp(g.lines, alpha=0.7)

#### Medians

In [None]:
plot_centrals(["real", "synth", "processed", "subsampled"], "median")

In [None]:
plot_centrals(["synth", "processed"], "median")

#### Means

In [None]:
plot_centrals(["real", "synth", "processed", "subsampled"], "mean")

In [None]:
plot_centrals(["synth", "processed"], "mean")

## Utility - Distributions

In [None]:
mask_sp = unified_df["name"].isin(["processed", "synth"])


def plot_2way_dist(x, y, is_scaled, ax, is_binned):
    x_df = binned_unified_df if is_binned else unified_df
    sns.violinplot(
        x=x_df.loc[mask_sp, x],
        y=unified_df.loc[mask_sp, y],
        hue=unified_df.loc[mask_sp, "name"],
        scale_hue=is_scaled,
        scale="count",
        inner="quartile",
        split=True,
        bw=0.2,
        palette="muted",
        ax=ax,
    )


def plot_double_2way_dist(x, y, is_binned):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    fig.suptitle(f"{x} $\\times$ {y}")
    plot_2way_dist(x, y, False, axes[0], is_binned)
    plot_2way_dist(x, y, True, axes[1], is_binned)

In [None]:
for target, by in zip(column_pairs["target"], column_pairs["by"]):
    plot_double_2way_dist(by, target, True)

In [None]:
for target, by in zip(column_pairs["target"], column_pairs["by"]):
    plot_double_2way_dist(by, target, False)

## Utility - Linear Regression (y=`brith_weight`)
TODO: Move to `synthflow.evaluation.utility`

**Note: the regression do not use categorical variables**

In [None]:
linear_reg_results = {}

for name, dataset in numerified_datasets.items():
    X = dataset.iloc[:, :-1].apply(stats.zscore)
    y = dataset.iloc[:, -1]
    res = sm.OLS(y, sm.add_constant(X)).fit()

    linear_reg_results[name] = res

In [None]:
_, ax = plt.subplots(1, figsize=(10, 5))
sns.heatmap(
    pd.DataFrame(
        {name: results.params for name, results in linear_reg_results.items()}
    ).T,
    annot=True,
    fmt=".2f",
    square=True,
    ax=ax,
)

In [None]:
utility_df[utility_df["name"].str.contains("lr")]

In [None]:
try:
    display(
        pd.DataFrame(
            {
                "$R^2$": {
                    name: results.rsquared
                    for name, results in linear_reg_results.items()
                },
                "F p-value": {
                    name: results.f_pvalue
                    for name, results in linear_reg_results.items()
                },
                "MAE": {
                    name: np.mean(np.abs(results.resid))
                    for name, results in linear_reg_results.items()
                },
            }
        )
    )

    for name, results in linear_reg_results.items():
        display(Markdown(f"### {name}"))
        display(results.summary())

except AttributeError:
    pass

## Constrains

TODO: Move to synthflow.evaluation.utility

In [None]:
print("# rows removed due to constraints:", len(constraints_df))

In [None]:
for columns_nway in n_way_gen(columns, [2]):
    columns_nway = list(columns_nway)
    processed_counts = processed_real_df[columns_nway].value_counts()
    synth_counts = synth_df[columns_nway].value_counts()
    constraints_counts = constraints_df[columns_nway].value_counts()

    (processed_complete_counts, synth_complete_counts) = compute_complete_counts(
        processed_counts, synth_counts
    )

    if (
        (
            (processed_complete_counts == 0)
            | (constraints_counts > 0) & (synth_complete_counts == 0)
        )
        .any()
        .any()
    ):
        comparision_df = (
            pd.DataFrame(
                {
                    "processed": processed_complete_counts,
                    "synth": synth_complete_counts,
                    "constraints": constraints_counts,
                }
            )
            .fillna(0)
            .sort_index()
            .astype(int)
        )

        with pd.option_context("display.max_rows", None):
            disp_comparision_df = comparision_df.copy()
            disp_comparision_df["processed"] = disp_comparision_df["processed"].replace(
                {0: "*** 0 ****"}
            )
            display(disp_comparision_df)

## Faithfulness

### Matching

In [None]:
faithfulness_df

In [None]:
column_weights

In [None]:
faithfulness_alpha_1_pr_s = get_single_row_by_query(
    faithfulness_df, "ɑ == 1 & comparison == 'val_pr_s'"
)

In [None]:
matching = faithfulness_alpha_1_pr_s["matching"]
processed_matched_df = (
    numerified_datasets["processed"].iloc[matching[0]].reset_index(drop=True)
)
synth_matched_df = numerified_datasets["synth"].iloc[matching[1]].reset_index(drop=True)

processed_sacled_matched_df = scale_columns_by_weights(
    processed_matched_df, column_weights
)
synth_sacled_matched_df = scale_columns_by_weights(synth_matched_df, column_weights)

matched_diff_df = synth_matched_df - processed_matched_df
matched_cost_df = processed_sacled_matched_df != synth_sacled_matched_df
columns_with_nonzero_match = list(matched_cost_df.columns[matched_cost_df.any(axis=0)])

### Matched Values & Costs

#### Overall

In [None]:
matched_cost_df.sum(axis=1).value_counts(normalize=True)

#### Per-column (for > 0 only)

In [None]:
matched_cost_per_column_df = matched_cost_df[matched_cost_df.any(axis=1)].mean(axis=0)
matched_cost_per_column_df[columns_with_nonzero_match]

#### Values

In [None]:
for column in columns_with_nonzero_match:
    sx = synth_matched_df[column]
    sx.name = "synth"
    py = processed_matched_df[column]
    py.name = "processed"

    _, ax = plt.subplots(1, figsize=(10, 7))
    ax.set_title(column)
    sns.heatmap(pd.crosstab(sx, py), ax=ax)

#### Pseudo Noise Addition

In [None]:
# TODO: refactor

pseudo_noise_df = pd.merge(
    processed_matched_df,
    matched_diff_df,
    left_index=True,
    right_index=True,
    suffixes=("0", "1"),
)
pseudo_noise_df["id"] = pseudo_noise_df.index


pseudo_noise_df = pd.wide_to_long(
    pseudo_noise_df, processed_matched_df.columns, i="id", j="is_noise"
).reset_index()
pseudo_noise_df = (
    pseudo_noise_df.assign(
        ds=pseudo_noise_df["is_noise"].apply(lambda x: "noise" if x else "processed")
    )
    .drop("is_noise", axis=1)
    .melt(id_vars=["id", "ds"], var_name="column")
)

pseudo_noise_df = (
    pd.merge(
        pseudo_noise_df[pseudo_noise_df["ds"] == "processed"],
        pseudo_noise_df[pseudo_noise_df["ds"] == "noise"],
        on=("id", "column"),
    )
    .drop(["id", "ds_x", "ds_y"], axis=1)
    .rename(columns={"value_x": "processed", "value_y": "noise"})
)

In [None]:
matched_max_diff = matched_diff_df.abs().max()
variable_columns = matched_max_diff[matched_max_diff > 0].index
matched_max_diff

## Require Inspection
TODO: Move to `synthflow.evaluation.faithfulness`

In [None]:
# See section 6 in the paper
# https://www.dbs.ifi.lmu.de/Publikationen/Papers/LOF.pdf

LOF_N_NEIGHBORS_RANGE = (20, 51)

synth_lof_score = np.ones(len(synth_df))

if not is_minimal:
    for n_neighbors in trange(*LOF_N_NEIGHBORS_RANGE):
        lof = LocalOutlierFactor(n_neighbors=n_neighbors, p=1, novelty=True)
        lof.fit(numerified_datasets["processed"])

        # It is the opposite as bigger is better,
        # i.e. large values correspond to inliers.
        # 0 is the threshold; i.e. < 0 is an outlier
        synth_lof_score = np.minimum(
            synth_lof_score, lof.decision_function(numerified_datasets["synth"])
        )

In [None]:
existing_processed_rows = {tuple(x) for x in numerified_datasets["processed"].values}
plausible_mask = np.array(
    [tuple(x) in existing_processed_rows for x in numerified_datasets["synth"].values]
)

In [None]:
synth_outliers = synth_df.copy()
synth_outliers["lof"] = synth_lof_score.round(2)

unmatched_indices = faithfulness_alpha_1_pr_s["unmatched_indices"]
unmatched_synth_indices = list(unmatched_indices[1])
synth_outliers = synth_outliers.assign(unmatched=False)
synth_outliers.loc[unmatched_synth_indices, "unmatched"] = True
synth_outliers["plausible"] = plausible_mask

synth_outliers = synth_outliers.sort_values("lof")

In [None]:
g = sns.catplot(x="unmatched", y="lof", col="plausible", data=synth_outliers, alpha=0.1)
for ax in g.axes[0]:
    ax.axhline(0, color="black", linestyle="--");

In [None]:
A = synth_outliers["unmatched"]
B = synth_outliers["lof"] < 0
C = ~synth_outliers["plausible"]


venn3(
    subsets=[
        sum(A),
        sum(B),
        sum(A & B),
        sum(C),
        sum(A & C),
        sum(B & C),
        sum(A & B & C),
    ],
    set_labels=[
        f"Unmatched ({100 * sum(A) / len(A):.1f}%)",
        f"Outlier ({100 * sum(B) / len(B):.1f}%)",
        f"Implausible ({100 * sum(C) / len(C):.1f}%)",
    ],
);

In [None]:
synth_requires_inspections = synth_outliers[
    synth_outliers["unmatched"]
    & (synth_outliers["lof"] < 0)
    & (~synth_outliers["plausible"])
]
num_inspections = len(synth_requires_inspections)
pct_inspections = 100 * num_inspections / len(synth_df)

print(
    f"Synth records requiring inspections: {pct_inspections:.2f}% ({num_inspections})"
)

In [None]:
MAX_ROWS = 100

with pd.option_context("display.max_rows", None):
    for column in real_df.columns:
        by_column_synth_requires_inspections = (
            synth_requires_inspections.reset_index().set_index([column])
        )

        display(Markdown(f"#### `{column}` - bottom"))
        display(
            by_column_synth_requires_inspections.sort_values(
                [column, "lof"], ascending=[True, True]
            ).head(MAX_ROWS)
        )

        display(Markdown(f"#### `{column}` - top"))
        display(
            by_column_synth_requires_inspections.sort_values(
                [column, "lof"], ascending=[False, True]
            ).head(MAX_ROWS)
        )

## Additional Analysis