Skip to content

Commit

Permalink
fix(api): better handling of alpha channels
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Nov 26, 2023
1 parent c134edf commit 1c3b2f8
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 18 deletions.
9 changes: 5 additions & 4 deletions api/onnx_web/chain/blend_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from PIL import Image

from ..params import ImageParams, SizeChart, StageParams
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
Expand Down Expand Up @@ -36,9 +36,10 @@ def run(
logger.info("combining source images using grid layout")

images = sources.as_image()
size = images[0].size
ref_image = images[0]
size = Size(*ref_image.size)

output = Image.new("RGB", (size[0] * width, size[1] * height))
output = Image.new(ref_image.mode, (size.width * width, size.height * height))

# TODO: labels
if order is None:
Expand All @@ -49,7 +50,7 @@ def run(
y = i // width

n = order[i]
output.paste(images[n], (x * size[0], y * size[1]))
output.paste(images[n], (x * size.width, y * size.height))

return StageResult(images=[*images, output])

Expand Down
3 changes: 2 additions & 1 deletion api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def run(
) -> StageResult:
logger.info("blending image using mask")

mult_mask = Image.new("RGBA", stage_mask.size, color="black")
# TODO: does this need an alpha channel?
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L")

Expand Down
14 changes: 13 additions & 1 deletion api/onnx_web/chain/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,16 @@ def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images

return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays]
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]


def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")

if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"

raise ValueError("unknown image format")
6 changes: 4 additions & 2 deletions api/onnx_web/chain/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def blend_tiles(
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
)

scaled_size = (height * scale, width * scale, 3)
channels = max([4 if tile_image.mode == "RGBA" else 3 for _left, _top, tile_image in tiles])
scaled_size = (height * scale, width * scale, channels)

count = np.zeros(scaled_size)
value = np.zeros(scaled_size)

Expand Down Expand Up @@ -221,7 +223,7 @@ def blend_tiles(
margin_left : equalized.shape[1] + margin_right,
np.newaxis,
],
3,
channels,
axis=2,
)

Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/image/mask_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def mask_filter_none(
) -> Image.Image:
width, height = dims

noise = Image.new("RGB", (width, height), fill)
noise = Image.new(mask.mode, (width, height), fill)
noise.paste(mask, origin)

return noise
Expand Down
19 changes: 11 additions & 8 deletions api/onnx_web/image/noise_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ def noise_source_fill_edge(
"""
width, height = dims

noise = Image.new("RGB", (width, height), fill)
noise = Image.new(source.mode, (width, height), fill)
noise.paste(source, origin)

return noise


def noise_source_fill_mask(
_source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
) -> Image.Image:
"""
Fill the whole canvas, no source or noise.
"""
width, height = dims

noise = Image.new("RGB", (width, height), fill)
noise = Image.new(source.mode, (width, height), fill)

return noise

Expand All @@ -52,7 +52,7 @@ def noise_source_gaussian(


def noise_source_uniform(
_source: Image.Image, dims: Point, _origin: Point, **kw
source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
Expand All @@ -61,18 +61,19 @@ def noise_source_uniform(
noise_g = random.uniform(0, 256, size=size)
noise_b = random.uniform(0, 256, size=size)

# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height))

for x in range(width):
for y in range(height):
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))

return noise
return noise.convert(source.mode)


def noise_source_normal(
_source: Image.Image, dims: Point, _origin: Point, **kw
source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
Expand All @@ -81,14 +82,15 @@ def noise_source_normal(
noise_g = random.normal(128, 32, size=size)
noise_b = random.normal(128, 32, size=size)

# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height))

for x in range(width):
for y in range(height):
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))

return noise
return noise.convert(source.mode)


def noise_source_histogram(
Expand All @@ -112,11 +114,12 @@ def noise_source_histogram(
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
)

# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height))

for x in range(width):
for y in range(height):
i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))

return noise
return noise.convert(source.mode)
2 changes: 1 addition & 1 deletion api/onnx_web/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def expand_image(
size = tuple(size)
origin = (expand.left, expand.top)

full_source = Image.new("RGB", size, fill)
full_source = Image.new(source.mode, size, fill)
full_source.paste(source, origin)

# new mask pixels need to be filled with white so they will be replaced
Expand Down

0 comments on commit 1c3b2f8

Please sign in to comment.