In [None]:
!pip install git+https://github.com/fbcotter/pytorch_wavelets PyWavelets open_clip_torch

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import open_clip
from glob import glob
from PIL import Image
import numpy as np

from arroz import Diffuzz, VQModel, DiffusionModel, PriorModel, to_latent, from_latent

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
generator_checkpoint_path = "checkpoints/clip2img_v1_288k_ema.pt"
prior_checkpoint_path = "checkpoints/prior_v1_352k_ema.pt"
vqgan_checkpoint_path = "checkpoints/vqwatercolor_v1.pt"
clip_model_name = ('ViT-H-14', 'laion2b_s32b_b79k')

# Prepare Models & Tools 
Running this cell might take a while as every model is loaded into memory.
We need:
- **diffuzz**: To handle the sampling process
- **vqmodel**: To decode the sampled image latents
- **clip_model**: To generate the CLIP text embeddings from our prompt
- **generator**: To convert our sampled CLIP image embeddings into image latents
- **prior**: To sample CLIP image embeddings from CLIP text embeddings 

In [None]:
# - diffusion tool -
diffuzz = Diffuzz(device=device)

# - vqgan -
vqmodel = VQModel().to(device)
vqmodel.load_state_dict(torch.load(vqgan_checkpoint_path, map_location=device))
vqmodel.eval().requires_grad_(False)

# - openclip -
clip_model, _, _ = open_clip.create_model_and_transforms(clip_model_name[0], pretrained=clip_model_name[1],  device=device)
clip_model.eval().requires_grad_(False)
clip_tokenizer = open_clip.get_tokenizer(clip_model_name[0])

# - diffusion models - 
generator = DiffusionModel().to(device)
generator.load_state_dict(torch.load(generator_checkpoint_path, map_location=device))
generator.eval().requires_grad_(False)

prior = PriorModel().to(device)
prior.load_state_dict(torch.load(prior_checkpoint_path, map_location=device))
prior.eval().requires_grad_(False)

pass

# Text2Image Sampling
This requires 2 sampling stages:
1. Sampling the CLIP image embedding from the CLIP text embedding using the prior
2. Sampling image latents embedding from the CLIP image embedding using the generator

Then the latents are decoded into an image using the vqGAN decoder

In [None]:
prompt = 'Closeup studio photography of an old vietnamese woman'
batch_size = 4

prior_timesteps = 60
prior_cfg = 3.0
prior_sampler = 'ddpm'

clip2img_timesteps = 20
clip2img_cfg = 7.0
clip2img_sampler = 'ddim'

clip_embedding_shape = (batch_size, 1024)
image_latent_shape = (batch_size, 4, 64, 64)

with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        with torch.random.fork_rng():
            torch.manual_seed(42) # For reproducibility
            
            # prompt to CLIP embeddings
            captions = [prompt] * batch_size
            captions = clip_tokenizer(captions).to(device)
            text_embeddings = clip_model.encode_text(captions).float()
            
            # sample image embedding with prior
            sampled_image_embeddings = diffuzz.sample(
                prior, {'c': text_embeddings}, clip_embedding_shape,
                timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
            )[-1]
            
            # sample image latents
            sampled = diffuzz.sample(
                generator, {'c': sampled_image_embeddings}, image_latent_shape,
                timesteps=clip2img_timesteps, cfg=clip2img_cfg, sampler=clip2img_sampler
            )[-1]
            
            # decode sampled latents
            sampled_images = from_latent(sampled, vqmodel).clamp(0, 1)
            
plt.figure(figsize=(32, 32))
plt.axis("off")
plt.imshow(torch.cat([
    torch.cat([i for i in sampled_images.cpu()], dim=-1)
], dim=-2).permute(1, 2, 0))
plt.show()