<a href="https://colab.research.google.com/github/tzzcl/JAX_CLIP/blob/main/JAX_CLIP_Guided_Diffusion_v2_7_(huemin_edit%2C_Apr_2022).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generates images from text prompts with CLIP guided diffusion.

[huemin](https://twitter.com/huemin_art) implementation based on [nshepperd's](https://twitter.com/nshepperd1) original notebook.
 - [nshepperd's JAX CLIP Guided Diffusion v2.7](https://colab.research.google.com/drive/1Z5kK1WXTkYoMAVN6FqkQg0Fa4bE5BnxG?usp=sharing)
 - [nshepperd's JAX CLIP Guided Diffusion v2.6](https://colab.research.google.com/drive/1fW_tPEX7iD3xZK3VBDQ_Y2WnfdSzpacM?usp=sharing)

**Note the UI is minimal, use *Show code* for more info**


In [None]:
#@title Changelog + Licensed under the MIT License{ display-mode: "form" }
change_log = """
Change Log (last updated 2022.03.26)
 - made UI changes
 - added QOL stuff for google colab
 - added clip guided cc12m1 model from jax 2.6
 - added prompt weights
 - added batch prompts, random prompts, and random settings
 - added video outputs
 - added support for image prompts
 - added modified cuts
 - added init skip steps
 - added batch inits
 - added simple symmetry 
 - added symmetry conds
 - added mean and variance conds
 - added database feature
 - added Prof RJ lerp models

Planned Changes (not implemented currently)
 - animations
 - animation key frames
 """
print(change_log)

#@title Licensed under the MIT License { display-mode: "form" }

# Copyright (c) 2021 Katherine Crowson; nshepperd; huemin

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

In [None]:
#@markdown Mount to Google Drive
googleDrive = True #@param {type:"boolean"}
modelsOnDrive = True #@param {type:"boolean"}
initImageFolder = True #@param {type:"boolean"}
promptFolder = True #@param {type:"boolean"}

outputFolder = "AI" #@param {type:"string"}
v2 = "nshepv2g/"
outputFolderStatic = outputFolder
if outputFolderStatic == "":
    outputFolderStatic = '/content/drive/MyDrive/AI/'
else:
    outputFolderStatic = '/content/drive/MyDrive/'+outputFolderStatic+'/'

if googleDrive or modelsOnDrive:
    from google.colab import drive
    drive.mount('/content/drive')

# Setup & Definitions

In [None]:
import os
if os.system("nvidia-smi | grep A100") == 0:
  !pip install -U https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.72+cuda111-cp37-none-manylinux2010_x86_64.whl "jax==0.2.25"
else:
  !pip install https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.1.75%2Bcuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl "jax==0.3.0"
!pip install dm-haiku==0.0.5 cbor2 ftfy einops braceexpand
!git clone https://github.com/nshepperd/CLIP_JAX
!git clone https://github.com/nshepperd/jax-guided-diffusion -b v2.7
!git clone https://github.com/crowsonkb/v-diffusion-jax

In [None]:
import sys, os
sys.path.append('./CLIP_JAX')
sys.path.append('./jax-guided-diffusion')
sys.path.append('./v-diffusion-jax')
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

from PIL import Image
from braceexpand import braceexpand
from dataclasses import dataclass
from functools import partial
from subprocess import Popen, PIPE
import functools
import io
import math
import re
import requests
import time
import json
import pandas as pd
import shutil

import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxtorch
from jaxtorch import PRNG, Context, Module, nn, init
from tqdm import tqdm

from IPython import display
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
import torch.utils.data
import torch

from diffusion_models.common import DiffusionOutput, Partial, make_partial, blur_fft, norm1, LerpModels
from diffusion_models.lazy import LazyParams
from diffusion_models.schedules import cosine, ddpm, ddpm2, spliced
from diffusion_models.perceptor import get_clip, clip_size, normalize

from diffusion_models.aesthetic import AestheticLoss, AestheticExpected
from diffusion_models.secondary import secondary1_wrap, secondary2_wrap
from diffusion_models.antijpeg import anti_jpeg_cfg, jpeg_classifier, jpeg_classifier_wrap, jpeg_classifier_params
from diffusion_models.pixelart import pixelartv4_wrap, pixelartv6_wrap
from diffusion_models.pixelartv7 import pixelartv7_ic_attn
from diffusion_models.cc12m_1 import cc12m_1_wrap, cc12m_1_cfg_wrap, cc12m_1_classifier_wrap
from diffusion_models.openai import openai_256, openai_512, openai_512_finetune
from diffusion_models.kat_models import danbooru_128, wikiart_128, wikiart_256, imagenet_128
from diffusion_models import sampler

In [None]:
devices = jax.devices()
n_devices = len(devices)
print('Using device:', devices)

In [None]:
# Drive location for caching model parameters
if modelsOnDrive:
    model_location = '/content/drive/MyDrive/models'
else:
    model_location = '/content/models'

os.makedirs(model_location, exist_ok=True)

# Drive location for inits
if initImageFolder:
  init_location = outputFolderStatic+"nshepv2g/"+"init/"
  os.makedirs(init_location, exist_ok=True)
else:
  init_location = ''

# Drive location for prompt
if promptFolder:
  prompt_location = outputFolderStatic+"nshepv2g/"+"prompts/"
  os.makedirs(prompt_location, exist_ok=True)

# make video output folder path
videoOutputFolder = outputFolder+"videos/"

In [None]:
# Define necessary functions

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def fetch_model(url_or_path):
    basename = os.path.basename(url_or_path)
    local_path = os.path.join(model_location, basename)
    if os.path.exists(local_path):
        return local_path
    else:
        os.makedirs(f'{model_location}/tmp', exist_ok=True)
        Popen(['curl', url_or_path, '-o', f'{model_location}/tmp/{basename}']).wait()
        os.rename(f'{model_location}/tmp/{basename}', local_path)
        return local_path
        
LazyParams.fetch = fetch_model

def grey(image):
    [*_, c, h, w] = image.shape
    return jnp.broadcast_to(image.mean(axis=-3, keepdims=True), image.shape)

def cutout_image(image, offsetx, offsety, size, output_size=224):
    """Computes (square) cutouts of an image given x and y offsets and size."""
    (c, h, w) = image.shape

    scale = jnp.stack([output_size / size, output_size / size])
    translation = jnp.stack([-offsety * output_size / size, -offsetx * output_size / size])
    return jax.image.scale_and_translate(image,
                                         shape=(c, output_size, output_size),
                                         spatial_dims=(1,2),
                                         scale=scale,
                                         translation=translation,
                                         method='lanczos3')

def cutouts_images(image, offsetx, offsety, size, output_size=224):
    f = partial(cutout_image, output_size=output_size)         # [c h w] [] [] [] -> [c h w]
    f = jax.vmap(f, in_axes=(0, 0, 0, 0), out_axes=0)          # [n c h w] [n] [n] [n] -> [n c h w]
    f = jax.vmap(f, in_axes=(None, 0, 0, 0), out_axes=0)       # [n c h w] [k n] [k n] [k n] -> [k n c h w]
    return f(image, offsetx, offsety, size)

@jax.tree_util.register_pytree_node_class
class MakeCutouts(object):
    def __init__(self, cut_size, cutn, cut_pow=1.0, p_grey=0.2, p_mixgrey=None, p_flip=0.5):
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.p_grey = p_grey
        self.p_mixgrey = p_mixgrey
        self.p_flip = p_flip

    def __call__(self, input, key):
        [n, c, h, w] = input.shape
        rng = PRNG(key)

        small_cuts = self.cutn//2
        large_cuts = self.cutn - self.cutn//2

        max_size = min(h, w)
        min_size = min(h, w, self.cut_size)
        cut_us = jax.random.uniform(rng.split(), shape=[small_cuts, n])**self.cut_pow
        sizes = (min_size + cut_us * (max_size - min_size)).clamp(min_size, max_size)
        offsets_x = jax.random.uniform(rng.split(), [small_cuts, n], minval=0, maxval=w - sizes)
        offsets_y = jax.random.uniform(rng.split(), [small_cuts, n], minval=0, maxval=h - sizes)
        cutouts = cutouts_images(input, offsets_x, offsets_y, sizes)

        B1 = 40
        B2 = 40
        lcut_us = jax.random.uniform(rng.split(), shape=[large_cuts, n])
        border = B1 + lcut_us * B2
        lsizes = (max(h,w) + border).astype(jnp.int32)
        loffsets_x = jax.random.uniform(rng.split(), [large_cuts, n], minval=w/2-lsizes/2-border, maxval=w/2-lsizes/2+border)
        loffsets_y = jax.random.uniform(rng.split(), [large_cuts, n], minval=h/2-lsizes/2-border, maxval=h/2-lsizes/2+border)
        lcutouts = cutouts_images(input, loffsets_x, loffsets_y, lsizes)

        cutouts = jnp.concatenate([cutouts, lcutouts], axis=0)

        greyed = grey(cutouts)

        if self.p_mixgrey is not None:
          grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, n, 1, 1, 1])
          grey_rs = jax.random.uniform(rng.split(), shape=[self.cutn, n, 1, 1, 1])
          cutouts = jnp.where(grey_us < self.p_mixgrey, grey_rs * greyed + (1 - grey_rs) * cutouts, cutouts)

        if self.p_grey is not None:
          grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, n, 1, 1, 1])
          cutouts = jnp.where(grey_us < self.p_grey, greyed, cutouts)

        if self.p_flip is not None:
          flip_us = jax.random.bernoulli(rng.split(), self.p_flip, [self.cutn, n, 1, 1, 1])
          cutouts = jnp.where(flip_us, jnp.flip(cutouts, axis=-1), cutouts)

        return cutouts

    def tree_flatten(self):
        return ([self.cut_pow, self.p_grey, self.p_mixgrey, self.p_flip], (self.cut_size, self.cutn))

    @staticmethod
    def tree_unflatten(static, dynamic):
        (cut_size, cutn) = static
        return MakeCutouts(cut_size, cutn, *dynamic)

