<a href="https://colab.research.google.com/github/rikalestari1411/rikaimageenhancement/blob/main/Image_Enhancement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#<font face="Trebuchet MS" size="6">Neural Image Super-Resolution<font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><font color="#999" size="4">Latent Diffusion + FBCNN</font><font color="#999" size="4">&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;</font><a href="https://github.com/olaviinha/NeuralImageSuperResolution" target="_blank"><font color="#999" size="4">Github</font></a>

This notebook combines [Latent Diffusion](https://github.com/CompVis/latent-diffusion) and [FBCNN](https://github.com/jiaxi-jiang/FBCNN) in an attempt to improve and enhance image quality.

This notebook does not increase image resolution per se, it will only increase quality and sharpness.

`local_models_dir` and `input` should be paths relative to your Google Drive root. I.e. if your Google Drive has a directory called _images_ and under that directory you have a file _face.jpg_, then `input` value should be `images/face.jpg`

In [None]:
#@title #Setup
#@markdown This cell needs to be run only once. It will mount your Google Drive and setup prerequisites.<br>
#@markdown <small>- Mounting Drive will enable this notebook to save outputs directly to your Drive. Otherwise you will need to copy/download them manually from this notebook.</small>
#@markdown <small><br>- `local_models_dir` is a directory path in your Google Drive to store models used by this notebook. Useful in case you don't want the notebook to download the models every time or in case a third party model download link get broken in the future.</small>

force_setup = False
pip_packages = ''
main_repository = 'https://github.com/olaviinha/FBCNN.git'
mount_drive = True #@param {type:"boolean"}

local_models_dir = "" #@param {type:"string"}

# Download the repo from Github
import os
from google.colab import output
import warnings
warnings.filterwarnings('ignore')
%cd /content/

# inhagcutils
if not os.path.isfile('/content/inhagcutils.ipynb') and force_setup == False:
  !pip -q install import-ipynb {pip_packages}
  !curl -s -O https://raw.githubusercontent.com/olaviinha/inhagcutils/master/inhagcutils.ipynb
import import_ipynb
from inhagcutils import *

# Mount Drive
if mount_drive is True:
  if not os.path.isdir('/content/drive'):
    from google.colab import drive
    drive.mount('/content/drive')
    drive_root = '/content/drive/My Drive'
  if not os.path.isdir('/content/mydrive'):
    os.symlink('/content/drive/My Drive', '/content/mydrive')
    drive_root = '/content/mydrive/'
  drive_root_set = True
else:
  create_dirs(['/content/faux_drive'])
  drive_root = '/content/faux_drive/'

if main_repository is not '':
  !git clone {main_repository}

if local_models_dir is '':
  local_models_dir = '/content/FBCNN/models/'
  r = requests.get('https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/fbcnn_color.pth', allow_redirects=True)
  open(local_models_dir, 'wb').write(r.content)
  r = requests.get('https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/fbcnn_gray.pth', allow_redirects=True)
  open(local_models_dir, 'wb').write(r.content)
else:
  local_models_dir = drive_root+fix_path(local_models_dir)

# JPEG artifact removal
sys.path.insert(1, '/content/FBCNN/models')
sys.path.insert(1, '/content/FBCNN/utils')
import os.path
import numpy as np
from collections import OrderedDict
import torch
import cv2
from PIL import Image, ImageOps
import utils_image as util
from network_fbcnn import FBCNN as net
import requests
import datetime

# SuperRes
!git clone https://github.com/CompVis/latent-diffusion.git
!git clone https://github.com/CompVis/taming-transformers
!pip install -e ./taming-transformers
!pip install ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb
import ipywidgets as widgets
import os
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan # checking correct import from taming
from torchvision.datasets.utils import download_url
%cd '/content/latent-diffusion'
from functools import partial
from ldm.util import instantiate_from_config
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
# from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
%cd '/content'
from google.colab import files
from IPython.display import Image as ipyimg
from numpy import asarray
from einops import rearrange, repeat
import torch, torchvision
import time
from omegaconf import OmegaConf
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):
    # global stride

    example = get_cond(task, img)

    save_intermediate_vid = False
    n_runs = 1
    masked = False
    guider = None
    ckwargs = None
    mode = 'ddim'
    ddim_use_x0_pred = False
    temperature = 1.
    eta = eta
    make_progrow = True
    custom_shape = None

    height, width = example["image"].shape[1:3]
    split_input = height >= 128 and width >= 128

    if split_input:
        ks = 128
        stride = 64
        vqf = 4  #
        model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
                                    "vqf": vqf,
                                    "patch_distributed_vq": True,
                                    "tie_braker": False,
                                    "clip_max_weight": 0.5,
                                    "clip_min_weight": 0.01,
                                    "clip_max_tie_weight": 0.5,
                                    "clip_min_tie_weight": 0.01}
    else:
        if hasattr(model, "split_input_params"):
            delattr(model, "split_input_params")

    invert_mask = False

    x_T = None
    for n in range(n_runs):
        if custom_shape is not None:
            x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
            x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])

        logs = make_convolutional_sample(example, model,
                                         mode=mode, custom_steps=custom_steps,
                                         eta=eta, swap_mode=False , masked=masked,
                                         invert_mask=invert_mask, quantize_x0=False,
                                         custom_schedule=None, decode_interval=10,
                                         resize_enabled=resize_enabled, custom_shape=custom_shape,
                                         temperature=temperature, noise_dropout=0.,
                                         corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
                                         make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
                                         )
    return logs




