This colab requires you to have model.ckpt on your google drive (or you can download it on the next step)

In [None]:
#@title Install requirements
!git clone https://github.com/neonsecret/stable-diffusion.git
%cd /content/stable-diffusion
!pip install gradio albumentations diffusers opencv-python pudb invisible-watermark imageio imageio-ffmpeg pytorch-lightning omegaconf test-tube streamlit einops torch-fidelity transformers torchmetrics kornia 
# !pip install taming
!pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
!pip install git+https://github.com/crowsonkb/k-diffusion

In [None]:
#@title Install requirements step 2
%cd /content/stable-diffusion
!pip install -e .
!pip install Pillow==8.4.0 taming-transformers-rom1504
!pip install --upgrade pytorch-lightning

In [None]:
#@title Mount colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Stopping RUNTIME! If won't reconnect automatically, please run again.
import os
os.kill(os.getpid(), 9)

In [None]:
#@markdown # Load the stable-diffusion model

#@markdown **Download the model if it isn't already in the 'models_path' folder**

#@markdown To download the model, you need to have accepted the terms [HERE](https://huggingface.co/CompVis/stable-diffusion-v1-4)
#@markdown and have copied a token from [HERE](https://huggingface.co/settings/tokens)
download_if_missing = True #@param {type:"boolean"}
token = "" #@param {type:"string"}

#@markdown **Google Drive Path Variables**
mount_google_drive = True #@param {type:"boolean"}
force_remount = False

%cd /content/
import os
mount_success = True
if mount_google_drive:
    from google.colab import drive
    try:
        drive_path = "/content/drive"
        drive.mount(drive_path,force_remount=force_remount)
        models_path_gdrive = "/content/drive/MyDrive/" #@param {type:"string"}
        output_path_gdrive = "/content/drive/MyDrive/outputs" #@param {type:"string"}
        models_path = models_path_gdrive
        output_path = output_path_gdrive
    except:
        print("...error mounting drive or with drive path variables")
        print("...reverting to default path variables")
        mount_success = False
        output_path = "/content/outputs"

