# Setup

In [None]:
import torch
import torch.nn.functional as F
from pathlib import Path
import create_models

# Models

## Initialization

In [None]:
image_size = 64

mean_variance = create_models.mean_variance(
    image_size=image_size,
    num_channels=128,
    num_res_blocks=3,
    channel_mult="4",
    learn_sigma=True,
    use_checkpoint=False,
    attention_resolutions="16,8",
    num_heads=4,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=True,
    dropout=0.3,
    resblock_updown=False,
    use_fp16=True,
    use_new_attention_order=False,
)
diffusion = create_models.gaussian_diffusion(
    steps=1000,
    learn_sigma=True,
    sigma_small=False,
    noise_schedule="cosine",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
    timestep_respacing="100",
)
regressor = create_models.regressor(
    image_size=image_size,
    in_channels=1 + 3 + 2 + 2,
    regressor_use_fp16=False,
    regressor_width=128,
    regressor_depth=4,
    regressor_attention_resolutions="32,16,8",
    regressor_use_scale_shift_norm=True,
    regressor_resblock_updown=True,
    regressor_pool="spatial",
)
classifier = create_models.classifier(
    image_size=image_size,
    in_channels=1,
    classifier_use_fp16=False,
    classifier_width=128,
    classifier_depth=2,
    classifier_attention_resolutions="32,16,8",
    classifier_use_scale_shift_norm=True,
    classifier_resblock_updown=True,
    classifier_pool="attention",
)

## Tuning

In [None]:
checkpoints_path = Path("./checkpoints")
cpu = torch.cpu.current_device()
gpu = torch.cuda.current_device()

mean_variance_path = checkpoints_path / "diff_checkpoints" / "model_180000.pt"
mean_variance.load_state_dict(torch.load(mean_variance_path, map_location=cpu))
mean_variance.to(gpu)
mean_variance.convert_to_fp16()
mean_variance.eval()

regressor_path = checkpoints_path / "reg_checkpoint" / "model_350000.pt"
regressor.load_state_dict(torch.load(regressor_path, map_location=cpu))
regressor.to(gpu)
regressor.eval()

classifier_path = checkpoints_path / "class_checkpoint" / "model_299999.pt"
classifier.load_state_dict(torch.load(classifier_path, map_location=cpu))
classifier.to(gpu)
classifier.eval()

# Data

# Sampling

In [None]:
# (batch size, channel count, height, width)
shape = (1, 1, image_size, image_size)

def cond_fn_1(x: torch.Tensor, time_steps: torch.Tensor):
    with torch.enable_grad():
        x_in = x.detach().requires_grad_()
        logits = regressor(x_in, time_steps)
        grad = torch.autograd.grad(logits.sum(), x_in)[0]
        return (-1) * grad[:,0,:,:].reshape(shape) * 4.0

def cond_fn_2(x: torch.Tensor, time_steps: torch.Tensor):
    with torch.enable_grad():
        y = np.array([1], dtype = int)
        x_in = x.detach().requires_grad_()
        logits = classifier(x_in, time_steps)
        log_probabilities = F.log_softmax(logits, dim=-1)
        selected = log_probabilities[range(len(logits)), y.view()]
        grad = torch.autograd.grad(selected.sum(), x_in)[0]
        return grad[:,0,:,:].reshape(shape) * 3.0

sample = diffusion.p_sample_loop(
    model=mean_variance,
    shape=(1, 1, image_size, image_size),
    cons,
    loads,
    BCs,
    noise=None,
    clip_denoised=True,
    denoised_fn=None,
    cond_fn_1=None,
    cond_fn_2=None,
    model_kwargs=None,
    device=gpu,
    progress=False,
)