Skip to content

Commit

Permalink
fix(api): convert Real ESRGAN v3 using same arch as runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 18, 2023
1 parent 2c9d96d commit 338fc23
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
10 changes: 5 additions & 5 deletions api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
last_pipeline_instance = None
last_pipeline_params = (None, None)

x4_v3_tag = "real-esrgan-x4-v3"
TAG_X4_V3 = "real-esrgan-x4-v3"


def load_resrgan(
Expand All @@ -37,7 +37,7 @@ def load_resrgan(
if not path.isfile(model_path):
raise Exception("Real ESRGAN model not found at %s" % model_path)

elif params.format == "onnx":
if params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet(
server,
Expand All @@ -46,7 +46,7 @@ def load_resrgan(
sess_options=device.sess_options(),
)
elif params.format == "pth":
if x4_v3_tag in model_file:
if TAG_X4_V3 in model_file:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
Expand All @@ -69,8 +69,8 @@ def load_resrgan(
raise Exception("unknown platform %s" % params.format)

dni_weight = None
if params.upscale_model == x4_v3_tag and params.denoise != 1:
wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag))
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]

Expand Down
31 changes: 23 additions & 8 deletions api/onnx_web/convert/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from torch.onnx import export

from .utils import ConversionContext, ModelDict

logger = getLogger(__name__)

TAG_X4_V3 = "real-esrgan-x4-v3"


@torch.no_grad()
def convert_upscale_resrgan(
Expand All @@ -28,14 +31,26 @@ def convert_upscale_resrgan(
return

logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)

if TAG_X4_V3 in name:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=scale,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)

torch_model = torch.load(source, map_location=ctx.map_location)
if "params_ema" in torch_model:
Expand Down

0 comments on commit 338fc23

Please sign in to comment.