## Setup

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [None]:
version = "version_X"

# load metrics and merge train and val rows
metrics = pd.read_csv(f"lightning_logs/{version}/metrics.csv")
metrics = metrics.groupby("epoch").first()

# fill in missing values for DIB when l=0
for v_info in ("vsuff", "vmin"):
    for dataset in ("train", "val"):
        metrics[f"{v_info}_0_{dataset}"] = metrics[f"{v_info}_0_{dataset}"].ffill()

# add generalisation error
metrics["gen_error"] = metrics["val_loss"] - metrics["train_loss"]

# save cleaned metrics
if not os.path.exists(version):
    os.mkdir(version) 
metrics.to_csv(f"../Overleaf/data/{version}.csv")
display(metrics)

# initial legend elements
base_handles = [
    Line2D([], [], color="gray", linestyle=linestyle) for linestyle in ("-", "--")
]

In [None]:
def get_cols(metric):
    return [
        metrics.filter(regex=f"{metric}_(layer_)?[0-9]_{dataset}").columns
        if metric != "acc-loss"
        else metrics.filter(like=f"{dataset}_").columns
        for dataset in ("train", "val")
    ]


def plot_metrics(name, labels, ylabel, xlabel="epoch of main network training", dummy=False):
    for i, dataset_cols in enumerate(get_cols(name)):
        if dummy:  # dummy plot to iterate colour cycle
            plt.plot([])
        plt.plot(metrics.index, metrics[dataset_cols], linestyle=["-", "--"][i])
        plt.gca().set_prop_cycle(None)
    plt.xlabel(xlabel)
    plt.xticks(range(0, metrics.shape[0], 2))
    plt.ylabel(ylabel)
    plt.grid(True, "major", linestyle="--", linewidth=0.5)
    plt.legend(
        base_handles + plt.gca().get_lines()[int(dummy) :],
        ["Train", "Test"] + labels,
    )
    # plt.savefig(f"{version}/{name}.pdf", bbox_inches="tight")
    plt.show()

## Loss and Accuracy Curves

In [None]:
plot_metrics("acc-loss", ["Accuracy", "Loss"], "cross-entropy loss & accuracy", "epoch")

## Neural Collapse

In [None]:
plot_metrics(
    "nc",
    [rf"$l = {layer}$" for layer in (1, 2, 3)],
    r"$\operatorname{tr}(Σ_W^l(Σ_B^l)⁺)$",
    "epoch",
    dummy=True,
)

### Generalisation Gap vs. Compression Gap

In [None]:
# sort metrics by generalisation error for cleaner plots
sorted_metrics = metrics.sort_values(by="gen_error")

# only plot positive gen. error
gen_error = sorted_metrics["gen_error"]
pos_mask = gen_error > 0

plt.plot([])  # dummy plot to iterate colour cycle
for train_col, test_col in zip(*get_cols("nc")):
    comp_diff = sorted_metrics[test_col] - sorted_metrics[train_col]
    plt.plot(gen_error[pos_mask], comp_diff[pos_mask])
plt.legend(plt.gca().get_lines()[1:], [rf"$l = {layer}$" for layer in (1, 2, 3)])
plt.show()

### Pearson Correlation Between $\{\log(\operatorname{tr}(Σ_W^l(Σ_B^l)⁺))\}_{l=1}^L$ and $\{l\}_{l=1}^L$

In [None]:
# data frame where every row is (1,2,…,L)
layer_idcs = pd.DataFrame(
    np.arange(1, metrics.shape[1] + 1)[np.newaxis, :].repeat(metrics.shape[0], axis=0),
    columns=metrics.columns,
)

for i, dataset_cols in enumerate(get_cols("nc")):
    log_metrics = np.log(metrics[dataset_cols])
    corrs = log_metrics.corrwith(layer_idcs, axis=1).rename(
        f"{version}_{['train', 'test'][i]}"
    )

    # save correlations to csv
    if not os.path.exists("corrs.csv"):
        corrs.to_csv("corrs.csv")
    else:
        old_corrs = pd.read_csv("corrs.csv")
        old_corrs[f"{version}_{['train', 'test'][i]}"] = corrs
        old_corrs.to_csv("corrs.csv", index=False)

corrs = pd.read_csv("corrs.csv", index_col=0)
plt.plot(corrs.index, corrs)
plt.xticks(range(0, metrics.shape[0], 2))
plt.show()

## 𝒱-Information

### Sufficiency

In [None]:
plot_metrics(
    "vsuff",
    [rf"$l = {layer}$" for layer in (0, 1, 2, 3)],
    r"$\operatorname{I}_{\mathcal{V}ˡ}(𝗵ˡ → y)$ in bits",
)

### Minimality

In [None]:
plot_metrics(
    "vmin",
    [rf"$l = {layer}$" for layer in (0, 1, 2, 3)],
    r"$\operatorname{I}_{\mathcal{V}ˡ}(𝗵ˡ → \operatorname{Dec}(𝘅, \mathcal{Y})$ in bits",
)

### Information Plane

In [None]:
for i, dataset in enumerate(("Train", "Test")):
    vmin_cols = get_cols("vmin")[i]
    vsuff_cols = get_cols("vsuff")[i]
    for vmin, vsuff in zip(vmin_cols, vsuff_cols):
        x = metrics[vmin]
        y = metrics[vsuff]
        plt.scatter(x, y, cmap="viridis", c=metrics.index)
        plt.plot(x, y, color="gray", linewidth=0.5)
    plt.xlabel(r"$\operatorname{I}_{\mathcal{V}ˡ}(𝗵ˡ → \operatorname{Dec}(𝘅, \mathcal{Y})$ in bits")
    plt.ylabel(r"$\operatorname{I}_{\mathcal{V}ˡ}(𝗵ˡ → y)$ in bits")
    plt.colorbar(label="epoch of main network training", ticks=range(0, metrics.shape[0], 5))
    plt.title(f"{dataset} Set")
    # plt.savefig(f"{version}/ip-{dataset.lower()}.pdf", bbox_inches="tight")
    plt.show()

In [None]:
""" # Code for generating NC visualisation data
%matplotlib ipympl
import numpy as np

rng = np.random.default_rng(999)


def gen_mu(rng=rng, r=0.8):
    theta = rng.uniform(0, 2 * np.pi, 3)
    phi = rng.uniform(0, np.pi, 3)
    return (
        r * np.c_[np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)]
    )


def gen_points(mu, n=5):
    return rng.multivariate_normal(mu.flatten(), 0.1 * np.eye(mu.size), n).reshape(
        (n, *mu.shape)
    )
    # return rng.uniform(-1, 1, (n, *mu.shape))


mu = gen_mu()
points = gen_points(mu)
while np.linalg.norm(points, axis=1).max() > 1:
    mu = gen_mu()
    points = gen_points(mu)
mu = points.mean(axis=0)

triangle = [[0, 0, 1], [-np.sqrt(3) / 2, 0, -0.5], [np.sqrt(3) / 2, 0, -0.5]]

for i in range(3):
    np.savetxt(f"points_{i}.csv", points[:, i])
np.savetxt("mu.csv", mu)
np.savetxt("simplex.csv", np.array(triangle))

fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
for i in range(3):
    ax.scatter(points[:, i, 0], points[:, i, 1], points[:, i, 2])
    mu = points[:, i].mean(axis=0)
    ax.scatter(*mu, marker="x", color="k")
    ax.scatter(*triangle[i], marker="+", color="k") """