In [None]:
import sys
sys.path.append('/path/to/src')

import torch
from tqdm import tqdm
import time, os, json, pickle

from src.train.util import *
from src.models.unet import set_unet
from src.models.diffusion import Diffusion
from src.utils.aux import unscale_tensor

from pprint import pprint

# Load models

In [None]:
fname = '//path/to/sampling/config.json'
with open(fname, 'r') as f:
    config = json.load(f)

# Set single GPU
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Set the UNet
unet_self_cond = config['unet_config'].get('self_condition', False)
unet = set_unet(config['unet_config'])
unet = unet.to(device)


# Set the AutoEncoder model
flag_compile = False
if 'vq' in config['autoencoder']['type']:
    autoencoder, autoencoder_eager = prepare_vqmodel(config, device, flag_compile, 'autoencoder')
    encoder_tanh = config['autoencoder'].get('encode_tanh_out', False)
    vq_model = True
elif 'vae' in config['autoencoder']['type']:
    autoencoder, autoencoder_eager = prepare_vaemodel(config, device, flag_compile, config_key = 'autoencoder')
    vq_model = False

In [None]:
# Diffusion
print('Setting the DDPM class')
sampling_batch = config['sampling'].get('sampling_batch', 4)
grid_rows = config['sampling'].get('grid_rows', 2)
eta = config['sampling'].get('eta', 1.0)
noise_dict = config['noise']
timesteps = config['diffusion']['timesteps']
ddim_skip = config['diffusion']['skip']
loss_type = config['diffusion']['loss']
diffusion = Diffusion(noise_dict, unet, timesteps,
                 loss=loss_type,
                 sample_every = ddim_skip,
                 device=device)

# Map code class actuall style

This is not necessary step, but helps to know which class code means what.

In [None]:
from src.datasets.artbench import im_dataset
from torchvision import utils
import torchvision

In [None]:
t2img = torchvision.transforms.ToPILImage()

image_size = config['dataset']['image_size']
root = config['dataset']['location']
img_resize = config['dataset']['img_resize']
dataset = im_dataset(root, resize=img_resize, image_size=image_size, flip_prob=0)

styles = {}
for i, _name in enumerate(dataset.classes):
    styles[i] = _name

num_classes = len(dataset.classes)
print(f'Num classes: {num_classes}')

pprint(styles)

# Sample

In [None]:
from src.utils.aux import save_grid_imgs

In [None]:
# Sampling parameters
nrow = 4 # will be used to generate image grid
sampling_batch = 16 # How many images
sampling_size = (64, 64) # Latens space size. Output will be multiplied by VAE scaling
latent_ch = 4 # Latent space dimensionality

# Generate a tuple with shape of sampled batch
sampling_size = (sampling_batch, latent_ch, sampling_size[0], sampling_size[1])

In [None]:
# Random styles
rand_sample_lbls = torch.randint(low = 0, high = num_classes-1, size = (sampling_batch, )).to(device)
for x in rand_sample_lbls:
    print(f'{x.item()}: {styles[x.item()]}')

In [None]:
# Specific styles
style_code = 6
print(f'Style: {styles[style_code]}')
det_style_lbls = style_code*torch.ones(size = (sampling_batch, ), dtype = int).to(device)
style =  styles[style_code]

In [None]:
# Uncomment depending on what you are interested in

sample_lbls = det_style_lbls
#sample_lbls = rand_sample_lbls

eta = 0.8
with torch.cuda.amp.autocast(dtype=torch.bfloat16) and torch.no_grad():
    samples = diffusion.p_sample(sampling_size,
                                 x_self_cond=unet_self_cond,
                                 classes=sample_lbls,
                                 last=True, eta=eta)
    Y = autoencoder.decode(samples.to(device)/autoencoder.scaling_factor)


all_images = unscale_tensor(Y)
grid_img = utils.make_grid(all_images.to('cpu'), nrow = nrow)

# Display the grid
t2img(grid_img)

In [None]:
# Chane the filename as you like
fname = 'sample.jpg'
save_grid_imgs(all_images, 4, fname)

In [None]:
# empty cuda cache
try:
    del Y, samples, all_images
except Exception as e:
    print(e)
torch.cuda.empty_cache()

### Many images

In [None]:
sample_lbls = det_style_lbls
#sample_lbls = rand_sample_lbls

eta = 0.5
images = []
with torch.cuda.amp.autocast(dtype=torch.bfloat16) and torch.no_grad():
    samples = diffusion.p_sample(sampling_size,
                                 x_self_cond=unet_self_cond,
                                 classes=sample_lbls,
                                 last=False, eta=eta)
    
for sample in samples[0]:
    Y = autoencoder.decode(sample.to(device)/autoencoder.scaling_factor)

    all_images = unscale_tensor(Y)
    grid_img = utils.make_grid(all_images.to('cpu'), nrow = nrow)
    
    images.append(t2img(grid_img))

In [None]:
import imageio
from PIL import Image

In [None]:
# Save to GIF using imageio (has to be installed)
imageio.mimsave('sample.gif', images)

In [None]:
# Save using PIL
images[0].save("sample_pillow.gif", save_all=True, append_images=images[1:], optimize=True, duration=40, loop=0)

In [None]:
# Another way to save the GIF

x = 512
y = 512
q = 50 # Quality
fp_out = 'sample_4x4_rand.gif'
img, *imgs = [image.resize((x,y), Image.LANCZOS) for image in images] 
img.save(fp=fp_out, format='GIF', append_images=imgs,quality=q, 
         save_all=True, duration=40, loop=0, optimize=True)