def do_superres(img, mode):

  if mode == 'Faster':
      sr_diffusion_steps = "25" 
      sr_pre_downsample = '1/2' 
  if mode == 'Fast':
      sr_diffusion_steps = "100" 
      sr_pre_downsample = '1/2' 
  if mode == 'Slow':
      sr_diffusion_steps = "25" 
      sr_pre_downsample = 'None' 
  if mode == 'Very Slow':
      sr_diffusion_steps = "100" 
      sr_pre_downsample = 'None' 

  # sr_diffusion_steps = "100" 
  # sr_pre_downsample = 'None' 

  sr_post_downsample = 'Original Size'
  sr_diffusion_steps = int(sr_diffusion_steps)
  sr_eta = 1.0 
  sr_downsample_method = 'Lanczos' 

  im_og = img
  width_og, height_og = im_og.size

  #Downsample Pre
  if sr_pre_downsample == '1/2':
    downsample_rate = 2
  elif sr_pre_downsample == '1/4':
    downsample_rate = 4
  else:
    downsample_rate = 1

  width_downsampled_pre = width_og//downsample_rate
  height_downsampled_pre = height_og//downsample_rate

  if downsample_rate != 1:
    # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
    im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
    # im_og.save('/content/temp.png')
    # filepath = '/content/temp.png'

  logs = sr_run(sr_model["model"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)

  sample = logs["sample"]
  sample = sample.detach().cpu()
  sample = torch.clamp(sample, -1., 1.)
  sample = (sample + 1.) / 2. * 255
  sample = sample.numpy().astype(np.uint8)
  sample = np.transpose(sample, (0, 2, 3, 1))
  a = Image.fromarray(sample[0])

  #Downsample Post
  if sr_post_downsample == '1/2':
    downsample_rate = 2
  elif sr_post_downsample == '1/4':
    downsample_rate = 4
  else:
    downsample_rate = 1

  width, height = a.size
  width_downsampled_post = width//downsample_rate
  height_downsampled_post = height//downsample_rate

  if sr_downsample_method == 'Lanczos':
    aliasing = Image.LANCZOS
  else:
    aliasing = Image.NEAREST

  if downsample_rate != 1:
    a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)
  elif sr_post_downsample == 'Original Size':
    a = a.resize((width_og, height_og), aliasing)

  # display.display(a)
  # a.save(filepath)
  return a

  
