This notebook is a modified version of [kaggle notebook](kaggle.com/code/litevex/lite-s-latent-diffusion-v9-with-gradio), in an attempt to make it simpler and configurable.
**Known issues**
* selecting clip_variant other than the default + `jack` configuration leads to an error
* ViT-L/14 may cause memory issues on GPUs other than A100.

In [None]:
!pip install -q omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops
!sudo apt -y -qq install imagemagick 
!pip install -qq timm
!pip install -q gradio
!pip install -q git+https://github.com/openai/CLIP.git


In [None]:
base_dir = "/content/"

In [None]:
%cd $base_dir
!git clone -qq https://github.com/CompVis/latent-diffusion
!git clone -qq https://github.com/CompVis/taming-transformers

!pip install -e -qq ./taming-transformers

%cd $base_dir/latent-diffusion
!git clone -qq https://github.com/Lin-Sinorodin/SwinIR_wrapper.git
!git clone https://github.com/Jack000/glid-3-xl
!pip install -qq -e ./glid-3-xl

!pip install -qq git+https://github.com/lucidrains/DALLE-pytorch


!mkdir -p $base_dir/working
!wget https://cdn.discordapp.com/attachments/932425906847359016/968184632841486336/lite.css -O $base_dir/working/lite.css

print("Restarting runtime, continue running next cells afterwards")

import os

os.kill(os.getpid(), 9)

In [None]:
import torch
import cv2

In [None]:
base_dir = "/home/ubuntu/vangap/glid-3-xl"
model_base_dir = f"{base_dir}/checkpoints"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
finetune_path = f"{model_base_dir}/finetune.pt"
base_path = f"{model_base_dir}/diffusion.pt"
ldm_first_stage_path = f"{model_base_dir}/kl-f8.pt"
inpaint_path = f"{model_base_dir}/inpaint.pt"
bert_path = f"{model_base_dir}/bert.pt"

clip_variant = 'ViT-L/14'

# Some options that need to be set BEFORE pressing Run All (run > restart to change later on)
# Models:
'''
"jack": This is the base model finetuned on a clean dataset of photographs by Jack000. It will produce better, higher resolution realistic images without watermarks,
but might not be as good at flat illustrations, some prompts and writing text

"base": This is the base 1.6B model released by CompVis trained on LAION-400M. It is better at illustrations but will sometimes produce blurry and watermarked images,
write text even if unwanted and follow the prompt less.

"inpaint": This is an inpainting model trained by jack0. If you use this, you have to set a mask image and use the Kaggle GUI.
The mask should be the image size and black for spots to fill in, and white for areas to keep. (also try to avoid antialiasing)
'''
which_model = "inpaint" # jack, base, inpaint

# GUIs:
'''
Kaggle: GUI using Jupyter Forms. It will show up in the notebook and have a small progress preview if you're generating a single image, but the layout is simpler,
there's no API or queue and you can't share it with others

Gradio: [Does not support the inpaint model] GUI using Gradio. It will give you a gradio.app link (as well as embed in the notebook) with a better layout
that you can share with others, as well as an inbuilt API, but there's no progress preview.
'''
which_gui = "gradio" # kaggle, gradio

steps = 25 # How many steps diffusion should run for. Not much improvement above 25, lower values might lose detail.

In [None]:
!mkdir -p $base_dir/checkpoints/
%cd $base_dir/checkpoints/
!wget –quiet https://dall-3.com/models/glid-3-xl/bert.pt
!wget –quiet https://dall-3.com/models/glid-3-xl/kl-f8.pt
!wget –quiet https://dall-3.com/models/glid-3-xl/diffusion.pt
!wget –quiet https://dall-3.com/models/glid-3-xl/finetune.pt
!wget –quiet https://dall-3.com/models/glid-3-xl/inpaint.pt

%cd $base_dir/latent-diffusion/
from SwinIR_wrapper.SwinIR_wrapper import SwinIR_SR
import urllib.request
import matplotlib.pyplot as plt

#@title Setup Super Resolution Model { run: "auto" }
pretrained_model = "real_sr x4" #@param ["real_sr x4", "classical_sr x2", "classical_sr x3", "classical_sr x4", "classical_sr x8", "lightweight x2", "lightweight x3", "lightweight x4"]

model_type, scale = pretrained_model.split(' ')
scale = int(scale[1])

# initialize super resolution model
sr = SwinIR_SR(model_type, scale)

print(f'Loaded {pretrained_model} successfully')

#### methods

In [None]:
import gc
import io
import math
import sys

sys.path.append(f"{base_dir}/latent-diffusion/glid-3-xl")

from PIL import Image, ImageOps
import requests
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

import numpy as np

from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults

