In [1]:
import json
import torch

from src.datasets import get_dataset
from src.modules import get_model, get_loss_fn
from src.modules.dropouts import LinearApproxDropout

In [2]:
path = "outputs/schemes/lin-apx_all_0"
with open(f"{path}/config.json") as f:
    config = json.load(f)

dataset = get_dataset(config["dataset"])
model = get_model(config["model"], dataset.input_shape, dataset.num_classes)
loss_fn = get_loss_fn(config["fit"]["loss_fn"])

state = torch.load(f"{path}/checkpoint-50.pt")
model.load_state_dict(state["model"])
model.train()

for inputs, targets in dataset.train_loader:
    inputs = inputs[:20]
    targets = targets[:20]

Files already downloaded and verified
Files already downloaded and verified


<All keys matched successfully>

In [3]:
outputs = model(inputs)
model.reg_loss()

tensor([0.1177, 0.4064, 0.1463, 0.1577, 0.1389, 0.3176, 0.1455, 0.0474, 0.1134,
        0.1207, 0.1704, 0.0755, 0.2161, 0.0706, 0.1244, 0.2616, 0.1149, 0.1964,
        0.0988, 0.0893], grad_fn=<AddBackward0>)

In [4]:
reg_loss = 0
modules = list(model.children())
for i, m in enumerate(modules):
    if isinstance(m, LinearApproxDropout) and m.std > 0:
        uppers = torch.nn.Sequential(*modules[i:])
        hess = torch.autograd.functional.hessian(lambda x: loss_fn(uppers(x), targets).sum(), m.state)
        reg_loss = reg_loss + torch.einsum("bibi,bi->b", hess, m.state.square()) * (0.5 ** 2 / 2)
print(reg_loss)

tensor([0.1177, 0.4064, 0.1463, 0.1577, 0.1389, 0.3176, 0.1455, 0.0474, 0.1134,
        0.1207, 0.1704, 0.0755, 0.2161, 0.0706, 0.1244, 0.2616, 0.1149, 0.1964,
        0.0988, 0.0893], grad_fn=<AddBackward0>)
