Skip to content

Commit

Permalink
feat(api): switch to codeformer lib that works with torch 2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 19, 2023
1 parent 6a6a3f0 commit 7ed30ee
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 6 deletions.
91 changes: 86 additions & 5 deletions api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from logging import getLogger
from os import path
from typing import Optional

import torch
from PIL import Image
from torchvision.transforms.functional import normalize

from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
Expand All @@ -16,7 +19,7 @@ class CorrectCodeformerStage(BaseStage):
def run(
self,
worker: WorkerContext,
_server: ServerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
Expand All @@ -27,10 +30,88 @@ def run(
) -> StageResult:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
from codeformer.basicsr.utils import img2tensor, tensor2img
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper

upscale = upscale.with_args(**kwargs)

device = worker.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return StageResult(images=[pipe(source) for source in sources.as_image()])

net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device.torch_str())

ckpt_path = path.join(server.cache_path, "correction-codeformer.pth")
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint)
net.eval()

det_model = "retinaface_resnet50"
face_helper = FaceRestoreHelper(
upscale.face_outscale,
face_size=512,
crop_ratio=(1, 1),
det_model=det_model,
save_ext="png",
use_parse=True,
device=device.torch_str(),
)

results = []
for img in sources.as_image():
# clean all the intermediate results to process the next image
face_helper.clean_all()
face_helper.read_image(img)

# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5
)
logger.debug("detected %s faces", num_det_faces)

# align and warp each face
face_helper.align_warp_face()

# face restoration for each cropped face
for cropped_face in face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device.torch_str())

try:
with torch.no_grad():
output = net(
cropped_face_t, w=upscale.face_strength, adain=True
)[0]
restored_face = tensor2img(
output, rgb2bgr=True, min_max=(-1, 1)
)

del output
except Exception:
logger.exception("inference failed for CodeFormer")
restored_face = tensor2img(
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
)

restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face, cropped_face)

# paste_back
face_helper.get_inverse_affine(None)

# paste each restored face to the input image
results.append(
face_helper.paste_faces_to_input_image(upsample_img=img, draw_box=False)
)

return StageResult.from_images(results)
1 change: 1 addition & 0 deletions api/onnx_web/server/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def apply_patch_basicsr(server: ServerContext):
def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
try:
import codeformer.basicsr.utils # download_util
import codeformer.facelib.utils.misc

codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
Expand Down
3 changes: 2 additions & 1 deletion api/requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ onnx==1.15.0
optimum==1.16.0
safetensors==0.4.1
timm==0.9.12
torchsde==0.2.6
transformers==4.36.1

#### Upscaling and face correction
basicsr==1.4.2
codeformer-perceptor==0.1.2
# codeformer-perceptor==0.1.2
facexlib==0.2.5
gfpgan==1.3.8
realesrgan==0.3.0
Expand Down

0 comments on commit 7ed30ee

Please sign in to comment.