<a href="https://colab.research.google.com/github/pollinations/hive/blob/main/interesting_notebooks/CLIP_Conditioned_CLIP_Guided_Diffusion_(cc12m_1%2C_256x256).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [CLIP-Conditioned CLIP-Guided Diffusion (cc12m_1, 256x256)](https://github.com/crowsonkb/v-diffusion-pytorch)

By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings)

and JD Pressman (https://twitter.com/jd_pressman).

Notebook by BoneAmputee (https://twitter.com/BoneAmputee).

# Setup

In [None]:
!pip install ftfy
%cd /content/
!git clone https://github.com/crowsonkb/v-diffusion-pytorch.git
%cd v-diffusion-pytorch
!git clone https://github.com/openai/CLIP.git
%mkdir -p checkpoints
%mkdir -p frames
!curl -L "https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1.pth" > "checkpoints/cc12m_1.pth"

# Run

In [None]:
#!/usr/bin/env python3

"""CLIP guided sampling from a diffusion model."""

import argparse
from pathlib import Path

from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import trange
from IPython import display
from shutil import rmtree
import os

from CLIP import clip
from diffusion import get_model, get_models, utils

MODULE_DIR = Path("/content/v-diffusion-pytorch/").resolve()


@torch.no_grad()
def sample(model, x, steps, check_in, eta, extra_args):
    """Draws samples from a model given starting noise."""
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    alphas, sigmas = utils.t_to_alpha_sigma(steps)

    # The sampling loop
    for i in trange(len(steps)):

        # Get the model output (v, the predicted velocity)
        with torch.cuda.amp.autocast():
            v = model(x, ts * steps[i], **extra_args).float()

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        if i % check_in == 0:
          outfile = f'frames/{str(i).zfill(4)}.png'
          utils.to_pil_image(pred).save(outfile)
          display.display(display.Image(outfile))

        # If we are not on the last timestep, compute the noisy image for the
        # next timestep.
        if i < len(steps) - 1:
            # If eta > 0, adjust the scaling factor for the predicted noise
            # downward according to the amount of additional noise to add
            ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
                (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
            adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()

            # Recombine the predicted noise and predicted denoised image in the
            # correct proportions for the next step
            x = pred * alphas[i + 1] + eps * adjusted_sigma

            # Add the correct amount of fresh noise
            if eta:
                x += torch.randn_like(x) * ddim_sigma

    # If we are on the last timestep, output the denoised image
    return pred


@torch.no_grad()
def cond_sample(model, x, steps, check_in, eta, extra_args, cond_fn):
    """Draws guided samples from a model given starting noise."""
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    alphas, sigmas = utils.t_to_alpha_sigma(steps)

    # The sampling loop
    for i in trange(len(steps)):

        # Get the model output
        with torch.enable_grad():
            x = x.detach().requires_grad_()
            with torch.cuda.amp.autocast():
                v = model(x, ts * steps[i], **extra_args)

            if steps[i] < 1:
                pred = x * alphas[i] - v * sigmas[i]
                if i % check_in == 0:
                  outfile = f'frames/{str(i).zfill(4)}.png'
                  utils.to_pil_image(pred).save(outfile)
                  display.display(display.Image(outfile))
                cond_grad = cond_fn(x, ts * steps[i], pred, **extra_args).detach()
                v = v.detach() - cond_grad * (sigmas[i] / alphas[i])
            else:
                v = v.detach()

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        # If we are not on the last timestep, compute the noisy image for the
        # next timestep.
        if i < len(steps) - 1:
            # If eta > 0, adjust the scaling factor for the predicted noise
            # downward according to the amount of additional noise to add
            ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
                (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
            adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()

            # Recombine the predicted noise and predicted denoised image in the
            # correct proportions for the next step
            x = pred * alphas[i + 1] + eps * adjusted_sigma

            # Add the correct amount of fresh noise
            if eta:
                x += torch.randn_like(x) * ddim_sigma

    # If we are on the last timestep, output the denoised image
    return pred


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
            cutouts.append(cutout)
        return torch.cat(cutouts)


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', '1'][len(vals):]
    return vals[0], float(vals[1])


def main():

    #@markdown `prompts`: the text prompts to use. Relative weights for text prompts can be specified by putting the weight after a colon. The vertical bar character can be used to denote multiple prompts.
    prompts = "an armchair in the shape of an avocado|conceptual art" #@param {type:"string"}
    #@markdown `batch_size`: sample this many images at a time (default 1)
    batch_size = 1 #@param {type:"integer"}
    #@markdown `checkpoint`: manually specify the model checkpoint file
    checkpoint = "" #@param {type:"string"}
    #@markdown `clip_guidance_scale`: how strongly the result should match the text prompt (default 500). If set to 0, the cc12m_1 model will still be CLIP conditioned and sampling will go faster and use less memory.
    clip_guidance_scale = 500 #@param {type:"number"}
    #@markdown `device`: the PyTorch device name to use (default autodetects)
    device = "cuda:0" #@param {type:"string"}
    #@markdown `eta`: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps.
    eta = 1.0 #@param {type:"number"}
    #@markdown `images`: the image prompts to use (local files or HTTP(S) URLs). Relative weights for image prompts can be specified by putting the weight after a colon, for example: `"image_1.png:0.5"`.
    images = "" #@param {type:"string"}
    #@markdown `model`: specify the model to use (default cc12m_1)
    model = "cc12m_1" #@param {type:"string"}
    #@markdown `n`: sample until this many images are sampled (default 1)
    n = 1 #@param {type:"integer"}
    #@markdown `seed`: specify the random seed (default 0)
    seed = 0 #@param {type:"integer"}
    #@markdown `steps`: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)
    steps = 1000 #@param {type:"integer"}
    #@markdown `check_in`: specify the number of steps between each image update
    check_in = 100 #@param {type:"integer"}
    #@markdown `cutn`: specify the number of cuts to observe when guiding
    cutn = 16 #@param {type:"integer"}
    #@markdown `cut_pow`: specify the cut power
    cut_pow = 1.0 #@param {type:"number"}
    #@markdown `width`: specify the width
    width = 256 #@param {type:"integer"}
    #@markdown `height`: specify the height
    height = 256 #@param {type:"integer"}

    prompts = [x.strip() for x in prompts.split('|')]
    prompts = [x for x in prompts if x != '']
    images = [x.strip() for x in images.split('|')]
    images = [x for x in images if x != '']

    args = argparse.Namespace(
      prompts = prompts,
      batch_size = batch_size,
      checkpoint = checkpoint,
      clip_guidance_scale = clip_guidance_scale,
      device = device,
      eta = eta,
      images = images,
      model = model,
      n = n,
      seed = seed,
      steps = steps,
      check_in = check_in,
      cutn = cutn,
      cut_pow = cut_pow,
      width = width,
      height = height
    )

    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    model = get_model(args.model)()
    # _, side_y, side_x = model.shape
    side_y, side_x = (args.height//64)*64, (args.width//64)*64
    checkpoint = args.checkpoint
    if not checkpoint:
        checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'
    model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
    if device.type == 'cuda':
        model = model.half()
    model = model.to(device).eval().requires_grad_(False)
    clip_model = clip.load(model.clip_model, jit=False, device=device)[0]
    clip_model.eval().requires_grad_(False)
    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])
    cutn = args.cutn
    make_cutouts = MakeCutouts(clip_model.visual.input_resolution, cutn=cutn, cut_pow=args.cut_pow)

    target_embeds, weights = [], []

    for prompt in args.prompts:
        txt, weight = parse_prompt(prompt)
        target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
        weights.append(weight)

    for prompt in args.images:
        path, weight = parse_prompt(prompt)
        img = Image.open(utils.fetch(path)).convert('RGB')
        img = TF.resize(img, min(side_x, side_y, *img.size),
                        transforms.InterpolationMode.LANCZOS)
        batch = make_cutouts(TF.to_tensor(img)[None].to(device))
        embeds = F.normalize(clip_model.encode_image(normalize(batch)).float(), dim=-1)
        target_embeds.append(embeds)
        weights.extend([weight / cutn] * cutn)

    if not target_embeds:
        raise RuntimeError('At least one text or image prompt must be specified.')
    target_embeds = torch.cat(target_embeds)
    weights = torch.tensor(weights, device=device)
    if weights.sum().abs() < 1e-3:
        raise RuntimeError('The weights must not sum to 0.')
    weights /= weights.sum().abs()

    clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
    clip_embed = clip_embed.repeat([args.n, 1])

    torch.manual_seed(args.seed)

    def cond_fn(x, t, pred, clip_embed):
        clip_in = normalize(make_cutouts((pred + 1) / 2))
        image_embeds = clip_model.encode_image(clip_in).view([cutn, x.shape[0], -1])
        losses = spherical_dist_loss(image_embeds, clip_embed[None])
        loss = losses.mean(0).sum() * args.clip_guidance_scale
        grad = -torch.autograd.grad(loss, x)[0]
        return grad

    def run(x, clip_embed):
        t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
        steps = utils.get_spliced_ddpm_cosine_schedule(t)
        extra_args = {'clip_embed': clip_embed}
        if not args.clip_guidance_scale:
            return sample(model, x, steps, args.check_in, args.eta, extra_args)
        return cond_sample(model, x, steps, args.check_in, args.eta, extra_args, cond_fn)

    def run_all(n, batch_size):
        x = torch.randn([args.n, 3, side_y, side_x], device=device)
        for i in trange(0, n, batch_size):
            cur_batch_size = min(n - i, batch_size)
            outs = run(x[i:i+cur_batch_size], clip_embed[i:i+cur_batch_size])
            for j, out in enumerate(outs):
                outfile = f'out_{i + j:05}.png'
                utils.to_pil_image(out).save(outfile)
                display.display(display.Image(outfile))

    try:
        run_all(args.n, args.batch_size)
    except KeyboardInterrupt:
        pass


if __name__ == '__main__':
    if os.path.exists("frames"):
      rmtree("frames")
    os.makedirs("frames")
    main()