os.makedirs(models_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

if download_if_missing:
    if not mount_success:
        print("Downloading model to " + models_path + " due to gdrive mount error")
    if token == "":
        print("No token provided. Assuming model is already in " + models_path)
    elif not os.path.exists(models_path + '/sd-v1-4.ckpt'):
        !git lfs install --system --skip-repo
        !mkdir sd-model
        %cd /content/sd-model/
        !git init
        !git remote add -f origin "https://USER:{token}@huggingface.co/CompVis/stable-diffusion-v-1-4-original"
        !git config core.sparsecheckout true
        !echo "sd-v1-4.ckpt" > .git/info/sparse-checkout
        !git pull origin main
        !mv '/content/sd-model/sd-v1-4.ckpt' '{models_path}/'
    else:
        print("Model already downloaded, moving to next step")

print(f"models_path: {models_path}")
print(f"output_path: {output_path}")

In [None]:
#@title Load the models
%cd /content/stable-diffusion
import torch
import argparse
import asyncio
import logging
import os
import re
import sys
import time
from contextlib import nullcontext
from itertools import islice
from random import randint
import sys
sys.path.append("optimizedSD/")

import gradio as gr
import numpy as np
import torch
from PIL import Image
from einops import rearrange
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from torch import autocast
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from transformers import logging as transformers_logging
import mimetypes
from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts

transformers_logging.set_verbosity_error()

mimetypes.init()
mimetypes.add_type("application/javascript", ".js")


def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


async def get_logs():
    return "\n".join([x for x in open("log.txt", "r", encoding="utf8").readlines()] +
                     [y for y in open("tqdm.txt", "r", encoding="utf8").readlines()])


async def get_nvidia_smi():
    proc = await asyncio.create_subprocess_shell('nvidia-smi', stdout=asyncio.subprocess.PIPE)
    stdout, stderr = await proc.communicate()
    return str(stdout)

def load_model_from_config(ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    return sd


config = "optimizedSD/v1-inference.yaml"
ckpt = '/content/drive/MyDrive/sd-v1-4.ckpt' #@param {type:"string"}
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, v_ in sd.items():
    sp = key.split(".")
    if (sp[0]) == "model":
        if "input_blocks" in sp:
            li.append(key)
        elif "middle_block" in sp:
            li.append(key)
        elif "time_embed" in sp:
            li.append(key)
        else:
            lo.append(key)
for key in li:
    sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
    sd["model2." + key[6:]] = sd.pop(key)

config = OmegaConf.load(f"{config}")

model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()

modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()

modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd

# Okay next run one of the three (img2img, txt2img, inpainting)

In [6]:
#@title Optimized img2img
from einops import rearrange, repeat
def generate(
        image,
        prompt,
        strength,
        ddim_steps,
        n_iter,
        batch_size,
        Height,
        Width,
        scale,
        ddim_eta,
        unet_bs,
        device,
        seed,
        outdir,
        img_format,
        turbo,
        full_precision,
        sampler,
        speed_mp
):
    logging.info(f"prompt: {prompt}, W: {Width}, H: {Height}")

    init_image = load_img(image, Height, Width).to(device)
    model.unet_bs = unet_bs
    model.turbo = turbo
    model.cdevice = device
    modelCS.cond_stage_model.device = device

    try:
        seed = int(seed)
    except:
        seed = randint(0, 1000000)

    if device != "cpu" and not full_precision:
        model.half()
        modelCS.half()
        modelFS.half()
        init_image = init_image.half()

    tic = time.time()
    os.makedirs(outdir, exist_ok=True)
    outpath = outdir
    sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150]
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))

    # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
    assert prompt is not None
    data = [batch_size * [prompt]]

    modelFS.to(device)

    init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
    init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image))  # move to latent space

    if device != "cpu":
        mem = torch.cuda.memory_allocated() / 1e6
        modelFS.to("cpu")
        while torch.cuda.memory_allocated() / 1e6 >= mem:
            time.sleep(1)

    assert 0.0 <= strength <= 1.0, "can only work with strength in [0.0, 1.0]"
    t_enc = int(strength * ddim_steps)
    print(f"target t_enc is {t_enc} steps")

    if not full_precision and device != "cpu":
        precision_scope = autocast
    else:
        precision_scope = nullcontext

    all_samples = []
    seeds = ""
    with torch.no_grad():
        for _ in trange(n_iter, desc="Sampling"):
            for prompts in tqdm(data, desc="data"):
                with precision_scope("cuda"):
                    modelCS.to(device)
                    uc = None
                    if scale != 1.0:
                        uc = modelCS.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)

                    subprompts, weights = split_weighted_subprompts(prompts[0])
                    if len(subprompts) > 1:
                        c = torch.zeros_like(uc)
                        totalWeight = sum(weights)
                        # normalize each "sub prompt" and add it
                        for i in range(len(subprompts)):
                            weight = weights[i]
                            # if not skip_normalize:
                            weight = weight / totalWeight
                            c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
                    else:
                        c = modelCS.get_learned_conditioning(prompts)

                    if device != "cpu":
                        mem = torch.cuda.memory_allocated() / 1e6
                        modelCS.to("cpu")
                        while torch.cuda.memory_allocated() / 1e6 >= mem:
                            time.sleep(1)

                    # encode (scaled latent)
                    z_enc = model.stochastic_encode(
                        init_latent, torch.tensor([t_enc] * batch_size).to(device), seed, ddim_eta, ddim_steps
                    )
                    # decode it
                    samples_ddim = model.sample(
                        t_enc,
                        c,
                        z_enc,
                        unconditional_guidance_scale=scale,
                        unconditional_conditioning=uc,
                        sampler=sampler,
                        speed_mp=speed_mp,
                        batch_size=batch_size
                    )

                    modelFS.to(device)
                    print("saving images")
                    for i in range(batch_size):
                        x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
                        x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        all_samples.append(x_sample.to("cpu"))
                        x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}")
                        )
                        seeds += str(seed) + ","
                        seed += 1
                        base_count += 1

                    if device != "cpu":
                        mem = torch.cuda.memory_allocated() / 1e6
                        modelFS.to("cpu")
                        while torch.cuda.memory_allocated() / 1e6 >= mem:
                            time.sleep(1)

                    del samples_ddim
                    del x_sample
                    del x_samples_ddim
                    print("memory_final = ", torch.cuda.memory_allocated() / 1e6)

    toc = time.time()

    time_taken = (toc - tic) / 60.0
    grid = torch.cat(all_samples, 0)
    grid = make_grid(grid, nrow=n_iter)
    grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()

    txt = (
            "Samples finished in "
            + str(round(time_taken, 3))
            + " minutes and exported to \n"
            + sample_path
            + "\nSeeds used = "
            + seeds[:-1]
    )
    return Image.fromarray(grid.astype(np.uint8)), txt

