# CLIP Upscaler and Enhancer
Using OpenAI's CLIP to upscale and enhance images

[![GitHub Repo stars](https://img.shields.io/github/stars/tripplyons/clip-upscaler-and-enhancer?style=social)](https://github.com/tripplyons/clip-upscaler-and-enhancer)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tripplyons/clip-upscaler-and-enhancer/blob/main/clip-upscaler-and-enhancer.ipynb)

Based on [nshepperd's JAX CLIP Guided Diffusion v2.4](https://colab.research.google.com/drive/10YWuTxtBI7PS0xBJCLAUjhR5cB0UUXe-)

# Original Notebook Description (from nshepperd's JAX CLIP Guided Diffusion v2.4)

## Generates images from text prompts with CLIP guided diffusion.

Based on my previous jax port of Katherine Crowson's CLIP guided diffusion notebook.
 - [nshepperd's JAX CLIP Guided Diffusion 512x512.ipynb](https://colab.research.google.com/drive/1ZZi1djM8lU4sorkve3bD6EBHiHs6uNAi)
 - [CLIP Guided Diffusion HQ 512x512.ipynb](https://colab.research.google.com/drive/1V66mUeJbXrTuQITvJunvnWVn96FEbSI3)

Added multi-perceptor and pytree ~trickery~ while eliminating the complicated OpenAI gaussian_diffusion classes. Supports both 256x256 and 512x512 OpenAI models (just change the `'image_size': 256` under Model Settings).
 - Added small secondary model for clip guidance.
 - Added anti-jpeg model for clearer samples.
 - Added secondary anti-jpeg classifier.
 - Added Katherine Crowso's v diffusion models (<https://github.com/crowsonkb/v-diffusion-jax>).
 - Added pixel art model.

In [None]:
#@title Licensed under the MIT License { display-mode: "form" }

# Copyright (c) 2021 Katherine Crowson; nshepperd

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
import os

# Mount drive for saving samples and caching model parameters
from google.colab import drive
drive.mount('/content/drive')
save_location = '/content/drive/MyDrive/samples/v2'
model_location = '/content/drive/MyDrive/models'

os.makedirs(save_location, exist_ok=True)
os.makedirs(model_location, exist_ok=True)

In [None]:
!nvidia-smi

# Setup

In [None]:
!nvidia-smi | grep A100 && pip install https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.71+cuda111-cp37-none-manylinux2010_x86_64.whl

In [None]:
# Install dependencies
!pip install tensorflow==1.15.2
!pip install dm-haiku cbor2 ftfy einops
!git clone https://github.com/kingoflolz/CLIP_JAX
!git clone https://github.com/nshepperd/jax-guided-diffusion -b v2
!git clone https://github.com/crowsonkb/v-diffusion-jax

In [None]:
import sys
sys.path.append('./CLIP_JAX')
sys.path.append('./jax-guided-diffusion')
sys.path.append('./v-diffusion-jax')

import math
import io
import time
import functools
from functools import partial
from dataclasses import dataclass
import weakref

from PIL import Image
import requests

import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxtorch
from jaxtorch import PRNG, Context, Module, nn, init
from tqdm import tqdm

import clip_jax
import diffusion as v_diffusion

from lib.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from lib import util

from IPython import display
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
import torch.utils.data
import torch

In [None]:
devices = jax.devices()
n_devices = len(devices)
print('Using device:', devices)

In [None]:
# Implement lazy loading and caching of model parameters for all the different models.

class WeakKey(object):
  """Weak pointer equality based keys for hashable dicts. Does not keep x alive."""
  def __init__(self, x):
    self.id = id(x)
    self.weak = weakref.ref(x)
  def __hash__(self):
    return hash(self.id)
  def __eq__(self, other):
    a = self.weak()
    b = other.weak()
    return self.id == other.id and (a is b)

class WeakCache(object):
  """A cache using weak references so values are cached only as long as they are referenced from elsewhere."""
  def __init__(self):
    self.cache = {}

  def lookup(self, f, x):
    """Look up the cached value of f(x)."""
    key = WeakKey(x)
    if key in self.cache:
      val = self.cache[key]()
      if val is not None:
        return val
    val = f(x)
    self.cache[key] = weakref.ref(val)
    return val

gpu_cache = WeakCache()

def to_gpu(params):
  """Convert a pytree of params to jax, using cached arrays if they are still alive."""
  return jax.tree_util.tree_map(lambda x: gpu_cache.lookup(jnp.array,x) if type(x) is np.ndarray else x, params)

# @jax.tree_util.register_pytree_node_class
class LazyParams(object):
  """Lazily download parameters and load onto gpu. Parameters are kept in cpu memory and only loaded to gpu as long as needed."""
  def __init__(self, load):
    self.load = load
    self.params = None
  @staticmethod
  def pt(url, key=None):
    def load():
      params = jaxtorch.pt.load(fetch_model(url))
      if key is not None:
        return params[key]
      else:
        return params
    return LazyParams(load)
  def __call__(self):
    if self.params is None:
      self.params = jax.tree_util.tree_map(np.array, self.load())
    return to_gpu(self.params)

  def tree_flatten(self):
      return [self()], []
  def tree_unflatten(static, dynamic):
      return dynamic[0]

In [None]:
# Define necessary functions

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    basename = os.path.basename(url_or_path)
    local_path = os.path.join(model_location, basename)
    if os.path.exists(local_path):
        return local_path
    else:
        !curl '{url_or_path}' -o '{local_path}'
        return local_path


def grey(image):
    [*_, c, h, w] = image.shape
    return jnp.broadcast_to(image.mean(axis=-3, keepdims=True), image.shape)

def cutout_image(image, offsetx, offsety, size, output_size=224):
    """Computes (square) cutouts of an image given x and y offsets and size."""
    (c, h, w) = image.shape

    scale = jnp.stack([output_size / size, output_size / size])
    translation = jnp.stack([-offsety * output_size / size, -offsetx * output_size / size])
    return jax.image.scale_and_translate(image,
                                         shape=(c, output_size, output_size),
                                         spatial_dims=(1,2),
                                         scale=scale,
                                         translation=translation,
                                         method='lanczos3')

def cutouts_images(image, offsetx, offsety, size, output_size=224):
    f = partial(cutout_image, output_size=output_size)         # [c h w] [] [] [] -> [c h w]
    f = jax.vmap(f, in_axes=(0, None, None, None), out_axes=0) # [n c h w] [] [] [] -> [n c h w]
    f = jax.vmap(f, in_axes=(None, 0, 0, 0), out_axes=0)       # [n c h w] [k] [k] [k] -> [k n c h w]
    return f(image, offsetx, offsety, size)

@jax.tree_util.register_pytree_node_class
class MakeCutouts(object):
    def __init__(self, cut_size, cutn, cut_pow=1., p_grey=0.2, p_mixgrey=0.0):
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.p_grey = p_grey
        self.p_mixgrey = p_mixgrey

    def __call__(self, input, key):
        [b, c, h, w] = input.shape
        rng = PRNG(key)
        max_size = min(h, w)
        min_size = min(h, w, self.cut_size)
        cut_us = jax.random.uniform(rng.split(), shape=[self.cutn//2])**self.cut_pow
        sizes = (min_size + cut_us * (max_size - min_size + 1)).astype(jnp.int32).clamp(min_size, max_size)
        offsets_x = jax.random.uniform(rng.split(), [self.cutn//2], minval=0, maxval=w - sizes)
        offsets_y = jax.random.uniform(rng.split(), [self.cutn//2], minval=0, maxval=h - sizes)
        cutouts = cutouts_images(input, offsets_x, offsets_y, sizes)

        B1 = 40
        B2 = 40
        lcut_us = jax.random.uniform(rng.split(), shape=[self.cutn//2])
        border = B1 + lcut_us * B2
        lsizes = (max(h,w) + border).astype(jnp.int32)
        loffsets_x = jax.random.uniform(rng.split(), [self.cutn//2], minval=w/2-lsizes/2-border, maxval=w/2-lsizes/2+border)
        loffsets_y = jax.random.uniform(rng.split(), [self.cutn//2], minval=h/2-lsizes/2-border, maxval=h/2-lsizes/2+border)
        lcutouts = cutouts_images(input, loffsets_x, loffsets_y, lsizes)

        cutouts = jnp.concatenate([cutouts, lcutouts], axis=0)

        greyed = grey(cutouts)

        grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, b, 1, 1, 1])
        grey_rs = jax.random.uniform(rng.split(), shape=[self.cutn, b, 1, 1, 1])
        cutouts = jnp.where(grey_us < self.p_mixgrey, grey_rs * greyed + (1 - grey_rs) * cutouts, cutouts)

        grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, b, 1, 1, 1])
        cutouts = jnp.where(grey_us < self.p_grey, greyed, cutouts)
        return cutouts

    def tree_flatten(self):
        return ([self.p_grey, self.cut_pow, self.p_mixgrey], (self.cut_size, self.cutn))

    @staticmethod
    def tree_unflatten(static, dynamic):
        (cut_size, cutn) = static
        (p_grey, cut_pow, p_mixgrey) = dynamic
        return MakeCutouts(cut_size, cutn, cut_pow, p_grey, p_mixgrey)

def Normalize(mean, std):
    mean = jnp.array(mean).reshape(3,1,1)
    std = jnp.array(std).reshape(3,1,1)
    def forward(image):
        return (image - mean) / std
    return forward

def norm1(x):
    """Normalize to the unit sphere."""
    return x / x.square().sum(axis=-1, keepdims=True).sqrt()

def spherical_dist_loss(x, y):
    x = norm1(x)
    y = norm1(y)
    return (x - y).square().sum(axis=-1).sqrt().div(2).arcsin().square().mul(2)

def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    x_diff = input[..., :, 1:] - input[..., :, :-1]
    y_diff = input[..., 1:, :] - input[..., :-1, :]
    return x_diff.square().mean([1,2,3]) + y_diff.square().mean([1,2,3])

def downscale2d(image, f):
  [c, n, h, w] = image.shape
  return jax.image.resize(image, [c, n, h//f, w//f], method='cubic')

def upscale2d(image, f):
  [c, n, h, w] = image.shape
  return jax.image.resize(image, [c, n, h*f, w*f], method='cubic')

def gaussian_blur(image, sigma, radius):
    if len(image.shape) == 4:
      [n, c, h, w] = image.shape
      return gaussian_blur(image.reshape([n*c,h,w]), sigma, radius).reshape(image.shape)
    # image : [c, h, w]
    kernel_size = radius * 2 + 1
    kernel_1d = jsp.stats.norm.pdf(jnp.linspace(-radius / sigma, radius / sigma, kernel_size))
    kernel = (kernel_1d[:, None] @ kernel_1d[None, :])[None]
    kernel = kernel / jnp.sum(kernel)
    return jsp.signal.convolve(image, kernel, 'same')

@dataclass
@jax.tree_util.register_pytree_node_class
class DiffusionOutput:
    v: torch.Tensor
    pred: torch.Tensor
    eps: torch.Tensor

    def tree_flatten(self):
        return [self.v, self.pred, self.eps], []

    @classmethod
    def tree_unflatten(cls, static, dynamic):
        return cls(*dynamic)
  
# Noise schedule

def alpha_sigma_to_t(alpha, sigma):
    return jnp.arctan2(sigma, alpha) * 2 / math.pi

def cosine_t_to_ddpm(t):
    alpha, sigma = get_cosine_alphas_sigmas(t)
    log_snr = jnp.log(alpha**2 / sigma**2)
    return ((jnp.log1p(jnp.exp(-log_snr)) - 1e-4) / 10).clamp(0,1).sqrt()

def get_ddpm_alphas_sigmas(t):
    log_snrs = -jnp.expm1(1e-4 + 10 * t**2).log()
    alphas_squared = jax.nn.sigmoid(log_snrs)
    sigmas_squared = jax.nn.sigmoid(-log_snrs)
    return alphas_squared.sqrt(), sigmas_squared.sqrt()

def get_cosine_alphas_sigmas(t):
    return jnp.cos(t * math.pi/2), jnp.sin(t * math.pi/2)

In [None]:
# Define combinators.

# These (ab)use the jax pytree registration system to define parameterised
# objects for doing various things, which are compatible with jax.jit.

# For jit compatibility an object needs to act as a pytree, which means implementing two methods:
#  - tree_flatten(self): returns two lists of the object's fields: 
#       1. 'dynamic' parameters: things which can be jax tensors, or other pytrees
#       2. 'static' parameters: arbitrary python objects, will trigger recompilation when changed
#  - tree_unflatten(static, dynamic): reconstitutes the object from its parts

# With these tricks, you can simply define your cond_fn as an object, as is done
# below, and pass it into the jitted sample step as a regular argument. JAX will
# handle recompiling the jitted code whenever a control-flow affecting parameter
# is changed (such as cut_batches).

@jax.tree_util.register_pytree_node_class
class CosineModel(object):
    def __init__(self, model, params, **kwargs):
      if isinstance(params, LazyParams):
        params = params()
      self.model = model
      self.params = params
      self.kwargs = kwargs
    @jax.jit
    def __call__(self, x, t, key):
        n = x.shape[0]
        alpha, sigma = get_ddpm_alphas_sigmas(t)
        cosine_t = alpha_sigma_to_t(alpha, sigma)
        cx = Context(self.params, key).eval_mode_()
        return self.model(cx, x, cosine_t.broadcast_to([n]), **self.kwargs)
    def tree_flatten(self):
        return [self.params, self.kwargs], [self.model]
    def tree_unflatten(static, dynamic):
        [params, kwargs] = dynamic
        [model] = static
        return CosineModel(model, params, **kwargs)

@jax.tree_util.register_pytree_node_class
class OpenaiModel(object):
    def __init__(self, model, params):
      if isinstance(params, LazyParams):
        params = params()
      self.model = model
      self.params = params
    @jax.jit
    def __call__(self, x, t, key):
        n = x.shape[0]
        alpha, sigma = get_ddpm_alphas_sigmas(t)
        cx = Context(self.params, key).eval_mode_()
        openai_t = (t * 1000).broadcast_to([n])
        eps = self.model(cx, x, openai_t)[:, :3, :, :]
        pred = (x - eps * sigma) / alpha
        v    = (eps - x * sigma) / alpha
        return DiffusionOutput(v, pred, eps)
    def tree_flatten(self):
        return [self.params], [self.model]
    def tree_unflatten(static, dynamic):
        [params] = dynamic
        [model] = static
        return OpenaiModel(model, params)

@jax.tree_util.register_pytree_node_class
class Perceptor(object):
    # Wraps a CLIP instance and its parameters.
    def __init__(self, image_fn, text_fn, clip_params):
        self.image_fn = image_fn
        self.text_fn = text_fn
        self.clip_params = clip_params
    @jax.jit
    def embed_cutouts(self, cutouts):
        return norm1(self.image_fn(self.clip_params, cutouts))
    def embed_text(self, text):
        tokens = clip_jax.tokenize([text])
        text_embed = self.text_fn(self.clip_params, tokens)
        return norm1(text_embed.reshape(512))
    def embed_texts(self, texts):
        return jnp.stack([self.embed_text(t) for t in texts])
    def tree_flatten(self):
        return [self.clip_params], [self.image_fn, self.text_fn]
    def tree_unflatten(static, dynamic):
        [clip_params] = dynamic
        [image_fn, text_fn] = static
        return Perceptor(image_fn, text_fn, clip_params)

@jax.tree_util.register_pytree_node_class
class LerpModels(object):
    """Linear combination of diffusion models."""
    def __init__(self, models):
        self.models = models
    def __call__(self, x, t, key):
        outputs = [m(x,t,key) for (m,w) in self.models]
        v = sum(out.v * w for (out, (m,w)) in zip(outputs, self.models))
        pred = sum(out.pred * w for (out, (m,w)) in zip(outputs, self.models))
        eps = sum(out.eps * w for (out, (m,w)) in zip(outputs, self.models))
        return DiffusionOutput(v, pred, eps)
    def tree_flatten(self):
        return [self.models], []
    def tree_unflatten(static, dynamic):
        return LerpModels(*dynamic)

@jax.tree_util.register_pytree_node_class
class KatModel(object):
    def __init__(self, model, params, **kwargs):
      if isinstance(params, LazyParams):
        params = params()
      self.model = model
      self.params = params
      self.kwargs = kwargs
    @jax.jit
    def __call__(self, x, t, key):
        n = x.shape[0]
        alpha, sigma = get_ddpm_alphas_sigmas(t)
        cosine_t = alpha_sigma_to_t(alpha, sigma)
        v = self.model.apply(self.params, key, x, cosine_t.broadcast_to([n]), self.kwargs)
        pred = x * alpha - v * sigma
        eps = x * sigma + v * alpha
        return DiffusionOutput(v, pred, eps)
    def tree_flatten(self):
        return [self.params, self.kwargs], [self.model]
    def tree_unflatten(static, dynamic):
        [params, kwargs] = dynamic
        [model] = static
        return KatModel(model, params, **kwargs)

In [None]:
# Common nn modules.
class SkipBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, cx, input):
        return jnp.concatenate([self.main(cx, input), self.skip(cx, input)], axis=1)


class FourierFeatures(nn.Module):
    def __init__(self, in_features, out_features, std=1.):
        super().__init__()
        assert out_features % 2 == 0
        self.weight = init.normal(out_features // 2, in_features, stddev=std)

    def forward(self, cx, input):
        f = 2 * math.pi * input @ cx[self.weight].transpose()
        return jnp.concatenate([f.cos(), f.sin()], axis=-1)


class AvgPool2d(nn.Module):
    def forward(self, cx, x):
        [n, c, h, w] = x.shape
        x = x.reshape([n, c, h//2, 2, w//2, 2])
        x = x.mean((3,5))
        return x


def expand_to_planes(input, shape):
    return input[..., None, None].broadcast_to(list(input.shape) + [shape[2], shape[3]])


In [None]:
# Secondary Model 
class ConvBlock(nn.Sequential):
    def __init__(self, c_in, c_out):
        super().__init__(
            nn.Conv2d(c_in, c_out, 3, padding=1),
            nn.ReLU(),
        )

class SecondaryDiffusionImageNet(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64  # The base channel count

        self.timestep_embed = FourierFeatures(1, 16)

        self.net = nn.Sequential(
            ConvBlock(3 + 16, c),
            ConvBlock(c, c),
            SkipBlock([
                AvgPool2d(),
                # nn.image.Downsample2d('linear'),
                ConvBlock(c, c * 2),
                ConvBlock(c * 2, c * 2),
                SkipBlock([
                    AvgPool2d(),
                    # nn.image.Downsample2d('linear'),
                    ConvBlock(c * 2, c * 4),
                    ConvBlock(c * 4, c * 4),
                    SkipBlock([
                        AvgPool2d(),
                        # nn.image.Downsample2d('linear'),
                        ConvBlock(c * 4, c * 8),
                        ConvBlock(c * 8, c * 4),
                        nn.image.Upsample2d('linear'),
                    ]),
                    ConvBlock(c * 8, c * 4),
                    ConvBlock(c * 4, c * 2),
                    nn.image.Upsample2d('linear'),
                ]),
                ConvBlock(c * 4, c * 2),
                ConvBlock(c * 2, c),
                nn.image.Upsample2d('linear'),
            ]),
            ConvBlock(c * 2, c),
            nn.Conv2d(c, 3, 3, padding=1),
        )

    def forward(self, cx, input, t):
        timestep_embed = expand_to_planes(self.timestep_embed(cx, t[:, None]), input.shape)
        v = self.net(cx, jnp.concatenate([input, timestep_embed], axis=1))
        alphas, sigmas = get_cosine_alphas_sigmas(t)
        alphas = alphas[:, None, None, None]
        sigmas = sigmas[:, None, None, None]
        pred = input * alphas - v * sigmas
        eps = input * sigmas + v * alphas
        return DiffusionOutput(v, pred, eps)

class SecondaryDiffusionImageNet2(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64  # The base channel count
        cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]

        self.timestep_embed = FourierFeatures(1, 16)
        self.down = AvgPool2d()
        self.up = nn.image.Upsample2d('linear')

        self.net = nn.Sequential(
            ConvBlock(3 + 16, cs[0]),
            ConvBlock(cs[0], cs[0]),
            SkipBlock([
                self.down,
                ConvBlock(cs[0], cs[1]),
                ConvBlock(cs[1], cs[1]),
                SkipBlock([
                    self.down,
                    ConvBlock(cs[1], cs[2]),
                    ConvBlock(cs[2], cs[2]),
                    SkipBlock([
                        self.down,
                        ConvBlock(cs[2], cs[3]),
                        ConvBlock(cs[3], cs[3]),
                        SkipBlock([
                            self.down,
                            ConvBlock(cs[3], cs[4]),
                            ConvBlock(cs[4], cs[4]),
                            SkipBlock([
                                self.down,
                                ConvBlock(cs[4], cs[5]),
                                ConvBlock(cs[5], cs[5]),
                                ConvBlock(cs[5], cs[5]),
                                ConvBlock(cs[5], cs[4]),
                                self.up,
                            ]),
                            ConvBlock(cs[4] * 2, cs[4]),
                            ConvBlock(cs[4], cs[3]),
                            self.up,
                        ]),
                        ConvBlock(cs[3] * 2, cs[3]),
                        ConvBlock(cs[3], cs[2]),
                        self.up,
                    ]),
                    ConvBlock(cs[2] * 2, cs[2]),
                    ConvBlock(cs[2], cs[1]),
                    self.up,
                ]),
                ConvBlock(cs[1] * 2, cs[1]),
                ConvBlock(cs[1], cs[0]),
                self.up,
            ]),
            ConvBlock(cs[0] * 2, cs[0]),
            nn.Conv2d(cs[0], 3, 3, padding=1),
        )

    def forward(self, cx, input, t):
        timestep_embed = expand_to_planes(self.timestep_embed(cx, t[:, None]), input.shape)
        v = self.net(cx, jnp.concatenate([input, timestep_embed], axis=1))
        alphas, sigmas = get_cosine_alphas_sigmas(t)
        alphas = alphas[:, None, None, None]
        sigmas = sigmas[:, None, None, None]
        pred = input * alphas - v * sigmas
        eps = input * sigmas + v * alphas
        return DiffusionOutput(v, pred, eps)

secondary1_model = SecondaryDiffusionImageNet()
secondary1_params = secondary1_model.init_weights(jax.random.PRNGKey(0))
secondary1_params = LazyParams.pt('https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet.pth')

secondary2_model = SecondaryDiffusionImageNet2()
secondary2_params = secondary2_model.init_weights(jax.random.PRNGKey(0))
secondary2_params = LazyParams.pt('https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth')

In [None]:
# Anti-JPEG model
class ResidualBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, cx, input):
        return self.main(cx, input) + self.skip(cx, input)


class ResConvBlock(ResidualBlock):
    def __init__(self, c_in, c_mid, c_out, dropout=True):
        skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
        super().__init__([
            nn.LeakyReLU(),
            nn.Conv2d(c_in, c_mid, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(c_mid, c_out, 3, padding=1),
        ], skip)
        

CHANNELS=64
class JPEGModel(nn.Module):
    def __init__(self, c=CHANNELS):
        super().__init__()

        self.timestep_embed = FourierFeatures(1, 16, std=1.0)
        self.class_embed = nn.Embedding(3, 16)

        self.arch = '11(22(22(2)22)22)11'

        self.net = nn.Sequential(
            nn.Conv2d(3 + 16 + 16, c, 1),
            ResConvBlock(c, c, c),
            ResConvBlock(c, c, c),
            SkipBlock([
                nn.image.Downsample2d(),
                ResConvBlock(c,     c * 2, c * 2),
                ResConvBlock(c * 2, c * 2, c * 2),
                SkipBlock([
                    nn.image.Downsample2d(),
                    ResConvBlock(c * 2, c * 2, c * 2),
                    ResConvBlock(c * 2, 2 * 2, c * 2),
                    SkipBlock([
                        nn.image.Downsample2d(),
                        ResConvBlock(c * 2, c * 2, c * 2),
                        nn.image.Upsample2d(),
                    ]),
                    ResConvBlock(c * 4, c * 2, c * 2),
                    ResConvBlock(c * 2, c * 2, c * 2),
                    nn.image.Upsample2d(),
                ]),
                ResConvBlock(c * 4, c * 2, c * 2),
                ResConvBlock(c * 2, c * 2, c),
                nn.image.Upsample2d(),
            ]),
            ResConvBlock(c * 2, c, c),
            ResConvBlock(c, c, 3, dropout=False),
        )

    def forward(self, cx, input, ts, cond):
        [n, c, h, w] = input.shape
        timestep_embed = expand_to_planes(self.timestep_embed(cx, ts[:, None]), input.shape)
        class_embed = expand_to_planes(self.class_embed(cx, cond), input.shape)
        v = self.net(cx, jnp.concatenate([input, timestep_embed, class_embed], axis=1))
        alphas, sigmas = get_cosine_alphas_sigmas(ts)
        alphas = alphas[:, None, None, None]
        sigmas = sigmas[:, None, None, None]
        pred = input * alphas - v * sigmas
        eps = input * sigmas + v * alphas
        return DiffusionOutput(v, pred, eps)

jpeg_model = JPEGModel()
jpeg_params = jpeg_model.init_weights(jax.random.PRNGKey(0))
jpeg_params = LazyParams.pt('https://set.zlkj.in/models/diffusion/jpeg-db-oi-614.pt', key='params_ema')

In [None]:
# Secondary Anti-JPEG Classifier

CHANNELS=64
class Classifier(nn.Module):
    def __init__(self, c=CHANNELS):
        super().__init__()

        self.timestep_embed = FourierFeatures(1, 16, std=1.0)

        self.arch = '11-22-22-22'

        self.net = nn.Sequential(
            nn.Conv2d(3 + 16, c, 1),
            ResConvBlock(c, c, c),
            ResConvBlock(c, c, c),
            nn.image.Downsample2d(),
            ResConvBlock(c,     c * 2, c * 2),
            ResConvBlock(c * 2, c * 2, c * 2),
            nn.image.Downsample2d(),
            ResConvBlock(c * 2, c * 2, c * 2),
            ResConvBlock(c * 2, 2 * 2, c * 2),
            nn.image.Downsample2d(),
            ResConvBlock(c * 2, c * 2, c * 2),
            ResConvBlock(c * 2, c * 2, c * 2),
            ResConvBlock(c * 2, c * 2, 1, dropout=False),
        )

    def forward(self, cx, input, ts):
        [n, c, h, w] = input.shape
        timestep_embed = expand_to_planes(self.timestep_embed(cx, ts[:, None]), input.shape)
        return self.net(cx, jnp.concatenate([input, timestep_embed], axis=1))

    def score(self, cx, reals, ts, cond, flood_level, blur_size):
        cond = cond[:, None, None, None]
        logits = self.forward(cx, reals, ts)
        logits = gaussian_blur(logits, blur_size, 6)
        loss = -jax.nn.log_sigmoid(jnp.where(cond==0, logits, -logits))
        loss = loss.clamp(minval=flood_level, maxval=None)
        return loss.mean()


@jax.jit
def classifier_probs(classifier_params, x, ts):
  n = x.shape[0]
  cx = Context(classifier_params, jax.random.PRNGKey(0)).eval_mode_()
  probs = jax.nn.sigmoid(classifier_model(cx, x, ts.broadcast_to([n])))
  return probs

classifier_model = Classifier()
classifier_params = classifier_model.init_weights(jax.random.PRNGKey(0))
classifier_params = LazyParams.pt('https://set.zlkj.in/models/diffusion/jpeg-classifier-72.pt', 'params_ema')

In [None]:
# Pixel art model

CHANNELS=192
class PixelArtV4(nn.Module):
    def __init__(self):
        super().__init__()
        c = CHANNELS  # The base channel count

        self.timestep_embed = FourierFeatures(1, 16, std=1.0)

        self.arch = '122222'

        muls = [1, 2, 2, 2, 2, 2]
        cs = [CHANNELS * m for m in muls]

        def downsample(c1, c2):
            return nn.Sequential(nn.image.Downsample2d(), nn.Conv2d(c1, c2, 1) if c1!=c2 else nn.Identity())

        def upsample(c1, c2):
            return nn.Sequential(nn.Conv2d(c1, c2, 1) if c1!=c2 else nn.Identity(), nn.image.Upsample2d())

        class ResConvBlock(ResidualBlock):
            def __init__(self, c_in, c_mid, c_out, dropout=True):
                skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
                super().__init__([
                    nn.Conv2d(c_in, c_mid, 3, padding=1),
                    nn.Dropout2d(p=0.1) if dropout else nn.Identity(),
                    nn.ReLU(),
                    nn.Conv2d(c_mid, c_out, 3, padding=1),
                    nn.Dropout2d(p=0.1) if dropout else nn.Identity(),
                    nn.ReLU() if dropout else nn.Identity(),
                ], skip)


        self.net = nn.Sequential(
            ResConvBlock(3 + 16, cs[0], cs[0]),
            ResConvBlock(cs[0], cs[0], cs[0]),
            SkipBlock([
                downsample(cs[0], cs[1]), # 2x2
                ResConvBlock(cs[1], cs[1], cs[1]),
                ResConvBlock(cs[1], cs[1], cs[1]),
                SkipBlock([
                    downsample(cs[1], cs[2]),  # 4x4
                    ResConvBlock(cs[2], cs[2], cs[2]),
                    ResConvBlock(cs[2], cs[2], cs[2]),
                    SkipBlock([
                        downsample(cs[2], cs[3]),  # 8x8
                        ResConvBlock(cs[3], cs[3], cs[3]),
                        ResConvBlock(cs[3], cs[3], cs[3]),
                        SkipBlock([
                            downsample(cs[3], cs[4]),  # 16x16
                            ResConvBlock(cs[4], cs[4], cs[4]),
                            ResConvBlock(cs[4], cs[4], cs[4]),
                            SkipBlock([
                                downsample(cs[4], cs[5]),  # 32x32
                                ResConvBlock(cs[5], cs[5], cs[5]),
                                ResConvBlock(cs[5], cs[5], cs[5]),
                                ResConvBlock(cs[5], cs[5], cs[5]),
                                ResConvBlock(cs[5], cs[5], cs[5]),
                                upsample(cs[5],cs[4]),
                            ]),
                            ResConvBlock(cs[4]*2, cs[4], cs[4]),
                            ResConvBlock(cs[4], cs[4], cs[4]),
                            upsample(cs[4],cs[3]),
                        ]),
                        ResConvBlock(cs[3]*2, cs[3], cs[3]),
                        ResConvBlock(cs[3], cs[3], cs[3]),
                        upsample(cs[3],cs[2]),
                    ]),
                    ResConvBlock(cs[2]*2, cs[2], cs[2]),
                    ResConvBlock(cs[2], cs[2], cs[2]),
                    upsample(cs[2],cs[1]),
                ]),
                ResConvBlock(cs[1]*2, cs[1], cs[1]),
                ResConvBlock(cs[1], cs[1], cs[1]),
                upsample(cs[1],cs[0]),
            ]),
            ResConvBlock(cs[0]*2, cs[0], cs[0]),
            ResConvBlock(cs[0], cs[0], 3, dropout=False),
        )

    def forward(self, cx, input, t):
        timestep_embed = expand_to_planes(self.timestep_embed(cx, t[:, None]), input.shape)
        v = self.net(cx, jnp.concatenate([input, timestep_embed], axis=1))
        alphas, sigmas = get_cosine_alphas_sigmas(t)
        alphas = alphas[:, None, None, None]
        sigmas = sigmas[:, None, None, None]
        pred = input * alphas - v * sigmas
        eps = input * sigmas + v * alphas
        return DiffusionOutput(v, pred, eps)

pixelartv4_model = PixelArtV4()
pixelartv4_params = pixelartv4_model.init_weights(jax.random.PRNGKey(0))

# There are many checkpoints supported with this model
pixelartv4_params = LazyParams.pt(
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_34.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_63.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_150.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_50.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_65.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_97.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_173.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_344.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_432.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_600.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_700.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_800.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_1000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_2000.pt'
    'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_3000.pt',
    key='params_ema'
)

In [None]:
@jax.tree_util.register_pytree_node_class
class MakeCutoutsPixelated(object):
    """Used for pixel art model - nearest upscale by 4x before taking cutouts to present a more pixel-arty pred to CLIP."""
    def __init__(self, make_cutouts, factor=4):
        self.make_cutouts = make_cutouts
        self.factor = factor
        self.cutn = make_cutouts.cutn
    
    def __call__(self, input, key):
        [n, c, h, w] = input.shape
        input = jax.image.resize(input, [n, c, h*self.factor, w * self.factor], method='nearest')
        return self.make_cutouts(input, key)

    def tree_flatten(self):
        return ([self.make_cutouts], [self.factor])
    @staticmethod
    def tree_unflatten(static, dynamic):
        return MakeCutoutsPixelated(*dynamic, *static)

In [None]:
# Kat models

danbooru_128_model = v_diffusion.get_model('danbooru_128')
danbooru_128_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/danbooru_128.pkl')))

wikiart_256_model = v_diffusion.get_model('wikiart_256')
wikiart_256_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/wikiart_256.pkl')))

wikiart_128_model = v_diffusion.get_model('wikiart_128')
wikiart_128_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/wikiart_128.pkl')))

imagenet_128_model = v_diffusion.get_model('imagenet_128')
imagenet_128_params = LazyParams(lambda: v_diffusion.load_params(fetch_model('https://v-diffusion.s3.us-west-2.amazonaws.com/imagenet_128.pkl')))

Model Settings

In [None]:
use_checkpoint = False # Set to True to save some memory

model_urls = {
    512: 'https://set.zlkj.in/models/diffusion/512x512_diffusion_uncond_finetune_008100.pt',
    256: 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'
}

# Load models, both 256 and 512

model_config = model_and_diffusion_defaults()
model_config.update({
    'attention_resolutions': '32, 16, 8',
    'class_cond': False,
    'diffusion_steps': 1000,
    'rescale_timesteps': True,
    'timestep_respacing': '1000',
    'image_size': 512,
    'learn_sigma': True,
    'noise_schedule': 'linear',
    'num_channels': 256,
    'num_head_channels': 64,
    'num_res_blocks': 2,
    'resblock_updown': True,
    'use_scale_shift_norm': True,
    'use_checkpoint': use_checkpoint 
})


openai_512_model, _ = create_model_and_diffusion(**model_config)
openai_512_params = openai_512_model.init_weights(jax.random.PRNGKey(0))
openai_512_params = LazyParams.pt(model_urls[512])

model_config = model_and_diffusion_defaults()
model_config.update({
    'attention_resolutions': '32, 16, 8',
    'class_cond': False,
    'diffusion_steps': 1000,
    'rescale_timesteps': True,
    'timestep_respacing': '1000',
    'image_size': 256,
    'learn_sigma': True,
    'noise_schedule': 'linear',
    'num_channels': 256,
    'num_head_channels': 64,
    'num_res_blocks': 2,
    'resblock_updown': True,
    'use_scale_shift_norm': True,
    'use_checkpoint': use_checkpoint 
})

openai_256_model, _ = create_model_and_diffusion(**model_config)
openai_256_params = openai_256_model.init_weights(jax.random.PRNGKey(0))
openai_256_params = LazyParams.pt(model_urls[256])

In [None]:
# Losses and cond fn.

@jax.tree_util.register_pytree_node_class
class CondCLIP(object):
    # CLIP guidance loss. Pushes the image toward a text prompt.
    def __init__(self, text_embed, clip_guidance_scale, perceptor, make_cutouts, cut_batches):
        self.text_embed = text_embed
        self.clip_guidance_scale = clip_guidance_scale
        self.perceptor = perceptor
        self.make_cutouts = make_cutouts
        self.cut_batches = cut_batches
    def __call__(self, x_in, key):
        n = x_in.shape[0]
        def main_clip_loss(x_in, key):
            cutouts = normalize(self.make_cutouts(x_in.add(1).div(2), key)).rearrange('k n c h w -> (k n) c h w')
            image_embeds = self.perceptor.embed_cutouts(cutouts).reshape([self.make_cutouts.cutn, n, 512])
            losses = spherical_dist_loss(image_embeds, self.text_embed).mean(0)
            return losses.sum() * self.clip_guidance_scale
        num_cuts = self.cut_batches
        keys = jnp.stack(jax.random.split(key, num_cuts))
        main_clip_grad = jax.lax.scan(lambda total, key: (total + jax.grad(main_clip_loss)(x_in, key), key),
                                        jnp.zeros_like(x_in),
                                        keys)[0] / num_cuts
        return main_clip_grad
    def tree_flatten(self):
        return [self.text_embed, self.clip_guidance_scale, self.perceptor, self.make_cutouts], [self.cut_batches]
    def tree_unflatten(static, dynamic):
        [text_embed, clip_guidance_scale, perceptor, make_cutouts] = dynamic
        [cut_batches] = static
        return CondCLIP(text_embed, clip_guidance_scale, perceptor, make_cutouts, cut_batches)

@jax.tree_util.register_pytree_node_class
class InfoLOOB(object):
    # CLIP guidance loss. Pushes the image toward a text prompt.
    def __init__(self, text_embed, clip_guidance_scale, perceptor, make_cutouts, lm, cut_batches):
        self.text_embed = text_embed
        self.clip_guidance_scale = clip_guidance_scale
        self.perceptor = perceptor
        self.make_cutouts = make_cutouts
        self.lm = lm
        self.cut_batches = cut_batches
    def __call__(self, x_in, key):
        n = x_in.shape[0]
        def main_clip_loss(x_in, key):
            cutouts = normalize(self.make_cutouts(x_in.add(1).div(2), key)).rearrange('k n c h w -> (k n) c h w')
            image_embeds = self.perceptor.embed_cutouts(cutouts).reshape([self.make_cutouts.cutn, n, 512])
                  
            all_image_embeds = norm1(image_embeds.mean(0))
            all_text_embeds = norm1(self.text_embed)
            sim_matrix = jnp.einsum('nc,mc->nm', all_image_embeds, all_text_embeds)

            x = 1
            xn = sim_matrix.shape[0]
            def loob(sim_matrix):
              diag = jnp.eye(xn) * sim_matrix
              off_diag = (1 - jnp.eye(xn))*sim_matrix
              return -diag.sum() + self.lm * off_diag.exp().sum(axis=-1).log().sum()
            losses = (loob(sim_matrix) + loob(sim_matrix.transpose())) / x
            return losses.sum() * self.clip_guidance_scale
        num_cuts = self.cut_batches
        keys = jnp.stack(jax.random.split(key, num_cuts))
        main_clip_grad = jax.lax.scan(lambda total, key: (total + jax.grad(main_clip_loss)(x_in, key), key),
                                        jnp.zeros_like(x_in),
                                        keys)[0] / num_cuts
        return main_clip_grad

    def tree_flatten(self):
        return [self.text_embed, self.clip_guidance_scale, self.perceptor, self.make_cutouts, self.lm], [self.cut_batches]
    @classmethod
    def tree_unflatten(cls, static, dynamic):
        return cls(*dynamic, *static)

@jax.tree_util.register_pytree_node_class
class CondTV(object):
    # Multiscale Total Variation loss. Tries to smooth out the image.
    def __init__(self, tv_scale):
        self.tv_scale = tv_scale
    def __call__(self, x_in, key):
        def sum_tv_loss(x_in, f=None):
            if f is not None:
                x_in = downscale2d(x_in, f)
            return tv_loss(x_in).sum() * self.tv_scale
        tv_grad_512 = jax.grad(sum_tv_loss)(x_in)
        tv_grad_256 = jax.grad(partial(sum_tv_loss,f=2))(x_in)
        tv_grad_128 = jax.grad(partial(sum_tv_loss,f=4))(x_in)
        return tv_grad_512 + tv_grad_256 + tv_grad_128        
    def tree_flatten(self):
        return [self.tv_scale], []
    def tree_unflatten(static, dynamic):
        return CondTV(*dynamic)

@jax.tree_util.register_pytree_node_class
class CondSat(object):
    # Saturation loss. Tries to prevent the image from going out of range.
    def __init__(self, sat_scale):
        self.sat_scale = sat_scale
    def __call__(self, x_in, key):
        def saturation_loss(x_in):
            return jnp.abs(x_in - x_in.clamp(minval=-1,maxval=1)).mean()
        return self.sat_scale * jax.grad(saturation_loss)(x_in)
    def tree_flatten(self):
        return [self.sat_scale], []
    def tree_unflatten(static, dynamic):
        return CondSat(*dynamic)


@jax.tree_util.register_pytree_node_class
class CondMSE(object):
    # MSE loss. Targets the output towards an image.
    def __init__(self, target, mse_scale):
        self.target = target
        self.mse_scale = mse_scale
    def __call__(self, x_in, key):
        def mse_loss(x_in):
            return (x_in - self.target).square().mean()
        return self.mse_scale * jax.grad(mse_loss)(x_in)
    def tree_flatten(self):
        return [self.target, self.mse_scale], []
    def tree_unflatten(static, dynamic):
        return CondMSE(*dynamic)

@jax.tree_util.register_pytree_node_class
class MaskedMSE(object):
    # MSE loss. Targets the output towards an image.
    def __init__(self, target, mse_scale, mask, grey=False):
        self.target = target
        self.mse_scale = mse_scale
        self.mask = mask
        self.grey = grey
    def __call__(self, x_in, key):
        def mse_loss(x_in):
            if self.grey:
              return (self.mask * grey(x_in - self.target).square()).mean()
            else:
              return (self.mask * (x_in - self.target).square()).mean()
        return self.mse_scale * jax.grad(mse_loss)(x_in)
    def tree_flatten(self):
        return [self.target, self.mse_scale, self.mask], [self.grey]
    def tree_unflatten(static, dynamic):
        return MaskedMSE(*dynamic, *static)


@jax.tree_util.register_pytree_node_class
class MainCondFn(object):
    # Used to construct the main cond_fn. Accepts a diffusion model which will 
    # be used for denoising, plus a list of 'conditions' which will
    # generate gradient of a loss wrt the denoised, to be summed together.
    def __init__(self, diffusion, conditions, use='pred'):
        self.diffusion = diffusion
        self.conditions = [c for c in conditions if c is not None]
        self.use = use

    @jax.jit
    def __call__(self, key, x, t):
        rng = PRNG(key)
        n = x.shape[0]

        alphas, sigmas = get_ddpm_alphas_sigmas(t)

        def denoise(key, x):
            pred = self.diffusion(x, t, key).pred
            if self.use == 'pred':
                return pred
            elif self.use == 'x_in':
                return pred * sigmas + x * alphas
        (x_in, backward) = jax.vjp(partial(denoise, rng.split()), x)

        total = jnp.zeros_like(x_in)
        for cond in self.conditions:
            total += cond(x_in, rng.split())
        final_grad = -backward(total)[0]

        # clamp gradients to a max of 0.2
        magnitude = final_grad.square().mean(axis=(1,2,3), keepdims=True).sqrt()
        final_grad = final_grad * jnp.where(magnitude > 0.2, 0.2 / magnitude, 1.0)
        return final_grad  
    def tree_flatten(self):
        return [self.diffusion, self.conditions], [self.use]
    def tree_unflatten(static, dynamic):
        return MainCondFn(*dynamic, *static)


@jax.tree_util.register_pytree_node_class
class ClassifierFn(object):
    def __init__(self, model, params, guidance_scale, **kwargs):
       self.model = model
       self.params = params
       self.guidance_scale = guidance_scale
       self.kwargs = kwargs

    @jax.jit
    def __call__(self, key, x, t):
        n = x.shape[0]
        alpha, sigma = get_ddpm_alphas_sigmas(t)
        cosine_t = alpha_sigma_to_t(alpha, sigma).broadcast_to([n])
        def fwd(x):
          cx = Context(self.params, key).eval_mode_()
          return self.guidance_scale * self.model.score(cx, x, cosine_t, **self.kwargs)
        return -jax.grad(fwd)(x)
    def tree_flatten(self):
        return [self.params, self.guidance_scale, self.kwargs], [self.model]
    def tree_unflatten(static, dynamic):
        [params, guidance_scale, kwargs] = dynamic
        [model] = static
        return ClassifierFn(model, params, guidance_scale, **kwargs)


@jax.tree_util.register_pytree_node_class
class CondFns(object):
    def __init__(self, *conditions):
        self.conditions = conditions
    def __call__(self, key, x, t):
        rng = PRNG(key)
        total = jnp.zeros_like(x)
        for cond in self.conditions:
          total += cond(rng.split(), x, t)
        return total
    def tree_flatten(self):
        return [self.conditions], []
    def tree_unflatten(static, dynamic):
        [conditions] = dynamic
        return MixCondFn(*conditions)

In [None]:
def sample_step(key, x, t1, t2, diffusion, cond_fn, eta):
    rng = PRNG(key)

    n = x.shape[0]
    alpha1, sigma1 = get_ddpm_alphas_sigmas(t1)
    alpha2, sigma2 = get_ddpm_alphas_sigmas(t2)

    # Run the model
    out = diffusion(x, t1, rng.split())
    eps = out.eps
    pred0 = out.pred

    # # Predict the denoised image
    # pred0 = (x - eps * sigma1) / alpha1

    # Adjust eps with conditioning gradient
    cond_score = cond_fn(rng.split(), x, t1)
    eps = eps - sigma1 * cond_score

    # Predict the denoised image with conditioning
    pred = (x - eps * sigma1) / alpha1

    # Negative eta allows more extreme levels of noise.
    ddpm_sigma = (sigma2**2 / sigma1**2).sqrt() * (1 - alpha1**2 / alpha2**2).sqrt()
    ddim_sigma = jnp.where(eta >= 0.0, 
                           eta * ddpm_sigma, # Normal: eta interpolates between ddim and ddpm
                           -eta * sigma2)    # Extreme: eta interpolates between ddim and q_sample(pred)
    adjusted_sigma = (sigma2**2 - ddim_sigma**2).sqrt()

    # Recombine the predicted noise and predicted denoised image in the
    # correct proportions for the next step
    x = pred * alpha2 + eps * adjusted_sigma

    # Add the correct amount of fresh noise
    x += jax.random.normal(rng.split(), x.shape) * ddim_sigma
    return x, pred0

Load CLIP Models

In [None]:
clip_size = 224
normalize = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                      std=[0.26862954, 0.26130258, 0.27577711])

image_fn, text_fn, clip_params, vit32_preprocess = clip_jax.load('ViT-B/32')
vit32 = Perceptor(image_fn, text_fn, clip_params)
vit32_embed = image_fn
vit32_params = clip_params

image_fn, text_fn, clip_params, vit16_preprocess = clip_jax.load('ViT-B/16')
vit16 = Perceptor(image_fn, text_fn, clip_params)
vit16_embed = image_fn
vit16_params = clip_params

# Settings and Run

## Configuration for the run

By default, it will look for an image at `/content/init.png`

Using the default settings, it takes about 45 minutes on a P100 on Google Colab. To lower the quality, lower the value of `steps` or the dimensions of `image_size`.

In [None]:
seed = None # if None, uses the current time in seconds.
image_size = (1024, 1024)
batch_size = 1
n_batches = 1

# not used for image generation, just the filename of the output image
all_title = "enhancer"
title = [all_title] * batch_size

clip_guidance_scale = 1000 # Note: with two perceptors, effective guidance scale is ~2x because they are added together.
tv_scale = 150 #150  # Smooths out the image
sat_scale = 600 # Tries to prevent pixel values from going out of range
cutn = 8        # Effective cutn is cut_batches * this
cut_pow = 1.0   # Affects the size of cutouts. Larger cut_pow -> smaller cutouts (down to the min of 224x244)
cut_batches = 4
make_cutouts = MakeCutouts(clip_size, cutn, cut_pow=cut_pow, p_mixgrey=0.0)

steps = 1000     # Number of steps for sampling. Generally, more = better.
eta = 0.0       # 0.0: DDIM | 1.0: DDPM | -1.0: Extreme noise (q_sample)
init_image = 'init.png'      # Diffusion will start with a mixture of this image with noise.
starting_noise = 0.25  # Between 0 and 1. When using init image, generally 0.5-0.8 is good. Lower starting noise makes the result look more like the init.
init_weight_mse = 0    # MSE loss between the output and the init makes the result look more like the init (should be between 0 and width*height*3). 
                       # (LPIPS... will be added later)


init_pil = Image.open(init_image)
vit32_embedding = vit32_embed(vit32_params, np.expand_dims(vit32_preprocess(init_pil), 0))
vit16_embedding = vit16_embed(vit16_params, np.expand_dims(vit16_preprocess(init_pil), 0))


# OpenAI used T=1000 to 0. We've just rescaled to between 1 and 0.
schedule = jnp.linspace(starting_noise, 0, steps+1)

if init_image is not None:
    init_array = Image.open(fetch(init_image)).convert('RGB')
    init_array = init_array.resize(image_size, Image.LANCZOS)
    init_array = jnp.array(TF.to_tensor(init_array)).unsqueeze(0).mul(2).sub(1)
else:
    init_array = None

def config():
    # Configure models and load parameters onto gpu.
    # We do this in a function to avoid leaking gpu memory.

    # -- Openai with anti-jpeg --
    openai = OpenaiModel(openai_512_model, openai_512_params())
    secondary2 = CosineModel(secondary2_model, secondary2_params())
    jpeg_0 = CosineModel(jpeg_model, jpeg_params(), cond=jnp.array([0]*batch_size)) # Clean class
    jpeg_1 = CosineModel(jpeg_model, jpeg_params(), cond=jnp.array([2]*batch_size)) # Noisy class

    jpeg_classifier_fn = ClassifierFn(classifier_model, classifier_params(), 
                                      guidance_scale=10000.0, # will generally depend on image size
                                      cond=jnp.array([0]*batch_size), # Clean class
                                      flood_level=0.7, # Prevent over-optimization
                                      blur_size=3.0)

    diffusion = LerpModels([(openai, 1.0),
                            (jpeg_0, 1.0),
                            (jpeg_1, -1.0)])
    cond_model = secondary2

    cond_fn = CondFns(MainCondFn(cond_model, [
                        CondCLIP(vit32_embedding, clip_guidance_scale, vit32, make_cutouts, cut_batches),
                        CondCLIP(vit16_embedding, clip_guidance_scale, vit16, make_cutouts, cut_batches),
                        CondTV(tv_scale) if tv_scale > 0 else None,
                        CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
                        CondSat(sat_scale) if sat_scale > 0 else None,
                        ], use='pred'),
                      jpeg_classifier_fn,
                      )

    # # -- v diffusion models --
    # # Uncomment one of the four below.
    # diffusion = KatModel(wikiart_256_model, wikiart_256_params())
    # diffusion = KatModel(wikiart_128_model, wikiart_128_params())
    # diffusion = KatModel(danbooru_128_model, danbooru_128_params())
    # diffusion = KatModel(imagenet_128_model, imagenet_128_params())
    # cond_model = diffusion

    # cond_fn = MainCondFn(cond_model, [
    #             CondCLIP(vit32.embed_texts(title), clip_guidance_scale, vit32, make_cutouts, cut_batches),
    #             CondCLIP(vit16.embed_texts(title), clip_guidance_scale, vit16, make_cutouts, cut_batches),
    #             CondTV(tv_scale) if tv_scale > 0 else None,
    #             CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
    #             CondSat(sat_scale) if sat_scale > 0 else None,
    #             ], use='pred')
    
    # # -- pixel art model --
    # diffusion = CosineModel(pixelartv4_model, pixelartv4_params())
    # cond_model = diffusion
    # cond_fn = MainCondFn(cond_model, [
    #             CondCLIP(vit32.embed_texts(title), clip_guidance_scale, vit32, MakeCutoutsPixelated(make_cutouts), cut_batches),
    #             CondCLIP(vit16.embed_texts(title), clip_guidance_scale, vit16, MakeCutoutsPixelated(make_cutouts), cut_batches),
    #             CondTV(tv_scale) if tv_scale > 0 else None,
    #             CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
    #             CondSat(sat_scale) if sat_scale > 0 else None,
    #             ], use='pred')

    return diffusion, cond_fn

diffusion, cond_fn = config()

In [None]:
# Actually do the run

def sanitize(title):
  return title[:100].replace('/', '_').replace('\\', '_')

@torch.no_grad()
def run():
    if seed is None:
        local_seed = int(time.time())
    else:
        local_seed = seed
    print(f'Starting run with seed {local_seed}...')
    rng = PRNG(jax.random.PRNGKey(local_seed))

    for i in range(n_batches):
        timestring = time.strftime('%Y%m%d%H%M%S')

        ts = schedule
        alphas, sigmas = get_ddpm_alphas_sigmas(ts)
        cosine_ts = alpha_sigma_to_t(alphas, sigmas)

        x = sigmas[0] * jax.random.normal(rng.split(), [batch_size, 3, image_size[1], image_size[0]])

        if init_array is not None:
            x = x + alphas[0] * init_array

        # Main loop
        local_steps = schedule.shape[0] - 1
        for j in tqdm(range(local_steps)):
            if ts[j] != ts[j+1]:
                # Skip steps where the ts are the same, to make it easier to
                # make complicated schedules out of cat'ing linspaces.
                x, pred = sample_step(rng.split(), x, ts[j], ts[j+1], diffusion, cond_fn, eta)
            if j % 50 == 0 or j == local_steps - 1:
                images = pred.add(1).div(2).clamp(0, 1)
                images = torch.tensor(np.array(images))
                display.display(TF.to_pil_image(utils.make_grid(images, 4).cpu()))

        # Save samples
        os.makedirs('samples/grid', exist_ok=True)
        os.makedirs(f'{save_location}/grid', exist_ok=True)
        TF.to_pil_image(utils.make_grid(images, 4).cpu()).save(f'samples/grid/{timestring}_{sanitize(all_title)}.png')
        TF.to_pil_image(utils.make_grid(images, 4).cpu()).save(f'{save_location}/grid/{timestring}_{sanitize(all_title)}.png')

        os.makedirs('samples/images', exist_ok=True)
        os.makedirs(f'{save_location}/images', exist_ok=True)
        for k in range(batch_size):
            this_title = sanitize(title[k])
            dname = f'samples/images/{timestring}_{k}_{this_title}.png'
            pil_image = TF.to_pil_image(images[k])
            pil_image.save(dname)
            pil_image.save(f'{save_location}/images/{timestring}_{k}_{this_title}.png')

try:
  run()
  success = True
except:
  import traceback
  traceback.print_exc()
  success = False
assert success