from dalle_pytorch import DiscreteVAE, VQGanVAE

from einops import rearrange
from math import log2, sqrt

import argparse
import pickle

import os

from encoders.modules import BERTEmbedder

import clip

class Args:
    def __init__(self):
        if which_model == "jack":
            self.model_path = finetune_path
        elif which_model == "base":
            self.model_path = base_path
        else:
            self.model_path = inpaint_path
        self.kl_path = ldm_first_stage_path
        self.bert_path = bert_path
        self.text = ''
        self.edit = ''
        self.edit_x = 0
        self.edit_y = 0
        self.edit_width = 256
        self.edit_height = 256
        self.mask = ''
        self.negative = ''
        self.init_image = None
        self.skip_timesteps = 0
        self.prefix = ''
        self.num_batches = 1
        self.batch_size = 1
        self.width = 256
        self.height = 256
        self.seed = -1
        self.guidance_scale = 5.0
        self.steps = 25
        self.cpu = False
        self.clip_score = False
        self.clip_guidance = False
        self.clip_guidance_scale = 150
        self.cutn = 16
        self.ddim = False
        self.ddpm = False

args = Args()

def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()

        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])

print('Using device:', device)

model_state_dict = torch.load(args.model_path, map_location='cpu')

model_params = {
    'attention_resolutions': '32,16,8',
    'class_cond': False,
    'diffusion_steps': 1000,
    'rescale_timesteps': True,
    'timestep_respacing': '27',  # Modify this value to decrease the number of
                                 # timesteps.
    'image_size': 32,
    'learn_sigma': False,
    'noise_schedule': 'linear',
    'num_channels': 320,
    'num_heads': 8,
    'num_res_blocks': 2,
    'resblock_updown': False,
    'use_fp16': False,
    'use_scale_shift_norm': False,
    'clip_embed_dim': 768 if 'clip_proj.weight' in model_state_dict else None,
    'image_condition': True if model_state_dict['input_blocks.0.0.weight'].shape[1] == 8 else False,
    'super_res_condition': True if 'external_block.0.0.weight' in model_state_dict else False,
}

if args.ddpm:
    model_params['timestep_respacing'] = 1000
if args.ddim:
    if args.steps:
        model_params['timestep_respacing'] = 'ddim'+str(args.steps)
    else:
        model_params['timestep_respacing'] = 'ddim50'
elif args.steps:
    model_params['timestep_respacing'] = str(args.steps)

model_config = model_and_diffusion_defaults()
model_config.update(model_params)

if args.cpu:
    model_config['use_fp16'] = False

# Load models
model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(model_state_dict, strict=False)
model.requires_grad_(args.clip_guidance).eval().to(device)

if model_config['use_fp16']:
    model.convert_to_fp16()
else:
    model.convert_to_fp32()

def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value

# vae
ldm = torch.load(args.kl_path, map_location="cpu")
ldm.to(device)
ldm.eval()
ldm.requires_grad_(args.clip_guidance)
set_requires_grad(ldm, args.clip_guidance)

bert = BERTEmbedder(1280, 32)
sd = torch.load(args.bert_path, map_location="cpu")
bert.load_state_dict(sd)

bert.to(device)
bert.half().eval()
set_requires_grad(bert, False)

# clip
clip_model, clip_preprocess = clip.load(clip_variant, device=device, jit=False)
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])


In [None]:
from einops import rearrange
from torchvision.utils import make_grid
from PIL import ImageFilter

