Skip to content

Commit

Permalink
support alpha, gray
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Aug 22, 2021
1 parent c1685f6 commit 3c06885
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0.2
0.2.0.3
42 changes: 31 additions & 11 deletions facexlib/utils/face_restoration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,16 @@ def read_image(self, img):
"""img can be image path or cv2 loaded image."""
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
if isinstance(img, str):
self.input_img = cv2.imread(img)
else:
self.input_img = img
img = cv2.imread(img)

if np.max(img) > 256: # 16-bit image
img = img / 65535 * 255
if len(img.shape) == 2: # gray image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # RGBA image with alpha channel
img = img[:, :, 0:3]

self.input_img = img

def get_face_landmarks_5(self, only_keep_largest=False, only_center_face=False, resize=None, blur_ratio=0.01):
if resize is None:
Expand All @@ -114,6 +121,7 @@ def get_face_landmarks_5(self, only_keep_largest=False, only_center_face=False,
scale = min(h, w) / resize
h, w = int(h / scale), int(w / scale)
input_img = cv2.resize(self.input_img, (w, h), cv2.INTER_LANCZOS4)

with torch.no_grad():
bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
for bbox in bboxes:
Expand Down Expand Up @@ -252,7 +260,7 @@ def add_restored_face(self, face):

def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
h, w, _ = self.input_img.shape
h_up, w_up = int(h * self.upscale_factor), (w * self.upscale_factor)
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)

if upsample_img is None:
# simply resize the background
Expand All @@ -264,22 +272,34 @@ def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
mask = np.ones((*self.face_size, 3), dtype=np.float32)
mask = np.ones(self.face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
np.uint8))
inv_restored_remove_border = inv_mask_erosion * inv_restored
total_face_area = np.sum(inv_mask_erosion) // 3
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
upsample_img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * upsample_img
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]

if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
alpha = upsample_img[:, :, 3:]
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
else:
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img

upsample_img = upsample_img.astype(np.uint8)
if np.max(upsample_img) > 256: # 16-bit image
upsample_img = upsample_img.astype(np.uint16)
else:
upsample_img = upsample_img.astype(np.uint8)
if save_path is not None:
path = os.path.splitext(save_path)[0]
save_path = f'{path}.{self.save_ext}'
Expand Down

0 comments on commit 3c06885

Please sign in to comment.