<a href="https://colab.research.google.com/github/juanalonso/360Diffusion/blob/main/360Diffusion_Simplified-Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#360Diffusion [Demo] - Text-to-image Generation for Beginners

##Are you new to AI image generation? 

This demo has been developed with beginners in mind. Many of the advanced settings have been hidden, so as not to overwhelm those unfamiliar with Google Colab. Depending on the GPU Google assigns you, you can generate 1 image every 20-60 seconds. This demo uses the 256 model, as that one requires less resources and is therefore less prone to errors. Images are upscaled to 1024x1024px after each generation.

---

###Major components in this notebook were originally developed and/or revised by:

- `Katherine Crowson` (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings); Original CLIP Guided Diffusion Notebook Founder; Plays a big role in the AI gen scene
- `Daniel Russell` (https://github.com/russelldc, https://twitter.com/danielrussruss); Forked Katherin Crowson's Diffusion notebook into Quick / Fast Diffusion
- `sadnow` (https://github.com/sadnow); Forked Daniel's notebook into 360Diffusion with integration of Real-ESRGAN upscaling
- Parameter research assisted in-part by the following community (https://www.patreon.com/sportsracer48); pyTTI / VQLIPSE animation
---
To begin, click the play buttons on the left hand sides. There are 2 steps you will want to run. After you have ran the 1st step, you do not need to run it again.

In [None]:
#@title  { form-width: "100px" }

#@markdown #Step 1 (Only needs ran once per session)


##@markdown ##Check GPU
##@markdown ***This Notebook supports "Run All. Once everything is set to your liking, you can do `Ctrl+F9`!***
import torch
# Check the GPU status
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

enable_error_checking = True #saves ram
if enable_error_checking:
  !nvidia-smi
else:
  !nvidia-smi
  !nvidia-smi -i 0 -e 0

##########

#universal functions
def dir_make(a):
  import os.path
  from os import path
  if not path.exists(a):
    print("Creating"+a+"...")
    !mkdir -p $a

def dir_check(a):
  import os.path
  from os import path
  if not path.exists(a):
    return False
  else:
    return True

def file_check(a):
  import os.path
  from os import path
  if not path.isfile(a):
    return False
  else:
    return True

#def image_upscale(model_path,scale,input,output):
#    %cd /content/Real-ESRGAN/
#    !python inference_realesrgan.py --model_path $model_path --netscale $scale --face_enhance --input $input --output $output --ext jpg
#    %cd /content/

##@title Choose model here:
# diffusion_model = "256x256_diffusion_uncond" #@param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"]
diffusion_model = "256x256_diffusion_uncond" #demo-specific


google_drive = True #@param {type:"boolean"}
if google_drive:
  from google.colab import drive

#@markdown Using Google Drive is optional but recommended. If enabled, your images will be saved in "AI/Diffusion-Demo." You can find newly-generated images in your "Recent" window when opening Google Drive.

_drive_location = '/content/drive/MyDrive/AI/Diffusion-Demo/'


#####################################################################
##@title Google Drive & Download diffusion model

model_path = '/content/'
if google_drive:
    yes_please = True #demo-specifc
    print("NOTE: Because you are using Google Drive, the model file will be saved in ",_drive_location," for faster reloading")
    #^ demo specific
    from google.colab import drive
    drive.mount('/content/drive')
    if yes_please:
        dir_make(_drive_location)
        model_path = _drive_location


if diffusion_model == '256x256_diffusion_uncond':
    !wget --continue 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' -P {model_path}
elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':
    !wget --continue 'https://the-eye.eu/public/AI/models/512x512_diffusion_unconditional_ImageNet/512x512_diffusion_uncond_finetune_008100.pt' -P {model_path}
    #_diffusion_int = 512
if google_drive and not yes_please:
    model_path = _drive_location
#################################################

def install_ESRGAN():
    %cd /content/
    print("Installing libraries for Real-ESRGAN upscaling.")
    !git clone https://github.com/xinntao/Real-ESRGAN.git
    %cd /content/Real-ESRGAN
    !pip install basicsr -q
    !pip install facexlib -q
    !pip install gfpgan -q
    !pip install -r requirements.txt -q
    %cd /content/Real-ESRGAN
    !python setup.py develop -q
    !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
    # !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
    # !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P experiments/pretrained_models
    !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth -P experiments/pretrained_models
    # !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth -P experiments/pretrained_models  #regular esrgan


    print("Finished Installing libraries for Real-ESRGAN upscaling.")
    %cd /content/
    #


!git clone https://github.com/openai/CLIP
!git clone https://github.com/crowsonkb/guided-diffusion
!pip install -e ./CLIP
!pip install -e ./guided-diffusion
!pip install lpips datetime
install_ESRGAN()

#########################################################

import time
import gc
import io
import math
import sys
from IPython import display
import lpips
from PIL import Image, ImageOps
import requests
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm
sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from datetime import datetime #filename
import numpy as np
import matplotlib.pyplot as plt
import random

#import subprocess #future implementation
import os

###########################################################

def add_command(var,string):
  var = (var + string + ' ')
  return var

def image_resize(filepath,width):
  from PIL import Image
  basewidth = width
  img = Image.open(filepath)
  wpercent = (basewidth/float(img.size[0]))
  hsize = int((float(img.size[1])*float(wpercent)))
  #img = img.resize((basewidth,hsize), Image.ANTIALIAS)
  if width == 1024: img = img.resize((basewidth,hsize), Image.LANCZOS)
  else: img = img.resize((basewidth,hsize), Image.BICUBIC)
  img.save(filepath)


def interp(t):
    return 3 * t**2 - 2 * t ** 3

def perlin(width, height, scale=10, device=None):
    gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
    xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
    ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
    wx = 1 - interp(xs)
    wy = 1 - interp(ys)
    dots = 0
    dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
    dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
    dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
    dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
    return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)

def perlin_ms(octaves, width, height, grayscale, device=device):
    out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
    # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
    for i in range(1 if grayscale else 3):
        scale = 2 ** len(octaves)
        oct_width = width
        oct_height = height
        for oct in octaves:
            p = perlin(oct_width, oct_height, scale, device)
            out_array[i] += p * oct
            scale //= 2
            oct_width *= 2
            oct_height *= 2
    return torch.cat(out_array)

def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):
    out = perlin_ms(octaves, width, height, grayscale)
    if grayscale:
        out = TF.resize(size=(side_x, side_y), img=out.unsqueeze(0))
        out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')
    else:
        out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])
        out = TF.resize(size=(side_x, side_y), img=out)
        out = TF.to_pil_image(out.clamp(0, 1).squeeze())

    out = ImageOps.autocontrast(out)
    return out

