Skip to content

Commit

Permalink
add parsenet to face_restoration_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed May 26, 2022
1 parent e5768d1 commit d0ea779
Showing 1 changed file with 56 additions and 17 deletions.
73 changes: 56 additions & 17 deletions facexlib/utils/face_restoration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import numpy as np
import os
import torch
from torchvision.transforms.functional import normalize

from facexlib.detection import init_detection_model
from facexlib.utils.misc import imwrite
from facexlib.parsing import init_parsing_model
from facexlib.utils.misc import img2tensor, imwrite


def get_largest_face(det_faces, h, w):
Expand Down Expand Up @@ -54,6 +56,7 @@ def __init__(self,
save_ext='png',
template_3points=False,
pad_blur=False,
use_parse=False,
device=None):
self.template_3points = template_3points # improve robustness
self.upscale_factor = upscale_factor
Expand Down Expand Up @@ -94,6 +97,10 @@ def __init__(self,
# init face detection model
self.face_det = init_detection_model(det_model, half=False, device=device)

# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(model_name='parsenet', device=device)

def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor

Expand Down Expand Up @@ -290,22 +297,54 @@ def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
extra_offset = 0
inverse_affine[:, 2] += extra_offset
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
mask = np.ones(self.face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]

if self.use_parse:
# inference
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_input = torch.unsqueeze(face_input, 0).cuda()
with torch.no_grad():
out = self.face_parse(face_input)[0]
out = out.argmax(dim=1).squeeze().cpu().numpy()

mask = np.zeros(out.shape)
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
for idx, color in enumerate(MASK_COLORMAP):
mask[out == idx] = color
# blur the mask
mask = cv2.GaussianBlur(mask, (101, 101), 11)
mask = cv2.GaussianBlur(mask, (101, 101), 11)
# remove the black borders
thres = 10
mask[:thres, :] = 0
mask[-thres:, :] = 0
mask[:, :thres] = 0
mask[:, -thres:] = 0
mask = mask / 255.

mask = cv2.resize(mask, restored_face.shape[:2])
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
inv_soft_mask = mask[:, :, None]
pasted_face = inv_restored

else: # use square parse maps
mask = np.ones(self.face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]

if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
alpha = upsample_img[:, :, 3:]
Expand Down

0 comments on commit d0ea779

Please sign in to comment.