In [80]:
import pathlib
from typing import List, Optional
import os
from matplotlib.ticker import FuncFormatter

import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

In [81]:
EXP_NAME = "25_10_17_sllm_d2"

RUN_DIR = pathlib.Path("/u/marti.juanola/experiments", EXP_NAME, "runs")
OUT_DIR = pathlib.Path("/u/marti.juanola/Documents/tfm/images")

print(f"Using log {RUN_DIR}")

Using log /u/marti.juanola/experiments/25_10_17_sllm_d2/runs


In [82]:
def load_scalars(logdir):
    ea = event_accumulator.EventAccumulator(
        logdir,
        size_guidance={
            event_accumulator.SCALARS: 0,
        },
    )
    ea.Reload()
    return ea


def _format_steps(x, pos):
    if x >= 1_000_000:
        return f"{x / 1_000_000:.1f}M"
    elif x >= 1_000:
        return f"{x / 1_000:.0f}k"
    return str(int(x))


def plot_scalar(ea, tag, outdir, log_name:Optional[str]=None, show: bool = True, figsize=(12, 8)):
    events = ea.Scalars(tag)
    steps = [e.step for e in events]
    values = [e.value for e in events]

    fig, ax = plt.subplots(figsize=figsize)

    ax.plot(steps, values, linewidth=2)

    ax.set_xlabel("Step", fontsize=18)
    ax.set_ylabel(tag, fontsize=18)

    if log_name:
        title = f"{log_name} - {tag}"
    else:
        title = tag
    ax.set_title(title, fontsize=22, pad=12)

    ax.tick_params(axis="both", labelsize=14)
    ax.xaxis.set_major_formatter(FuncFormatter(_format_steps))

    ax.grid(True)
    fig.tight_layout()

    if show:
        plt.show()
    else:
        filename = title.replace("/", "_") + ".png"
        fig.savefig(os.path.join(outdir, filename), dpi=300)
        plt.close(fig)


def plot_log(log_name: str,
             scalars_to_plot: Optional[List[str]] = None,
             out_dir: pathlib.Path = OUT_DIR,
             show: bool = True,
             figsize: Optional[tuple[int, int]] = (12, 8)):
    LOGDIR = os.path.join(RUN_DIR, log_name)
    ea = load_scalars(LOGDIR)
    scalar_tags = ea.Tags()["scalars"]

    for tag in scalar_tags:
        if tag not in scalars_to_plot:
            continue
        plot_scalar(ea, tag, out_dir, log_name=log_name, show=show, figsize=figsize)

## Plot divergence

In [83]:
plot_log("SLLM-d-2", scalars_to_plot=["devtrain/ce"], show=False, figsize=(12, 7))