#################################################################################################################
def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')


def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', '1'][len(vals):]
    return vals[0], float(vals[1])

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))

def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()

def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]

def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.reshape([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.reshape([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, skip_augs=False):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.skip_augs = skip_augs
        self.augs = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            # T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
            # T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomPerspective(distortion_scale=0.4, p=0.7),
            # T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomGrayscale(p=0.15),
            # T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        ])


    def forward(self, input):
        input = T.Pad(input.shape[2]//4, fill=0)(input)
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)

        cutouts = []
        for ch in range(cutn):
            if ch > cutn - cutn//4:
                cutout = input.clone()
            else:
                size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
                offsetx = torch.randint(0, abs(sideX - size + 1), ())
                offsety = torch.randint(0, abs(sideY - size + 1), ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]

            if not self.skip_augs:
                cutout = self.augs(cutout)
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
            del cutout

        cutouts = torch.cat(cutouts, dim=0)
        return cutouts


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])


def range_loss(input):
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])

###
# NEWEST PERLIN NOISE EDITS
def unitwise_norm(x, norm_type=2.0):
    if x.ndim <= 1:
        return x.norm(norm_type)
    else:
        # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
        # might need special cases for other weights (possibly MHA) where this may not be true
        return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    for p in parameters:
        if p.grad is None:
            continue
        p_data = p.detach()
        g_data = p.grad.detach()
        max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
        grad_norm = unitwise_norm(g_data, norm_type=norm_type)
        clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
        new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
        p.grad.detach().copy_(new_grads)

def regen_perlin(): #NEWEST PERLIN UPDATE
    if perlin_mode == 'color':
        init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
        init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
    elif perlin_mode == 'gray':
        init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
        init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
    else:
        init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
        init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)

    init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
    del init2
    return init

##########################################################################################################


