Skip to content

Commit

Permalink
feat(api): support more RealESRGAN-based models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 27, 2023
1 parent f5e7b3b commit 9588643
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion api/onnx_web/convert/upscaling/resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,39 @@

TAG_X4_V3 = "real-esrgan-x4-v3"

SPECIAL_KEYS = {
"model.0.bias": "conv_first.bias",
"model.0.weight": "conv_first.weight",
"model.1.sub.23.bias": "conv_body.bias",
"model.1.sub.23.weight": "conv_body.weight",
"model.3.bias": "conv_up1.bias",
"model.3.weight": "conv_up1.weight",
"model.6.bias": "conv_up2.bias",
"model.6.weight": "conv_up2.weight",
"model.8.bias": "conv_hr.bias",
"model.8.weight": "conv_hr.weight",
"model.10.bias": "conv_last.bias",
"model.10.weight": "conv_last.weight",
}

SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)"


def fix_resrgan_keys(model):
original_keys = list(model.keys())
for key in original_keys:
if key in SPECIAL_KEYS:
new_key = SPECIAL_KEYS[key]
else:
# convert RDBN keys
sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME)
new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"

model[new_key] = model[key]
del model[key]

return model


@torch.no_grad()
def convert_upscale_resrgan(
Expand Down Expand Up @@ -54,8 +87,11 @@ def convert_upscale_resrgan(
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
else:
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
else:
# keys need fixed up to match
model.load_state_dict(fix_resrgan_keys(torch_model), strict=False)

model.to(conversion.training_device).train(False)
model.eval()
Expand Down

0 comments on commit 9588643

Please sign in to comment.