Skip to content

Commit

Permalink
feat(api): add stage for local standard deviation denoising for XL
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 28, 2023
1 parent 7d56689 commit f28fdda
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = getLogger(__name__)


class BlendDenoiseStage(BaseStage):
class BlendDenoiseFastNLMeansStage(BaseStage):
max_tile = SizeChart.max

def run(
Expand Down
113 changes: 113 additions & 0 deletions api/onnx_web/chain/blend_denoise_localstd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from logging import getLogger
from typing import Optional

import numpy as np
from PIL import Image

from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult

logger = getLogger(__name__)


class BlendDenoiseLocalStdStage(BaseStage):
max_tile = SizeChart.max

def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
strength: int = 3,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("denoising source images")

results = []
for source in sources.as_numpy():
results.append(remove_noise(source))

return StageResult(arrays=results)


def downscale_image(image):
result_image = np.zeros((image.shape[0] // 2, image.shape[1] // 2), dtype=np.uint8)

for i in range(0, image.shape[0] - 1, 2):
for j in range(0, image.shape[1] - 1, 2):
# Average the four neighboring pixels
pixel_average = np.mean(image[i : i + 2, j : j + 2], axis=(0, 1))
result_image[i // 2, j // 2] = pixel_average.astype(np.uint8)

return result_image


def replace_noise(region, threshold):
# Identify stray pixels (brightness significantly deviates from surrounding pixels)
central_pixel = region[1, 1]

region_median = np.median(region)
region_deviation = np.std(region)
diff = np.abs(central_pixel - region_median)

# If the whole region is fairly consistent but the central pixel deviates significantly,
if diff > region_deviation and diff > threshold:
surrounding_pixels = region[region != central_pixel]
surrounding_median = np.median(surrounding_pixels)
# replace it with the median of surrounding pixels
region[1, 1] = surrounding_median
return True

return False


def remove_noise(image, region_size=(6, 6), threshold=10):
# Assuming 'image' is a 3D numpy array representing the RGB image

# Create a copy of the original image to store the result
result_image = np.copy(image)
# result_mask = np.ones_like(image) * 255

# Iterate over regions in each channel
i_inc = region_size[0] // 2
j_inc = region_size[1] // 2

for i in range(i_inc, image.shape[0] - i_inc, 1):
for j in range(j_inc, image.shape[1] - j_inc, 1):
i_min = i - (region_size[0] // 2)
i_max = i + (region_size[0] // 2)
j_min = j - (region_size[1] // 2)
j_max = j + (region_size[1] // 2)

# Extract region from each channel
region_red = downscale_image(image[i_min:i_max, j_min:j_max, 0])
region_green = downscale_image(image[i_min:i_max, j_min:j_max, 1])
region_blue = downscale_image(image[i_min:i_max, j_min:j_max, 2])

replaced = any(
[
replace_noise(region_red, threshold),
replace_noise(region_green, threshold),
]
)

# Apply the noise removal function to each channel
if replaced:
# Assign the processed region back to the result image
result_image[i - 1 : i + 1, j - 1 : j + 1, 0] = region_red[1, 1]
result_image[i - 1 : i + 1, j - 1 : j + 1, 1] = region_green[1, 1]
result_image[i - 1 : i + 1, j - 1 : j + 1, 2] = region_blue[1, 1]

# result_mask[i-1:i+1, j-1:j+1, 0] = 0
# result_mask[i-1:i+1, j-1:j+1, 1] = 0
# result_mask[i-1:i+1, j-1:j+1, 2] = 0

return result_image # , result_mask)
7 changes: 5 additions & 2 deletions api/onnx_web/chain/stages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from logging import getLogger

from .base import BaseStage
from .blend_denoise import BlendDenoiseStage
from .blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage
from .blend_denoise_localstd import BlendDenoiseLocalStdStage
from .blend_grid import BlendGridStage
from .blend_img2img import BlendImg2ImgStage
from .blend_linear import BlendLinearStage
Expand All @@ -27,7 +28,9 @@
logger = getLogger(__name__)

CHAIN_STAGES = {
"blend-denoise": BlendDenoiseStage,
"blend-denoise": BlendDenoiseFastNLMeansStage,
"blend-denoise-fastnlmeans": BlendDenoiseFastNLMeansStage,
"blend-denoise-localstd": BlendDenoiseLocalStdStage,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage,
Expand Down
6 changes: 3 additions & 3 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image, ImageOps

from ..chain import (
BlendDenoiseStage,
BlendDenoiseLocalStdStage,
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
Expand Down Expand Up @@ -78,9 +78,9 @@ def run_txt2img_pipeline(

# apply upscaling and correction, before highres
highres_size = get_highres_tile(server, params, highres, tile_size)
if params.is_panorama():
if params.is_xl():
chain.stage(
BlendDenoiseStage(),
BlendDenoiseLocalStdStage(),
StageParams(tile_size=highres_size),
)

Expand Down

0 comments on commit f28fdda

Please sign in to comment.