By [Katherine Crowson](https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional [ImageNet diffusion model](https://github.com/openai/guided-diffusion) together with [CLIP](https://github.com/openai/CLIP) to connect text prompts with images.

Modified by [Daniel Russell](https://github.com/russelldc) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.

Creates images of 1024x1024px with a 4 x Superresolution step added by Thomash. Could be a little slower if turned on or run out of memory but it usually works.


##### Example: *The [Fermi Paradox](https://en.wikipedia.org/wiki/Fermi_paradox)*

[![The Fermi Paradox](https://pollinations.ai/ipfs/QmVzXi6oWygPuFKqYj3858E1nciXf8YrcM9EQURBEurYv5?filename=fermi.jpg)]()



In [None]:
# Text prompt
text_input = 'the fermi paradox' #@param {type: "string"}

# Perform 4x neural super-resolution (from 256x256px to 1024x124)
super_resolution = True   #@param {type: "boolean"}

output_path = "/content/output"


In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
#@title Upscale images/video frames


loaded_upscale_model = False

def upscale(filepath):
  global loaded_upscale_model
  if not super_resolution:
    return
  if not loaded_upscale_model:
    # Clone Real-ESRGAN and enter the Real-ESRGAN
    !git clone https://github.com/xinntao/Real-ESRGAN.git
    %cd /content/Real-ESRGAN
    # Set up the environment
    !pip install basicsr
    !pip install facexlib
    !pip install gfpgan
    !pip install -r requirements.txt
    !python setup.py develop
    # Download the pre-trained model
    !wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
    %cd -
    loaded_upscale_model = True 
  
  %cd /content/Real-ESRGAN
  !python inference_realesrgan.py --model_path experiments/pretrained_models/RealESRGAN_x4plus.pth --input $filepath --netscale 4 --outscale 4 --half --output $output_path
  filepath_out = filepath.replace(".jpg","_out.jpg")
  !mv -v $filepath_out $filepath
  %cd -

In [None]:
# Check the GPU status

!nvidia-smi

In [None]:
# Install dependencies

!git clone https://github.com/openai/CLIP
!git clone https://github.com/crowsonkb/guided-diffusion
!pip install -e ./CLIP
!pip install -e ./guided-diffusion
!pip install lpips kornia datetime

In [None]:
# Download the diffusion model
!sudo apt install aria2
!aria2c -x 5 --auto-file-renaming=false 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'

In [None]:
import torch
# Check the GPU status
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
!nvidia-smi

In [None]:
import gc
import io
import math
import sys
from IPython import display
import lpips
from PIL import Image, ImageOps
import requests
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm
sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from datetime import datetime
import kornia.color as KC
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
# Define necessary functions
# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869

def interp(t):
    return 3 * t**2 - 2 * t ** 3

def perlin(width, height, scale=10, device=None):
    gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
    xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
    ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
    wx = 1 - interp(xs)
    wy = 1 - interp(ys)
    dots = 0
    dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
    dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
    dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
    dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
    return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)

def perlin_ms(octaves, width, height, grayscale, device=device):
    out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
    # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
    for i in range(1 if grayscale else 3):
        scale = 2 ** len(octaves)
        oct_width = width
        oct_height = height
        for oct in octaves:
            p = perlin(oct_width, oct_height, scale, device)
            out_array[i] += p * oct
            scale //= 2
            oct_width *= 2
            oct_height *= 2
    return torch.cat(out_array)

def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):
    out = perlin_ms(octaves, width, height, grayscale)
    if grayscale:
        out = TF.resize(size=(side_x, side_y), img=out.unsqueeze(0))
        out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')
    else:
        out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])
        out = TF.resize(size=(side_x, side_y), img=out)
        out = TF.to_pil_image(out.clamp(0, 1).squeeze())

    out = ImageOps.autocontrast(out)
    return out

In [None]:
def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')


def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', '1'][len(vals):]
    return vals[0], float(vals[1])

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))

def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()

def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]

def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.reshape([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.reshape([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, skip_augs=False, grayscale_cuts=False):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.skip_augs = skip_augs
        self.grayscale_cuts = grayscale_cuts
        self.augs = T.Compose([
            T.RandomHorizontalFlip(p=1.0),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomPerspective(distortion_scale=0.4, p=0.7),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        ])

    def forward(self, input):
        input = T.Pad(input.shape[2]//5, fill=0)(input)
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)

        cutouts = []
        for ch in range(cutn):
            if ch > cutn - cutn//4:
                cutout = input.clone()
            else:
                size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
                offsetx = torch.randint(0, abs(sideX - size + 1), ())
                offsety = torch.randint(0, abs(sideY - size + 1), ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]

            if self.grayscale_cuts:
                cutout = KC.rgb_to_grayscale(cutout)
                cutout = KC.grayscale_to_rgb(cutout)

            if not self.skip_augs:
                cutout = self.augs(cutout)
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
            del cutout

        cutouts = torch.cat(cutouts, dim=0)
        return cutouts


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])


def range_loss(input):
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])