def do_run():
    global firstRun,_scale_multiplier
    loss_values = []

    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    make_cutouts = MakeCutouts(clip_size, cutn, skip_augs=skip_augs)
    target_embeds, weights = [], []

    for prompt in text_prompts:
        txt, weight = parse_prompt(prompt)
        txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
        #----
        print('Perception is quite a ride... \n')
        time.sleep(0.2)
        print('Considering this is just the demo. \n')
        time.sleep(0.5)
        if not _debug_mode:
              display.clear_output(wait=True)
        #----
        if fuzzy_prompt:
            for i in range(25):
                # target_embeds.append((txt + torch.randn(txt.shape).cuda() * rand_mag).clamp(0,1))
                target_embeds.append(txt + torch.randn(txt.shape).cuda() * rand_mag)
                weights.append(weight)
        else:
            target_embeds.append(txt)
            weights.append(weight)

    for prompt in image_prompts:
        path, weight = parse_prompt(prompt)
        img = Image.open(fetch(path)).convert('RGB')
        img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)
        batch = make_cutouts(TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))
        embed = clip_model.encode_image(normalize(batch)).float()
        if fuzzy_prompt:
            for i in range(25):
                # target_embeds.append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))
                target_embeds.append(embed + torch.randn(embed.shape).cuda() * rand_mag)
                weights.extend([weight / cutn] * cutn)
        else:
            target_embeds.append(embed)
            weights.extend([weight / cutn] * cutn)

    target_embeds = torch.cat(target_embeds)
    weights = torch.tensor(weights, device=device)
    if weights.sum().abs() < 1e-3:
        raise RuntimeError('The weights must not sum to 0.')
    weights /= weights.sum().abs()

    init = None
    if init_image is not None:
        init = Image.open(fetch(init_image)).convert('RGB')
        init = init.resize((side_x, side_y), Image.LANCZOS)
        init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
    
    #DISABLED WITH PERLIN UPDATE
    # if perlin_init:
    #     if perlin_mode == 'color':
    #         init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
    #         init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
    #     elif perlin_mode == 'gray':
    #         init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
    #         init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
    #     else:
    #         init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
    #         init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
        
        # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)
        # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
        # del init2

    cur_t = None

    def cond_fn(x, t, y=None):
        with torch.enable_grad():
            x = x.detach().requires_grad_()
            n = x.shape[0]
            my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
            out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
            fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
            x_in = out['pred_xstart'] * fac + x * (1 - fac)
            x_in_grad = torch.zeros_like(x_in)
            for i in range(cutn_batches):
                clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
                image_embeds = clip_model.encode_image(clip_in).float()
                dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
                dists = dists.view([cutn, n, -1])
                losses = dists.mul(weights).sum(2).mean(0)
                loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch
                x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches
            tv_losses = tv_loss(x_in)
            range_losses = range_loss(out['pred_xstart'])
            sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()
            loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale
            if init is not None and init_scale:
                init_losses = lpips_model(x_in, init)
                loss = loss + init_losses.sum() * init_scale
            x_in_grad += torch.autograd.grad(loss, x_in)[0]
            grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
        if clamp_grad:
            adaptive_clip_grad([x]) #ADDED WITH PERLIN UPDATE
            magnitude = grad.square().mean().sqrt()
            return grad * magnitude.clamp(max=0.05) / magnitude
        return grad

    if model_config['timestep_respacing'].startswith('ddim'):
        sample_fn = diffusion.ddim_sample_loop_progressive
    else:
        sample_fn = diffusion.p_sample_loop_progressive

    for i in range(n_batches):
        cur_t = diffusion.num_timesteps - skip_timesteps - 1
        
        if perlin_init: #ADDED WITH PERLIN UPDATE 
            init = regen_perlin()

        if model_config['timestep_respacing'].startswith('ddim'):
            samples = sample_fn(
                model,
                (batch_size, 3, side_y, side_x),
                clip_denoised=clip_denoised,
                model_kwargs={},
                cond_fn=cond_fn,
                progress=True,
                skip_timesteps=skip_timesteps,
                init_image=init,
                randomize_class=randomize_class,
                eta=eta,
            )
        else:
            samples = sample_fn(
                model,
                (batch_size, 3, side_y, side_x),
                clip_denoised=clip_denoised,
                model_kwargs={},
                cond_fn=cond_fn,
                progress=True,
                skip_timesteps=skip_timesteps,
                init_image=init,
                randomize_class=randomize_class,
            )

        for j, sample in enumerate(samples):
            if not _debug_mode:
              display.clear_output(wait=True)
            cur_t -= 1
            if j % display_rate == 0 or cur_t == -1:  #Only single iteration has finished
                for k, image in enumerate(sample['pred_xstart']):
                    #################################################################
                    tqdm.write(f'Batch {i}, step {j}, output {k}:')
                    if firstRun: max_mutation_amount = 0
                    tqdm.write(f'Max mutation {max_mutation_amount}, Guidance scale {clip_guidance_scale}, TV scale {tv_scale}, Range scale {range_scale}')
                    tqdm.write(msg_runtime)
                    #################################################################
                    current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')
                    filename = f'batch{i:04}_iteration{j:04}_output{k:04}_{current_time}.png'  #reduced padding
                    image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
                    image.save('/content/image_storage/' + _project_name + filename) 
                    display.display(display.Image('/content/image_storage/' + _project_name + filename))
                    
                    if _batch_genetics or _init_genetics:
                      if firstRun:
                        print(firstRun)  #debug
                        max_mutation_amount = (_max_genetic_variance) / (n_batches)
                        #print("Max mutation per output is ",max_mutation_amount)  #debug
                        firstRun = False #we don't want this calculating again until next run
                      #import random
                      actual_mutation_amount = random.uniform(0,max_mutation_amount)
                      _scale_multiplier = _scale_multiplier + actual_mutation_amount
                      calculate_scale_multiplier()
                      #cond_fn()
                      cond_fn=cond_fn #unsure whether thse will work

                    if google_drive and cur_t == -1:  #IMAGE HAS FINISHED & DRIVE IS ENABLED
                      gc.collect()
                      torch.cuda.empty_cache()
                      #Image_upscale(filename)  #unsued resizer function
                      #INSERT CODE ON THIS LINE
                      if not _skip_upscaling: Image_upscale(filename,image)  #brought to you by world hunger  
                      else:
                        print("\n Saving to ",output_folder_images,filename)
                        os.path.join(_drive_location,_project_name)
                        image.save(os.path.join(output_folder_images,filename))
                        #image.save(output_folder_images + filename)
                      #else:                   
                      #anything placed here will run before the upscale has completed
                      #will be ran every single step
                      #image.save(output_folder_images + filename)

                        

        plt.plot(np.array(loss_values), 'r')