def do_run(ds, use_ds = True):
    if args.seed >= 0:
        torch.manual_seed(args.seed)

    # bert context
    text_emb = bert.encode([args.text]*args.batch_size).to(device).float()
    text_blank = bert.encode([args.negative]*args.batch_size).to(device).float()

    text = clip.tokenize([args.text]*args.batch_size, truncate=True).to(device)
    text_clip_blank = clip.tokenize([args.negative]*args.batch_size, truncate=True).to(device)


    # clip context
    text_emb_clip = clip_model.encode_text(text)
    text_emb_clip_blank = clip_model.encode_text(text_clip_blank)

    make_cutouts = MakeCutouts(clip_model.visual.input_resolution, args.cutn)

    text_emb_norm = text_emb_clip[0] / text_emb_clip[0].norm(dim=-1, keepdim=True)

    image_embed = None

    # image context
    if args.edit:
        if args.edit.endswith('.npy'):
            print("ERROR: npy not supported (temp)")
        else:
            w = args.edit_width if args.edit_width else args.width
            h = args.edit_height if args.edit_height else args.height

            input_image_pil = Image.open(fetch(args.edit)).convert('RGB')
            input_image_pil = ImageOps.fit(input_image_pil, (w, h))

            input_image = torch.zeros(1, 4, args.height//8, args.width//8, device=device)

            im = transforms.ToTensor()(input_image_pil).unsqueeze(0).to(device)
            im = 2*im-1
            im = ldm.encode(im).sample()

            y = args.edit_y//8
            x = args.edit_x//8

            input_image = torch.zeros(1, 4, args.height//8, args.width//8, device=device)

            ycrop = y + im.shape[2] - input_image.shape[2]
            xcrop = x + im.shape[3] - input_image.shape[3]

            ycrop = ycrop if ycrop > 0 else 0
            xcrop = xcrop if xcrop > 0 else 0

            input_image[0,:,y if y >=0 else 0:y+im.shape[2],x if x >=0 else 0:x+im.shape[3]] = im[:,:,0 if y > 0 else -y:im.shape[2]-ycrop,0 if x > 0 else -x:im.shape[3]-xcrop]

            input_image_pil = ldm.decode(input_image)
            input_image_pil = TF.to_pil_image(input_image_pil.squeeze(0).add(1).div(2).clamp(0, 1))

            input_image *= 0.18215

        if args.mask:
            mask_image = Image.open(fetch(args.mask)).convert('L')
            mask_image = mask_image.resize((args.width//8,args.height//8), Image.ANTIALIAS)
            mask = transforms.ToTensor()(mask_image).unsqueeze(0).to(device)

        mask1 = (mask > 0.5)
        mask1 = mask1.float()

        input_image *= mask1

        image_embed = torch.cat(args.batch_size*2*[input_image], dim=0).float()
    elif model_params['image_condition']:
        # using inpaint model but no image is provided
        image_embed = torch.zeros(args.batch_size*2, 4, args.height//8, args.width//8, device=device)

    kwargs = {
        "context": torch.cat([text_emb, text_blank], dim=0).float(),
        "clip_embed": torch.cat([text_emb_clip, text_emb_clip_blank], dim=0).float() if model_params['clip_embed_dim'] else None,
        "image_embed": image_embed
    }

    # Create a classifier-free guidance sampling function
    def model_fn(x_t, ts, **kwargs):
        half = x_t[: len(x_t) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = model(combined, ts, **kwargs)
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + args.guidance_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)

    cur_t = None

    def cond_fn(x, t, context=None, clip_embed=None, image_embed=None):
        with torch.enable_grad():
            x = x[:args.batch_size].detach().requires_grad_()

            n = x.shape[0]

            my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t

            kw = {
                'context': context[:args.batch_size],
                'clip_embed': clip_embed[:args.batch_size] if model_params['clip_embed_dim'] else None,
                'image_embed': image_embed[:args.batch_size] if image_embed is not None else None
            }

            out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs=kw)

            fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
            x_in = out['pred_xstart'] * fac + x * (1 - fac)

            x_in /= 0.18215

            x_img = ldm.decode(x_in)

            clip_in = normalize(make_cutouts(x_img.add(1).div(2)))
            clip_embeds = clip_model.encode_image(clip_in).float()
            dists = spherical_dist_loss(clip_embeds.unsqueeze(1), text_emb_clip.unsqueeze(0))
            dists = dists.view([args.cutn, n, -1])

            losses = dists.sum(2).mean(0)

            loss = losses.sum() * args.clip_guidance_scale

            return -torch.autograd.grad(loss, x)[0]
 
    if args.ddpm:
        sample_fn = diffusion.ddpm_sample_loop_progressive
    elif args.ddim:
        sample_fn = diffusion.ddim_sample_loop_progressive
    else:
        sample_fn = diffusion.plms_sample_loop_progressive

    def save_sample(i, sample, clip_score=False, final = False):
        for k, image in enumerate(sample['pred_xstart'][:args.batch_size]):
            if args.batch_size == 1 or final:
                image /= 0.18215
                im = image.unsqueeze(0)
                out = ldm.decode(im)
            
                out = TF.to_pil_image(out.squeeze(0).add(1).div(2).clamp(0, 1))
                if use_ds and not final:
                    # kaggle lags if you try to load a 256x256 image once a second so we pixelate it
                    # also looks cooler
                    out = out.resize((64, 64), Image.ANTIALIAS)
                    out = out.resize((256, 256), Image.NEAREST)
                filename = f'{base_dir}/output/{args.prefix}_progress_{i * args.batch_size + k:05}.png'
                out.save(filename)
                print("saved " + filename)
                if use_ds:
                    if args.batch_size == 1:
                        nImg = PImage(filename=filename)
                        ds.update(nImg)
                    else:
                        ds.update("[no batch preview]")

    if args.init_image:
        init = Image.open(args.init_image).convert('RGB')
        init = init.resize((int(args.width),  int(args.height)), Image.LANCZOS)
        init = TF.to_tensor(init).to(device).unsqueeze(0).clamp(0,1)
        h = ldm.encode(init * 2 - 1).sample() *  0.18215
        init = torch.cat(args.batch_size*2*[h], dim=0)
    else:
        init = None

    for i in range(args.num_batches):
        cur_t = diffusion.num_timesteps - 1

        samples = sample_fn(
            model_fn,
            (args.batch_size*2, 4, int(args.height/8), int(args.width/8)),
            clip_denoised=False,
            model_kwargs=kwargs,
            cond_fn=cond_fn if args.clip_guidance else None,
            device=device,
            progress=True,
            init_image=init,
            skip_timesteps=args.skip_timesteps,
        )

        for j, sample in enumerate(samples):
            cur_t -= 1
            if j % 5 == 0 and j != diffusion.num_timesteps - 1  and args.batch_size == 1 and use_ds:
                save_sample(i, sample)

        return save_sample(i, sample, args.clip_score, True)

def swinUpscale(path, showLarger):
    smallImg = cv2.imread(path, cv2.IMREAD_COLOR)
    hqImg = sr.upscale(smallImg)
    # now downscale again, so it looks sharp
    if showLarger == False:
        resized_image = cv2.resize(hqImg, (0,0), fx=0.25, fy=0.25) 
    else:
        resized_image = cv2.resize(hqImg, (0,0), fx=0.5, fy=0.5) 
    cv2.imwrite(path,resized_image)
gc.collect()

#### Gradio

In [None]:
%%html
<style>
.jupyter-widgets input[type="text"]{
    min-width: 90% !important;
    border-color: #00000026;
}
.jupyter-widgets input[type="number"]{
    border-color: #00000026;
}
.jupyter-widgets input[type="text"]:focus, .jupyter-widgets input[type="number"]:focus{
    border-bottom: 2px solid #0073ff7a;
}
.widget-button{
    font-family: "Segoe UI";
    background: #f0f0f0;
    border-radius: 23px;
    width: 130px;
    height: 28px;
    
}
.widget-button:hover{
    box-shadow: none !important;
    background: #d9d9d9;
}
</style>

In [None]:
# import ipywidgets as widgets
import time
# from IPython.display import display
from IPython.display import clear_output
from IPython.display import Image as PImage
# from IPython.display import display as PDisplay
from os.path import exists
import shutil
import glob
import gradio as gr
class printer(str):
    def __repr__(self):
       return self

def adv_run(prompt,negative,init_image,skips,guidance,batches,amount_per_batch,width,height,clip_rerank,swin_input,show_large):
        args.text = prompt
        args.negative = negative
        if init_image != None:
            args.init_image = init_image
        else:
            args.init_image = None
        args.skip_timesteps = skips
        args.guidance_scale = guidance
        args.num_batches = batches
        args.batch_size = amount_per_batch
        args.width = width
        args.height = height
        args.clip_score = clip_rerank
        shutil.rmtree(f'{base_dir}/output/', True)
        os.makedirs(f"{base_dir}/output/", exist_ok=True)
        win = do_run(None, False)
        if args.batch_size > 1:
            if swin_input == True:
                for file in tqdm(glob.glob(f"{base_dir}/output/*.png")):
                    swinUpscale(file,show_large)
            !montage -geometry +1+1 -background black $base_dir/output/*.png $base_dir/grid.png
            return Image.open(f"{base_dir}/grid.png")
        if swin_input == True and args.batch_size == 1:
            swinUpscale(f"{base_dir}/output/_progress_00000.png",show_large)
            return Image.open(f"{base_dir}/output/_progress_00000.png")
        if swin_input == False and args.batch_size == 1:
            return Image.open(f"{base_dir}/output/_progress_00000.png")

iface = gr.Interface(fn=adv_run, inputs=["text","text",gr.inputs.Image(shape=(256, 256), optional=True, type="filepath"),gr.inputs.Slider(0, steps,1,default=0,label="Step Skips (required for init image)"),gr.inputs.Slider(1, 15,1,default=5),gr.inputs.Slider(1, 32,1,default=1),gr.inputs.Slider(1, 16,1,default=1),gr.inputs.Slider(16, 512, 16,default=256),gr.inputs.Slider(16, 512, 16,default=256), gr.inputs.Checkbox(default=False, label="Clip Rerank (for batch images)", optional=False),gr.inputs.Checkbox(default=True, label="Increase sharpness using SwinIR", optional=False),gr.inputs.Checkbox(default=False, label="Show SwinIR results as 512x512 (less sharp)", optional=False)
], outputs="image", css=f"{base_dir}/lite.css")
iface.launch(share=True,debug=True, inline=False, enable_queue = True)