In [None]:
import torch
import lib.utils.bookkeeping as bookkeeping
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import lib.models.models as models
import lib.models.model_utils as model_utils
from lib.datasets import mnist, maze, protein, synthetic
import lib.datasets.dataset_utils as dataset_utils
import lib.losses.losses as losses
import lib.losses.losses_utils as losses_utils
import lib.training.training as training
import lib.training.training_utils as training_utils
import lib.optimizers.optimizers as optimizers
import lib.optimizers.optimizers_utils as optimizers_utils
from lib.d3pm import make_diffusion
import os
from lib.datasets.maze import maze_acc


In [None]:
# creating path
path = "SavedModels/MNIST/"
date = '2024-02-07' # 2
config_name = 'config_001.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_name.pt'

config_path = os.path.join(path, date, config_name)
checkpoint_path = os.path.join(path, date, model_name)

In [None]:
# creating models
cfg = bookkeeping.load_config(config_path)

diffusion = make_diffusion(cfg.model)
#print(cfg)
device = torch.device(cfg.device)
print(device)

model = model_utils.create_model(cfg, device)
print("number of parameters: ", sum([p.numel() for p in model.parameters()]))

optimizer = torch.optim.Adam(model.parameters(), cfg.optimizer.lr)

state = {"model": model, "optimizer": optimizer, "n_iter": 0}
state = bookkeeping.load_state(state, checkpoint_path, device)
state['model'].eval()


In [None]:
n_samples = 100
if cfg.data.name == 'Maze3S':
    shape = (n_samples, 1, 15, 15)
elif cfg.data.name == 'DiscreteMNIST':
    shape = (n_samples, 1, 28, 28)
elif cfg.data.name == 'SyntheticData':
    shape = (n_samples, 32)
else:
    raise ValueError("wrong")

samples = diffusion.p_sample_loop(state['model'], shape, cfg.model.num_timesteps).cpu().numpy()
saved_samples = samples

In [None]:
from lib.datasets.mnist_fid import evaluate_fid_score
data = np.load(f'sample_path.npy')
dataset_location = "lib/datasets"
fid_values = []
cfg.data.train = False
dataset = dataset_utils.get_dataset(cfg, device, dataset_location)

dataloader = torch.utils.data.DataLoader(dataset,
    batch_size=data.shape[0],
    shuffle=cfg.data.shuffle)
for true_data in (dataloader):
    #print(f'mnist_hollow_{sampler_n}{step}.npy')
    print("----------------------------------")

    fid = evaluate_fid_score(data, true_data.cpu().numpy(), 100)
    print("FID:", fid)
    fid_values.append(fid)
    break
print(fid_values)


In [None]:
is_img = True
if is_img:
    samples = samples.reshape(-1, 1, cfg.data.image_size, cfg.data.image_size)
    saving_train_path = os.path.join(cfg.saving.sample_plot_path, f"{cfg.model.name}{state['n_iter']}_{cfg.sampler.name}{cfg.sampler.num_steps}.png")
    fig = plt.figure(figsize=(9, 9)) 
    for i in range(n_samples):
        plt.subplot(int(np.sqrt(n_samples)), int(np.sqrt(n_samples)), 1 + i)
        plt.axis("off")
        plt.imshow(np.transpose(samples[i, ...], (1,2,0)), cmap="gray")
 
    # saving_train_path
    plt.savefig('image_samples.pdf', transparent=True)
    plt.show()
    plt.close()
else:
    bm, inv_bm = synthetic.get_binmap(cfg.model.concat_dim, cfg.data.binmode)
    print(inv_bm)
    samples = synthetic.bin2float(samples.astype(np.int32), inv_bm, cfg.model.concat_dim, cfg.data.int_scale)

 
    saving_plot_path = os.path.join(path, f"{cfg.model.name}{state['n_iter']}_{cfg.sampler.name}{cfg.sampler.num_steps}.png")
    saving_np_path = os.path.join(path, f"samples_{cfg.model.name}{state['n_iter']}_{cfg.sampler.name}{cfg.sampler.num_steps}.npy")
    synthetic.plot_samples(samples, 'synthetic_samples.pdf', im_size=cfg.data.plot_size, im_fmt="pdf")

In [None]:
correct_mazes = maze_acc(saved_samples)

In [None]:
cfg.data.name = 'Maze3SComplete'
cfg.data.batch_size = n_samples

if cfg.data.name == 'Maze3SComplete':
    limit = cfg.data.batch_size
    cfg.data.limit = limit 

dataset = dataset_utils.get_dataset(cfg, device)
dataloader = torch.utils.data.DataLoader(dataset,
    batch_size=cfg.data.batch_size,
    shuffle=cfg.data.shuffle)

for i in dataloader:
    true_dl = i
    c_i = maze_acc(i.cpu().numpy())
    true_dl = true_dl.reshape(cfg.data.batch_size, -1) #.flatten()