In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import json
import itertools

In [None]:
dbs = ["cifar-10"] #["mnist", "cifar-10", "cifar-100", "svhn"]
dims = ["d200"] #["d100", "d200", "d500"]
dos = ["norm", "reg"]
ons = ["x", "w"]
stds = ["s1o4", "s1o2", "s1o1"]
lyrs = ["no", "l1",  "l2", "l3", "all"]

In [None]:
data = []
for db, dim, do, on, std, lyr in itertools.product(dbs, dims, dos, ons, stds, lyrs):
  label = f"{db}_{dim}_{do}_{on}_{std}_{lyr}_1"
  try:
    state = torch.load(f"outputs/{label}/analysis.pt")
  except:
    continue
  data.append({
    "db": db, "dim": dim, "do": do, "on": on, "std": std, "lyr": lyr,
    **{f"w{i}": w.norm().item() for i, w in enumerate(state["ml"]["w"])},
    **{f"b{i}": b.norm().item() for i, b in enumerate(state["ml"]["b"])},
    **{f"x{i}": m.diagonal().sum().sqrt().item() for i, m in enumerate(state["ml"]["m2"])},
  })
df = pd.DataFrame(data)

mets = {"state": "x", "weight": "w", "bias": "b"}
for i, (db, std) in enumerate(itertools.product(dbs, stds)):
  ncols = len(dos)*len(ons)
  nrows = len(mets)
  _, axes = plt.subplots(
    nrows=nrows, ncols=ncols, sharex="col", sharey="row",
    figsize=(ncols*2.5, nrows*2), dpi=100, facecolor="w")
  for j, (on, do) in enumerate(itertools.product(ons, dos)):
    sdf = df[(df["db"]==db)&(df["dim"]==dim)&(df["do"]==do)&(df["on"]==on)&(df["std"]==std)]
    if len(sdf) == 0:
      continue
    sdf = sdf.set_index("lyr").drop(columns=["db", "dim", "do", "on", "std"])
    sdf = sdf.divide(sdf.loc["no"]).drop("no")
    for k, (met, prefix) in enumerate(mets.items()):
      cols = [col for col in sdf.columns if col.startswith(prefix)]
      sdf[cols].plot.bar(ax=axes[k][j], legend=False, rot=0)
      axes[k][j].set_xlabel(f"{do} {on} {std}")
      axes[k][j].set_ylabel(met)
      axes[k][j].set_ylim((0, 1.5))
      axes[k][j].grid()

In [None]:
def read(config):
  data = []
  for values in itertools.product(*config.values()):
    label = "_".join(values)+"_1"
    title = {k: v for k, v in zip(config.keys(), values)}
    if title["detach"] == "ans":
      db, dim, do, on, detach, std, lyr = values
      label = f"{db}_{dim}_norm_{on}_{std}_{lyr}_1"
    try:
      with open(f"outputs/{label}/logs.json") as f:
          logs = json.load(f)
      for i, log in enumerate(logs):
          data.append({**title, "epoch": i+1, **log})
    except:
      pass
  return pd.DataFrame(data)

def plot_curve(config, x, y, g):
  df = read(config)
  values = list(itertools.product(*[config[k] for k in x]))
  fields = y

  ncols = len(values)
  nrows = len(fields)
  _, axes = plt.subplots(
    nrows=nrows, ncols=ncols, sharex=True, sharey="row",
    figsize=(ncols*2.5, nrows*2), dpi=100, facecolor="w")

  for i, vs in enumerate(values):
    label = " ".join(vs)
    sdf = df
    for k, v in zip(x, vs):
      sdf = sdf[sdf[k] == v]
    if len(sdf) == 0:
      continue
    sdf = sdf.set_index("epoch").groupby(g)
    for j, field in enumerate(fields):
      ax = axes[j][i]
      sdf[field].plot(ax=ax, legend=False)
      ax.set_xlabel(label)
      ax.set_ylabel(field)
      ax.grid()
      if j == 0:
        ax.legend()
      if field == "eval_top1":
        ax.set_ylim((0.45, 0.6))

