# Class-Conditional Synthesis with Latent Diffusion Models

Install all the requirements

Load it.

In [1]:
%cd /home/panzy/latent-diffusion

/home/panzy/latent-diffusion


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '7'


In [3]:
#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load("/n/owens-data1/mnt/big2/data/panzy/latent/imgnet/cin256-v2.yaml")  
    model = load_model_from_config(config, "/n/owens-data1/mnt/big2/data/panzy/latent/imgnet/model.ckpt")
    return model

In [4]:
from ldm.models.diffusion.ddim import DDIMSampler

model = get_model()
sampler = DDIMSampler(model)

Loading model from /n/owens-data1/mnt/big2/data/panzy/latent/imgnet/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 400.92 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels


And go. Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` variables. As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity. Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250. Fast sampling (i e. low values of `ddim_steps`) while retaining good quality can be achieved by using `ddim_eta = 0.0`.

In [None]:
# import numpy as np 
# from PIL import Image
# from einops import rearrange
# from torchvision.utils import make_grid


# classes = [847, 895]   # define classes to be sampled here
# n_samples_per_class = 10

# ddim_steps = 20
# ddim_eta = 0.0
# scale = 3.0   # for unconditional guidance


# all_samples = list()

# with torch.no_grad():
#     with model.ema_scope():
#         uc = model.get_learned_conditioning(
#             {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
#             )
        
#         for class_label in classes:
#             print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
#             xc = torch.tensor(n_samples_per_class*[class_label])
#             c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
            
#             samples_ddim, _ = sampler.sample(S=ddim_steps,
#                                              conditioning=c,
#                                              batch_size=n_samples_per_class,
#                                              shape=[3, 64, 64],
#                                              verbose=False,
#                                              unconditional_guidance_scale=scale,
#                                              unconditional_conditioning=uc, 
#                                              eta=ddim_eta)

#             x_samples_ddim = model.decode_first_stage(samples_ddim)
#             x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
#                                          min=0.0, max=1.0)
#             all_samples.append(x_samples_ddim)


# # display as grid
# grid = torch.stack(all_samples, 0)
# grid = rearrange(grid, 'n b c h w -> (n b) c h w')
# grid = make_grid(grid, nrow=n_samples_per_class)

# # to image
# grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
# Image.fromarray(grid.astype(np.uint8))

In [6]:
!mkdir /home/panzy/latent-diffusion/outputs/full_dataset

In [6]:
import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid

img_root = '/n/owens-data1/mnt/big2/data/panzy/diffusion_fake_file/latent_diffusion'
exp_name = 'full_dataset'
classes = [847, 895]   # define classes to be sampled here
n_samples_per_class = 2500
batch_size = 10
num_iter = np.ceil(n_samples_per_class / batch_size).astype(np.int32) - 141
num_iter = 1

ddim_steps = 20
ddim_eta = 0.0
scale = 3.0   # for unconditional guidance


all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(batch_size*[1000]).to(model.device)}
            )
        for iter in range(num_iter):
            samples = []
            for class_label in classes:
                xc = torch.tensor(batch_size*[class_label])
                c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
                
                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                conditioning=c,
                                                batch_size=batch_size,
                                                shape=[3, 64, 64],
                                                verbose=False,
                                                unconditional_guidance_scale=scale,
                                                unconditional_conditioning=uc, 
                                                eta=ddim_eta)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                            min=0.0, max=1.0)
                samples.append(x_samples_ddim)
            
            samples = torch.stack(samples)
            samples = samples.reshape(-1, samples.shape[-3], samples.shape[-2], samples.shape[-1])
            filename = os.path.join(img_root, exp_name, 'iter_140.pt')
            torch.save(samples, filename)


Data shape for DDIM sampling is (10, 3, 64, 64), eta 0.0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:07<00:00,  2.72it/s]


Data shape for DDIM sampling is (10, 3, 64, 64), eta 0.0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:07<00:00,  2.72it/s]


In [8]:
!mkdir latent

In [26]:
!rm -rf latent

In [23]:
all_images = torch.stack(all_samples)
all_images = all_images.reshape(-1, all_images.shape[-3], all_images.shape[-2], all_images.shape[-1])

In [24]:
# import cv2
# for i in range(all_images.shape[0]):
#   image = all_images[i].squeeze().flip([0])
#   cv2.imwrite('latent/'+str(i).zfill(2)+'.jpg', 255*image.permute(1,2,0).cpu().detach().numpy())

In [27]:
torch.save(all_images, 'img.pt')