<a href="https://colab.research.google.com/github/sadnow/ESRGAN-UltraFast-CLIP-Guided-Diffusion-Colab/blob/main/Upscaling-UltraQuick_CLIP_Guided_Diffusion_HQ_256x256_and_512x512.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generates images from text prompts with CLIP guided diffusion.

By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.

Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.

**Update**: Sep 19th 2021


Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.

Katherine's original notebook can be found here:
https://colab.research.google.com/drive/1QBsaDAZv8np29FPbvjffbE1eytoJcsgA

---

**Update**: 9/27/21
Further UI simplifications and Real-ESRGAN integration made by sadnow on 9/27/21. Updates can be found at https://github.com/sadnow/ESRGAN-UltraFast-CLIP-Guided-Diffusion-Colab if available.

-Notebook now does 256x diffusion and then upscales to 1024x1024.

- Real-ESRGAN upscaling at the end of every batch

-Added a "project name" setting and changed temp folder location

-Added a supper settings slider for simplicity and an optional override toggle



In [None]:
#@markdown ##Check GPU
#@markdown ***This Notebook supports "Run All. Once everything is set to your liking, you can do `Ctrl+F9`!***
import torch
# Check the GPU status
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

enable_error_checking = False#@param {type:"boolean"}
if enable_error_checking:
  !nvidia-smi
else:
  !nvidia-smi
  !nvidia-smi -i 0 -e 0

In [None]:
def dir_make(a):
  import os.path
  from os import path
  if not path.exists(a):
    print("Creating"+a+"...")
    !mkdir -p $a

def dir_check(a):
  import os.path
  from os import path
  if not path.exists(a):
    return False
  else:
    return True

def file_check(a):
  import os.path
  from os import path
  if not path.isfile(a):
    return False
  else:
    return True

def image_upscale(model_path,scale,input,output):
    %cd /content/Real-ESRGAN/
    !python inference_realesrgan.py --model_path $model_path --netscale $scale --face_enhance --input $input --output $output --ext jpg
    %cd /content/

from google.colab import drive

#@title Choose model here:
#diffusion_model = "256x256_diffusion_uncond" #@param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"]
diffusion_model = "256x256_diffusion_uncond"

#@markdown If you connect your Google Drive, you can save the final image of each run on your drive.

google_drive = True #@param {type:"boolean"}

#@markdown You can use your mounted Google Drive to load the model checkpoint file if you've already got a copy downloaded there. This will save time (and resources!) when you re-visit this notebook in the future.

#@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:
yes_please = True #@param {type:"boolean"}

_drive_location = '/content/drive/MyDrive/AI/Diffusion/' #@param{type:"string"}
contains_slash = (_drive_location.endswith('/'))
if not contains_slash:
  _drive_location = _drive_location + '/'

# Install and import dependencies

In [None]:
#@title Download diffusion model

model_path = '/content/'
if google_drive:
    from google.colab import drive
    drive.mount('/content/drive')
    if yes_please:
        dir_make(_drive_location)
        model_path = _drive_location


if diffusion_model == '256x256_diffusion_uncond':
    !wget --continue 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt' -P {model_path}
elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':
    !wget --continue 'https://the-eye.eu/public/AI/models/512x512_diffusion_unconditional_ImageNet/512x512_diffusion_uncond_finetune_008100.pt' -P {model_path}

if google_drive and not yes_please:
    model_path = _drive_location

In [None]:
def install_ESRGAN():
    %cd /content/
    print("Installing libraries for Real-ESRGAN upscaling.")
    !git clone https://github.com/xinntao/Real-ESRGAN.git
    %cd /content/Real-ESRGAN
    !pip install basicsr -q
    !pip install facexlib -q
    !pip install gfpgan -q
    !pip install -r requirements.txt -q
    %cd /content/Real-ESRGAN
    !python setup.py develop -q
    !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
    !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models
    !wget -nc https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P experiments/pretrained_models
    print("Finished Installing libraries for Real-ESRGAN upscaling.")
    %cd /content/
    #

#@markdown ##Git and pip install
!git clone https://github.com/openai/CLIP
!git clone https://github.com/crowsonkb/guided-diffusion
!pip install -e ./CLIP
!pip install -e ./guided-diffusion
!pip install lpips datetime
install_ESRGAN()


In [None]:
#@markdown ##imports
import gc
import io
import math
import sys
from IPython import display
import lpips
from PIL import Image, ImageOps
import requests
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from tqdm.notebook import tqdm
sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import random

# Define necessary functions

In [None]:

# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869

def interp(t):
    return 3 * t**2 - 2 * t ** 3

def perlin(width, height, scale=10, device=None):
    gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
    xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
    ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
    wx = 1 - interp(xs)
    wy = 1 - interp(ys)
    dots = 0
    dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
    dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
    dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
    dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
    return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)

