In [1]:
import torch
import lib.utils.bookkeeping as bookkeeping
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import lib.models.models as models
import lib.models.model_utils as model_utils
from lib.datasets import mnist, maze, protein, synthetic
import lib.datasets.dataset_utils as dataset_utils
import lib.losses.losses as losses
import lib.losses.losses_utils as losses_utils
import lib.training.training as training
import lib.training.training_utils as training_utils
import lib.optimizers.optimizers as optimizers
import lib.optimizers.optimizers_utils as optimizers_utils
import lib.sampling.sampling as sampling
import lib.sampling.sampling_utils as sampling_utils
from lib.d3pm import make_diffusion
import os
from lib.datasets.maze import maze_acc
from ruamel.yaml.scalarfloat import ScalarFloat

In [2]:
# creating path
path = "SavedModels/MAZE/"
date = '2024-02-06' # 2
config_name = 'config_001_d3pm_256.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_59999_d3pm_256.pt'

config_path = os.path.join(path, date, config_name)
checkpoint_path = os.path.join(path, date, model_name)



In [3]:
# creating models
cfg = bookkeeping.load_config(config_path)
cfg.sampler.name = 'LBJF' #ExactSampling' # ElboLBJF CRMTauL CRMLBJF
cfg.sampler.num_corrector_steps = 0
cfg.sampler.corrector_entry_time = ScalarFloat(0.0)
cfg.sampler.num_steps = 1000
cfg.sampler.is_ordinal =False
diffusion = make_diffusion(cfg.model)
#print(cfg)
device = torch.device(cfg.device)
print(device)

model = model_utils.create_model(cfg, device)
print("number of parameters: ", sum([p.numel() for p in model.parameters()]))

#modified_model_state = utils.remove_module_from_keys(loaded_state['model'])
#model.load_state_dict(modified_model_state)
#optimizer = optimizers_utils.get_optimizer(model.parameters(), cfg)
optimizer = torch.optim.Adam(model.parameters(), cfg.optimizer.lr)

sampler = sampling_utils.get_sampler(cfg)

state = {"model": model, "optimizer": optimizer, "n_iter": 0}
state = bookkeeping.load_state(state, checkpoint_path, device)
state['model'].eval()

# Maze:
# 80% TauL
# 84% LBJF
# 96%

