Skip to content

Commit

Permalink
feat(api): add params for more SwinIR models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 11, 2023
1 parent 23fb752 commit 2a7621c
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 26 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
noise_source_uniform,
valid_image,
)
from .onnx import OnnxNet, OnnxTensor
from .onnx import OnnxRRDBNet, OnnxTensor
from .params import (
Border,
DeviceParams,
Expand Down
14 changes: 7 additions & 7 deletions api/onnx_web/chain/upscale_bsrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def load_bsrgan(
cache_pipe = server.cache.get("bsrgan", cache_key)

if cache_pipe is not None:
logger.info("reusing existing BSRGAN pipeline")
logger.debug("reusing existing BSRGAN pipeline")
return cache_pipe

logger.debug("loading BSRGAN model from %s", model_path)
logger.info("loading BSRGAN model from %s", model_path)

pipe = OnnxModel(
server,
Expand Down Expand Up @@ -62,7 +62,7 @@ def upscale_bsrgan(
logger.warn("no upscaling model given, skipping")
return source

logger.info("correcting faces with BSRGAN model: %s", upscale.upscale_model)
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
bsrgan = load_bsrgan(server, stage, upscale, device)

Expand All @@ -73,13 +73,13 @@ def upscale_bsrgan(
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("BSRGAN input shape: %s", image.shape)
logger.trace("BSRGAN input shape: %s", image.shape)

scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
)
logger.info("BSRGAN output shape: %s", dest.shape)
logger.trace("BSRGAN output shape: %s", dest.shape)

for x in range(tile_x):
for y in range(tile_y):
Expand All @@ -90,7 +90,7 @@ def upscale_bsrgan(
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
logger.debug(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
Expand All @@ -114,5 +114,5 @@ def upscale_bsrgan(
dest = (dest * 255.0).round().astype(np.uint8)

output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
logger.debug("output image size: %s x %s", output.width, output.height)
return output
6 changes: 3 additions & 3 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
from PIL import Image

from ..onnx import OnnxNet
from ..models.rrdb import RRDBNet
from ..onnx import OnnxRRDBNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
Expand All @@ -20,7 +21,6 @@ def load_resrgan(
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

Expand All @@ -38,7 +38,7 @@ def load_resrgan(

if params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),
Expand Down
2 changes: 2 additions & 0 deletions api/onnx_web/chain/upscale_swinir.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def upscale_swinir(
device = job.get_device()
swinir = load_swinir(server, stage, upscale, device)

# TODO: add support for other sizes
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]

# TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
Expand Down
6 changes: 6 additions & 0 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth",
"scale": 4,
},
{
"model": "bsrgan",
"name": "upscaling-bsrgan-x2",
"source": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGANx2.pth",
"scale": 2,
},
],
# download only
"sources": [
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/convert/correction/gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from os import path

import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from torch.onnx import export

from ...models.rrdb import RRDBNet
from ..utils import ConversionContext, ModelDict

logger = getLogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/convert/upscaling/bsrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def convert_upscaling_bsrgan(
return

logger.info("loading and training model")
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
model = RRDBNet(
in_nc=3,
out_nc=3,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/convert/upscaling/resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch.onnx import export

from ...models.rrdb import RRDBNet
from ..utils import ConversionContext, ModelDict

logger = getLogger(__name__)
Expand All @@ -17,7 +18,6 @@ def convert_upscale_resrgan(
model: ModelDict,
source: str,
):
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

name = model.get("name")
Expand Down
51 changes: 41 additions & 10 deletions api/onnx_web/convert/upscaling/swinir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,50 @@ def convert_upscaling_swinir(
return

logger.info("loading and training model")
img_size = (64, 64) # TODO: does this need to be a fixed value?
# values based on https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py#L128
params = {
"depths": [6, 6, 6, 6, 6, 6],
"embed_dim": 180,
"img_range": 1.0,
"img_size": (64, 64),
"in_chans": 3,
"num_heads": [6, 6, 6, 6, 6, 6],
"resi_connection": "1conv",
"upsampler": "pixelshuffle",
"window_size": 8,
}

if "lightweight" in name:
logger.debug("using SwinIR lightweight params")
params["depths"] = [6, 6, 6, 6]
params["embed_dim"] = 60
params["num_heads"] = [6, 6, 6, 6]
params["upsampler"] = "pixelshuffledirect"
elif "real" in name:
# TODO: add params for large model
logger.debug("using SwinIR real params")
params["upsampler"] = "nearest+conv"
elif "gray_dn" in name:
params["img_size"] = (128, 128)
params["in_chans"] = 1
params["upsampler"] = ""
elif "color_dn" in name:
params["img_size"] = (128, 128)
params["upsampler"] = ""
elif "gray_jpeg" in name:
params["img_size"] = (126, 126)
params["in_chans"] = 1
params["upsampler"] = ""
params["window_size"] = 7
elif "color_jpeg" in name:
params["img_size"] = (126, 126)
params["upsampler"] = ""
params["window_size"] = 7

model = SwinIR(
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
img_range=1.0,
img_size=img_size,
in_chans=3,
mlp_ratio=2,
num_heads=[6, 6, 6, 6, 6, 6],
resi_connection="1conv",
upscale=scale,
upsampler="pixelshuffle",
window_size=8,
**params,
)

torch_model = torch.load(source, map_location=conversion.map_location)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .onnx_net import OnnxTensor, OnnxNet
from .onnx_net import OnnxTensor, OnnxRRDBNet
4 changes: 2 additions & 2 deletions api/onnx_web/onnx/onnx_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def size(self):
return np.shape(self.source)


class OnnxNet:
class OnnxRRDBNet:
"""
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
Provides the RRDBNet interface using an ONNX session.
"""

def __init__(
Expand Down

0 comments on commit 2a7621c

Please sign in to comment.