In [None]:
import torch
import numpy as np
import lib.utils.bookkeeping as bookkeeping
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
import os
from lib.datasets.datasets import get_mnist_dataset
from lib.models.networks import MNISTScoreNet
from torch.optim import Adam
from lib.sampling.sampling import Euler_Maruyama_sampler
# Main file which contrains all DDSM logic
from lib.models.ddsm import *
#from lib.utils.utils import binary_to_onehot
import warnings
from functools import partial
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
# creating paths
path = 'SavedModels/Bin_MNIST/'
date = '2023-09-11'
config_name = 'config_001.yaml'
model_name = 'model_name.pt' #ckpt_0000004999.pt'

config_path = os.path.join(path, date, config_name)
checkpoint_path = os.path.join(path, date, model_name)
config = bookkeeping.load_config(config_path)

In [None]:
# sample
time_dependent_weights = torch.load(config.loading.time_dep_weights_path)
model = MNISTScoreNet(ch=config.model.ch, ch_mult=config.model.ch_mult, attn=config.model.attn, num_res_blocks=config.model.num_res_blocks, dropout=0.1, time_dependent_weights=time_dependent_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=config.optimizer.lr, weight_decay=config.optimizer.weight_decay)
state = {"model": model, "optimizer": optimizer, "n_iter": 0}
state = bookkeeping.load_state(state, checkpoint_path)
state['model'].eval()

sampler = Euler_Maruyama_sampler
samples =sampler(state['model'], config.data.shape, batch_size=config.sampler.n_samples, max_time=4, min_time=0.01, num_steps=100, eps=1e-5, random_order=config.random_order, speed_balanced=config.speed_balanced, device=config.device)

samples = samples.clamp(0.0, config.data.num_cat)
sample_grid = make_grid(samples[:,None, :,:,0].detach().cpu(), nrow=int(np.sqrt(config.sampler.n_samples)))

plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu())
saving_plot_path = os.path.join(config.saving.sample_plot_path, f"samples_epoch_eval{state['n_iter']}.png")
plt.savefig(saving_plot_path)
plt.close()


In [None]:
# eval ELBO
sb = UnitStickBreakingTransform()
v_one, v_zero, v_one_loggrad, v_zero_loggrad, timepoints = torch.load(config.loading.diffusion_weights_path)
v_one = v_one.cpu()
v_zero = v_zero.cpu()
v_one_loggrad = v_one_loggrad.cpu()
v_zero_loggrad = v_zero_loggrad.cpu()
timepoints = timepoints.cpu()

torch.set_default_dtype(torch.float32)
alpha = torch.ones(config.data.num_cat - 1).float()
beta = torch.arange(config.data.num_cat - 1, 0, -1).float()

if config.use_fast_diff:
    diffuser_func = partial(diffusion_fast_flatdirichlet, noise_factory_one=v_one, v_one_loggrad=v_one_loggrad)
else:
    diffuser_func = partial(
        diffusion_factorynoise_factory_one=v_one,
        noise_factory_zero=v_zero,
        noise_factory_one_loggrad=v_one_loggrad,
        noise_factory_zero_loggrad=v_zero_loggrad,
        alpha=alpha,
        beta=beta,
        device=config.device,
    )

if config.speed_balanced:
    s = 2 / (
        torch.ones(config.data.num_cat - 1, device=config.device)
        + torch.arange(config.data.num_cat - 1, 0, -1, device=config.device).float()
    )
else:
    s = torch.ones(config.data.num_cat - 1, device=config.device)

min_time = 0.001
max_time = 4
all_bpds = 0.
all_items = 0
_, _, test_dataloader = get_mnist_dataset(config)
for _ in range(50):
    for x_test in tqdm(test_dataloader):
        elbo = elbo(x_test, model, diffuser_func, min_time, max_time, sb, alpha, beta, speed_balanced=s, device="cpu", elbo_only=True)

        bpd = -(elbo.cpu().detach().numpy()) / np.log(2)
        all_bpds += bpd.sum()
        all_items += bpd.shape[0]
        print("Average bits: {:5f}".format(all_bpds / all_items))
        print("Average nats: {:5f}".format(all_bpds / all_items * np.log(2)))
