In [None]:
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
import pandas as pd

import yaml
from PIL import Image

import iterative_coupling
import density_push as dp
from tqdm.notebook import tqdm, trange

In [None]:
results_folder = Path("results")

data = []
for run_dir in results_folder.iterdir():
    if not run_dir.is_dir():
        continue
    hparams_file = run_dir / "config.yaml"
    results_file = run_dir / "result.txt"
    if not hparams_file.exists() or not results_file.exists():
        continue
    with hparams_file.open("r") as file:
        hparams = yaml.safe_load(file)
    with results_file.open("r") as file:
        kl = float(file.read())
    data.append({
        "run_name": run_dir.name,
        "kl": kl,
        **hparams
    })

results = pd.DataFrame(data).sort_values("kl", ascending=True)
results.head(10)

In [None]:
ex = results.iloc[0]
ex

In [None]:
exp_dir = results_folder / ex.run_name
exp_dir

In [None]:
class RenamingUnpickler(pickle.Unpickler):
    """
    Unpickler that can load pickled objects from the __main__ module.
    They are renamed to iterative_coupling.
    """
    def find_class(self, module, name):
        if module == "__main__":
            module = "iterative_coupling"
        return super().find_class(module, name)

In [None]:
with open(exp_dir / "model.p", "rb") as file:
    model = RenamingUnpickler(file).load()

In [None]:
rots, layers = model["rots"], model["layers"]

In [None]:
transport = iterative_coupling.centered_chain(rots, layers)
dens0 = dp.get_density_by_name(ex.density)
latent = dp.GaussianDensity()

In [None]:
spline = transport[0][1]._s_t_fun
p = spline.s_spline.x
s, t = spline(torch.from_numpy(p))
mean = -t * s
std = 1 / s

selection = [0, 1, 2, 19, 100]

width_ratios = [1, 1.5, 1, 1]
gap = 0.2
height_ratios = [1] * (len(selection) - 2) + [gap, 1, gap, 1]
size_factor = 1
fig, axes = plt.subplots(len(selection) + 2, 4, figsize=(sum(width_ratios) * size_factor, sum(height_ratios) * size_factor),
                        constrained_layout=True, width_ratios=width_ratios,
                        height_ratios=height_ratios)

mesh_kwargs = dict(
    pos_min=-3,
    pos_max=3,
    resolution=150,
    fallback=False,
    mesh_mode=dp.vis.MESH_MODE_IMAGE,
    levels=9
)

for idx, axs in enumerate(tqdm(list(axes[:-4]) + [axes[-3], axes[-1]])):
    stop_idx = selection[idx]
    part_transport = transport[:stop_idx + 1]

    if idx == 0:
        axs[0].set_title("Rotated $p_{n-1}(z)$", fontsize=10)
        axs[1].set_title("Transform", fontsize=10)
        axs[2].set_title("Latent $p_{n}(z)$", fontsize=10)
        axs[3].set_title("Data $p_{n}(x)$", fontsize=10)

    depth = len(part_transport)
    axs[0].set_ylabel(f"$n={depth}$")

    # Determine densities
    data_estimate = dp.PullBackwardDensity(latent, part_transport)
    in_latent_estimate = dp.PushForwardDensity(dens0, part_transport[:-1] + part_transport[-1][:1])
    new_latent_estimate = dp.PushForwardDensity(dens0, part_transport[:-1] + part_transport[-1][:-1])

    # Incoming latent
    plt.sca(axs[0])
    dp.vis.density_mesh(in_latent_estimate, **mesh_kwargs)

    # Transform
    plt.sca(axs[1])
    offset = (p.max() - p.min()) * 1.5
    dp.vis.deformed_grid(part_transport[-1][:0],
                         [p.min() - offset, p.min()], [p.max() - offset, p.max()],
                         resolution=5, sub_resolution=100)
    dp.vis.deformed_grid(part_transport[-1][1], [p.min(), p.min()], [p.max(), p.max()],
                         resolution=5, sub_resolution=100)
    # dp.vis.absorbtion(part_transport[-1][1], [p.min(), p.min()], [p.max(), p.max()],
    #                   20, symmetric=True, log=True, zorder=3)
    center = (p.min() + p.max() - offset) / 2
    width = (p.min() - p.max() + offset) / 2
    border_factor = 1.1
    inner_factor = 0.9
    height_factor = 0.7
    """plt.annotate("",
                 [center - offset * border_factor * 1.1, 0],
                 [center + offset * border_factor * 1.1, 0],
                 arrowprops=dict(fc="black", arrowstyle="<->"), zorder=1)"""
    
    plt.fill(
        [
            center - offset * border_factor, center - offset * inner_factor, center + offset * inner_factor,
            center + offset * border_factor, center + offset * inner_factor, center - offset * inner_factor,
            center - offset * border_factor
        ],
        [
            0, height_factor * offset, height_factor * offset,
            0, -height_factor * offset, -height_factor * offset,
            0
        ], fc="white", ec="w", lw=1.5
    )
    plt.fill(
        [center - width / 2, center + width / 2, center - width / 2],
        [0 - width, 0, 0 + width],
        fc="C0"
    )
    plt.xlim(center - offset * border_factor * 1.5, center + offset * border_factor * 1.5)
    plt.axis("equal")

    # Outgoing latent
    plt.sca(axs[2])
    dp.vis.density_mesh(new_latent_estimate, **mesh_kwargs)

    # Resulting data estimate
    plt.sca(axs[3])
    dp.vis.density_mesh(data_estimate, -3, 3, 150, fallback=False, mesh_mode=dp.vis.MESH_MODE_IMAGE)


axes[-2, 0].set_ylabel("…")
axes[-4, 0].set_ylabel("…")

for ax in axes.flat:
    plt.sca(ax)
    plt.xticks([])
    plt.yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)
# plt.savefig("iterative_spline.pdf", bbox_inches="tight")