##############################################################################################################
#Settings and generation
timestep_respacing = 'ddim25' 
  ##@param ['25','50','100','150','250','500','1000','ddim25','ddim50','ddim100','ddim150','ddim250','ddim500','ddim1000']  
##@markdown *Modify this value to decrease the number of iterations/prompt.
# timestep_respacing = '25'
diffusion_steps = 1000

model_config = model_and_diffusion_defaults()
if diffusion_model == '512x512_diffusion_uncond_finetune_008100':
    model_config.update({
        'attention_resolutions': '32, 16, 8',
        'class_cond': False,
        'diffusion_steps': diffusion_steps,
        'rescale_timesteps': True,
        'timestep_respacing': timestep_respacing,
        'image_size': 512,
        'learn_sigma': True,
        'noise_schedule': 'linear',
        'num_channels': 256,
        'num_head_channels': 64,
        'num_res_blocks': 2,
        'resblock_updown': True,
        'use_fp16': True,
        'use_scale_shift_norm': True,
    })
elif diffusion_model == '256x256_diffusion_uncond':
    model_config.update({
        'attention_resolutions': '32, 16, 8',
        'class_cond': False,
        'diffusion_steps': diffusion_steps,
        'rescale_timesteps': True,
        'timestep_respacing': timestep_respacing,
        'image_size': 256,
        'learn_sigma': True,
        'noise_schedule': 'linear',
        'num_channels': 256,
        'num_head_channels': 64,
        'num_res_blocks': 2,
        'resblock_updown': True,
        'use_fp16': True,
        'use_scale_shift_norm': True,
    })
side_x = side_y = model_config['image_size']

model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load(f'{model_path}{diffusion_model}.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)
for name, param in model.named_parameters():
    if 'qkv' in name or 'norm' in name or 'proj' in name:
        param.requires_grad_()
if model_config['use_fp16']:
    model.convert_to_fp16()

################################################

clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)

In [None]:
########################################################################
#INITIAL VALUES
#@title  { form-width: "300px" }
#@markdown #Step 2
firstRun = True
msg_runtime = ''
_keep_first_upscale = False
_run_upscaler = True
batch_size =  1
clamp_grad = True # True - Experimental: Using adaptive clip grad in the cond_fn
skip_augs = False # False - Controls whether to skip torchvision augmentations
randomize_class = True # True - Controls whether the imagenet class is randomly changed each iteration
#############################################################################################
_text_prompt =  "'College Campus' art piece inspired by Edward Hopper" #@param {type:"string"} 

_batch_genetics = False #future implementation
_init_genetics = False  #not currently implemented
_max_genetic_variance =  0.1
_saturation_scale =  0

#_enhance_upscale = True #@param{type:"boolean"}

cutn = 16
  #Controls how many crops to take from the image. Increase for higher quality.
quality_preset = "4" #@param [2,4,8,16,32]
cutn_batches = int(quality_preset) #demo-specific

  #Accumulate CLIP gradient from multiple batches of cuts [Can help with OOM errors / Low VRAM]
_esrgan_tilesize = "128"
#_upscale_performance_mode = False
##@markdown `Performance Settings`

##@markdown ---
clip_denoised = False
fuzzy_prompt = False
  # False - Controls whether to add multiple noisy prompts to the prompt losses
eta =  0.5
_clip_guidance_scale =  5000
_tv_scale = 150
_range_scale =  100
_scale_multiplier = 1
##@markdown `Visual Settings`
seed = None #@param {type:"raw"}
# seed = random.randint(0, 2**32) # Choose a random seed and print it at end of run for reproduction

##@markdown ---
_noise_mode = 'gray'
_noise_amount = 0.1
_init_image = ''
_image_prompts = ""
_init_scale = 1000
skip_timesteps = 5
n_batches =  50#@param{type:"raw"}
  #Controls the starting point along the diffusion timesteps