def perlin_ms(octaves, width, height, grayscale, device=device):
    out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
    # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
    for i in range(1 if grayscale else 3):
        scale = 2 ** len(octaves)
        oct_width = width
        oct_height = height
        for oct in octaves:
            p = perlin(oct_width, oct_height, scale, device)
            out_array[i] += p * oct
            scale //= 2
            oct_width *= 2
            oct_height *= 2
    return torch.cat(out_array)

def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):
    out = perlin_ms(octaves, width, height, grayscale)
    if grayscale:
        out = TF.resize(size=(side_x, side_y), img=out.unsqueeze(0))
        out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')
    else:
        out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])
        out = TF.resize(size=(side_x, side_y), img=out)
        out = TF.to_pil_image(out.clamp(0, 1).squeeze())

    out = ImageOps.autocontrast(out)
    return out

#################################################################################################################
def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')


def parse_prompt(prompt):
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', '1'][len(vals):]
    return vals[0], float(vals[1])

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))

def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()

def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]

def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.reshape([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.reshape([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, skip_augs=False):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.skip_augs = skip_augs
        self.augs = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomPerspective(distortion_scale=0.4, p=0.7),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            T.RandomGrayscale(p=0.15),
            T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        ])

    def forward(self, input):
        input = T.Pad(input.shape[2]//4, fill=0)(input)
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)

        cutouts = []
        for ch in range(cutn):
            if ch > cutn - cutn//4:
                cutout = input.clone()
            else:
                size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
                offsetx = torch.randint(0, abs(sideX - size + 1), ())
                offsety = torch.randint(0, abs(sideY - size + 1), ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]

            if not self.skip_augs:
                cutout = self.augs(cutout)
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
            del cutout

        cutouts = torch.cat(cutouts, dim=0)
        return cutouts


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])


def range_loss(input):
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])

##########################################################################################################

#@markdown ##def do_run()
def do_run():
    loss_values = []
 
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
 
    make_cutouts = MakeCutouts(clip_size, cutn, skip_augs=skip_augs)
    target_embeds, weights = [], []
 
    for prompt in text_prompts:
        txt, weight = parse_prompt(prompt)
        txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
 
        if fuzzy_prompt:
            for i in range(25):
                target_embeds.append((txt + torch.randn(txt.shape).cuda() * rand_mag).clamp(0,1))
                weights.append(weight)
        else:
            target_embeds.append(txt)
            weights.append(weight)
 
    for prompt in image_prompts:
        path, weight = parse_prompt(prompt)
        img = Image.open(fetch(path)).convert('RGB')
        img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)
        batch = make_cutouts(TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))
        embed = clip_model.encode_image(normalize(batch)).float()
        if fuzzy_prompt:
            for i in range(25):
                target_embeds.append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))
                weights.extend([weight / cutn] * cutn)
        else:
            target_embeds.append(embed)
            weights.extend([weight / cutn] * cutn)
 
    target_embeds = torch.cat(target_embeds)
    weights = torch.tensor(weights, device=device)
    if weights.sum().abs() < 1e-3:
        raise RuntimeError('The weights must not sum to 0.')
    weights /= weights.sum().abs()
 
    init = None
    if init_image is not None:
        init = Image.open(fetch(init_image)).convert('RGB')
        init = init.resize((side_x, side_y), Image.LANCZOS)
        init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
    
    if perlin_init:
        if perlin_mode == 'color':
            init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
            init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
        elif perlin_mode == 'gray':
           init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
           init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
        else:
           init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
           init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
        
        # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)
        init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
        del init2
 
    cur_t = None
 
    def cond_fn(x, t, y=None):
        with torch.enable_grad():
            x = x.detach().requires_grad_()
            n = x.shape[0]
            my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
            out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
            fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
            x_in = out['pred_xstart'] * fac + x * (1 - fac)
            x_in_grad = torch.zeros_like(x_in)
            for i in range(cutn_batches):
                clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
                image_embeds = clip_model.encode_image(clip_in).float()
                dists = spherical_dist_loss(image_embeds.unsqueeze(1), target_embeds.unsqueeze(0))
                dists = dists.view([cutn, n, -1])
                losses = dists.mul(weights).sum(2).mean(0)
                loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch
                x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches
            tv_losses = tv_loss(x_in)
            range_losses = range_loss(out['pred_xstart'])
            sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()
            loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale
            if init is not None and init_scale:
                init_losses = lpips_model(x_in, init)
                loss = loss + init_losses.sum() * init_scale
            x_in_grad += torch.autograd.grad(loss, x_in)[0]
            grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
        if clamp_grad:
            magnitude = grad.square().mean().sqrt()
            return grad * magnitude.clamp(max=0.05) / magnitude
        return grad
 
    if model_config['timestep_respacing'].startswith('ddim'):
        sample_fn = diffusion.ddim_sample_loop_progressive
    else:
        sample_fn = diffusion.p_sample_loop_progressive
 
    for i in range(n_batches):
        cur_t = diffusion.num_timesteps - skip_timesteps - 1
 
        if model_config['timestep_respacing'].startswith('ddim'):
            samples = sample_fn(
                model,
                (batch_size, 3, side_y, side_x),
                clip_denoised=clip_denoised,
                model_kwargs={},
                cond_fn=cond_fn,
                progress=True,
                skip_timesteps=skip_timesteps,
                init_image=init,
                randomize_class=randomize_class,
                eta=eta,
            )
        else:
            samples = sample_fn(
                model,
                (batch_size, 3, side_y, side_x),
                clip_denoised=clip_denoised,
                model_kwargs={},
                cond_fn=cond_fn,
                progress=True,
                skip_timesteps=skip_timesteps,
                init_image=init,
                randomize_class=randomize_class,
            )

        for j, sample in enumerate(samples):
            display.clear_output(wait=True)
            cur_t -= 1
            if j % display_rate == 0 or cur_t == -1:
                for k, image in enumerate(sample['pred_xstart']):
                    tqdm.write(f'Batch {i}, step {j}, output {k}:')
                    current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')
                    filename = f'progress_batch{i:05}_iteration{j:05}_output{k:05}_{current_time}.png'
                    image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
                    image.save('/content/image_storage/' + filename)
                    display.display(display.Image('/content/image_storage/' + filename))
                    if google_drive and cur_t == -1:
                        final_filename = output_folder_images + filename
                        print(final_filename)
                        #print('THIS VISIBLE? TEST!')
                        if upscale_multiplier >= 2:
                          temp_image_filename = temp_image_storage + _project_name + filename
                          image.save(temp_image_filename)
                          #gc.collect()
                          #image_upscale(realesrgan_mpath,upscale_value,temp_image_filename,'final_filename')
                          #def image_upscale(realesrgan_mpath,scale,temp_image_filename,_output_folder_images):
                          %cd /content/Real-ESRGAN/
                          !python inference_realesrgan.py --model_path $realesrgan_mpath --netscale $upscale_value --face_enhance --input $temp_image_filename --output $output_folder_images --ext jpg
                          if upscale_multiplier == 4:
                            gc.collect()
                            import os
                            doubleup_path = os.path.splitext(final_filename)[0]
                            doubleup_path = doubleup_path + '_out.jpg'
                            print("finalfilename was",final_filename)
                            print("Thingy is ",doubleup_path)
                            !python inference_realesrgan.py --model_path '/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x2plus.pth' --netscale 2 --face_enhance --input $doubleup_path --output $output_folder_images --ext jpg
                          %cd /content/
                        else:
                          image.save(output_folder_images + filename)

                        
 
        plt.plot(np.array(loss_values), 'r')

