In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [None]:
# set version
version = "test"

# load metrics
metrics = pd.read_csv(f"lightning_logs/{version}/metrics.csv")

# merge entries on the same epoch
metrics = metrics.groupby("epoch").first()
metrics["dib_0_train"] = metrics["dib_0_train"].ffill()
metrics["dib_0_val"] = metrics["dib_0_val"].ffill()
display(metrics)

# initial legend elements
base_handles = [
    Line2D([], [], color="gray"),
    Line2D([], [], color="gray", linestyle="--"),
]

In [None]:
# plot train and val loss and accuracy
train_cols = metrics.filter(like="train_").columns
val_cols = metrics.filter(like="val_").columns
plt.plot(metrics.index, metrics[train_cols])
plt.gca().set_prop_cycle(None)
plt.plot(metrics.index, metrics[val_cols], linestyle="--")
# plt.ylim(0, 1)
plt.xlabel("Epoch")
plt.ylabel("Cross-Entropy Loss & Accuracy")
plt.legend(
    base_handles + plt.gca().get_lines(), ["Train", "Validation", "Accuracy", "Loss"]
)
plt.show()

In [None]:
# plot NC1
nc_cols = metrics.filter(like="nc_").columns
train_nc_cols = nc_cols[nc_cols.str.contains("train")]
val_nc_cols = nc_cols[nc_cols.str.contains("val")]
plt.plot(metrics.index, metrics[train_nc_cols])
plt.gca().set_prop_cycle(None)
plt.plot(metrics.index, metrics[val_nc_cols], linestyle="--")
plt.legend(
    base_handles + plt.gca().get_lines(),
    ["Train", "Validation"] + [rf"$l = {layer}$" for layer in (1, 2, 3, "L-1")],
)
plt.xlabel("Epoch")
plt.ylabel(r"$\operatorname{tr}(Σ_W^l(Σ_B^l)⁺)$")
plt.show()

In [None]:
# plot DIB
diff = metrics["dib_0_val"] - metrics["dib_0_train"]
num_epochs = metrics.shape[0]
for i, set in enumerate(("train", "val")):
    dib_cols = metrics.filter(regex=f"dib_[0-9]_{set}").columns
    values = metrics[dib_cols]
    if set == "train":
        values = values.add(diff, axis=0)

    plt.plot(metrics.index, values, linestyle=["-", "--"][i])

    if set == "train":
        plt.gca().set_prop_cycle(None)
        handles = base_handles + plt.gca().get_lines()

plt.xlabel("Epoch")
plt.ylabel("DIB Cross-Entropy Training Loss")
plt.legend(
    handles, ["Train", "Validation"] + [rf"$n_{{block}} = {layer}$" for layer in (0, 1, 2, 3)]
)
plt.show()

In [None]:
# IP analysis
for name in ("train", "val"):
    for dib in epoch_metrics.filter(regex=f"dib_[0-9]_{name}").columns:
        x = -epoch_metrics[dib]
        y = -epoch_metrics["train_loss"]
        plt.scatter(x, y, cmap="viridis", c=epoch_metrics["epoch"])
        plt.plot(x, y, color="gray", linewidth=0.5)
    plt.xlabel(r"$I_{\mathcal{V}}[Z \to \operatorname{Dec}(X,\mathcal{Y})]$")
    plt.ylabel(r"$I_{\mathcal{V}}[Z \to Y]$")
    plt.colorbar(label="Epochs")
    plt.title(f"{name.capitalize()} Set")
    plt.show()