Skip to content

Commit

Permalink
Merge pull request #65 from prs-eth/pipeline_variable
Browse files Browse the repository at this point in the history
Add pipeline variables
  • Loading branch information
markkua committed May 25, 2024
2 parents dfc2e11 + c91a70a commit ed281c7
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 137 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,10 @@ Activate the environment again after restarting the terminal session.

### 🚀 Run inference with LCM (faster)

The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 to 4:
The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 (default) to 4. Run with default LCM setting:

```bash
python run.py \
--denoise_steps 4 \
--ensemble_size 5 \
--input_rgb_dir input/in-the-wild_example \
--output_dir output/in-the-wild_example_lcm
```
Expand All @@ -156,11 +154,11 @@ The default settings are optimized for the best result. However, the behavior of

- Trade-offs between the **accuracy** and **speed** (for both options, larger values result in better accuracy at the cost of slower inference.)
- `--ensemble_size`: Number of inference passes in the ensemble. For LCM `ensemble_size` is more important than `denoise_steps`. Default: ~~10~~ 5 (for LCM).
- `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. Default: ~~10~~ 4 (for LCM).
- `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. When unassigned (`None`), will read default setting from model config. Default: ~~10 4 (for LCM)~~ `None`.

- By default, the inference script resizes input images to the *processing resolution*, and then resizes the prediction back to the original resolution. This gives the best quality, as Stable Diffusion, from which Marigold is derived, performs best at 768x768 resolution.

- `--processing_res`: the processing resolution; set 0 to process the input resolution directly. Default: 768.
- `--processing_res`: the processing resolution; set as 0 to process the input resolution directly. When unassigned (`None`), will read default setting from model config. Default: ~~768~~ `None`.
- `--output_processing_res`: produce output at the processing resolution instead of upsampling it to the input resolution. Default: False.
- `--resample_method`: resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`.

Expand Down
5 changes: 4 additions & 1 deletion infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Last modified: 2024-04-15
# Last modified: 2024-05-24
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -213,6 +213,9 @@ def check_directory(directory):
logging.debug("run without xformers")

pipe = pipe.to(device)
logging.info(
f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}"
)

# -------------------- Inference and saving --------------------
with torch.no_grad():
Expand Down
109 changes: 79 additions & 30 deletions marigold/marigold_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
# Last modified: 2024-05-24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +20,7 @@


import logging
from typing import Dict, Union
from typing import Dict, Optional, Union

import numpy as np
import torch
Expand All @@ -33,13 +34,13 @@
from diffusers.utils import BaseOutput
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms.functional import resize, pil_to_tensor
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import pil_to_tensor, resize
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from .util.batchsize import find_batch_size
from .util.ensemble import ensemble_depths
from .util.ensemble import ensemble_depth
from .util.image_util import (
chw2hwc,
colorize_depth_maps,
Expand Down Expand Up @@ -85,6 +86,25 @@ class MarigoldPipeline(DiffusionPipeline):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
scale_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
the model config. When used together with the `shift_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
shift_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
the model config. When used together with the `scale_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
default_denoising_steps (`int`, *optional*):
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
quality with the given model. This value must be set in the model config. When the pipeline is called
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
default_processing_resolution (`int`, *optional*):
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
default value is used. This is required to ensure reasonable results with various model flavors trained
with varying optimal processing resolution values.
"""

rgb_latent_scale_factor = 0.18215
Expand All @@ -97,26 +117,40 @@ def __init__(
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
scale_invariant: Optional[bool] = True,
shift_invariant: Optional[bool] = True,
default_denoising_steps: Optional[int] = None,
default_processing_resolution: Optional[int] = None,
):
super().__init__()

self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
self.register_to_config(
scale_invariant=scale_invariant,
shift_invariant=shift_invariant,
default_denoising_steps=default_denoising_steps,
default_processing_resolution=default_processing_resolution,
)

self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant
self.default_denoising_steps = default_denoising_steps
self.default_processing_resolution = default_processing_resolution

self.empty_text_embed = None

@torch.no_grad()
def __call__(
self,
input_image: Union[Image.Image, torch.Tensor],
denoising_steps: int = 10,
ensemble_size: int = 10,
processing_res: int = 768,
denoising_steps: Optional[int] = None,
ensemble_size: int = 5,
processing_res: Optional[int] = None,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
Expand All @@ -131,18 +165,21 @@ def __call__(
Args:
input_image (`Image`):
Input RGB (or gray-scale) image.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
denoising_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
for Marigold-LCM models.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
processing_res (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, processes at the original image resolution. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
value `None` resolves to the optimal value from the model config.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
Expand All @@ -152,6 +189,10 @@ def __call__(
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
Colormap used to colorize the depth map.
scale_invariant (`str`, *optional*, defaults to `True`):
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
shift_invariant (`str`, *optional*, defaults to `True`):
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
Expand All @@ -161,6 +202,12 @@ def __call__(
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
coming from ensembling. None if `ensemble_size = 1`
"""
# Model-specific optimal default values leading to fast and reasonable results.
if denoising_steps is None:
denoising_steps = self.default_denoising_steps
if processing_res is None:
processing_res = self.default_processing_resolution

assert processing_res >= 0
assert ensemble_size >= 1

Expand All @@ -175,14 +222,15 @@ def __call__(
input_image = input_image.convert("RGB")
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
rgb = pil_to_tensor(input_image)
rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
elif isinstance(input_image, torch.Tensor):
rgb = input_image.squeeze()
rgb = input_image
else:
raise TypeError(f"Unknown input type: {type(input_image) = }")
input_size = rgb.shape
assert (
3 == rgb.dim() and 3 == input_size[0]
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
4 == rgb.dim() and 3 == input_size[-3]
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"

# Resize image
if processing_res > 0:
Expand All @@ -199,7 +247,7 @@ def __call__(

# ----------------- Predicting depth -----------------
# Batch repeated input image
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
single_rgb_dataset = TensorDataset(duplicated_rgb)
if batch_size > 0:
_bs = batch_size
Expand Down Expand Up @@ -231,35 +279,36 @@ def __call__(
generator=generator,
)
depth_pred_ls.append(depth_pred_raw.detach())
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
depth_preds = torch.concat(depth_pred_ls, dim=0)
torch.cuda.empty_cache() # clear vram cache for ensembling

# ----------------- Test-time ensembling -----------------
if ensemble_size > 1:
depth_pred, pred_uncert = ensemble_depths(
depth_preds, **(ensemble_kwargs or {})
depth_pred, pred_uncert = ensemble_depth(
depth_preds,
scale_invariant=self.scale_invariant,
shift_invariant=self.shift_invariant,
max_res=50,
**(ensemble_kwargs or {}),
)
else:
depth_pred = depth_preds
pred_uncert = None

# ----------------- Post processing -----------------
# Scale prediction to [0, 1]
min_d = torch.min(depth_pred)
max_d = torch.max(depth_pred)
depth_pred = (depth_pred - min_d) / (max_d - min_d)

# Resize back to original resolution
if match_input_res:
depth_pred = resize(
depth_pred.unsqueeze(0),
input_size[1:],
depth_pred,
input_size[-2:],
interpolation=resample_method,
antialias=True,
).squeeze()
)

# Convert to numpy
depth_pred = depth_pred.squeeze()
depth_pred = depth_pred.cpu().numpy()
if pred_uncert is not None:
pred_uncert = pred_uncert.squeeze().cpu().numpy()

# Clip output range
depth_pred = depth_pred.clip(0, 1)
Expand Down
Loading

0 comments on commit ed281c7

Please sign in to comment.