##@markdown `Generation Settings`

##@markdown ---
_project_name = 'images' #@param{type:"string"}
global _debug_mode
_debug_mode = False
_upscale_model='RealESRGAN_x4plus'
_target_resolution = "1024"
_skip_upscaling = False
display_rate =  2
##@markdown Original Defaults: `clip_guidance_scale 5000`,`tv_scale 150`,`range scale 150`
##@markdown Recommended defaults for init_images (ddim50): `clip_guidance_scale 2000`,`tv_scale 150`,`range scale 50`, `init_scale 1000`, `skip_timesteps 16 (7-9 for ddim25)`
##@markdown There is a possibility that `tv_scale` can be set between `0` to `10000`
##@markdown `skip_timesteps` does a lot for the similarity in `init_settings`
##@markdown Special thanks to many people on the VQLIPSE Discord
##---

#--------------------------------------------------------------------------------------------------------

text_prompts = [
    # "an abstract painting of 'ravioli on a plate'",
    # 'cyberpunk wizard on top of a skyscraper, trending on artstation, photorealistic depiction of a cyberpunk wizard',
    _text_prompt]
    # 'cyberpunk wizard',

if diffusion_model == "512x512_diffusion_uncond_finetune_008100": model_size = 512
else: model_size = 256
if int(_target_resolution) == model_size:
  print("\n Due to your _target_resolution, _skip_upsscaling will be enabled. \n")
  _skip_upscaling = True
if int(model_size) > int(_target_resolution):
  print("\n NOTICE: Your _upscale_reoslution is higher than your model size! Setting to match and disabling upscaling... \n")
  _target_resolution = int(model_size)
  _skip_upscaling = True


if _init_image == '':
  init_image = None
else:
  init_image = _init_image  # None - URL or local path
  msg_runtime = msg_runtime + 'Init image: ' + init_image + ' \n'

if _noise_amount > 0: 
  add_random_noise = True
else:
  add_random_noise = False

# if not _init_image == '':  #to prevent noise from messing with init
#   if add_random_noise == True:
#     msg_runtime = msg_runtime + 'Notice: init_images have mixed results when _noise_amount > 0 \n'
#   # _noise_amount = 0
#   # add_random_noise = False
perlin_init = add_random_noise
if init_image is not None: # Can't combine init_image and perlin options
  perlin_init = False
  msg_runtime = msg_runtime + 'NOTICE: You may want to disable _noise_amount when using _init_images \n'


rand_mag = _noise_amount # 0.1 - Controls the magnitude of the random noise

#default_init_scale = 1000 #in case user forgets to set
if _init_image == '':
  print("No init image detected. Setting init_scale to 0...")
  init_scale = 0
else:
  init_scale = _init_scale
  if init_scale == 0:
    msg_runtime = msg_runtime + 'Notice: init_image is set but there is no init_scale! Try 1000 as a default. \n'
    #init_scale = default_init_scale
if not _image_prompts == "":
  msg_runtime = msg_runtime + 'Image prompt: ' + _image_prompts + ' \n'
  image_prompts = [
      _image_prompts,
  ]
else: image_prompts = []
perlin_mode = _noise_mode # 'mixed' ('gray', 'color')

sat_scale = _saturation_scale
  # 0 - Controls how much saturation is allowed. From nshepperd's JAX notebook.

##@markdown `skip_timesteps` best 5 (thx steven), 10 for dd50

###@markdown False - Determines whether CLIP discriminates a noisy or denoised image
# if _diffusion_int == 512: display_rate = 2
# else: display_rate = 1
###@markdown False - Controls whether to add multiple noisy prompts to the prompt losses
###@markdown 0.0 - DDIM hyperparameter


################################################
#---------------------------------------------------------------------------------------
#Output image directory handling
import os
output_folder_images = os.path.join(_drive_location,_project_name)
print ("\n output_folder_images is ",output_folder_images)
# contains_slash = (output_folder_images.endswith('/'))
# if not contains_slash:
#   output_folder_images = output_folder_images + '/'
dir_make(output_folder_images)
temp_image_storage = '/content/image_storage/'  #for non-upscaled images
dir_make(temp_image_storage)
#---------------------------------------------------------------------------------------
def calculate_scale_multiplier():
  global clip_guidance_scale,tv_scale,range_scale
  clip_guidance_scale = _scale_multiplier * _clip_guidance_scale
  tv_scale = _scale_multiplier * _tv_scale
  range_scale = _scale_multiplier * _range_scale
  return clip_guidance_scale,tv_scale,range_scale
calculate_scale_multiplier()