def remove_artifacts(input_img, input_quality):
  input_img_width, input_img_height = input_img.size
  if is_gray:
    n_channels = 1 # set 1 for grayscale image, set 3 for color image
    model_name = 'fbcnn_gray.pth'
  else:
    n_channels = 3 # set 1 for grayscale image, set 3 for color image
    model_name = 'fbcnn_color.pth'
  nc = [64,128,256,512]
  nb = 4
  input_quality = 100 - input_quality
  model_path = local_models_dir+model_name
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
  model.load_state_dict(torch.load(model_path), strict=True)
  model.eval()
  for k, v in model.named_parameters():
      v.requires_grad = False
  model = model.to(device)
  test_results = OrderedDict()
  test_results['psnr'] = []
  test_results['ssim'] = []
  test_results['psnrb'] = []
  if n_channels == 1:
      open_cv_image = Image.fromarray(input_img)
      # open_cv_image = input_img
      open_cv_image = ImageOps.grayscale(open_cv_image)
      open_cv_image = np.array(open_cv_image) # PIL to open cv image
      img = np.expand_dims(open_cv_image, axis=2)  # HxWx1
  elif n_channels == 3:
      open_cv_image = np.array(input_img) # PIL to open cv image
      # open_cv_image = input_img # PIL to open cv image
      if open_cv_image.ndim == 2:
          open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB)  # GGG
      else:
          open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)  # RGB
  img_L = util.uint2tensor4(open_cv_image)
  img_L = img_L.to(device)
  img_E,QF = model(img_L)
  img_E = util.tensor2single(img_E)
  img_E = util.single2uint(img_E)
  qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
  img_E,QF = model(img_L, qf_input)  
  img_E = util.tensor2single(img_E)
  img_E = util.single2uint(img_E)

  if img_E.ndim == 3:
    img_E = img_E[:, :, [2, 1, 0]]
  out_img = Image.fromarray(img_E)

  return out_img


def download_models(mode):
  if mode == "superresolution":
    # this is the small bsr light model
    url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
    url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'

    path_conf = f'{local_models_dir}/superres/project.yaml'
    path_ckpt = f'{local_models_dir}/superres/last.ckpt'

    download_url(url_conf, path_conf)
    download_url(url_ckpt, path_ckpt)

    path_conf = path_conf + '/?dl=1' # fix it
    path_ckpt = path_ckpt + '/?dl=1' # fix it
    return path_conf, path_ckpt

  else:
    raise NotImplementedError
      
def get_model(mode):
  path_conf, path_ckpt = download_models(mode)
  config = OmegaConf.load(path_conf)
  model, step = load_model_from_config(config, path_ckpt)
  return model

def load_model_from_config(config, ckpt):
  #print(f"Loading model from {ckpt}")
  pl_sd = torch.load(ckpt, map_location="cpu")
  global_step = pl_sd["global_step"]
  sd = pl_sd["state_dict"]
  model = instantiate_from_config(config.model)
  m, u = model.load_state_dict(sd, strict=False)
  model.cuda()
  model.eval()
  return {"model": model}, global_step


def get_cond(mode, img):
  example = dict()
  if mode == "superresolution":
    up_f = 4
    c = img
    c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
    c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
    c_up = rearrange(c_up, '1 c h w -> 1 h w c')
    c = rearrange(c, '1 c h w -> 1 h w c')
    c = 2. * c - 1.
    c = c.to(torch.device("cuda"))
    example["LR_image"] = c
    example["image"] = c_up
  return example

class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               **kwargs
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        # print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        #print(f"Running DDIM Sharpening with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
        b, *_, device = *x.shape, x.device
        e_t = self.model.apply_model(x, t, c)
        if score_corrector is not None:
            assert self.model.parameterization == "eps"
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # current prediction for x_0
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0


