Skip to content

Commit

Permalink
feat(api): support both ESRGAN variants
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 30, 2023
1 parent 6cee411 commit 0ddc162
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 25 deletions.
33 changes: 22 additions & 11 deletions api/onnx_web/convert/upscaling/resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch.onnx import export

from ...models.rrdb import RRDBNet
from ...models.rrdb import RRDBNetFixed, RRDBNetRescale
from ...models.srvgg import SRVGGNetCompact
from ..utils import ConversionContext, ModelDict

Expand Down Expand Up @@ -79,6 +79,14 @@ def convert_upscale_resrgan(
logger.info("ONNX model already exists, skipping")
return

torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
state_dict = torch_model["params_ema"]
elif "params" in torch_model:
state_dict = torch_model["params"]
else:
state_dict = torch_model

if TAG_X4_V3 in name:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
Expand All @@ -89,25 +97,28 @@ def convert_upscale_resrgan(
upscale=scale,
act_type="prelu",
)
else:
model = RRDBNet(
elif any(["RDB" in key for key in state_dict.keys()]):
# keys need fixed up to match. capitalized RDB is the best indicator.
state_dict = fix_resrgan_keys(state_dict)
model = RRDBNetFixed(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)

torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
else:
# keys need fixed up to match
model.load_state_dict(fix_resrgan_keys(torch_model), strict=False)
model = RRDBNetRescale(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)

model.load_state_dict(state_dict, strict=True)
model.to(conversion.training_device).train(False)
model.eval()

Expand Down
91 changes: 77 additions & 14 deletions api/onnx_web/models/rrdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,24 @@ class RRDB(nn.Module):

def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
self.rdb1 = ResidualDenseBlock_5C(nf, gc)
self.rdb2 = ResidualDenseBlock_5C(nf, gc)
self.rdb3 = ResidualDenseBlock_5C(nf, gc)

def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x


class RRDBNet(nn.Module):
class RRDBNetRescale(nn.Module):
"""
RRDBNet with variable input channels based on scale.
This is the format expected by the official models.
In this architecture, the modules stay the same but input channels change.
"""

def __init__(
self,
num_in_ch=3,
Expand All @@ -98,7 +104,7 @@ def __init__(
num_grow_ch=32,
scale=4,
):
super(RRDBNet, self).__init__()
super(RRDBNetRescale, self).__init__()
self.scale = scale

if scale == 2:
Expand All @@ -107,7 +113,7 @@ def __init__(
num_in_ch = num_in_ch * 16

logger.trace(
"RRDBNet params: %s",
"RRDBNetRescale params: %s",
[num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale],
)

Expand All @@ -116,11 +122,8 @@ def __init__(
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)

# upsampling
if self.scale > 1:
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)

if self.scale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)

self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
Expand Down Expand Up @@ -152,3 +155,63 @@ def forward(self, x):
out = self.conv_last(self.lrelu(self.conv_hr(feat)))

return out


class RRDBNetFixed(nn.Module):
"""
RRDBNet with fixed input channels regardless of scale.
This is the format expected by many third-party models.
In this architecture, the modules come and go based on scale, but the input channels stay the same.
"""

def __init__(
self,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
):
super(RRDBNetFixed, self).__init__()
self.scale = scale

logger.trace(
"RRDBNetFixed params: %s",
[num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale],
)

self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.body = make_layer(RRDB, num_block, nf=num_feat, gc=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)

# upsampling
if self.scale > 1:
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)

if self.scale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)

self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)

self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, x):
feat = self.conv_first(x)
trunk = self.conv_body(self.body(feat))
feat = feat + trunk

if self.scale > 1:
feat = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
)

if self.scale == 4:
feat = self.lrelu(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
)

out = self.conv_last(self.lrelu(self.conv_hr(feat)))

return out

0 comments on commit 0ddc162

Please sign in to comment.