#---------------------------------
# UPSCALING
msg_runtime = msg_runtime + 'Upscaler model: ' + _upscale_model + '; Target resolution: ' + str(_target_resolution) + ' \n'
#if _upscale_model == '(Regular) ESRGAN 4x': _upscale_model = 'c-ff704c30.pth'
#if _upscale_model == 'RealESR_NET 4x': _upscale_model = 'RealESRNet_x4plus.pth'
#if _upscale_model == 'RealESRGAN_x2plus': _upscale_model = 'RealESRGAN_x2plus.pth'
#if _upscale_model == 'RealESRGAN 4x': _upscale_model = 'RealESRGAN_x4plus.pth'
#if _upscale_model == 'RealESRGAN x4 Anime_6B': _upscale_model = 'RealESRGAN_x4plus_anime_6B.pth'
#upscale_model_fullpath='/content/Real-ESRGAN/experiments/pretrained_models/' + _upscale_model
if _upscale_model == 'RealESRGAN_x2plus': upscale_value="2"
#else: upscale_value="4"
 




def Image_upscale(filename,image):   
  
  def round_up_to_even(f):
    return math.ceil(f / 2.) * 2

  def File_addSuffix(filename,needed_passes,completed_passes):
    #returns the filename of the previously upscaled file
    #if not completed_passes == needed_passes:
    stripped_filename = os.path.splitext(filename)[0] #get non-upscaled image filename
    #stripped_filename = os.path.join(_project_name,stripped_filename)
    stripped_filename = _project_name + stripped_filename

    #stripped_filename = os.path.join(output_folder_images,stripped_filename)
    z = 1
    while z < completed_passes + 1:
      stripped_filename = stripped_filename + '_out'
      z = z + 1
    if completed_passes > 1 or needed_passes == 1:
      target_extension = 'jpg'  #dont forget the dot
    else: target_extension = 'png'
    stripped_filename = stripped_filename + '.' + target_extension  #this identifies the fullpath output from the 1st upscale
    # if completed_passes == needed_passes:
    #   return
    # else:
    current_upscale_target = os.path.join(output_folder_images,stripped_filename)
    #prior_upscale_target = current_upscale_target #newest addition
    #print('Prior upscale target is ',prior_upscale_target)
    print('Current upscale target is ',current_upscale_target)
    return current_upscale_target
  
  %cd /content/Real-ESRGAN/
  temp_image_filepath = temp_image_storage + _project_name + filename
  #image.save(temp_image_filepath) #here we save the un-upscaled so we can load it IN the upscaler
  
  retry_attempt = 0

  
  def Load_config(retry_attempt,_upscale_performance_mode):
    needed_passes = 1
    global _skip_upscaling
    if retry_attempt > 1:
      print("\n UPSCALING FAILED! Try raising your cutn_batch settings.")
      print("Attempting to save un-upscaled file... \n")
      image.save((os.path.join(output_folder_images,filename)))
      _skip_upscaling = True
      import sys
      sys.exit()

    global model_size
    
    global _target_resolution
    outscale_resolution_int = int(_target_resolution)
    if int(model_size) == 256 and int(_target_resolution) > 2048:
      print("\n NOTE: 256 model not currently compatible with 4096 upscaling. \n You can use the 512 model by picking it at the top of the Notebook. \n")
      outscale_resolution_int = 2048
    _outscale = outscale_resolution_int / model_size
      #ex: 2048 / 512 = _outscale of 4
    if not (_outscale % 2) == 0: #number is not even
      _outscale = round_up_to_even(_outscale)
      #if _outscale > 4: _outscale = 4
    if _outscale > 4 :
      needed_passes = _outscale / 4
      _outscale = _outscale / 2    
    if _upscale_performance_mode:
      _outscale = _outscale / 2
      needed_passes = needed_passes * 2
      half_precision = True
    else:
      # _outscale = 4
      # needed_passes = 2
      _skip_upscaling = False
      half_precision = False
    needed_passes = round(needed_passes)
    print("\n NOTICE: Performing ",needed_passes,"needed_passes...")
    print("Outscale is", _outscale)
    print("Needed passes is",needed_passes,"\n")
    return _outscale,needed_passes,_skip_upscaling,half_precision,model_size

  run_config = True
  #double_pass = False
  #second_pass = False
  _upscale_performance_mode = False

  completed_passes = 0

  retry_attempt = 0
  upscale_complete = False
  next_outscale = 0
  
  while not upscale_complete:
    if run_config:
      _outscale,needed_passes,_skip_upscaling,half_precision,model_size = Load_config(retry_attempt,_upscale_performance_mode)
      completed_passes = 0
    end_flag = '--face_enhance '
    #--input $current_upscale_target
    _outscale = str(_outscale)
    if half_precision: end_flag = end_flag + '--half '
    # if not _upscale_model == 'RealESRGAN 2x':
    #   if next_outscale == 0:
    #     end_flag = end_flag + '--outscale ' + _outscale + ' '
    #   else: end_flag = end_flag + '--outscale ' + next_outscale + ' '


    #current_upscale_target = File_addSuffix(filename,needed_passes,completed_passes)
    if completed_passes < 1:
      end_flag = end_flag + '--input ' + temp_image_filepath + ' '
    else:
      ##################################################################################
      current_upscale_target = File_addSuffix(filename,needed_passes,completed_passes)    
      #################################################################################
      end_flag = end_flag + '--input ' + current_upscale_target + ' '

    if completed_passes > 0 or needed_passes == 1:
      output_extension = 'jpg'
    else: output_extension = 'png'
    
    if (int(model_size) * int(float(_outscale)) * (int(needed_passes) ** 2)) > int(_target_resolution):
      next_outscale = 2
    if not _upscale_model == 'RealESRGAN_x2plus' and next_outscale == 0:
      end_flag = end_flag + '--outscale ' + _outscale + ' '
    else:
      if completed_passes + 1 == needed_passes:
        end_flag = end_flag + '--outscale ' + str(next_outscale) + ' '
        print("next_outscale is",next_outscale)
    if end_flag.endswith(' '): end_flag = end_flag[:-1] #removes extra space

    if not upscale_complete:
      !python inference_realesrgan.py --model_name $_upscale_model --output $output_folder_images --tile $_esrgan_tilesize --ext $output_extension $end_flag
    

    # if double_pass: completed_passes = completed_passes + 1
    # else: completed_passes = completed_passes + 2
    if completed_passes > 0: _upscaled_path = file_check(current_upscale_target)
    else: _upscaled_path = file_check(temp_image_filepath)  #attempting to load in performance_mode
    
    if _upscaled_path:
      print("\n NOTICE: Upscale pass ",completed_passes," COMPLETE! \n")
      #if double_pass: second_pass = True
      run_config = False  #because we are done
      completed_passes = completed_passes + 1
      if completed_passes >= needed_passes and needed_passes > 1:
        #print("the fuckn thing is", current_upscale_target)
        print("\n Removing previous file: ",current_upscale_target,"\n")
        !rm $current_upscale_target
    else: 
      if completed_passes < 1: print("\n ERROR: CAN'T FIND UPSCALED FILE ", temp_image_filepath)
      else: print("\n ERROR: CAN'T FIND UPSCALED FILE ", current_upscale_target)
      if not _upscale_performance_mode:
        print("\n RECOVERY STATUS: ATTEMPTING TO RUN IN upscale_performance mode... \n")
        _upscale_performance_mode = True
      completed_passes = 0
      retry_attempt = retry_attempt + 1
      run_config = True
    #if completed_passes >= needed_passes - 1:
    if completed_passes >= needed_passes:
      upscale_complete = True
    if upscale_complete:
      return

