In [1]:
import torch
import lib.utils.bookkeeping as bookkeeping
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import lib.models.models as models
import lib.models.model_utils as model_utils
from lib.datasets import mnist, maze, protein, synthetic
import lib.datasets.dataset_utils as dataset_utils
import lib.losses.losses as losses
import lib.losses.losses_utils as losses_utils
import lib.training.training as training
import lib.training.training_utils as training_utils
import lib.optimizers.optimizers as optimizers
import lib.optimizers.optimizers_utils as optimizers_utils
import lib.sampling.sampling as sampling
import lib.sampling.sampling_utils as sampling_utils
import os
from tqdm import tqdm
from lib.datasets.maze import maze_acc
from ruamel.yaml.scalarfloat import ScalarFloat

In [2]:

path = 'SavedModels/MNIST/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-25' # 2
config_name = 'config_001_unet.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_599999_unet14Mlogits.pt' #  model_299999_hollowMLEProb.pt
config_path = os.path.join(path, date, config_name)
checkpoint_path = os.path.join(path, date, model_name)

In [12]:
# creating models
cfg = bookkeeping.load_config(config_path)
cfg.sampler.name = 'ElboTauL' #ExactSampling' # ElboLBJF CRMTauL CRMLBJF
cfg.sampler.num_corrector_steps = 0
cfg.sampler.corrector_entry_time = ScalarFloat(0.0)
cfg.sampler.num_steps = 500
cfg.sampler.is_ordinal = False
cfg.training.n_iters = 600000 
cfg.sampler.sample_freq = 100000000
cfg.saving.checkpoint_freq = 10000
cfg.sampler.num_steps = 1000

#print(cfg)
device = torch.device(cfg.device)

model = model_utils.create_model(cfg, device)
print("number of parameters: ", sum([p.numel() for p in model.parameters()]))
optimizer = optimizers_utils.get_optimizer(model.parameters(), cfg)

sampler = sampling_utils.get_sampler(cfg)

state = {"model": model, "optimizer": optimizer, "n_iter": 0}
state = bookkeeping.load_state(state, checkpoint_path)
state['model'].eval()

loss = losses_utils.get_loss(cfg)
training_step = training_utils.get_train_step(cfg)

number of parameters:  8040834
ema state dict function


In [13]:
dataset = dataset_utils.get_dataset(cfg, device)
dataloader = torch.utils.data.DataLoader(dataset,
    batch_size=cfg.data.batch_size,
    shuffle=cfg.data.shuffle)


In [14]:
training_loss = []
exit_flag = False
while True:
    for minibatch in tqdm(dataloader):
        minibatch = minibatch.to(device)
        l = training_step.step(state, minibatch.long(), loss)

        training_loss.append(l.item())

 

torch.Size([100, 1, 15, 15])


In [15]:
n_samples = 100
samples, changes = sampler.sample(model, n_samples)
saved_samples = samples

In [16]:
correct_mazes = maze_acc(saved_samples)

5
12
16
23
24
38
43
51
58
64
69
82
Accuracy: From 100 are 88.0% solvable.


In [None]:
cfg.data.name = 'Maze3SComplete'
cfg.data.batch_size = 100

if cfg.data.name == 'Maze3SComplete':
    limit = cfg.data.batch_size
    cfg.data.limit = limit 

dataset = dataset_utils.get_dataset(cfg, device)
dataloader = torch.utils.data.DataLoader(dataset,
    batch_size=cfg.data.batch_size,
    shuffle=cfg.data.shuffle)