@jax.tree_util.register_pytree_node_class
class MakeCutouts_huemin(object):
    def __init__(self, cut_size, cutn, cut_pow=1.0, p_grey=0.2, p_mixgrey=None, p_flip=0.5):
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.p_grey = p_grey
        self.p_mixgrey = p_mixgrey
        self.p_flip = p_flip

    def __call__(self, input, key):
        [n, c, h, w] = input.shape
        rng = PRNG(key)

        small_cuts = self.cutn//2
        large_cuts = self.cutn - self.cutn//2

        max_size = min(h, w)
        min_size = min(h, w, self.cut_size)

        cut_power = np.random.gamma(1, 1, 1)[0]*self.cut_pow

        cut_us = jax.random.uniform(rng.split(), shape=[small_cuts, n])**cut_power
        sizes = (min_size + cut_us * (max_size - min_size)).clamp(min_size, max_size)
        offsets_x = jax.random.uniform(rng.split(), [small_cuts, n], minval=0, maxval=w - sizes)
        offsets_y = jax.random.uniform(rng.split(), [small_cuts, n], minval=0, maxval=h - sizes)
        cutouts = cutouts_images(input, offsets_x, offsets_y, sizes)

        B1 = np.random.gamma(1, max_size/4, 1)[0]
        B2 = np.random.gamma(1, max_size/4, 1)[0]
        
        lcut_us = jax.random.uniform(rng.split(), shape=[large_cuts, n])
        border = B1 + lcut_us * B2
        lsizes = (max(h,w) + border).astype(jnp.int32)
        loffsets_x = jax.random.uniform(rng.split(), [large_cuts, n], minval=w/2-lsizes/2-border, maxval=w/2-lsizes/2+border)
        loffsets_y = jax.random.uniform(rng.split(), [large_cuts, n], minval=h/2-lsizes/2-border, maxval=h/2-lsizes/2+border)
        lcutouts = cutouts_images(input, loffsets_x, loffsets_y, lsizes)

        cutouts = jnp.concatenate([cutouts, lcutouts], axis=0)

        greyed = grey(cutouts)

        if self.p_mixgrey is not None:
          grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, n, 1, 1, 1])
          grey_rs = jax.random.uniform(rng.split(), shape=[self.cutn, n, 1, 1, 1])
          cutouts = jnp.where(grey_us < self.p_mixgrey, grey_rs * greyed + (1 - grey_rs) * cutouts, cutouts)

        if self.p_grey is not None:
          grey_us = jax.random.uniform(rng.split(), shape=[self.cutn, n, 1, 1, 1])
          cutouts = jnp.where(grey_us < self.p_grey, greyed, cutouts)

        if self.p_flip is not None:
          flip_us = jax.random.bernoulli(rng.split(), self.p_flip, [self.cutn, n, 1, 1, 1])
          cutouts = jnp.where(flip_us, jnp.flip(cutouts, axis=-1), cutouts)

        return cutouts

    def tree_flatten(self):
        return ([self.cut_pow, self.p_grey, self.p_mixgrey, self.p_flip], (self.cut_size, self.cutn))

    @staticmethod
    def tree_unflatten(static, dynamic):
        (cut_size, cutn) = static
        return MakeCutouts_huemin(cut_size, cutn, *dynamic)

@jax.tree_util.register_pytree_node_class
class MakeCutoutsPixelated(object):
    def __init__(self, make_cutouts, factor=4):
        self.make_cutouts = make_cutouts
        self.factor = factor
        self.cutn = make_cutouts.cutn

    def __call__(self, input, key):
        [n, c, h, w] = input.shape
        input = jax.image.resize(input, [n, c, h*self.factor, w * self.factor], method='nearest')
        return self.make_cutouts(input, key)

    def tree_flatten(self):
        return ([self.make_cutouts], [self.factor])
    @staticmethod
    def tree_unflatten(static, dynamic):
        return MakeCutoutsPixelated(*dynamic, *static)

def spherical_dist_loss(x, y):
    x = norm1(x)
    y = norm1(y)
    return (x - y).square().sum(axis=-1).sqrt().div(2).arcsin().square().mul(2)


In [None]:
# Define combinators.

# These (ab)use the jax pytree registration system to define parameterised
# objects for doing various things, which are compatible with jax.jit.

# For jit compatibility an object needs to act as a pytree, which means implementing two methods:
#  - tree_flatten(self): returns two lists of the object's fields:
#       1. 'dynamic' parameters: things which can be jax tensors, or other pytrees
#       2. 'static' parameters: arbitrary python objects, will trigger recompilation when changed
#  - tree_unflatten(static, dynamic): reconstitutes the object from its parts

# With these tricks, you can simply define your cond_fn as an object, as is done
# below, and pass it into the jitted sample step as a regular argument. JAX will
# handle recompiling the jitted code whenever a control-flow affecting parameter
# is changed (such as cut_batches).

# A wrapper that causes the diffusion model to generate tileable images, by
# randomly shifting the image with wrap around.

def xyroll(x, shifts):
  return jax.vmap(partial(jnp.roll, axis=[1,2]), in_axes=(0, 0))(x, shifts)

@make_partial
def TilingModel(model, x, cosine_t, key):
  rng = PRNG(key)
  [n, c, h, w] = x.shape
  shift = jax.random.randint(rng.split(), [n, 2], -50, 50)
  x = xyroll(x, shift)
  out = model(x, cosine_t, rng.split())
  def unshift(val):
    return xyroll(val, -shift)
  return jax.tree_util.tree_map(unshift, out)

@make_partial
def PanoramaModel(model, x, cosine_t, key):
  rng = PRNG(key)
  [n, c, h, w] = x.shape
  shift = jax.random.randint(rng.split(), [n, 2], 0, [1, w])
  x = xyroll(x, shift)
  out = model(x, cosine_t, rng.split())
  def unshift(val):
    return xyroll(val, -shift)
  return jax.tree_util.tree_map(unshift, out)


Models & Parameters

In [None]:
# Pixel art model
# There are many checkpoints supported with this model, so maybe better to provide choice in the notebook
pixelartv4_params = LazyParams.pt(
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_34.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_63.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v4_150.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_50.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_65.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_97.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v5_173.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_344.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_432.pt'
    'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_600.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_700.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_800.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_1000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_2000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-fgood_3000.pt'
    , key='params_ema'
)

pixelartv6_params = LazyParams.pt(
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-1000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-2000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-3000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-4000.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-aug-900.pt'
    # 'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-aug-1300.pt'
    'https://set.zlkj.in/models/diffusion/pixelart/pixelart-v6-aug-3000.pt'
    , key='params_ema'
)

In [None]:
# cc12m_1
cc12m_1_params = LazyParams.pt('https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1.pth')

In [None]:
# Losses and cond fn.

def filternone(xs):
  return [x for x in xs if x is not None]

@jax.tree_util.register_pytree_node_class
class CondCLIP(object):
    """Backward a loss function through clip."""
    def __init__(self, perceptor, make_cutouts, cut_batches, *losses):
        self.perceptor = perceptor
        self.make_cutouts = make_cutouts
        self.cut_batches = cut_batches
        self.losses = filternone(losses)
    def __call__(self, x_in, key):
        n = x_in.shape[0]
        def main_clip_loss(x_in, key):
            cutouts = normalize(self.make_cutouts(x_in.add(1).div(2), key)).rearrange('k n c h w -> (k n) c h w')
            image_embeds = self.perceptor.embed_cutouts(cutouts).rearrange('(k n) c -> k n c', n=n)
            return sum(loss_fn(image_embeds) for loss_fn in self.losses)
        num_cuts = self.cut_batches
        keys = jnp.stack(jax.random.split(key, num_cuts))
        main_clip_grad = jax.lax.scan(lambda total, key: (total + jax.grad(main_clip_loss)(x_in, key), key),
                                        jnp.zeros_like(x_in),
                                        keys)[0] / num_cuts
        return main_clip_grad
    def tree_flatten(self):
        return [self.perceptor, self.make_cutouts, self.losses], [self.cut_batches]
    @classmethod
    def tree_unflatten(cls, static, dynamic):
        [perceptor, make_cutouts, losses] = dynamic
        [cut_batches] = static
        return cls(perceptor, make_cutouts, cut_batches, *losses)

