In [1]:
import argparse
import os.path
import torch
from lib.models.ddsm import noise_factory
from lib.config.config_mnist import get_config
from lib.sampling.sampling_utils import importance_sampling
import torch
import numpy as np
from functools import partial
import os
from lib.datasets.datasets import get_mnist_dataset
from lib.sampling.sampling_utils import importance_sampling
# Main file which contrains all DDSM logic
from lib.models.ddsm import *

config = get_config()

if not os.path.exists(config.noise_sample.out_path):
    os.makedirs(config.noise_sample.out_path)
elif not os.path.isdir(config.noise_sample.out_path):
    print(f"{config.out_path} is already exists and it is not a directory")
    exit(1)

str_speed = ".speed_balance" if config.noise_sample.speed_balance  else ""
filename = (
    f"mnist_steps{config.noise_sample.n_time_steps}.cat{config.noise_sample.num_cat}{str_speed}.time{config.noise_sample.max_time}."
    f"samples{config.noise_sample.n_samples}"
)
filepath = os.path.join(config.noise_sample.out_path, filename + ".pth")

if os.path.exists(filepath):
    print("File is already exists.")
    exit(1)

torch.set_default_dtype(torch.float64)

alpha = torch.ones(config.data.num_cat - 1)
beta = torch.arange(config.data.num_cat - 1, 0, -1)
print("n_samples", config.noise_sample.n_samples, config.noise_sample.num_cat)
v_one, v_zero, v_one_loggrad, v_zero_loggrad, timepoints = noise_factory(
    config.noise_sample.n_samples,
    config.noise_sample.n_time_steps,
    alpha,
    beta,
    total_time=config.noise_sample.max_time,
    order=config.noise_sample.order,
    time_steps=config.noise_sample.steps_per_tick,
    logspace=config.noise_sample.logspace,
    speed_balanced=config.noise_sample.speed_balance,
    mode=config.noise_sample.mode,
)

v_one = v_one.cpu()
v_zero = v_zero.cpu()
v_one_loggrad = v_one_loggrad.cpu()
v_zero_loggrad = v_zero_loggrad.cpu()
timepoints = torch.FloatTensor(timepoints)

torch.save((v_one, v_zero, v_one_loggrad, v_zero_loggrad, timepoints), filepath)


ModuleNotFoundError: No module named 'cooltools'

In [None]:
sb = UnitStickBreakingTransform()
if config.use_fast_diff:
    diffuser_func = partial(
        diffusion_fast_flatdirichlet,
        noise_factory_one=v_one,
        v_one_loggrad=v_one_loggrad,
    )
else:
    diffuser_func = partial(
        diffusion_factory,
        noise_factory_one=v_one,
        noise_factory_zero=v_zero,
        noise_factory_one_loggrad=v_one_loggrad,
        noise_factory_zero_loggrad=v_zero_loggrad,
        alpha=alpha,
        beta=beta,
        device=config.device,
    )


if config.speed_balanced:
    s = 2 / (
        torch.ones(config.data.num_cat - 1, device=config.device)
        + torch.arange(config.data.num_cat - 1, 0, -1, device=config.device).float()
    )
else:
    s = torch.ones(config.data.num_cat - 1, device=config.device)

if not os.path.exists(config.saving.time_dep_weights_path):
    os.makedirs(config.saving.time_dep_weights_path)
str_speed = ".speed_balance" if config.speed_balanced  else ""
str_random_order = ".random_order" if config.random_order else ""
filename = (f"time_depend_weights_steps{config.n_time_steps}.cat{config.data.num_cat}{str_speed}{str_random_order}")
filepath = os.path.join(config.saving.time_dep_weights_path, filename + ".pth")



train_dataloader, valid_dataloader, test_dataloader = get_mnist_dataset(config)
time_dependent_weights = importance_sampling(config, train_dataloader,  diffuser_func, sb, s)
torch.save(time_dependent_weights, filepath)