# Settings & Generation

In [None]:
timestep_respacing = 'ddim25' #@param ['25','50','150','250','500','1000','ddim25','ddim50','ddim150','ddim250','ddim500','ddim1000']
# Modify this value to decrease the number of timesteps.
# timestep_respacing = '25'
diffusion_steps = 1000

model_config = model_and_diffusion_defaults()
if diffusion_model == '512x512_diffusion_uncond_finetune_008100':
    model_config.update({
        'attention_resolutions': '32, 16, 8',
        'class_cond': False,
        'diffusion_steps': diffusion_steps,
        'rescale_timesteps': True,
        'timestep_respacing': timestep_respacing,
        'image_size': 512,
        'learn_sigma': True,
        'noise_schedule': 'linear',
        'num_channels': 256,
        'num_head_channels': 64,
        'num_res_blocks': 2,
        'resblock_updown': True,
        'use_fp16': True,
        'use_scale_shift_norm': True,
    })
elif diffusion_model == '256x256_diffusion_uncond':
    model_config.update({
        'attention_resolutions': '32, 16, 8',
        'class_cond': False,
        'diffusion_steps': diffusion_steps,
        'rescale_timesteps': True,
        'timestep_respacing': timestep_respacing,
        'image_size': 256,
        'learn_sigma': True,
        'noise_schedule': 'linear',
        'num_channels': 256,
        'num_head_channels': 64,
        'num_res_blocks': 2,
        'resblock_updown': True,
        'use_fp16': True,
        'use_scale_shift_norm': True,
    })
side_x = side_y = model_config['image_size']

