In [1]:
from typing import List, Tuple
import os
import math
from argparse import ArgumentParser, Namespace

import numpy as np
import torch
import einops
import pytorch_lightning as pl
from PIL import Image
from omegaconf import OmegaConf

from model.spaced_sampler import SpacedSampler
from model.ddim_sampler import DDIMSampler
from model.cldm import ControlLDM
from utils.image import (
    wavelet_reconstruction, adaptive_instance_normalization, auto_resize, pad
)
from utils.common import instantiate_from_config, load_state_dict
from utils.file import list_image_files, get_file_name_parts

  from .autonotebook import tqdm as notebook_tqdm
  rank_zero_deprecation(


In [2]:
@torch.no_grad()
def process(
    model: ControlLDM,
    control_imgs: List[np.ndarray],
    sampler: str,
    steps: int,
    strength: float,
    color_fix_type: str,
    disable_preprocess_model: bool,
    positive_prompt: str = "",
    prompt_scale: float = 1.0
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """
    Apply DiffBIR model on a list of low-quality images.
    
    Args:
        model (ControlLDM): Model.
        control_imgs (List[np.ndarray]): A list of low-quality images (HWC, RGB, range in [0, 255])
        sampler (str): Sampler name.
        steps (int): Sampling steps.
        strength (float): Control strength. Set to 1.0 during traning.
        color_fix_type (str): Type of color correction for samples.
        disable_preprocess_model (bool): If specified, preprocess model (SwinIR) will not be used.
    
    Returns:
        preds (List[np.ndarray]): Restoration results (HWC, RGB, range in [0, 255]).
        stage1_preds (List[np.ndarray]): Outputs of preprocess model (HWC, RGB, range in [0, 255]). 
            If `disable_preprocess_model` is specified, then preprocess model's outputs is the same 
            as low-quality inputs.
    """
    n_samples = len(control_imgs)
    if sampler == "ddpm":
        sampler = SpacedSampler(model, var_type="fixed_small")
    else:
        sampler = DDIMSampler(model)
    control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
    # TODO: model.preprocess_model = lambda x: x
    if not disable_preprocess_model and hasattr(model, "preprocess_model"):
        control = model.preprocess_model(control)
    elif disable_preprocess_model and not hasattr(model, "preprocess_model"):
        raise ValueError(f"model doesn't have a preprocess model.")
    
    height, width = control.size(-2), control.size(-1)
    cond = {
        "c_latent": [model.apply_condition_encoder(control)],
        "c_crossattn": [model.get_learned_conditioning([positive_prompt] * n_samples)]
    }
    model.control_scales = [strength] * 13
    
    shape = (n_samples, 4, height // 8, width // 8)
    x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
    if isinstance(sampler, SpacedSampler):
        samples = sampler.sample(
            steps, shape, cond,
            unconditional_guidance_scale=prompt_scale,
            unconditional_conditioning=None,
            cond_fn=None, x_T=x_T
        )
    else:
        sampler: DDIMSampler
        samples, _ = sampler.sample(
            S=steps, batch_size=shape[0], shape=shape[1:],
            conditioning=cond, unconditional_conditioning=None,
            x_T=x_T, eta=0
        )
    x_samples = model.decode_first_stage(samples)
    x_samples = ((x_samples + 1) / 2).clamp(0, 1)
    
    # apply color correction (borrowed from StableSR)
    if color_fix_type == "adain":
        x_samples = adaptive_instance_normalization(x_samples, control)
    elif color_fix_type == "wavelet":
        x_samples = wavelet_reconstruction(x_samples, control)
    else:
        assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}"
    
    x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    control = (einops.rearrange(control, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    
    preds = [x_samples[i] for i in range(n_samples)]
    stage1_preds = [control[i] for i in range(n_samples)]
    
    return preds, stage1_preds

In [3]:
def parse_args(x) -> Namespace:
    parser = ArgumentParser()
    
    parser.add_argument("--ckpt", required=True, type=str)
    parser.add_argument("--config", required=True, type=str)
    parser.add_argument("--reload_swinir", action="store_true")
    parser.add_argument("--swinir_ckpt", type=str, default="")
    
    parser.add_argument("--input", type=str, required=True)
    parser.add_argument("--sampler", type=str, default="ddpm", choices=["ddpm", "ddim"])
    parser.add_argument("--steps", required=True, type=int)
    parser.add_argument("--sr_scale", type=float, default=1)
    parser.add_argument("--image_size", type=int, default=512)
    parser.add_argument("--repeat_times", type=int, default=1)
    parser.add_argument("--disable_preprocess_model", action="store_true")
    
    parser.add_argument("--color_fix_type", type=str, default="wavelet", choices=["wavelet", "adain", "none"])
    parser.add_argument("--resize_back", action="store_true")
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--show_lq", action="store_true")
    parser.add_argument("--skip_if_exist", action="store_true")
    
    parser.add_argument("--seed", type=int, default=231)
    parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
    
    return parser.parse_args(x)

args = parse_args("--ckpt weights/general_full_v1.ckpt --config configs/model/cldm.yaml --reload_swinir --swinir_ckpt weights/general_swinir_v1.ckpt --input X --output output --steps 50".split(" "))
args.input = "inputs/my/"
args.output = "output"
args.steps = 50
args.sr_scale = 3
args.image_size = 300
args.resize_back = True

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model: ControlLDM = instantiate_from_config(OmegaConf.load(args.config))
load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
# reload preprocess model if specified
if args.reload_swinir:
    if not hasattr(model, "preprocess_model"):
        raise ValueError(f"model don't have a preprocess model.")
    print(f"reload swinir model from {args.swinir_ckpt}")
    load_state_dict(model.preprocess_model, torch.load(args.swinir_ckpt, map_location="cpu"), strict=True)
model.freeze()
model.to(device);

ControlLDM: Running in eps-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.


Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /opt/conda/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
reload swinir model from weights/general_swinir_v1.ckpt


In [6]:
args.input = "inputs/my/"
args.output = "output"
args.steps = 50
args.sr_scale = 3
args.image_size = 300
args.resize_back = True
positive_prompt = ""
prompt_scale = 1.
strength = 0.95

assert os.path.isdir(args.input)

print(f"sampling {args.steps} steps using ddpm sampler")
for file_path in list_image_files(args.input, follow_links=True):
    lq = Image.open(file_path).convert("RGB")
    if args.sr_scale != 1:
        lq = lq.resize(
            tuple(math.ceil(x * args.sr_scale) for x in lq.size),
            Image.BICUBIC
        )
    lq_resized = auto_resize(lq, args.image_size)
    x = pad(np.array(lq_resized), scale=64)
    
    save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
    parent_path, stem, _ = get_file_name_parts(save_path)
    i = 0
    save_path = os.path.join(parent_path, f"{stem}_{i}.png")
    while os.path.exists(save_path):
        i += 1
        save_path = os.path.join(parent_path, f"{stem}_{i}.png")

    
    os.makedirs(parent_path, exist_ok=True)
    
    preds, stage1_preds = process(
        model, [x], steps=args.steps, sampler=args.sampler,
        strength=strength,
        color_fix_type=args.color_fix_type,
        disable_preprocess_model=args.disable_preprocess_model,
        positive_prompt=positive_prompt,
        prompt_scale=prompt_scale
    )
    
    pred, stage1_pred = preds[0], stage1_preds[0]
    
    # remove padding
    pred = pred[:lq_resized.height, :lq_resized.width, :]
    stage1_pred = stage1_pred[:lq_resized.height, :lq_resized.width, :]
    
    if args.show_lq:
        if args.resize_back:
            if lq_resized.size != lq.size:
                pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
                stage1_pred = np.array(Image.fromarray(stage1_pred).resize(lq.size, Image.LANCZOS))
            lq = np.array(lq)
        else:
            lq = np.array(lq_resized)
        images = [lq, pred] if args.disable_preprocess_model else [lq, stage1_pred, pred]
        Image.fromarray(np.concatenate(images, axis=1)).save(save_path)
    else:
        if args.resize_back and lq_resized.size != lq.size:
            Image.fromarray(pred).resize(lq.size, Image.LANCZOS).save(save_path)
        else:
            Image.fromarray(pred).save(save_path)
    print(f"save to {save_path}")


sampling 50 steps using ddpm sampler
start to sample from a given noise
Running Spaced Sampling with 50 timesteps


Spaced Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.62it/s]


save to output/synthetic fruits (bbox)_ds0_182 (x300, sigma=4)_1.png
start to sample from a given noise
Running Spaced Sampling with 50 timesteps


Spaced Sampler: 100%|██████████| 50/50 [00:07<00:00,  6.62it/s]


save to output/synthetic fruits (bbox)_ds0_182 (x300, sigma=2)_1.png
