In [None]:
import matplotlib.pyplot as plt
import torch

from coolchic.enc.io.io import load_frame_data_from_tensor
from coolchic.hypernet.hypernet import DeltaWholeNet
from coolchic.utils.paths import CONFIG_DIR, COOLCHIC_REPO_ROOT, DATA_DIR
from coolchic.utils.tensors import load_img_from_path
from coolchic.utils.types import HypernetRunConfig, load_config

plt.rcParams["figure.dpi"] = 300

In [None]:
# load image from clic20
img_path = DATA_DIR / "clic20-pro-valid" / "gian-reto-tarnutzer-45212.png"
img = load_frame_data_from_tensor(load_img_from_path(img_path)).data
img = img[..., :512, :512]  # crop to 512x512

In [None]:
# Plot image in tensor
import matplotlib.pyplot as plt


def plot_img(img):
    img = img.squeeze()
    plt.imshow(img.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.show()

In [None]:
plot_img(img)  # plot the compressed image

In [None]:
# instantiate a deltawholenet
cfg_path = CONFIG_DIR / "exps" / "delta-hn" / "ups-best-orange" / "config_04.yaml"
cfg = load_config(cfg_path, HypernetRunConfig)
checkpoint_path = COOLCHIC_REPO_ROOT / "best_delta_config04.pt"

net = DeltaWholeNet(cfg.hypernet_cfg)
net.load_state_dict(torch.load(checkpoint_path, weights_only=True, map_location="cpu"))

In [None]:
with torch.no_grad():
    net.eval()  # set to eval mode
    compre_img = net.forward(
        # img.unsqueeze(0),  # add batch dimension
        img,
        quantizer_noise_type="none",
        quantizer_type="true_ste",
    )

In [None]:
plot_img(compre_img[0].squeeze())
plot_img(img[0])

In [None]:
from coolchic.hypernet.common import add_deltas

net.eval()
with torch.no_grad():
    latents, synth_deltas, arm_deltas, ups_deltas = net.hypernet.forward(img)
    deltas = {"synthesis": synth_deltas, "arm": arm_deltas}
    model_params = {
        "synthesis": dict(net.mean_decoder.synthesis.named_parameters()),
        "arm": dict(net.mean_decoder.arm.named_parameters()),
    }
    added_params = add_deltas(
        net.mean_decoder.named_parameters(),
        deltas["synthesis"],
        deltas["arm"],
        {},
        batch_size=1,
    )

effective_params = {
    "synthesis": {},
    "arm": {},
}
for key in added_params:
    if "synthesis" in key:
        effective_params["synthesis"][key] = added_params[key]
    elif "arm" in key:
        effective_params["arm"][key] = added_params[key]

In [None]:
def hist_weights(module, deltas, effective):
    import matplotlib.pyplot as plt
    import pandas as pd
    import seaborn as sns

    flat_deltas = (
        torch.cat([v.view(-1) for k, v in deltas[module].items()])
        .cpu()
        .detach()
        .numpy()
    )
    flat_effective = (
        torch.cat([v.view(-1) for k, v in effective[module].items()])
        .cpu()
        .detach()
        .numpy()
    )
    data = pd.DataFrame(
        {
            "type": ["delta"] * len(flat_deltas) + ["effective"] * len(flat_effective),
            "value": list(flat_deltas) + list(flat_effective),
        }
    )

    plt.figure(figsize=(4, 3))
    sns.histplot(
        data=data,
        x="value",
        hue="type",
        bins=25,
    )
    plt.title(f"Weight distribution for {module}")
    plt.xlabel("Weight Value")
    plt.ylabel("Frequency")


hist_weights("synthesis", deltas, effective_params)
hist_weights("arm", deltas, effective_params)

In [None]:
def print_weight_stats(module_name: str):
    for name, delta in deltas[module_name].items():
        eq_weight = model_params[module_name][name]
        print(
            f"{name}, {delta.mean()=:.3f}, {delta.std()=:.3f}, {delta.min()=:.3f}, {delta.max()=:.3f}"
        )
        print(
            f"    {100 * delta.std() / eq_weight.std():.1f}% of the original weight std"
        )

In [None]:
print_weight_stats("synthesis")

In [None]:
print_weight_stats("arm")