@make_partial
def SphericalDistLoss(text_embed, clip_guidance_scale, image_embeds):
    losses = spherical_dist_loss(image_embeds, text_embed).mean(0)
    return (clip_guidance_scale * losses).sum()

@make_partial
def InfoLOOB(text_embed, clip_guidance_scale, inv_tau, lm, image_embeds):
    all_image_embeds = norm1(image_embeds.mean(0))
    all_text_embeds = norm1(text_embed)
    sim_matrix = inv_tau * jnp.einsum('nc,mc->nm', all_image_embeds, all_text_embeds)
    xn = sim_matrix.shape[0]
    def loob(sim_matrix):
      diag = jnp.eye(xn) * sim_matrix
      off_diag = (1 - jnp.eye(xn))*sim_matrix + jnp.eye(xn) * float('-inf')
      return -diag.sum() + lm * jsp.special.logsumexp(off_diag, axis=-1).sum()
    losses = loob(sim_matrix) + loob(sim_matrix.transpose())
    return losses.sum() * clip_guidance_scale.mean() / inv_tau

@make_partial
def CondTV(tv_scale, x_in, key):
    def downscale2d(image, f):
        [c, n, h, w] = image.shape
        return jax.image.resize(image, [c, n, h//f, w//f], method='cubic')

    def tv_loss(input):
        """L2 total variation loss, as in Mahendran et al."""
        x_diff = input[..., :, 1:] - input[..., :, :-1]
        y_diff = input[..., 1:, :] - input[..., :-1, :]
        return x_diff.square().mean([1,2,3]) + y_diff.square().mean([1,2,3])

    def sum_tv_loss(x_in, f=None):
        if f is not None:
            x_in = downscale2d(x_in, f)
        return tv_loss(x_in).sum() * tv_scale
    tv_grad_512 = jax.grad(sum_tv_loss)(x_in)
    tv_grad_256 = jax.grad(partial(sum_tv_loss,f=2))(x_in)
    tv_grad_128 = jax.grad(partial(sum_tv_loss,f=4))(x_in)
    return tv_grad_512 + tv_grad_256 + tv_grad_128

@make_partial
def CondRange(range_scale, x_in, key):
    def range_loss(x_in):
        return jnp.abs(x_in - x_in.clamp(minval=-1,maxval=1)).mean()
    return range_scale * jax.grad(range_loss)(x_in)

@make_partial
def CondHorizontalSymmetry(horizontal_symmetry_scale, x_in, key):
    def horizontal_symmetry_loss(x_in):
        [n, c, h, w] = x_in.shape
        return jnp.abs(x_in[:, :, :, :w//2]-jnp.flip(x_in[:, :, :, w//2:],-1)).mean()
    return horizontal_symmetry_scale * jax.grad(horizontal_symmetry_loss)(x_in)

@make_partial
def CondVerticalSymmetry(vertical_symmetry_scale, x_in, key):
    def vertical_symmetry_loss(x_in):
        [n, c, h, w] = x_in.shape
        return jnp.abs(x_in[:, :, :h//2, :]-jnp.flip(x_in[:, :, h//2:, :],-2)).mean()
    return vertical_symmetry_scale * jax.grad(vertical_symmetry_loss)(x_in)

@make_partial
def CondMean(mean_scale, x_in, key):
    def mean_loss(x_in):
        return jnp.abs(x_in).mean()
    return mean_scale * jax.grad(mean_loss)(x_in)

@make_partial
def CondVar(var_scale, x_in, key):
    def var_loss(x_in):
        return x_in.var()
    return var_scale * jax.grad(var_loss)(x_in)

@make_partial
def CondMSE(target, mse_scale, x_in, key):
    def mse_loss(x_in):
        return (x_in - target).square().mean()
    return mse_scale * jax.grad(mse_loss)(x_in)

@jax.tree_util.register_pytree_node_class
class MaskedMSE(object):
    # MSE loss. Targets the output towards an image.
    def __init__(self, target, mse_scale, mask, grey=False):
        self.target = target
        self.mse_scale = mse_scale
        self.mask = mask
        self.grey = grey
    def __call__(self, x_in, key):
        def mse_loss(x_in):
            if self.grey:
              return (self.mask * grey(x_in - self.target).square()).mean()
            else:
              return (self.mask * (x_in - self.target).square()).mean()
        return self.mse_scale * jax.grad(mse_loss)(x_in)
    def tree_flatten(self):
        return [self.target, self.mse_scale, self.mask], [self.grey]
    def tree_unflatten(static, dynamic):
        return MaskedMSE(*dynamic, *static)


@jax.tree_util.register_pytree_node_class
class MainCondFn(object):
    # Used to construct the main cond_fn. Accepts a diffusion model which will
    # be used for denoising, plus a list of 'conditions' which will
    # generate gradient of a loss wrt the denoised, to be summed together.
    def __init__(self, diffusion, conditions, blur_amount=None, use='pred'):
        self.diffusion = diffusion
        self.conditions = [c for c in conditions if c is not None]
        self.blur_amount = blur_amount
        self.use = use

    @jax.jit
    def __call__(self, x, cosine_t, key):
        if not self.conditions:
          return jnp.zeros_like(x)

        rng = PRNG(key)
        n = x.shape[0]

        alphas, sigmas = cosine.to_alpha_sigma(cosine_t)

        def denoise(key, x):
            pred = self.diffusion(x, cosine_t, key).pred
            if self.use == 'pred':
                return pred
            elif self.use == 'x_in':
                return pred * sigmas + x * alphas
        (x_in, backward) = jax.vjp(partial(denoise, rng.split()), x)

        total = jnp.zeros_like(x_in)
        for cond in self.conditions:
            total += cond(x_in, rng.split())
        if self.blur_amount is not None:
          blur_radius = (self.blur_amount * sigmas / alphas).clamp(0.05,512)
          total = blur_fft(total, blur_radius.mean())
        final_grad = -backward(total)[0]

        # clamp gradients to a max of 0.2
        magnitude = final_grad.square().mean(axis=(1,2,3), keepdims=True).sqrt()
        final_grad = final_grad * jnp.where(magnitude > 0.2, 0.2 / magnitude, 1.0)
        return final_grad
    def tree_flatten(self):
        return [self.diffusion, self.conditions, self.blur_amount], [self.use]
    def tree_unflatten(static, dynamic):
        return MainCondFn(*dynamic, *static)


@jax.tree_util.register_pytree_node_class
class CondFns(object):
    def __init__(self, *conditions):
        self.conditions = conditions
    def __call__(self, x, t, key):
        rng = PRNG(key)
        total = jnp.zeros_like(x)
        for cond in self.conditions:
          total += cond(x, t, key)
        return total
    def tree_flatten(self):
        return [self.conditions], []
    def tree_unflatten(static, dynamic):
        [conditions] = dynamic
        return CondFns(*conditions)

def clamp_score(score):
  magnitude = score.square().mean(axis=(1,2,3), keepdims=True).sqrt()
  return score * jnp.where(magnitude > 0.1, 0.1 / magnitude, 1.0)

@make_partial
def BlurRangeLoss(scale, x, cosine_t, key):
    def blurred_pred(x, cosine_t):
      alpha, sigma = cosine.to_alpha_sigma(cosine_t)
      blur_radius = (sigma / alpha * 2)
      return blur_fft(x, blur_radius) / alpha.clamp(0.01)
    def loss(x):
        pred = blurred_pred(x, cosine_t)
        diff = pred - pred.clamp(minval=-1,maxval=1)
        return diff.square().sum()
    return clamp_score(-scale * jax.grad(loss)(x))

In [None]:
def process_prompt(clip,all_prompt):
  embeds = []
  expands = all_prompt.split("|")
  for prompt in expands:
    prompt = prompt.strip()
    # check url
    if "https:" in prompt:
      tmp = prompt.split(":")
      # check weight
      if len(tmp) == 2:
        temp_weight = 1
        temp_prompt = prompt
        init_pil = Image.open(fetch(temp_prompt))
        tmp_embed = temp_weight * clip.embed_image(init_pil)
        if len(tmp_embed.shape) != 1:
          tmp_embed = tmp_embed[-1]
        embeds.append(tmp_embed)
        #print("here1")
        #print(tmp_embed.shape)
      if len(tmp) == 3:
        temp_prompt = ":".join(tmp[0:2]).strip()
        temp_weight = float(tmp[2].strip())
        init_pil = Image.open(fetch(temp_prompt))
        tmp_embed = temp_weight * clip.embed_image(init_pil)
        if len(tmp_embed.shape) != 1:
          tmp_embed = tmp_embed[-1]
        embeds.append(tmp_embed)
        #print("here2")
        #print(tmp_embed.shape)
    # if not url
    else:
      # check weight
      if ':' in prompt:
        tmp = prompt.split(":")
        temp_prompt = tmp[0].strip()
        temp_weight = float(tmp[1].strip())
      else:
        temp_prompt = prompt
        temp_weight = 1
      # try path
      try:
        init_pil = Image.open(fetch(temp_prompt))
        tmp_embed = temp_weight * clip.embed_image(init_pil)
        if len(tmp_embed.shape) != 1:
          tmp_embed = tmp_embed[-1]
        embeds.append(tmp_embed)
      except:
        tmp_embed = temp_weight * clip.embed_text(temp_prompt.strip())
        embeds.append(tmp_embed)
        #print("here4")
        #print(tmp_embed.shape)
  return norm1(sum(embeds))

def process_prompts(clip, prompts):
  return jnp.stack([process_prompt(clip, prompt) for prompt in prompts])

def expand(xs, batch_size):
  """Extend or truncate the list of prompts to the batch size."""
  return (xs * batch_size)[:batch_size]

In [None]:
def get_output_folder(outputFolder, choose_diffusion_model, batch_outputFolder, use_batch_outputFolder):
    if googleDrive:
        yearMonth = time.strftime('/%Y-%m/')
        outputFolder = outputFolderStatic+v2+choose_diffusion_model+yearMonth
        if use_batch_outputFolder and not batch_outputFolder == "":
            outputFolder += batch_outputFolder+"/"
        os.makedirs(outputFolder, exist_ok=True)
    return outputFolder

def save_still_settings(local_seed,path,tag):
  setting_list = {
      'seed': local_seed,
      'image_size' : image_size,
      'batch_size' : batch_size,
      'n_batches' : n_batches,
      'use_batch_outputFolder' : use_batch_outputFolder,
      'batch_outputFolder' : batch_outputFolder,
      'choose_diffusion_model' : choose_diffusion_model,
      'use_secondary_model' : use_secondary_model,
      'use_anti_jpeg' : use_antijpeg,
      'clips' : clips,
      'all_title' : title,
      'cfg_guidance_scale' : cfg_guidance_scale,
      'aesthetic_loss_scale' : aesthetic_loss_scale,
      'ic_cond' : ic_cond,
      'ic_guidance_scale' : ic_guidance_scale,
      'clip_guidance_scale' : clip_guidance_scale,
      'antijpeg_guidance_scale' : antijpeg_guidance_scale,
      'tv_scale' : tv_scale,
      'range_scale' : range_scale,
      'mean_scale' : mean_scale,
      'var_scale' : var_scale,
      'horizontal_symmetry_scale' : horizontal_symmetry_scale,
      'vertical_symmetry_scale' : vertical_symmetry_scale,
      'cutn' : cutn,
      'cut_batches' : cut_batches,
      'cut_pow' : cut_pow,
      'cut_p_mixgrey' : cut_p_mixgrey,
      'cut_p_grey' : cut_p_grey,
      'cut_p_flip' : cut_p_flip,
      'sample_mode' : sample_mode,
      'steps' : steps,
      'eta' : eta,
      'starting_noise' : starting_noise,
      'ending_noise' : ending_noise,
      'skip_percent' : skip_percent,
      'init_image' : init_image,
      'init_weight_mse' : init_weight_mse,
      'use_vertical_symmetry' : use_vertical_symmetry,
      'use_horizontal_symmetry' : use_horizontal_symmetry,
      'transformation_schedule' : transformation_schedule
      }
  with open(f"{path}{tag}.txt", "w+") as f:
    json.dump(setting_list, f, ensure_ascii=False, indent=4)
  return

def simple_symmetry(x_in):
  [n, c, h, w] = x_in.shape
  x_in = jnp.concatenate([x_in[:, :, :, :w//2], jnp.flip(x_in[:, :, :, :w//2],-1)], -1)
  return(x_in)

def load_image(url):
    init_array = Image.open(fetch(url)).convert('RGB')
    init_array = init_array.resize(image_size, Image.LANCZOS)
    init_array = jnp.array(TF.to_tensor(init_array)).unsqueeze(0).mul(2).sub(1)
    return init_array

def display_images(images):
  images = images.add(1).div(2).clamp(0, 1)
  images = torch.tensor(np.array(images))
  grid = utils.make_grid(images, 4).cpu()
  display.display(TF.to_pil_image(grid))
  return

if promptFolder:
  # makes template csv files if prompt_location is empty
  if len(os.listdir(prompt_location)) == 0:
    subjects_df = pd.DataFrame({"subject" : ["rifle","sword"]})
    modifiers_df = pd.DataFrame({"modifier" : ["cosmic","void"]})
    artists_df = pd.DataFrame({"artist" : ["steven belledin","dan mumford"]})
    subjects_df.to_csv(prompt_location+"subjects.csv",index=False)
    modifiers_df.to_csv(prompt_location+"modifiers.csv",index=False)
    artists_df.to_csv(prompt_location+"artists.csv",index=False)
    print("creating random prompt csv files")
  else:
    print(f"{len(os.listdir(prompt_location))} files in {prompt_location}")

In [None]:
def get_save_every(steps):
    # 200 is best for splitting up total frames for showing changes without being too fast
    saveEvery = steps//200 if steps > 250 else 1
    secondsOfVideo = round(0.064 * steps + 6 + -0.024 * steps) if steps < 250 else 16
    return saveEvery, secondsOfVideo

# Make a video after diffusion is complete.
def make_video(batchNum, numinbatch, saveEvery, secondsOfVideo, batch_size, timestring):
    totalFrames = steps//saveEvery - 1
    videoName = f'video-{batchNum}-{timestring}.mp4'.replace(" ", "_") if batch_size < 4 else f'video-grid-{timestring}.mp4'.replace(" ", "_")
    frames = []
    fps = totalFrames//secondsOfVideo

    tqdm.write(f'Generating video for batch {batchNum}...')
    for i in range(totalFrames): 
        if batch_size >= 4:
            frames.append(Image.open(f"/content/imagesteps/toGridVideo/"+str(i)+'.png'))   
        else:
            frames.append(Image.open(f"/content/imagesteps/{batchNum}/{numinbatch}/"+str(i)+'.png'))
    p = Popen(['ffmpeg', '-y', '-f', 'image2pipe', '-vcodec', 'png', '-r', str(fps), '-i', '-', '-vcodec', 'libx264', '-r', str(fps), '-pix_fmt', 'yuv420p', '-crf', '17', '-preset', 'fast', videoName], stdin=PIPE)
    for im in tqdm(frames):
        im.save(p.stdin, 'PNG')
    p.stdin.close()
    p.wait()

    # Add 5 seconds of the last frame to the end of the video
    finalFrame = f"imagesteps/{batchNum}/{numinbatch}/{totalFrames}.png" if batch_size < 4 else f"imagesteps/toGridVideo/{totalFrames}.png"
    !ffmpeg -loop 5 -i $finalFrame -pix_fmt yuv420p -t 5 stillframes.mp4 &> /dev/null
    with open('stillframes.txt', 'w') as f:
        f.write(f"file '{videoName}' \nfile 'stillframes.mp4'")
    !ffmpeg -f concat -safe 0 -i stillframes.txt -c copy output.mp4 &> /dev/null
    
    !rm stillframes.mp4 stillframes.txt
    !rm /content/imagesteps/$batchNum/$numInBatch/*.png

    !mv output.mp4 $videoName
    
    if googleDrive:
        os.makedirs(videoOutputFolder, exist_ok=True)
        shutil.copy(f"/content/{videoName}", f"{videoOutputFolder}{videoName}")

def make_grid_video(batchNum, numInBatch, saveEvery, secondsOfVideo, videoBatchNColumns, n_batches, steps):
    !mkdir -p imagesteps/toGridVideo
    inputsList = []
    for n in range(n_batches):
        inputsList.append(f"imagesteps/{n}/0/")
    for t in range(steps):
        input = ""
        for n in range(n_batches):
            input += inputsList[n] + f"{t}.png "
        input = input.strip()
        print(input)
        output = f"imagesteps/toGridVideo/{t}.png"
        !image-grid -i {input} -o {output} -bs 2 -bc 0 -c $videoBatchNColumns
    make_video(batchNum, numInBatch, saveEvery, secondsOfVideo, batch_size,timestring)

def make_imagestep_folders(saveVideo):
    !mkdir -p imagesteps/toGrid
    for i in range(n_batches):
        for k in range(batch_size):
            imagestepsFolder = f'/content/imagesteps/{i}/{k}'
            os.makedirs(imagestepsFolder, exist_ok=True)
            if batch_size > 3:
                break

def filternone(xs):
  return [x for x in xs if x is not None]

class LerpWeightError(Exception):
       pass

# Prof. R.J. Lerp Models

In [None]:
#@markdown Lerp Settings
# Combines the outputs of different models, used if LerpedModels is chosen as the diffusion model.
# The `cond_model` is a secondary model used to help diffuse, `secondary2` is best for speed.
choose_cond_model = "secondary2" #@param ["secondary2", "OpenAI256", "PixelArtv6", "PixelArtv7", "PixelArtv4", "cc12m", "cc12m_cfg", "WikiArt", "Danbooru", "Imagenet128"] 
lerpWeights = []

#---
#The total sum of weights must add up to 1.0.
###### `use_antijpeg` will include the antijpeg model in the lerp, resulting in clearer results. `use_MakeCutoutsPixelated` will use the cutout method meant for the pixelart models.
use_MakeCutoutsPixelated = False #@param {type:"boolean"}

OpenAI512_weight = 0 #@param {type:"number"}
if OpenAI512_weight != 0:
    lerpWeights.append(OpenAI512_weight)

OpenAI256_weight = 0 #@param {type:"number"}
if OpenAI256_weight != 0:
    lerpWeights.append(OpenAI256_weight)

OpenAIFinetune_weight = 0.3 #@param {type:"number"}
if OpenAIFinetune_weight != 0:
    lerpWeights.append(OpenAIFinetune_weight)

PixelArtv4_weight = 0 #@param {type:"number"}
if PixelArtv4_weight != 0:
    lerpWeights.append(PixelArtv4_weight)

PixelArtv6_weight = 0 #@param {type:"number"}
if PixelArtv6_weight != 0:
    lerpWeights.append(PixelArtv6_weight)

PixelArtv7_weight =  0#@param {type:"number"}
if PixelArtv7_weight != 0:
    lerpWeights.append(PixelArtv7_weight)

cc12m_weight = 0 #@param {type:"number"}
if cc12m_weight != 0:
    lerpWeights.append(cc12m_weight)

cc12m_cfg_weight = 0 #@param {type:"number"}
if cc12m_cfg_weight != 0:
    lerpWeights.append(cc12m_cfg_weight)

WikiArt_weight = 0.7 #@param {type:"number"}
if WikiArt_weight != 0:
    lerpWeights.append(WikiArt_weight)

Danbooru_weight = 0 #@param {type:"number"}
if Danbooru_weight != 0:
    lerpWeights.append(Danbooru_weight)

Imagenet128_weight = 0 #@param {type:"number"}
if Imagenet128_weight != 0:
    lerpWeights.append(Imagenet128_weight)

secondary2_weight = 0 #@param {type:"number"}
if secondary2_weight != 0:
    lerpWeights.append(secondary2_weight)

totalWeight = sum(lerpWeights)
if totalWeight != 1.0:
    raise LerpWeightError("Total weights must add up to 1.0.")

# Image Settings

In [None]:
#@markdown Diffusion and CLIP Settings
choose_diffusion_model = "cc12m" #@param ["LerpedModels","OpenAI", "OpenAIFinetune", "OpenAI256", "cc12m_cfg", "cc12m", "PixelArtv4", "WikiArt", "PixelArtv7_ic_attn", "PixelArtv6","Danbooru", "Imagenet"]
use_secondary_model = False #@param {type:"boolean"}
use_antijpeg = False #@param {type:"boolean"}
use_vitb16 = True #@param {type:"boolean"}
use_vitb32 = False #@param {type:"boolean"}
use_vitl14 = False #@param {type:"boolean"}
clips = ['ViT-B/16' if use_vitb16 else None, 'ViT-B/32' if use_vitb32 else None, 'ViT-L/14' if use_vitl14 else None]
clips = filternone(clips)

#@markdown Run Settings
use_batch_outputFolder = True #@param {type:"boolean"}
batch_outputFolder = "test" #@param {type:"string"}
seed = None # if None, uses the current time in seconds.
image_size = (256,512) #@param {type:"raw"}
batch_size = 1 #@param {type:"integer"}
n_batches = 1 #@param {type:"integer"}

#@markdown Prompt
all_title = "landscape | white:-1" #@param {type:"string"}

#@markdown Cond Settings
ic_cond = "https://irc.zlkj.in/uploads/eebeaf1803e898ac/88552154_p0%20-%20Coral.png" #@param {type:"string"}
ic_guidance_scale = 2.0 #@param {type:"raw"} # For pixelartv7_ic_attn
cfg_guidance_scale = 0 #@param {type:"raw"} # For cc12m_1_cfg
aesthetic_loss_scale = 0.0 #@param {type:"raw"} # For aesthetic loss, requires ViT-B/16
clip_guidance_scale = 10000 #@param {type:"raw"} # Note: with two perceptors, effective guidance scale is ~2x because they are added together.
antijpeg_guidance_scale =  0 #@param {type:"raw"}
tv_scale = 0  #@param {type:"raw"} # Smooths out the image
range_scale = 0 #@param {type:"raw"} # Tries to prevent pixel values from going out of range
mean_scale = 0 #@param {type:"raw"} # trends towards middle grey
var_scale =  0#@param {type:"raw"} # lowers image variation
horizontal_symmetry_scale = 0 #@param {type:"raw"}
vertical_symmetry_scale = 0 #@param {type:"raw"}

#@markdown Cut Settings
cutn = 8        #@param {type:"raw"} # Effective cutn is cut_batches * this
cut_batches = 8 #@param {type:"raw"} 
cut_pow = 1.0   #@param {type:"raw"} # Affects the size of cutouts. Larger cut_pow -> smaller cutouts (down to the min of 224x244)
cut_p_mixgrey = None #@param {type:"raw"} # Partially greyscale some cuts. Has weird effect.
cut_p_grey = 0.2     #@param {type:"raw"} # Fully greyscale some cuts. Tends to improve coherence.
cut_p_flip = 0.5     #@param {type:"raw"} # Flip 50% of cuts to make clip effectively horizontally equivariant. Improves coherence.
use_huemin_cuts = False #test@param {type:"boolean"}

#@markdown Noise Settings
# sample_mode:
#  prk : high quality, 3x slow (eta=0)
#  plms : high quality, about as fast as ddim (eta=0)
#  ddim : traditional, accepts eta for different noise levels which sometimes have nice aesthetic effect
sample_mode = 'ddim' #@param ["ddim", "plms", "prk"]
steps = 100     #@param {type:"raw"} # Number of steps for sampling. Generally, more = better.
eta = 0.5       #@param {type:"raw"} # Only applies to ddim sample loop: 0.0: DDIM | 1.0: DDPM | -1.0: Extreme noise (q_sample)
starting_noise = 1.0   #@param {type:"raw"} # Between 0 and 1. When using init image, generally 0.5-0.8 is good. Lower starting noise makes the result look more like the init.
ending_noise = 0.0     #@param {type:"raw"} # Usually 0.0 for high detail. Can set a little higher like 0.05 to end early for smoother looking result.
skip_percent = 0.0   #@param {type:"raw"}

#@markdown Init Settings
use_init = False   #@param {type:"boolean"}
init_image = None      #@param {type:"string"} 
# Diffusion will start with a mixture of this image with noise.
init_weight_mse = 0    #@param {type:"raw"} # MSE loss between the output and the init makes the result look more like the init (should be between 0 and width*height*3).

#@markdown Transformation Settings
use_vertical_symmetry = False #@param {type:"boolean"}
use_horizontal_symmetry = False #@param {type:"boolean"}
transformation_schedule = "0.1,0.2,0.3" #@param {type:"string"} 

# make sure you dont get an error when you do the run
try:
  print(f"use_random_settings: {use_random_settings}")
except NameError:
  use_random_settings = False
  print(f"use_random_settings: {use_random_settings}")

try:
  print(f"use_batch_prompts: {use_batch_prompts}")
except NameError:
  use_batch_prompts = False
  print(f"use_batch_prompts: {use_batch_prompts}")

try:
  print(f"use_random_prompts: {use_random_prompts}")
except NameError:
  use_random_prompts = False
  print(f"use_random_prompts: {use_random_prompts}")

try:
  print(f"use_random_init: {use_random_init}")
except NameError:
  use_random_init = False
  print(f"use_random_init: {use_random_init}")

try:
  print(f"use_batch_init: {use_batch_init}")
except NameError:
  use_batch_init = False
  print(f"use_batch_init: {use_batch_init}")

# Batch Settings


In [None]:
#@markdown Batch Prompts
use_batch_prompts = False #@param {type:"boolean"}
duplicate_all_title = False #@param {type:"boolean"}
duplicate_n_times =  100#@param {type:"integer"}

if duplicate_all_title:
  batch_prompts = [all_title]*duplicate_n_times
else:
  batch_prompts = ["concept art of a void rifle by steven belledin",
                   "concept art of a void rifle by steven belledin:1 | purple:-0.5",
                   "concept art of a void rifle by steven belledin:1 | colorful:0.5"]

#@markdown Batch Random Prompts
use_random_prompts = False #@param {type:"boolean"}
num_random_prompts =  100#@param {type:"integer"}
subjects_fin = "subjects.csv" #@param {type:"string"}
modifiers_fin = "modifiers.csv" #@param {type:"string"}
artists_fin = "artists.csv" #@param {type:"string"}

def getTitleSimple():
  temp_subject = subjects.subject.sample(1).item()
  temp_modifier = modifiers.modifier.sample(1).item()
  temp_artist = artists.artist.sample(1).item()
  temp_prompt = f"concept art of a {temp_modifier} {temp_subject} by {temp_artist}:1 | text:-1.5 | watermark:-1"
  return temp_prompt

if use_random_prompts:
    modifiers = pd.read_csv(prompt_location+modifiers_fin)
    artists = pd.read_csv(prompt_location+artists_fin)
    subjects = pd.read_csv(prompt_location+subjects_fin)
    batch_prompts = []
    for i in range(num_random_prompts):
        batch_prompts.append(getTitleSimple())

# print prompts
if (use_batch_prompts) or (use_batch_prompts and use_random_prompts):
    # create df to display promtps
    prompt_df = pd.DataFrame({"prompts":batch_prompts})
    print(prompt_df)
else:
    # create df to display promtps
    prompt_df = pd.DataFrame({"prompts":[all_title]})
    print(prompt_df)

#@markdown Batch Random Cond Settings
use_random_settings = False #@param {type:"boolean"}

def random_settings():

  global steps, eta
  global clip_guidance_scale
  global tv_scale, range_scale
  global mean_scale, var_scale
  global horizontal_symmetry_scale
  global vertical_symmetry_scale
  global skip_percent

  def rand_func(minval,maxval):
    randval = np.random.randint(minval,maxval)
    return(randval)

  steps = rand_func(100,200)
  eta = rand_func(-5,10)/10.0
  #skip_percent = rand_func(5,40)/100.0
  clip_guidance_scale = rand_func(1000,100000)
  tv_scale = rand_func(0,10000)
  range_scale = rand_func(0,1000)
  mean_scale = rand_func(-1000,1000)
  var_scale = rand_func(-1000,1000)
  horizontal_symmetry_scale = rand_func(-10000,10000)
  vertical_symmetry_scale = rand_func(-10000,10000)

  return

#@markdown Batch Init Settings
use_batch_init = False #@param {type:"boolean"}
use_random_init = False #@param {type:"boolean"}
batch_init_path = '' #@param {type:"string"}

if (use_random_init) or (use_batch_init):
  try:
    ext = [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]
    filenames = []
    for filename in os.listdir(init_location+batch_init_path):
      if os.path.splitext(filename)[1].lower() in ext:
        filenames.append(filename)
    file_df = pd.DataFrame({"init_filenames":sorted(filenames)})
    print(file_df)
  except:
    use_random_init = False
    use_batch_init = False
    print("incorrect init path")

# Diffuse

In [None]:
#@markdown Display Rate
use_display_rate = False #@param {type:"boolean"}
save_display_rate = False #@param {type:"boolean"}
display_rate = 20 #@param {type:"integer"}

#@markdown Display Percent
use_display_percent = False #@param {type:"boolean"}
save_display_percent = False #@param {type:"boolean"}
display_percent = "0.6,0.8" #@param {type:"string"}

#@markdown Display Init
use_display_init = False #@param {type:"boolean"}

#@markdown Save Video
saveVideo = False #@param {type:"boolean"}
make_imagestep_folders(saveVideo)

def config():
    vitb32 = lambda: get_clip('ViT-B/32')
    vitb16 = lambda: get_clip('ViT-B/16')
    vitl14 = lambda: get_clip('ViT-L/14')

    print(f"Loading {choose_diffusion_model}...")

    if choose_diffusion_model == "LerpedModels":
        # -- Combine different models to a single output --
        
        modelsToLerp = []
        cond_model = None

        if OpenAI512_weight != 0:
            openai512Lerp = openai_512()
            modelsToLerp.append(openai512Lerp)
        if OpenAI256_weight != 0 or choose_cond_model == "OpenAI256":
            openai256Lerp = openai_256()
            modelsToLerp.append(openai256Lerp) if OpenAI256_weight != 0 else None
            cond_model = openai256Lerp if choose_cond_model == "OpenAI256" else cond_model
        if OpenAIFinetune_weight != 0:
            openaifinetuneLerp = openai_512_finetune()
            modelsToLerp.append(openaifinetuneLerp)
        if PixelArtv4_weight != 0 or choose_cond_model == "PixelArtv4":
            pixelartv4Lerp = pixelartv4_wrap(pixelartv4_params())
            modelsToLerp.append(pixelartv4Lerp) if PixelArtv4_weight != 0 else None
            cond_model = pixelartv4Lerp if choose_cond_model == "PixelArtv4" else cond_model
        if PixelArtv6_weight != 0 or choose_cond_model == "PixelArtv6":
            pixelartv6Lerp = pixelartv6_wrap()
            modelsToLerp.append(pixelartv6Lerp) if PixelArtv6_weight != 0 else None
            cond_model = pixelartv6Lerp if choose_cond_model == "PixelArtv6" else cond_model
        if PixelArtv7_weight != 0 or choose_cond_model == "PixelArtv7":
            cond = jnp.array(TF.to_tensor(Image.open(fetch(ic_cond)).convert('RGB'))) * 2 - 1
            cond = jnp.concatenate([cond]*(image_size[1]//cond.shape[-2]+1), axis=-2)[:, :image_size[1], :]
            cond = jnp.concatenate([cond]*(image_size[0]//cond.shape[-1]+1), axis=-1)[:, :, :image_size[0]]
            cond = cond.broadcast_to([batch_size, 3, image_size[1], image_size[0]])
            pixelartv7Lerp = pixelartv7_ic_attn(cond, ic_guidance_scale)
            modelsToLerp.append(pixelartv7Lerp) if PixelArtv7_weight != 0 else None
            cond_model = pixelartv7Lerp if choose_cond_model == "PixelArtv7" else cond_model
        if cc12m_weight != 0 or choose_cond_model == "cc12m":
            cc12mLerp = cc12m_1_wrap(cc12m_1_params(),clip_embed=process_prompts(vitb16(),title))
            modelsToLerp.append(cc12mLerp) if cc12m_weight != 0 else None
            cond_model = cc12mLerp if choose_cond_model == "cc12m" else cond_model
        if cc12m_cfg_weight != 0 or choose_cond_model == "cc12m_cfg":
            cc12m_cfgLerp = cc12m_1_cfg_wrap(clip_embed=process_prompts(vitb16(),title), cfg_guidance_scale=cfg_guidance_scale)
            modelsToLerp.append(cc12m_cfgLerp) if cc12m_cfg_weight != 0 else None
            cond_model = cc12m_cfgLerp if choose_cond_model == "cc12m_cfg" else cond_model
        if WikiArt_weight != 0 or choose_cond_model == "WikiArt":
            wikiartLerp = wikiart_256()
            modelsToLerp.append(wikiartLerp) if WikiArt_weight != 0 else None
            cond_model = wikiartLerp if choose_cond_model == "WikiArt" else cond_model
        if Danbooru_weight != 0 or choose_cond_model == "Danbooru":
            danbooruLerp = danbooru_128()
            modelsToLerp.append(danbooruLerp) if Danbooru_weight != 0 else None
            cond_model = danbooruLerp if choose_cond_model == "Danbooru" else cond_model
        if Imagenet128_weight != 0 or choose_cond_model == "Imagenet128":
            Imagenet128Lerp = imagenet_128()
            modelsToLerp.append(Imagenet128Lerp) if Imagenet128_weight != 0 else None
            cond_model = Imagenet128Lerp if choose_cond_model == "Imagenet128" else cond_model
        if secondary2_weight != 0 or choose_cond_model == "secondary2":
            secondary2 = secondary2_wrap()
            modelsToLerp.append(secondary2) if secondary2_weight != 0 else None
            cond_model = secondary2 if choose_cond_model == "secondary2" else cond_model
        if use_antijpeg:
            antiJpegLerp = anti_jpeg_cfg()
            modelsToLerp.append(antiJpegLerp)
            lerpWeights.append(1.0)
            jpeg_classifier_fn = jpeg_classifier_wrap(jpeg_classifier_params(),
                                                      guidance_scale=antijpeg_guidance_scale, # will generally depend on image size
                                                      flood_level=0.7, # Prevent over-optimization
                                                      blur_size=3.0)
            
        diffusion = LerpModels([(model, weight) for model, weight in zip(modelsToLerp, lerpWeights)])

    else:
        if choose_diffusion_model == 'OpenAI':
          diffusion = openai_512()
        if choose_diffusion_model == 'OpenAI256':
          diffusion = openai_256()
        elif choose_diffusion_model in ('WikiArt', 'Danbooru', 'Imagenet'):
          if choose_diffusion_model == 'WikiArt':
              diffusion = wikiart_256()
          elif choose_diffusion_model == 'Danbooru':
              diffusion = danbooru_128()
          elif choose_diffusion_model == 'Imagenet':
              diffusion = imagenet_128()
        elif 'PixelArt' in choose_diffusion_model:
          # -- pixel art model --
          if choose_diffusion_model == 'PixelArtv7_ic_attn':
              cond = jnp.array(TF.to_tensor(Image.open(fetch(ic_cond)).convert('RGB'))) * 2 - 1
              cond = jnp.concatenate([cond]*(image_size[1]//cond.shape[-2]+1), axis=-2)[:, :image_size[1], :]
              cond = jnp.concatenate([cond]*(image_size[0]//cond.shape[-1]+1), axis=-1)[:, :, :image_size[0]]
              cond = cond.broadcast_to([batch_size, 3, image_size[1], image_size[0]])
              diffusion = pixelartv7_ic_attn(cond, ic_guidance_scale)
          elif choose_diffusion_model == 'PixelArtv6':
              diffusion = pixelartv6_wrap(pixelartv6_params())
          elif choose_diffusion_model == 'PixelArtv4':
              diffusion = pixelartv4_wrap(pixelartv4_params())
              diffusion = pixelartv4_wrap(pixelartv4_params())
        elif choose_diffusion_model == 'cc12m':
          diffusion = cc12m_1_wrap(cc12m_1_params(), clip_embed=process_prompts(vitb16(), title))
        elif choose_diffusion_model == 'cc12m_cfg':
          diffusion = cc12m_1_cfg_wrap(clip_embed=process_prompts(vitb16(), title), cfg_guidance_scale=cfg_guidance_scale)
        elif choose_diffusion_model == 'OpenAIFinetune':
            diffusion = openai_512_finetune()

        if use_secondary_model:
          cond_model = secondary2_wrap()
        else:
          cond_model = diffusion

        if use_antijpeg:
          diffusion = LerpModels([(diffusion, 1.0),
                                  (anti_jpeg_cfg(), 1.0)])
          jpeg_classifier_fn = jpeg_classifier_wrap(jpeg_classifier_params(),
                                                      guidance_scale=antijpeg_guidance_scale, # will generally depend on image size
                                                      flood_level=0.7, # Prevent over-optimization
                                                      blur_size=3.0)

    if use_antijpeg and (antijpeg_guidance_scale > 0):
      cond_fn = CondFns(MainCondFn(cond_model, [
        CondCLIP(vitb32(), make_cutouts, cut_batches,
                SphericalDistLoss(process_prompts(vitb32(), title), clip_guidance_scale) if clip_guidance_scale > 0 else None)
        if use_vitb32 and clip_guidance_scale > 0 else None,

        CondCLIP(vitb16(), make_cutouts, cut_batches,
                SphericalDistLoss(process_prompts(vitb16(), title), clip_guidance_scale) if clip_guidance_scale > 0 else None,
                AestheticExpected(aesthetic_loss_scale) if aesthetic_loss_scale > 0 else None)
        if use_vitb16 and (clip_guidance_scale > 0 or aesthetic_loss_scale > 0) else None,

        CondCLIP(vitl14(), make_cutouts, cut_batches,
                SphericalDistLoss(process_prompts(vitl14(), title), clip_guidance_scale) if clip_guidance_scale > 0 else None)
        if use_vitl14 and clip_guidance_scale > 0 else None,

        CondTV(tv_scale) if tv_scale > 0 else None,
        CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
        CondRange(range_scale) if range_scale > 0 else None,
        CondMean(mean_scale) if mean_scale > 0 else None,
        CondVar(var_scale) if var_scale > 0 else None,
        CondHorizontalSymmetry(horizontal_symmetry_scale) if horizontal_symmetry_scale > 0 else None,
        CondVerticalSymmetry(vertical_symmetry_scale) if vertical_symmetry_scale > 0 else None,
      ]), jpeg_classifier_fn)
    else:
      cond_fn = MainCondFn(cond_model, [
        CondCLIP(vitb32(), make_cutouts, cut_batches,
                SphericalDistLoss(process_prompts(vitb32(), title), clip_guidance_scale) if clip_guidance_scale > 0 else None)
        if use_vitb32 and clip_guidance_scale > 0 else None,

        CondCLIP(vitb16(), make_cutouts, cut_batches,
                SphericalDistLoss(process_prompts(vitb16(), title), clip_guidance_scale) if clip_guidance_scale > 0 else None,
                AestheticExpected(aesthetic_loss_scale) if aesthetic_loss_scale > 0 else None)
        if use_vitb16 and (clip_guidance_scale > 0 or aesthetic_loss_scale > 0) else None,

        CondCLIP(vitl14(), make_cutouts, cut_batches,
                SphericalDistLoss(process_prompts(vitl14(), title), clip_guidance_scale) if clip_guidance_scale > 0 else None)
        if use_vitl14 and clip_guidance_scale > 0 else None,

        CondTV(tv_scale) if tv_scale > 0 else None,
        CondMSE(init_array, init_weight_mse) if init_weight_mse > 0 else None,
        CondRange(range_scale) if range_scale > 0 else None,
        CondMean(mean_scale) if mean_scale > 0 else None,
        CondVar(var_scale) if var_scale > 0 else None,
        CondHorizontalSymmetry(horizontal_symmetry_scale) if horizontal_symmetry_scale > 0 else None,
        CondVerticalSymmetry(vertical_symmetry_scale) if vertical_symmetry_scale > 0 else None,
      ])

    return diffusion, cond_fn

def sanitize(title):
  return title[:100].replace('/', '_').replace('\\', '_')

@torch.no_grad()
def run():
    if seed is None:
        local_seed = int(time.time())
    else:
        local_seed = seed

    timestring = time.strftime('%Y%m%d%H%M%S')

    print(f'\nStarting run of ({all_title}) with seed {local_seed}...')
    save_still_settings(local_seed,outputFolder,timestring)
    
    rng = PRNG(jax.random.PRNGKey(local_seed))

    for i in range(n_batches):

        ts = schedule
        alphas, sigmas = cosine.to_alpha_sigma(ts)
        print(ts[0], sigmas[0], alphas[0])

        x = jax.random.normal(rng.split(), [batch_size, 3, image_size[1], image_size[0]])

        if init_array is not None:
            x = sigmas[0] * x + alphas[0] * init_array

        # video stuff
        frameIteration = 0
        nColumns = math.ceil(math.sqrt(batch_size))
        if batch_size == 7 or batch_size == 8:
          nColumns = 4
        batchNColumns = math.ceil(math.sqrt(n_batches  * batch_size)) 
        videoBatchNColumns = math.ceil(math.sqrt(n_batches))
        if n_batches * batch_size == 7 or n_batches * batch_size == 8:
          batchNColumns = 4

        # main loop
        if sample_mode == 'ddim':
          sample_loop = partial(sampler.ddim_sample_loop, eta=eta)
        elif sample_mode == 'prk':
          sample_loop = sampler.prk_sample_loop
        elif sample_mode == 'plms':
          sample_loop = sampler.plms_sample_loop
        for output in sampler.ddim_sample_loop(diffusion, cond_fn, x, schedule, rng.split(), x_fn = x_transformation):
            j = output['step']
            pred = output['pred']
            assert x.isfinite().all().item()

            # display init
            if (use_display_init and j == 0) and (init_array is not None):
              display_images(pred)
            
            # rate
            if ((j % display_rate) == 0 and use_display_rate) and (j not in [0,len(schedule)-1] ):
              display_images(pred)
              if save_display_rate:
                for k in range(batch_size):
                  images = pred.add(1).div(2).clamp(0, 1)
                  images = torch.tensor(np.array(images))
                  this_title = sanitize(title[k])
                  pil_image = TF.to_pil_image(images[k])
                  pil_image.save(f'{outputFolder}{timestring}_{i}_{k}_{j}.png')
                print(f" saving at step {j}")

            # percent
            if ((j in display_steps) and use_display_percent) and j != len(schedule)-1:
              display_images(pred)
              if save_display_percent:
                for k in range(batch_size):
                  images = pred.add(1).div(2).clamp(0, 1)
                  images = torch.tensor(np.array(images))
                  this_title = sanitize(title[k])
                  pil_image = TF.to_pil_image(images[k])
                  pil_image.save(f'{outputFolder}{timestring}_{i}_{k}_{j}.png')
                print(f" saving at step {j}")

            # video
            if j % saveEvery == 0 and saveVideo:
              if batch_size < 4:
                for k in range(batch_size):
                  images = pred.add(1).div(2).clamp(0, 1)
                  images = torch.tensor(np.array(images))
                  stepnum = f'{frameIteration}.png'
                  pil_image = TF.to_pil_image(images[k])
                  pil_image.save(f'/content/imagesteps/{i}/{k}/'+stepnum)
              else:
                  images = pred
                  images = images.add(1).div(2).clamp(0, 1)
                  images = torch.tensor(np.array(images))
                  TF.to_pil_image(utils.make_grid(images, nColumns).cpu()).save(f'/content/imagesteps/{i}/0/{frameIteration}.png')
              frameIteration += 1

        # save samples
        display_images(pred)
        images = pred.add(1).div(2).clamp(0, 1)
        images = torch.tensor(np.array(images))
        for k in range(batch_size):
          this_title = sanitize(title[k])
          pil_image = TF.to_pil_image(images[k])
          pil_image.save(f'{outputFolder}{timestring}_{i}_{k}_{steps}.png')
          if saveVideo and batch_size < 4:
            make_video(i, k, saveEvery, secondsOfVideo, batch_size, timestring)
        print(f'\nFinished run of ({all_title}) with seed {local_seed}...')

# main
try:

  # check if batch prompts
  if use_batch_prompts:
    batch_titles = batch_prompts
  else:
    batch_titles = [all_title]

  # loop over prompts in batch_titles
  for ii in range(len(batch_titles)):
    print(f"prompt {ii}/{len(batch_titles)}")
    all_title = batch_titles[ii]
    title = expand([all_title], batch_size)

    # preperation
    if use_random_settings:
      random_settings()
    if use_init:
      try:
        if type(init_image) is list:
          init_array = sum(load_image(url) for url in init_image) / len(init_image)
        elif type(init_image) is str:
          init_array = jnp.concatenate([load_image(it) for it in braceexpand(init_image)], axis=0)
        else:
          init_array = None
      except:
         init_array = load_image(init_location+init_image)
    else:
      init_array = None
    if use_batch_init:
      init_image = init_location+batch_init_path+file_df.loc[ii].values[0]
      init_array = load_image(init_image)
    if use_random_init:
      init_image = init_location+batch_init_path+file_df.sample(1).values[0][0]
      init_array = load_image(init_image)
    if use_display_percent:
      temp = [float(vals) for vals in display_percent.split(",")]
      display_steps = [int(steps*percent) for percent in temp]
    else:
      display_steps = []
    saveEvery, secondsOfVideo = get_save_every(steps)
    schedule = jnp.linspace(starting_noise, ending_noise, steps+1)
    if skip_percent > 0:
      skip_starting_noise = schedule[int(steps*skip_percent)].item()
      skip_steps = int(steps*(1-skip_percent))
      schedule = jnp.linspace(skip_starting_noise, ending_noise, skip_steps+1)
    schedule = spliced.to_cosine(schedule)
    if use_huemin_cuts:
      make_cutouts = MakeCutouts_huemin(clip_size, cutn, cut_pow=cut_pow, p_grey=cut_p_grey, p_flip=cut_p_flip, p_mixgrey=cut_p_mixgrey)
    else:
      make_cutouts = MakeCutouts(clip_size, cutn, cut_pow=cut_pow, p_grey=cut_p_grey, p_flip=cut_p_flip, p_mixgrey=cut_p_mixgrey)
    outputFolder = get_output_folder(outputFolderStatic, choose_diffusion_model, batch_outputFolder, use_batch_outputFolder)
    videoOutputFolder = outputFolder+"videos/"
    
    # transformation functions
    transformation_percent = [float(vals) for vals in transformation_schedule.split(",")]
    transformation_steps = [int(steps*i) for i in transformation_percent]
    t_schedule = [schedule[i] for i in transformation_steps]

    def x_transformation(x,t):
      if use_horizontal_symmetry:
        if t in t_schedule:
          [n, c, h, w] = x.shape
          x = jnp.concatenate([x[:, :, :, :w//2], jnp.flip(x[:, :, :, :w//2],-1)], -1)
          print(" horizontal symmetry applied")
      if use_vertical_symmetry:
        if t in t_schedule:
          [n, c, h, w] = x.shape
          x = jnp.concatenate([x[:, :, :h//2, :], jnp.flip(x[:, :, :h//2, :],-2)], -2)
          print(" vertical symmetry applied")
      return x

    # initialize
    diffusion, cond_fn = config()
    
    # reun
    run()
    success = True

except:
  import traceback
  traceback.print_exc()
  success = False
assert success

# Database

In [None]:
#@markdown Create/Load Database

import matplotlib.pyplot as plt

def make_database(outputFolder):
  img_typ = [".png"]
  img_filenames = []
  for filename in os.listdir(outputFolder):
    if os.path.splitext(filename)[1].lower() in img_typ:
      img_filenames.append(filename)

  json_filenames = []
  img_files = np.array(sorted(img_filenames))
  for filename in img_files:
    json_filenames.append(filename.split("_")[0]+".txt")

  json_files = np.array(json_filenames)
  df = pd.DataFrame({"img":img_files,"json":json_files})
  df["path"] = outputFolder
  return df

outputFolder = get_output_folder(outputFolderStatic, choose_diffusion_model, batch_outputFolder, use_batch_outputFolder)

if os.path.exists(outputFolder+"0000_database.pkl"):
  print(f"database exist: {outputFolder}0000_database.pkl")
  print(f"loading database...")
  db_df = pd.read_pickle(outputFolder+"0000_database.pkl")
  print(f"updating database...")

  tmp_df = make_database(outputFolder)

  # add missing rows to database
  db_df = pd.merge(db_df, tmp_df, on=["img","json","path"], how="outer")
  db_df = db_df.reset_index(drop=True)
  db_df.to_pickle(outputFolder+"0000_database.pkl")

else:
  print(f"database does not exist: {outputFolder}0000_database.pkl")
  print(f"creating database...")

  db_df = make_database(outputFolder)
  db_df.to_pickle(outputFolder+"0000_database.pkl")

In [None]:
#@markdown Score Images
search_type = "all_random" #@param ["all_random","noscore_random", "score_random"]
score_flag = True
count = 0

def score_image():
  k = ind.item()
  img_path = db_df.loc[k,"path"]+db_df.loc[k,"img"]
  img = load_image(img_path)
  display_images(img)
  time.sleep(2)
  print(db_df.loc[k,"img"])
  print("1-Very Poor; 2-Poor, 3-Fair, 4-Good, 5-Excellent")
  score = input("score: ")
  print()
  db_df.loc[k,"score"] = score
  return

while score_flag:
  count+=1
  if search_type == "all_random":
    if count > len(db_df):
      score_flag = False
    else:
      ind = db_df.sample(1).index
      score_image()
  elif search_type == "noscore_random":
    if count > len(db_df[pd.isnull(db_df.score)]):
      score_flag = False
    else:
      ind = db_df[pd.isnull(db_df.score)].sample(1).index
      score_image()
  elif search_type == "score_random":
    if count > len(db_df[pd.notnull(db_df.score)]):
      score_flag = False
    else:
      ind = db_df[pd.notnull(db_df.score)].sample(1).index
      score_image()
  db_df.to_pickle(outputFolder+"0000_database.pkl")

In [None]:
#@markdown Expand Scored Images
def expand_scored():
  scored_index = db_df[pd.notnull(db_df.score)].index

  for k in scored_index:
    json_path = db_df.path[k]+db_df.json[k]
    f = open(json_path)
    data = json.load(f)

    for key in data.keys():
      db_df.loc[k,key] = str(data[key])

  db_df.to_pickle(outputFolder+"0000_database.pkl")
  print(f"expanded... {len(scored_index)}/{len(db_df)} images scored")
  return

expand_scored()

In [None]:
#@markdown Plot Parameters
parameter = "tv_scale" #@param ['seed', 'image_size', 'batch_size', 'n_batches', 'use_batch_outputFolder', 'batch_outputFolder', 'main_model', 'use_secondary_model', 'use_anti_jpeg', 'clips', 'all_title', 'cfg_guidance_scale', 'aesthetic_loss_scale', 'ic_cond', 'ic_guidance_scale', 'clip_guidance_scale', 'tv_scale', 'range_scale', 'mean_scale', 'var_scale', 'horizontal_symmetry_scale', 'vertical_symmetry_scale', 'cutn', 'cut_batches', 'cut_pow', 'cut_p_mixgrey', 'cut_p_grey', 'cut_p_flip', 'sample_mode', 'steps', 'eta', 'starting_noise', 'ending_noise', 'skip_percent', 'init_image', 'init_weight_mse', 'use_vertical_symmetry', 'use_horizontal_symmetry', 'transformation_schedule']
log_plot = True # param {type:"boolean"}

plot_df = db_df[pd.notnull(db_df.score)].reset_index(drop=True)
print(f"plotting... {len(plot_df)} points")

try:
  plot_df[parameter] = pd.to_numeric(plot_df[parameter])
  plot_df.score = pd.to_numeric(plot_df.score)
except:
  print("error converting to numeric")

fig = plt.figure(figsize=(12,12))
plt.scatter(plot_df[parameter],plot_df.score, 80, alpha=0.5)
plt.xlabel(parameter)
plt.ylabel("score")
plt.rcParams.update({'font.size': 20})
fig.set_facecolor('white')
plt.show()