@torch.no_grad()
def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
                              invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
                              resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
                              corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
    log = dict()
    z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
                                        return_first_stage_outputs=True,
                                        force_c_encode=not (hasattr(model, 'split_input_params')
                                                            and model.cond_stage_key == 'coordinates_bbox'),
                                        return_original_cond=True)

    log_every_t = 1 if save_intermediate_vid else None
    if custom_shape is not None:
        z = torch.randn(custom_shape)
        # print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
    z0 = None
    log["input"] = x
    log["reconstruction"] = xrec
    if ismap(xc):
        log["original_conditioning"] = model.to_rgb(xc)
        if hasattr(model, 'cond_stage_key'):
            log[model.cond_stage_key] = model.to_rgb(xc)
    else:
        log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
        if model.cond_stage_model:
            log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
            if model.cond_stage_key =='class_label':
                log[model.cond_stage_key] = xc[model.cond_stage_key]
    with model.ema_scope("Plotting"):
        t0 = time.time()
        img_cb = None
        sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
                                                eta=eta,
                                                quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,
                                                temperature=temperature, noise_dropout=noise_dropout,
                                                score_corrector=corrector, corrector_kwargs=corrector_kwargs,
                                                x_T=x_T, log_every_t=log_every_t)
        t1 = time.time()
        if ddim_use_x0_pred:
            sample = intermediates['pred_x0'][-1]
    x_sample = model.decode_first_stage(sample)
    try:
        x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
        log["sample_noquant"] = x_sample_noquant
        log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
    except:
        pass
    log["sample"] = x_sample
    log["time"] = t1 - t0
    return log

@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
                    mask=None, x0=None, quantize_x0=False, img_callback=None,
                    temperature=1., noise_dropout=0., score_corrector=None,
                    corrector_kwargs=None, x_T=None, log_every_t=None
                    ):
    ddim = DDIMSampler(model)
    bs = shape[0]  # dont know where this comes from but wayne
    shape = shape[1:]  # cut batch dim
    samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
                                         normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
                                         mask=mask, x0=x0, temperature=temperature, verbose=False,
                                         score_corrector=score_corrector,
                                         corrector_kwargs=corrector_kwargs, x_T=x_T)
    return samples, intermediates

from tqdm.notebook import tqdm
import gc
sr_diffMode = 'superresolution'
sr_model = get_model('superresolution')



output.clear()
# !nvidia-smi
op(c.ok, 'Setup finished.')

/content
[K     |████████████████████████████████| 1.6 MB 29.8 MB/s 
[?25himporting Jupyter notebook from inhagcutils.ipynb


MessageError: ignored

In [None]:
#@title #Run

#@markdown <small><font color="#999">`input` may be a file path or a directory path. Supports wildchar, e.g. `images/face_*.jpg`</font></small>
input = "" #@param {type:"string"}

#@markdown <small><font color="#999">Leave `output_dir` empty to save in the same directory where `input` image(s) reside. Input images are not overwritten; result images are saved with a `_superres.png` suffix.</font></small>
output_dir = "" #@param {type:"string"}

#@markdown Latent Diffusion sharpening settings:
#@markdown <small><br><font color="#999">The slower the better.</font></small>
superres_quality = 'Very Slow' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']

#@markdown FBCNN JPEG artifact removal settings:
#@markdown <small><br><font color="#999">The larger number the cleaner image.</font></small>
artifact_removal_rate = 75 #@param {type:"slider"}
is_gray = False #@param {type:"boolean"}

input_imgs = glob(drive_root+input)

if output_dir is not '':
  output_dir = drive_root+fix_path(output_dir)
  if not os.path.isdir(output_dir):
    os.mkdir(output_dir)
else:
  output_dir = drive_root+path_dir(input)

total_imgs = len(input_imgs)

for i, input_img in enumerate(input_imgs):

  gc.collect()
  torch.cuda.empty_cache()
  
  img_out_path = output_dir+basename(input_img)+'_superres.png'
  disp_in = input_img.replace(drive_root, '')
  disp_out = img_out_path.replace(drive_root, '')

  ndx = i+1
  op(c.title, 'Processing image '+str(ndx)+'/'+str(total_imgs), path_leaf(input_img), time=True)

  if os.path.isfile(img_out_path):
    op(c.fail, 'File already exists, skipping', disp_out)
  else:
    img_in = Image.open(input_img)
    img_state = img_in

    if artifact_removal_rate > 0:
      img_state = remove_artifacts(img_state, artifact_removal_rate)

    if superres_quality != 'Off':
      img_state = do_superres(img_state, superres_quality)

    img_out = img_state
    img_out.save(img_out_path)

    if os.path.isfile(img_out_path):
      op(c.ok, 'Image saved as', disp_out)
    else:
      op(c.fail, 'An error occurred with', disp_out)
    print('')

op(c.ok, 'FIN.')