In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt

from prior_networks.priornet.dpn import PriorNet, dirichlet_prior_network_uncertainty
from prior_networks.priornet.run.synth_model import SynthModel
from prior_networks.plot_util import visualise_uncertainty, plot_contourf
from prior_networks.util_pytorch import get_grid_eval_points, categorical_entropy_torch
from prior_networks.datasets.toy.classification.mog import MixtureOfGaussiansDataset

In [None]:
train_dataset = MixtureOfGaussiansDataset(size=1000, noise=1, scale=4)
ood_dataset = MixtureOfGaussiansDataset(size=1000, noise=1, scale=4, OOD=True)

fig, ax = plt.subplots(figsize=(10, 10))
ax.set_aspect('equal')
train_dataset.plot(ax=ax)
ood_dataset.plot(ax=ax)

In [None]:
std = 4
chkpt = torch.load(f"./checkpoints/dpn_synth_std{std}/checkpoint.tar", map_location=torch.device("cpu"))
model = SynthModel()
model.load_state_dict(chkpt["model_state_dict"])
model.eval()
torch.set_grad_enabled(False)

In [None]:
extent = 20
res = 200

points = get_grid_eval_points((-extent, extent), (-extent, extent), res)
# This are the alphas for the Dirichlet dist
model_out = model(points)

In [None]:
metrics = dirichlet_prior_network_uncertainty(model_out)

In [None]:
plot_contourf(points, metrics["expected_entropy"], extent, res)
plot_contourf(points, metrics["differential_entropy"], extent, res)


In [None]:
plot_contourf(points, metrics["expected_entropy"], extent, res, show=False, name=f"expected_entropy_std{std}")
plot_contourf(points, metrics["differential_entropy"], extent, res, show=False, name=f"differential_entropy_std{std}")

In [None]:
%debug