if _debug_mode:
  print("\n Setting skip to large number... \n")
  skip_timesteps = 20


#-------------------------------------------------------

# 1 - Controls how many consecutive batches of images are generated

gc.collect()
torch.cuda.empty_cache()
try:
    do_run()
except KeyboardInterrupt:
    gc.collect()
    torch.cuda.empty_cache()
    pass
finally:
    print('\n seed', seed)
    print('\n Output(s) saved to ',output_folder_images)
    gc.collect()
    torch.cuda.empty_cache()

###############################################

#Example prompts

- 'College Campus' art piece inspired by Edward Hopper
- Liminal space, liminal hotel hallway rendered in unreal engine, top post on r/liminalspaces

Diffusion prompting might seem trickier to master than VQGAN models, but it still allows for some level of control. Sometimes less it better.

It doesn't always need the `trending on artstation` type of lingo; sometimes it will benefit more from something like `photorealistic 4k nature replication`.

Diffusion is kind of unique in how it often doesn't duplicate the desired subject. For example, `cat photo` will usually only give you 1 cat.

------------------------------------------------

### A Giant List of Terms to Try

Credit for this list goes to @Atman on Discord and/or whoever else contributed to the Pastebin.

*'8k resolution'
,'pencil sketch'
,'8K 3D'
,'creative commons attribution'
,'deviantart'
,'CryEngine'
,'Unreal Engine'
,'concept art'
,'photoillustration'
,'pixiv'
,'Flickr'
,'ArtStation HD'
,'Behance HD'
,'HDR'
,'anime'
,'filmic'
,'Stock photo'
,'Ambient occlusion'
,'Global illumination'
,'Chalk art'
,'Low poly'
,'Booru'
,'Polycount'
,'Acrylic art'
,'Hyperrealism'
,'Zbrush Central'
,'Rendered in Cinema4D'
,'Rendered in Maya'
,'Photo taken with Nikon D750'
,'Tilt shift'
,'Mixed media'
,'Depth of field'
,'DSLR'
,'Detailed painting'
,'Volumetric lighting'
,'Storybook illustration'
,'Unsplash contest winner'
,'#vfxfriday'
,'Ultrafine detail'
,'20 megapixels'
,'Photo taken with Fujifilm Superia'
,'Photo taken with Ektachrome'
,'matte painting'
,'reimagined by industrial light and magic'
,'Watercolor'
,'CGSociety'
,'childs drawing'
,'marble sculpture'
,'airbrush art'
,'renaissance painting'
,'Velvia'
,'Provia'
,'photo taken with Provia'
,'prerendered graphics'
,'criterion collection'
,'dye-transfer'
,'stipple'
,'Parallax'
,'Bryce 3D'
,'Terragen'
,'(2013) directed by cinematography by'
,'Bokeh'
,'1990s 1995'
,'1970s 1975'
,'1920s 1925'
,'charcoal drawing'
,'commission for'
,'furaffinity'
,'flat shading'
,'ink drawing'
,'artwork'
,'oil on canvas'
,'macro photography'
,'hall of mirrors'
,'polished'
,'sunrays shine upon it'
,'aftereffects'
,'iridescent'
,'#film'
,'datamosh'
,'(1962) directed by cinematography'
,'holographic'
,'dutch golden age'
,'digitally enhanced'
,'National Geographic photo'
,'Associated Press photo'
,'matte background'
,'Art on Instagram'
,'#myportfolio'
,'digital illustration'
,'stock photo'
,'aftereffects'
,'speedpainting'
,'colorized'
,'detailed'
,'psychedelic'
,'wavy'
,'groovy'
,'movie poster'
,'pop art'
,'made of beads and yarn'
,'made of feathers'
,'made of crystals'
,'made of liquid metal'
,'made of glass'
,'made of cardboard'
,'made of vines'
,'made of cheese'
,'made of flowers'
,'made of insects'
,'made of mist'
,'made of paperclips'
,'made of rubber'
,'made of plastic'
,'made of wire'
,'made of trash'
,'made of wrought iron'
,'made of all of the above'
,'tattoo'
,'woodcut'
,'American propaganda'
,'Soviet propaganda'
,'PS1 graphics'
,'Fine art'
,'HD mod'
,'Photorealistic'
,'Poster art'
,'Constructivism'
,'pre-Raphaelite'
,'Impressionism'
,'Lowbrow'
,'RTX on'
,'chiaroscuro'
,'Egyptian art'
,'Fauvism'
,'shot on 70mm'
,'Art Deco'
,'Picasso'
,'Da Vinci'
,'Academic art'
,'3840x2160'
,'Photocollage'
,'Cubism'
,'Surrealist'
,'THX Sound'
,'ZBrush'
,'Panorama'
,'smooth'
,'DC Comics'
,'Marvel Comics'
,'Ukiyo-e'
,'Flemish Baroque'
,'vray tracing'
,'pixel perfect'
,'quantum wavetracing'
,'Zbrush central contest winner'
,'ISO 200'
,'Bob Ross'
,'32k HUHD'
,'Photocopy'
,'DeviantArt HD'
,'infrared'
,'Angelic photograph'
,'Demonic photograph'
,'Biomorphic'
,'Windows Vista'
,'Skeuomorphic'
,'Physically based rendering'
,'Trance compilation CD'
,'Concert poster'
,'Steampunk'
,'Sketchfab'
,'Goth'
,'Wiccan'
,'trending on artstation'
,'featured on artstation'
,'artstation HQ'
,'artstation contest winner'
,'ultra HD'
,'high quality photo'
,'instax'
,'ilford HP5'
,'infrared'
,'Lomo'
,'Matte drawing'
,'matte photo'
,'glowing neon'
,'Xbox 360 graphics'
,'flickering light'
,'Playstation 5 screenshot'
,'Kodak Gold 200'
,'by Edward Hopper'
,'rough'
,'maximalist'
,'minimalist'
,'Kodak Ektar'
,'Kodak Portra'
,'geometric'
,'cluttered'
,'Rococo'
,'destructive'
,'by James Gurney'
,'by Thomas Kinkade'
,'by Vincent Di Fate'
,'by Jim Burns'
,'androgynous'
,'masculine'
,'genderless'
,'feminine'
,'extremely gendered, masculine and feminine'
,'4k result'
,'#pixelart'
,'voxel art'
,'wimmelbilder'
,'dystopian art'
,'apocalypse art'
,'apocalypse landscape'
,'2D game art'
,'Windows XP'
,'y2k aesthetic'
,'#screenshotsaturday'
,'seapunk'
,'vaporwave'
,'Ilya Kuvshinov'
,'Paul Cezanne'
,'Henry Moore'
,'phallic'
,'creepypasta'
,'retrowave'
,'synthwave'
,'outrun'*