config = {
  "db": ["cifar-10"],
  "dim": ["d200"],
  "do": ["reg"],
  "on": ["x", "w"],
  "detach": ["ans", "no"],
  "std": ["s1o1"],
  "lyr": ["l1", "l2", "l3"],
}

metsets = [
  ["train_loss", "eval_loss", "eval_top1"], 
  [f"train_model_{i}.state.l1" for i in (4, 6, 8)] + ["train_model.output.l1"],
  [f"eval_model_{i}.weight.l1" for i in (1, 3, 5, 7)],
]
for mets in metsets:
  plot_curve(config, x=["do", "std", "on", "std", "lyr"], y=mets, g="detach")


In [None]:
data = []
for db, dim, do, on, std, lyr in itertools.product(dbs, dims, dos, ons, stds, lyrs):
  label = f"{db}_{dim}_{do}_{on}_{std}_{lyr}_1"
  title = {"db": db, "dim": dim, "do": do, "std": std, "lyr": lyr, "on": on}
  try:
    with open(f"outputs/{label}/logs.json") as f:
        logs = json.load(f)
    for i, log in enumerate(logs):
        data.append({**title, "epoch": i+1, **log})
  except:
    pass
df = pd.DataFrame(data)

metsets = [
  ["train_loss", "eval_loss", "eval_top1"], 
  [f"train_model_{i}.state.l1" for i in (4, 6, 8)] + ["train_model.output.l1"],
  #[f"eval_model_{i}.weight.l1" for i in (1, 3, 5, 7)],
]
for i, (db, on, mets) in enumerate(itertools.product(dbs, ons[:1], metsets)):
  ncols = len(dos)*len(stds)
  nrows = len(mets)
  _, axes = plt.subplots(
    nrows=nrows, ncols=ncols, sharex=True, sharey="row",
    figsize=(ncols*2.5, nrows*2), dpi=100, facecolor="w")
  for j, (do, std) in enumerate(itertools.product(dos, stds)):
    sdf = df[(df["db"]==db)&(df["dim"]==dim)&(df["do"]==do)&(df["std"]==std)&(df["on"]==on)]
    if len(sdf) == 0:
      continue
    sdf = sdf.drop(columns=["db", "dim", "do", "on", "std"])
    sdf = sdf.set_index("epoch").groupby("lyr")
    for k, met in enumerate(mets):
      sdf[met].plot(ax=axes[k][j], legend=False)
      axes[k][j].set_xlabel(f"{do} {on} {std}")
      axes[k][j].set_ylabel(met)
      axes[k][j].grid()
      if met == "eval_top1":
        axes[k][j].set_ylim((0.45, 0.6))
    axes[0][j].legend()

In [None]:
def plot(db, dim, do, on, std, lyr, layers, flip, axes):
  ylim, xlim = [0, 2], [-.25, .25]
  label = f"{db}_{dim}_{do}_{on}_{std}_{lyr}_1"
  try:
    state = torch.load(f"outputs/{label}/analysis.pt")
  except:
    return
  for j, layer in enumerate(layers):
    w = state["ml"]["w"][layer] if flip else state["ml"]["w"][layer+1].T
    x = torch.matmul(w,  w.T)
    y = state["ml"]["m2"][layer]
    axes[-j-1].hist2d(
      x.flatten().numpy(), y.flatten().numpy(), 
      range=[xlim, ylim], bins=[50, 50],
      norm=mpl.colors.LogNorm())
  axes[-1].set_xlabel(f"{do} {on} {std} {lyr}")

layers = range(3)
for db, dim, on, lyr in itertools.product(dbs, dims, ons, lyrs[1:]):
  nrows = len(layers)
  ncols = len(dos)*len(stds)
  _, axes = plt.subplots(
    nrows=nrows, ncols=ncols, sharex=True, sharey=True,
    figsize=(ncols*2, nrows*2), dpi=100, facecolor="w")
  for i, (do, std) in enumerate(itertools.product(dos, stds)):
    sect = [row[i] for row in axes]
    plot(db, dim, do, on, std, lyr, layers, False, sect)