model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load(f'{model_path}{diffusion_model}.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)
for name, param in model.named_parameters():
    if 'qkv' in name or 'norm' in name or 'proj' in name:
        param.requires_grad_()
if model_config['use_fp16']:
    model.convert_to_fp16()

################################################

clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
lpips_model = lpips.LPIPS(net='vgg').to(device)

In [None]:
#@markdown ##Simple Settings
_text_prompt = 'Indoor sauna 4k, unsettling architecture, sauna creepypasta artwork'#@param {type:"string"}
_project_name = 'saunatimes' #@param{type:"string"}

#_run_upscaler = True #@param{type:"boolean"}
_run_upscaler = True
#--------------------------
#Output image directory handling
output_folder_images = _drive_location + _project_name
contains_slash = (output_folder_images.endswith('/'))
if not contains_slash:
  print("Adding slash to _drive_location...")
  output_folder_images = output_folder_images + '/'
dir_make(output_folder_images)
temp_image_storage = '/content/image_storage/'  #for non-upscaled images
dir_make(temp_image_storage)
#--------------------------

text_prompts = [
    # "an abstract painting of 'ravioli on a plate'",
    # 'cyberpunk wizard on top of a skyscraper, trending on artstation, photorealistic depiction of a cyberpunk wizard',
    _text_prompt,
    # 'cyberpunk wizard',
]
n_batches =  10000#@param{type:"integer"}
_easySettings = 1 #@param {type:"slider", min:0, max:2, step:0.1}
#@markdown `_easySettings` controls a few things. Play around!
#------------------------------------------------
#@markdown ---
##@markdown ##Bulk/Batch Settings

batch_size =  1#@param{type:"raw"}
image_prompts = [
    # 'mona.jpg',
]
#------------------------------------------------

##@markdown ---
#@markdown ##Advanced Settings (Optional)



# 350/50/50/32 and 500/0/0/64 have worked well for 25 timesteps on 256px
# Also, sometimes 1 cutn actually works out fine

#clip_guidance_scale = 5000 # 1000 - Controls how much the image should look like the prompt.
#clip_guidance_scale = 5000 # 1000 - Controls how much the image should look like the prompt.


#tv_scale = 150 # 150 - Controls the smoothness of the final output.
#range_scale = 150 # 150 - Controls how far out of range RGB values are allowed to be.
sat_scale = 0 # 0 - Controls how much saturation is allowed. From nshepperd's JAX notebook.
cutn = 16 #16
#Controls how many crops to take from the image.
cutn_batches = 4 #@param [2,4,8] {type:"raw"}

override_easySettings = False #@param{type:"boolean"}
if override_easySettings:
  _clip_guidance_scale = 0.7 #@param {type:"slider", min:0, max:2, step:0.1}
  _tv_scale = 0.7 #@param {type:"slider", min:0, max:2, step:0.1}
  _range_scale = 0.7 #@param {type:"slider", min:0, max:2, step:0.1}
else:
  _clip_guidance_scale = _easySettings
  _tv_scale = _easySettings
  _range_scale = _easySettings

clip_guidance_scale = 5000 * _clip_guidance_scale
tv_scale = 150 * _tv_scale
range_scale = 150 * _range_scale

# 2 - Accumulate CLIP gradient from multiple batches of cuts [Can help with OOM errors / Low VRAM]

realesrgan_mpath='/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth' #@param ['/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth','/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth','/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x2plus.pth']
if realesrgan_mpath == '/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x2plus.pth':
  upscale_value="2"
else:
  upscale_value="4"

#upscale_multiplier = "2" #@param [2,4]
upscale_multiplier = "2"
upscale_multiplier = int(upscale_multiplier)

init_image = None # None - URL or local path
init_scale = 0 # 0 - This enhances the effect of the init image, a good value is 1000
skip_timesteps = 6 # 0 - Controls the starting point along the diffusion timesteps
perlin_init = False # False - Option to start with random perlin noise
perlin_mode = 'mixed' # 'mixed' ('gray', 'color')

skip_augs = False # False - Controls whether to skip torchvision augmentations
randomize_class = True # True - Controls whether the imagenet class is randomly changed each iteration
clip_denoised = False # False - Determines whether CLIP discriminates a noisy or denoised image
clamp_grad = True # True - Experimental: Using adaptive clip grad in the cond_fn

seed = None
# seed = random.randint(0, 2**32) # Choose a random seed and print it at end of run for reproduction

fuzzy_prompt = False # False - Controls whether to add multiple noisy prompts to the prompt losses
rand_mag = 0.05 # 0.1 - Controls the magnitude of the random noise
eta = 0.5 # 0.0 - DDIM hyperparameter

#@markdown `cutn_batches` might increase quality at the cost of speed. `override_easysettings` does not effect it.

################################################


display_rate = 1
# 1 - Controls how many consecutive batches of images are generated

gc.collect()
torch.cuda.empty_cache()
try:
    do_run()
except KeyboardInterrupt:
    gc.collect()
    torch.cuda.empty_cache()
    pass
finally:
    print('seed', seed)
    print('Output(s) saved to ',output_folder_images)
    gc.collect()
    torch.cuda.empty_cache()

###############################################