def unitwise_norm(x):
    if len(x.squeeze().shape) <= 1:
        dim = None
        keepdim = False
    elif len(x.shape) in (2, 3):
        dim = 1
        keepdim = True
    elif len(x.shape) == 4:
        dim = (1, 2, 3)
        keepdim = True
    else:
        raise ValueError(f'got a parameter with shape not in (1, 2, 3, 4) {x}')
    return x.norm(dim = dim, keepdim = keepdim, p = 2)


def adaptive_clip_grad(parameters, clipping = 0.01, eps = 1e-3):
    parameters = [p for p in parameters if p.grad is not None]
    if len(parameters) == 0:
        return
    for p in parameters:
        param_norm = unitwise_norm(p).clamp_(min = eps)
        grad_norm = unitwise_norm(p.grad)
        max_norm = param_norm * clipping
        trigger = grad_norm > max_norm
        clipped_grad = p.grad * (max_norm / grad_norm.clamp(min = 1e-6))
        new_grads = torch.where(trigger, clipped_grad, p.grad)
        p.grad.detach().copy_(new_grads)


In [None]:
def do_run():
    loss_values = []

    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    make_cutouts = MakeCutouts(clip_size, cutn, skip_augs=skip_augs, grayscale_cuts=grayscale_cuts)
    target_embeds, weights = [], []

    for prompt in text_prompts:
        txt, weight = parse_prompt(prompt)
        target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
        weights.append(weight)

    for prompt in image_prompts:
        path, weight = parse_prompt(prompt)
        img = Image.open(fetch(path)).convert('RGB')
        img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)
        batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
        embed = clip_model.encode_image(normalize(batch)).float()
        target_embeds.append(embed)
        weights.extend([weight / cutn] * cutn)

    target_embeds = torch.cat(target_embeds)
    weights = torch.tensor(weights, device=device)
    if weights.sum().abs() < 1e-3:
        raise RuntimeError('The weights must not sum to 0.')
    weights /= weights.sum().abs()

    init = None
    if init_image is not None:
        init = Image.open(fetch(init_image)).convert('RGB')
        init = init.resize((side_x, side_y), Image.LANCZOS)
        init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
    
    if perlin_init:
        if perlin_mode == 'color':
            init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
            init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
        elif perlin_mode == 'gray':
           init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
           init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
        else:
           init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
           init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
        
        # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)
        init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
        del init2

    cur_t = None

    def cond_fn(x, t, y=None):
        with torch.enable_grad():
            x = x.detach().requires_grad_()
            n = x.shape[0]
            my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
            out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
            fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
            x_in = out['pred_xstart'] * fac + x * (1 - fac)
            clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
            image_embeds = clip_model.encode_image(clip_in).float()
            dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
            dists = dists.view([cutn, n, -1])
            losses = dists.mul(weights).sum(2).mean(0)

            loss_values.append(losses.sum().item())

            tv_losses = tv_loss(x_in)
            range_losses = range_loss(out['pred_xstart'])
            loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale + range_losses.sum() * range_scale
            if init is not None and init_scale:
                init_losses = lpips_model(x_in, init)
                loss = loss + init_losses.sum() * init_scale

            # if clip_grad and int(timestep_respacing) - cur_t < int(timestep_respacing)//4:
            if clip_grad:
                adaptive_clip_grad([x])
            return -1 * torch.autograd.grad(loss, x)[0]

    if model_config['timestep_respacing'].startswith('ddim'):
        sample_fn = diffusion.ddim_sample_loop_progressive
    else:
        sample_fn = diffusion.p_sample_loop_progressive

    for i in range(n_batches):
        cur_t = diffusion.num_timesteps - skip_timesteps - 1

        samples = sample_fn(
            model,
            (batch_size, 3, side_y, side_x),
            clip_denoised=clip_denoised,
            model_kwargs={},
            cond_fn=cond_fn,
            progress=True,
            skip_timesteps=skip_timesteps,
            init_image=init,
            randomize_class=randomize_class,
        )
        
        for j, sample in enumerate(samples):
            display.clear_output(wait=True)
            cur_t -= 1
            if j % 1 == 0 or cur_t == -1:
                for k, image in enumerate(sample['pred_xstart']):
                    if j % display_rate == 0:
                      tqdm.write(f'Batch {i}, step {j}, output {k}:')
                      current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')
                      filename = f'{output_path}/progress_{i * batch_size + k + j:05}.jpg'
                      TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(filename)
                      if super_resolution:
                        upscale(filename)
                      #display.display(display.Image(filename))

        plt.plot(np.array(loss_values), 'r')

