Simple visualizer for log files written by the training loop

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def parse_logfile(logfile):
    # so the tricky part we have to deal with in these log files
    # is that the job could crash and get restarted, which will
    # re-wind back and start re-logging older steps. So we keep
    # all the data as dictionary and over-write old data with new
    # and then at the end compile everything together

    # read raw data
    streams = {} # stream:str -> {step: val}
    with open(logfile, "r") as f:
        for line in f:
            parts = line.split()
            assert len(parts) == 2
            step = int(parts[0].split(":")[1])
            stream = parts[1].split(":")[0]
            val = float(parts[1].split(":")[1])
            if not stream in streams:
                streams[stream] = {}
            d = streams[stream]
            d[step] = val
    # now re-represent as list of (step, val) tuples
    streams_xy = {}
    for k, v in streams.items():
        # get all (step, val) items, sort them
        xy = sorted(list(v.items()))
        # unpack the list of tuples to tuple of lists
        streams_xy[k] = zip(*xy)
    # return the xs, ys lists
    return streams_xy

# parse_logfile("../log124M/main.log")

In [None]:
sz = "350M"
loss_baseline = {
    "124M": 3.424958,
    "350M": 3.083089,
    "774M": 3.000580,
    "1558M": 2.831273,
}[sz]
hella_baseline = {
    "124M": 0.294463,
    "350M": 0.375224,
    "774M": 0.431986,
    "1558M": 0.488946,
}[sz]

# assumes each model run is stored in this way
logfile = f"../log{sz}/main.log"
streams = parse_logfile(logfile)

plt.figure(figsize=(16, 6))

# Panel 1: losses: both train and val
plt.subplot(121)
xs, ys = streams["trl"] # training loss
plt.plot(xs, ys, label=f'llm.c ({sz}) train loss')
print("Min Train Loss:", min(ys))
xs, ys = streams["tel"] # validation loss
plt.plot(xs, ys, label=f'llm.c ({sz}) val loss')
# horizontal line at GPT-2 baseline
if loss_baseline is not None:
    plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f"OpenAI GPT-2 ({sz}) checkpoint val loss")
plt.xlabel("steps")
plt.ylabel("loss")
plt.yscale('log')
plt.legend()
plt.title("Loss")
print("Min Validation Loss:", min(ys))

# Panel 2: HellaSwag eval
plt.subplot(122)
xs, ys = streams["eval"] # HellaSwag eval
plt.plot(xs, ys, label=f"llm.c ({sz})")
# horizontal line at GPT-2 baseline
if hella_baseline:
    plt.axhline(y=hella_baseline, color='r', linestyle='--', label=f"OpenAI GPT-2 ({sz}) checkpoint")
plt.xlabel("steps")
plt.ylabel("accuracy")
plt.legend()
plt.title("HellaSwag eval")
print("Max Hellaswag eval:", max(ys))