[32m2024-02-06 09:04:11.500[0m | [1mINFO    [0m | [36mlib.d3pm[0m:[36m__init__[0m:[36m98[0m - [1m[compute transition matrix]: uniform[0m


in betas 256
from beta 256


[32m2024-02-06 09:04:12.837[0m | [1mINFO    [0m | [36mlib.d3pm[0m:[36m__init__[0m:[36m119[0m - [1m[trainsition matrix]: torch.Size([256, 3, 3])[0m
[32m2024-02-06 09:04:12.838[0m | [1mINFO    [0m | [36mlib.d3pm[0m:[36m__init__[0m:[36m122[0m - [1m[Construct transition matrices for q(x_t|x_start)][0m
[32m2024-02-06 09:04:14.463[0m | [1mINFO    [0m | [36mlib.d3pm[0m:[36m__init__[0m:[36m142[0m - [1m[tilde(Q)t]: torch.Size([256, 3, 3])[0m


cuda
number of parameters:  8102704
ema state dict function


In [10]:
n_samples = 1000
if cfg.data.name == 'Maze3S':
    shape = (n_samples, 1, 15, 15)
elif cfg.data.name == 'DiscreteMNIST':
    shape = (n_samples, 1, 28, 28)
elif cfg.data.name == 'SyntheticData':
    shape = (n_samples, 32)
else:
    raise ValueError("wrong")

samples = diffusion.p_sample_loop(state['model'], shape, cfg.model.num_timesteps).cpu().numpy()
#print(changes_jump)

#print("# avg chang rejecting", np.mean(changes_clamp))

saved_samples = samples

#np.save('changing_dims_mnist_mptaul_ins_20.npy', change_dim)
#np.save('changing_dims_mnist_mptaul_1_20.npy', change_first)
#np.save('changing_dims_mnist_mptaul_1to2_20.npy', change_1to2)
#np.save('changing_dims_are_mjumps_1to2_mnist_mptaul_20.npy', change_jumps)
#np.save('changes_are_mjumps_mptaul_20.npy', change_mjumps)

[32m2024-02-06 09:27:28.124[0m | [1mINFO    [0m | [36mlib.d3pm[0m:[36mp_sample_loop[0m:[36m570[0m - [1mcuda[0m


In [None]:
is_img = cfg.data.is_img
#n_samples = 9
if is_img:
    samples = samples.reshape(-1, 1, cfg.data.image_size, cfg.data.image_size)
    saving_train_path = os.path.join(cfg.saving.sample_plot_path, f"{cfg.model.name}{state['n_iter']}_{cfg.sampler.name}{cfg.sampler.num_steps}.png")
    fig = plt.figure(figsize=(9, 9)) 
    for i in range(n_samples):
        plt.subplot(int(np.sqrt(n_samples)), int(np.sqrt(n_samples)), 1 + i)
        plt.axis("off")
        plt.imshow(np.transpose(samples[i, ...], (1,2,0)), cmap="gray")
 
    # saving_train_path
    plt.savefig('crm_hollow.pdf', transparent=True)
    plt.show()
    plt.close()
else:
    bm, inv_bm = synthetic.get_binmap(cfg.model.concat_dim, cfg.data.binmode)
    print(inv_bm)
    samples = synthetic.bin2float(samples.astype(np.int32), inv_bm, cfg.model.concat_dim, cfg.data.int_scale)

 
    saving_plot_path = os.path.join(path, f"{cfg.model.name}{state['n_iter']}_{cfg.sampler.name}{cfg.sampler.num_steps}.png")
    saving_np_path = os.path.join(path, f"samples_{cfg.model.name}{state['n_iter']}_{cfg.sampler.name}{cfg.sampler.num_steps}.npy")
    #np.save(f'{saving_np_path}', samples)

    #aving_plot_path = '/Users/paulheller/PythonRepositories/Master-Thesis/ContTimeDiscreteSpace/TAUnSDDM/SavedModels/MNIST/'
    synthetic.plot_samples(samples, 'hollow_crm_exact1000.pdf', im_size=cfg.data.plot_size, im_fmt="pdf")

In [11]:
#saved_samples = np.load('mazes3000_auxprotein1_lbjf.npy')
correct_mazes = maze_acc(saved_samples)

Accuracy: From 1000 are 84.2% solvable.
Average path length: 35.541567695961994 and prob 15.796252309316442%
Average wall length: 126.02256532066508 and prob 56.01002903140671%
Average way length: 63.43586698337292 and prob 28.19371865927685%


In [12]:
cfg.data.name = 'Maze3SComplete'
cfg.data.batch_size = n_samples

if cfg.data.name == 'Maze3SComplete':
    limit = cfg.data.batch_size
    cfg.data.limit = limit 

dataset = dataset_utils.get_dataset(cfg, device)
dataloader = torch.utils.data.DataLoader(dataset,
    batch_size=cfg.data.batch_size,
    shuffle=cfg.data.shuffle)

for i in dataloader:
    true_dl = i
    c_i = maze_acc(i.cpu().numpy())
    true_dl = true_dl.reshape(cfg.data.batch_size, -1) #.flatten()

1000 samples generated.
Accuracy: From 1000 are 100.0% solvable.
Average path length: 38.642 and prob 17.174222222222223%
Average wall length: 126.0 and prob 56.0%
Average way length: 60.358 and prob 26.825777777777773%


In [13]:
from scipy.stats import wasserstein_distance
#samples = np.load('mazes2000_hollow_aux_lbjf.npy')
#samples = np.load('Samples/Maze/mazes_hollow_CRMTauL1000.npy')
saved_samples = samples
samples = samples.reshape(-1,225) #.flatten()
samples = samples[:n_samples]


emd_dist = []
correct_mazes = maze_acc(saved_samples)
for i in range(samples.shape[0]):
    m_g = samples[i, :]
    emd = wasserstein_distance(samples[i, :], true_dl[i, :].cpu().numpy())
    emd_dist.append(emd)
    #print(i, emd_dist)
#print("EMD", wasserstein_distance(samples, true_dl.cpu().numpy()))
#print("EMD", emd_dist)
print("AVG", np.mean(emd_dist))

Accuracy: From 1000 are 84.2% solvable.
Average path length: 35.541567695961994 and prob 15.796252309316442%
Average wall length: 126.02256532066508 and prob 56.01002903140671%
Average way length: 63.43586698337292 and prob 28.19371865927685%
AVG 0.06582666666666666