In [None]:
# timestep_respacing = '25' # Modify this value to decrease the number of timesteps.
timestep_respacing = 'ddim100' # Modify this value to decrease the number of timesteps.
diffusion_steps = 1000

model_config = model_and_diffusion_defaults()
model_config.update({
    'attention_resolutions': '32, 16, 8',
    'class_cond': False,
    'diffusion_steps': diffusion_steps,
    'rescale_timesteps': True,
    'timestep_respacing': timestep_respacing,
    'image_size': 256,
    'learn_sigma': True,
    'noise_schedule': 'linear',
    'num_channels': 256,
    'num_head_channels': 64,
    'num_res_blocks': 2,
    'resblock_updown': True,
    'use_fp16': True,
    'use_scale_shift_norm': True,
})
side_x = side_y = model_config['image_size']

model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)
for name, param in model.named_parameters():
    if 'qkv' in name or 'norm' in name or 'proj' in name:
        param.requires_grad_()
if model_config['use_fp16']:
    model.convert_to_fp16()

In [None]:
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)

## Settings for this run:

In [None]:
text_prompts = text_input.split("|")

image_prompts = [
    # 'mona.jpg',
]

# 350/50/50/32 and 500/0/0/64 have worked well for 25 timesteps on 256px
# Also, sometimes 1 cutn actually works out fine

clip_guidance_scale = 1000 # 1000 - Controls how much the image should look like the prompt.
tv_scale = 0 # 150 - Controls the smoothness of the final output.
range_scale = 0 # 50 - Controls how far out of range RGB values are allowed to be.
cutn = 32 # 16 - Controls how many crops to take from the image.

init_image = None # None - URL or local path
init_scale = 0 # 0 - This enhances the effect of the init image, a good value is 1000
skip_timesteps = 0 # 0 - Controls the starting point along the diffusion timesteps
perlin_init = False # False - Option to start with random perlin noise
perlin_mode = 'mixed' # 'mixed' ('gray', 'color')

skip_augs = False # False - Controls whether to skip torchvision augmentations
grayscale_cuts = False # False - Controls whether CLIP discriminates a grayscale version of the image
randomize_class = True # True - Controls whether the imagenet class is randomly changed each iteration
clip_denoised = True # False - Determines whether CLIP discriminates a noisy or denoised image

clip_grad = True # False - Experimental: Using adaptive clip grad in the cond_fn

seed = None

### Actually do the run...

In [None]:
seed = random.randint(0, 2**32) # Choose a random seed and print it for reproduction
print('seed:', seed)

display_rate = 1
n_batches = 1 # 1 - Controls how many consecutive batches of images are generated
batch_size = 1 # 1 - Controls how many images are generated in parallel in a batch

gc.collect()
torch.cuda.empty_cache()
try:
    do_run()
except KeyboardInterrupt:
    pass
finally:
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
out_file=output_path+"/video.mp4"

!mkdir -p /tmp/ffmpeg
!cp $output_path/*.jpg /tmp/ffmpeg
last_frame=!ls -t /tmp/ffmpeg/*.jpg | head -1
last_frame = last_frame[0]

# Copy last frame to start and duplicate at end so it sticks around longer
end_still_seconds = 4
!cp -v $last_frame /tmp/ffmpeg/0000.jpg
for i in range(end_still_seconds * 10):
  pad_file = f"/tmp/ffmpeg/zzzz_pad_{i:05}.jpg"
  !cp -v $last_frame $pad_file

!ffmpeg  -r 10 -i /tmp/ffmpeg/%*.jpg -y -c:v libx264 /tmp/vid_no_audio.mp4
!ffmpeg -i /tmp/vid_no_audio.mp4 -f lavfi -i anullsrc -c:v copy -c:a aac -shortest -y "$out_file"

print("Written", out_file)
!sleep 2
!rm -r /tmp/ffmpeg