In [None]:
import json
import torch
from tqdm import tqdm

from datasets import get_dataset
from models import get_model
from tools import get_loss_fn

import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
def load(tag):
    root = f"outputs/{tag}"
    with open(f"{root}/config.json") as f:
        config = json.load(f)

    dataset = get_dataset(config["dataset"])
    loss_fn = get_loss_fn(config["fit"]["loss_fn"])
    sto_model = get_model(dataset, config["model"])
    config["model"]["dropout"]["name"] = "reg"
    reg_model = get_model(dataset, config["model"])

    state = torch.load(f"{root}/checkpoint.pt")
    sto_model.load_state_dict(state["model"])
    reg_model.load_state_dict(state["model"])

    sto_model.train()
    reg_model.train()

    det_losses = []
    reg_losses = []
    sto_losses = []

    with torch.no_grad():
        progress = tqdm(dataset.train_loader, leave=False)
        for input, target in progress:

            output = reg_model(input)
            det_loss = loss_fn(output, target)
            reg_loss = det_loss + \
                reg_model.reg_loss(output, target)

            sto_loss = torch.stack([
                loss_fn(sto_model(input), target)
                for _ in range(100)
            ]).mean(0)

            det_losses.append(det_loss)
            reg_losses.append(reg_loss)
            sto_losses.append(sto_loss)

            diffs = torch.cat(sto_losses) - torch.cat(reg_losses)
            progress.set_postfix({
                "mean": f"{diffs.mean().item():.4f}",
                "std": f"{diffs.std().item():.4f}",
            })
        
        det_losses = torch.cat(det_losses).numpy()
        reg_losses = torch.cat(reg_losses).numpy()
        sto_losses = torch.cat(sto_losses).numpy()

    return det_losses, reg_losses, sto_losses

In [None]:
det_losses, reg_losses, sto_losses = load("cifar-10_d200_norm_w_s1o1_l1_1")

In [None]:
r = 0.5
_, ax = plt.subplots(figsize=(4, 4), dpi=100, facecolor="w")
ax.hist2d(
    sto_losses - det_losses, reg_losses - det_losses,
    range=((-r, r), (-r, r)), bins=(100, 100),
    norm=mpl.colors.LogNorm())
None

In [None]:
def hess(tag):

    root = f"outputs/{tag}"
    with open(f"{root}/config.json") as f:
        config = json.load(f)
    config["model"]["dropout"]["name"] = "reg"

    dataset = get_dataset(config["dataset"])
    loss_fn = get_loss_fn(config["fit"]["loss_fn"])
    model = get_model(dataset, config["model"])

    state = torch.load(f"{root}/checkpoint.pt")
    model.load_state_dict(state["model"])

    print(model)

    model.eval()
    with torch.no_grad():
        input, target = next(iter(dataset.train_loader))
        output = model(input)
        batch = torch.arange(input.shape[0])

        modules = list(model.modules())
        ctx = modules[-2]._init(output)
        for i, m in reversed(list(enumerate(modules))):
            if isinstance(m, Regularization):

                input_ = m(m.state)
                model_ = torch.nn.Sequential(*modules[i+1:])
                expect = torch.autograd.functional.hessian(
                    lambda x: loss_fn(model_(x), target).sum(), input_)
                expect = expect[batch, :, batch, :]

                hess, jacob = ctx
                actual = torch.einsum(
                    "bij,bik,bjl->bkl", hess, jacob, jacob)
                print((expect - actual).norm())

                ctx = m._next(ctx)