def load_img(image, h0, w0):
    image = image.convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h})")
    if h0 is not None and w0 is not None:
        h, w = h0, w0

    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32

    print(f"New image size ({w}, {h})")
    image = image.resize((w, h), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

demo = gr.Blocks()

with demo:
    with gr.Column():
        gr.Markdown("# Stable diffusion img2img (neonsecret's adjustments)")
        gr.Markdown("### Press 'print logs' button to get the model output logs")
        with gr.Row():
            with gr.Column():
                outs1 = [gr.Image(label="Output Image"), gr.Text(label="Generation results")]
                outs2 = [gr.Text(label="Logs")]
                outs3 = [gr.Text(label="nvidia-smi")]
                b1 = gr.Button("Generate!")
                b2 = gr.Button("Print logs")
                b3 = gr.Button("nvidia-smi")
            with gr.Column():
                with gr.Box():
                    b1.click(generate, inputs=[
                        gr.Image(tool="editor", type="pil", label="Initial image"),
                        gr.Text(label="Your Prompt"),
                        gr.Slider(0, 1, value=0.75, label="Generated image strength"),
                        gr.Slider(1, 1000, value=50, label="Sampling Steps"),
                        gr.Slider(1, 100, step=1, label="Number of images"),
                        gr.Slider(1, 100, step=1, label="Batch size"),
                        gr.Slider(64, 4096, value=512, step=64, label="Height"),
                        gr.Slider(64, 4096, value=512, step=64, label="Width"),
                        gr.Slider(0, 50, value=7.5, step=0.1, label="Guidance scale"),
                        gr.Slider(0, 1, step=0.01, label="DDIM sampling ETA"),
                        gr.Slider(1, 2, value=1, step=1, label="U-Net batch size"),
                        gr.Radio(["cuda", "cpu"], value="cuda", label="Device"),
                        gr.Text(label="Seed"),
                        gr.Text(value=output_path, label="Outputs path"),
                        gr.Radio(["png", "jpg"], value='png', label="Image format"),
                        gr.Checkbox(value=True, label="Turbo mode (better leave this on)"),
                        gr.Checkbox(label="Full precision mode (practically does nothing)"),
                        gr.Radio(["ddim", "plms", "k_dpm_2_a", "k_dpm_2", "k_euler_a", "k_euler", "k_heun", "k_lms"], value="plms", label="Sampler"),
                        gr.Slider(1, 100, value=100, step=1,
                                  label="%, VRAM usage limiter (100 means max speed)"),
                    ], outputs=outs1)
                    b2.click(get_logs, inputs=[], outputs=outs2)
                    b3.click(get_nvidia_smi, inputs=[], outputs=outs3)

debug = False #@param {type:"boolean"}
demo.launch(share=True, debug=debug)

Keyboard interruption in main thread... closing server.


(<gradio.routes.App at 0x7f56aa42bb10>,
 'http://127.0.0.1:7862/',
 'https://26054.gradio.app')

In [5]:
#@title Optimized txt2img
def generate(
        prompt,
        ddim_steps,
        n_iter,
        batch_size,
        Height,
        Width,
        scale,
        ddim_eta,
        unet_bs,
        device,
        seed,
        outdir,
        img_format,
        turbo,
        full_precision,
        sampler,
        speed_mp
):
    logging.info(f"prompt: {prompt}, W: {Width}, H: {Height}")
    C = 4
    f = 8
    start_code = None
    model.to(device)
    model.unet_bs = unet_bs
    model.turbo = turbo
    model.cdevice = device
    modelCS.cond_stage_model.device = device

    if seed == "":
        seed = randint(0, 1000000)
    seed = int(seed)
    seed_everything(seed)

    if device != "cpu" and not full_precision:
        model.half()
        modelFS.half()
        modelCS.half()

    tic = time.time()
    os.makedirs(outdir, exist_ok=True)
    outpath = outdir
    sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt.replace("/", ""))))[:150]
    os.makedirs(sample_path, exist_ok=True)
    base_count = len(os.listdir(sample_path))

    # n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
    assert prompt is not None
    data = [batch_size * [prompt]]

    if device != "cpu" and not full_precision:
        precision_scope = autocast
    else:
        precision_scope = nullcontext

    seeds = ""
    with torch.no_grad():
        all_samples = list()
        for _ in trange(n_iter, desc="Sampling"):
            for prompts in tqdm(data, desc="data"):
                with precision_scope("cuda"):
                    modelCS.to(device)
                    uc = None
                    if scale != 1.0:
                        uc = modelCS.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)

                    subprompts, weights = split_weighted_subprompts(prompts[0])
                    if len(subprompts) > 1:
                        c = torch.zeros_like(uc)
                        totalWeight = sum(weights)
                        # normalize each "sub prompt" and add it
                        for i in range(len(subprompts)):
                            weight = weights[i]
                            # if not skip_normalize:
                            weight = weight / totalWeight
                            c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
                    else:
                        c = modelCS.get_learned_conditioning(prompts)

                    shape = [batch_size, C, Height // f, Width // f]

                    if device != "cpu":
                        mem = torch.cuda.memory_allocated() / 1e6
                        modelCS.to("cpu")
                        while torch.cuda.memory_allocated() / 1e6 >= mem:
                            time.sleep(1)
                    samples_ddim = model.sample(
                        S=ddim_steps,
                        conditioning=c,
                        seed=seed,
                        shape=shape,
                        verbose=False,
                        unconditional_guidance_scale=scale,
                        unconditional_conditioning=uc,
                        eta=ddim_eta,
                        x_T=start_code,
                        sampler=sampler,
                        speed_mp=speed_mp
                    )

                    modelFS.to(device)
                    model.cpu()
                    logging.info("saving images")
                    for i in range(batch_size):
                        x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
                        x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        all_samples.append(x_sample.to("cpu"))
                        x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
                        Image.fromarray(x_sample.astype(np.uint8)).save(
                            os.path.join(sample_path, "seed_" + str(seed) + "_" + f"{base_count:05}.{img_format}")
                        )
                        seeds += str(seed) + ","
                        seed += 1
                        base_count += 1

                    if device != "cpu":
                        mem = torch.cuda.memory_allocated() / 1e6
                        modelFS.to("cpu")
                        while torch.cuda.memory_allocated() / 1e6 >= mem:
                            time.sleep(1)

                    del samples_ddim
                    del x_sample
                    del x_samples_ddim
                    logging.info(str("memory_final = " + str(torch.cuda.memory_allocated() / 1e6)))

    toc = time.time()

    time_taken = (toc - tic) / 60.0
    grid = torch.cat(all_samples, 0)
    grid = make_grid(grid, nrow=n_iter)
    grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
    txt = (
            "Samples finished in "
            + str(round(time_taken, 3))
            + " minutes and exported to "
            + sample_path
            + "\nSeeds used = "
            + seeds[:-1]
    )
    return Image.fromarray(grid.astype(np.uint8)), txt


class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.write(msg)
            self.flush()
        except Exception:
            self.handleError(record)

demo = gr.Blocks()

with demo:
    with gr.Column():
        gr.Markdown("# Stable diffusion txt2img (neonsecret's adjustments)")
        gr.Markdown("### Press 'print logs' button to get the model output logs")
        with gr.Row():
            with gr.Column():
                outs1 = [gr.Image(label="Output Image"), gr.Text(label="Generation results")]
                outs2 = [gr.Text(label="Logs")]
                outs3 = [gr.Text(label="nvidia-smi")]
                b1 = gr.Button("Generate!")
                b2 = gr.Button("Print logs")
                b3 = gr.Button("nvidia-smi")
            with gr.Column():
                with gr.Box():
                    b1.click(generate, inputs=[
                        gr.Text(label="Your Prompt"),
                        gr.Slider(1, 1000, value=50, label="Sampling Steps"),
                        gr.Slider(1, 100, step=1, label="Number of images"),
                        gr.Slider(1, 100, step=1, label="Batch size"),
                        gr.Slider(64, 4096, value=512, step=64, label="Height"),
                        gr.Slider(64, 4096, value=512, step=64, label="Width"),
                        gr.Slider(0, 50, value=7.5, step=0.1, label="Guidance scale"),
                        gr.Slider(0, 1, step=0.01, label="DDIM sampling ETA"),
                        gr.Slider(1, 2, value=1, step=1, label="U-Net batch size"),
                        gr.Radio(["cuda", "cpu"], value="cuda", label="Device"),
                        gr.Text(label="Seed"),
                        gr.Text(value=output_path, label="Outputs path"),
                        gr.Radio(["png", "jpg"], value='png', label="Image format"),
                        gr.Checkbox(value=True, label="Turbo mode (better leave this on)"),
                        gr.Checkbox(label="Full precision mode (practically does nothing)"),
                        gr.Radio(["ddim", "plms", "k_dpm_2_a", "k_dpm_2", "k_euler_a", "k_euler", "k_heun", "k_lms"], value="plms", label="Sampler"),
                        gr.Slider(1, 100, value=100, step=1,
                                  label="%, VRAM usage limiter (100 means max speed)"),
                    ], outputs=outs1)
                    b2.click(get_logs, inputs=[], outputs=outs2)
                    b3.click(get_nvidia_smi, inputs=[], outputs=outs3)

debug = False #@param {type:"boolean"}
demo.launch(share=True, debug=debug)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://18776.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces


INFO:pytorch_lightning.utilities.seed:Global seed set to 601208
Sampling:   0%|          | 0/1 [00:00<?, ?it/s]
data:   0%|          | 0/1 [00:00<?, ?it/s][A

seeds used =  [601208]
Data shape for PLMS sampling is [1, 4, 224, 224]
Running PLMS Sampling with 1 timesteps


PLMS Sampler:   0%|          | 0/1 [00:00<?, ?it/s]


data: 100%|██████████| 1/1 [01:16<00:00, 76.17s/it]
Sampling: 100%|██████████| 1/1 [01:16<00:00, 76.19s/it]
INFO:pytorch_lightning.utilities.seed:Global seed set to 846492
Sampling:   0%|          | 0/1 [00:00<?, ?it/s]
data:   0%|          | 0/1 [00:00<?, ?it/s][A

seeds used =  [846492]
Data shape for PLMS sampling is [1, 4, 144, 256]
Running PLMS Sampling with 1 timesteps


PLMS Sampler:   0%|          | 0/1 [00:00<?, ?it/s]


data: 100%|██████████| 1/1 [00:40<00:00, 40.17s/it]
Sampling: 100%|██████████| 1/1 [00:40<00:00, 40.19s/it]
INFO:pytorch_lightning.utilities.seed:Global seed set to 22735
Sampling:   0%|          | 0/1 [00:00<?, ?it/s]
data:   0%|          | 0/1 [00:00<?, ?it/s][A

seeds used =  [22735]


  0%|          | 0/50 [00:00<?, ?it/s]


data: 100%|██████████| 1/1 [26:24<00:00, 1584.11s/it]
Sampling: 100%|██████████| 1/1 [26:24<00:00, 1584.13s/it]


Keyboard interruption in main thread... closing server.


(<gradio.routes.App at 0x7f562a29a750>,
 'http://127.0.0.1:7861/',
 'https://18776.gradio.app')