In [1]:
import argparse
import glob
import os
from copy import copy
from pprint import pprint

import cv2
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import imwrite
from gfpgan import GFPGANer
from gfpgan.archs.stylegan2_clean_arch import ModulatedConv2d
from realesrgan import RealESRGANer
from tqdm import tqdm
from utils import *

from concrete.ml.torch.hybrid_model import HybridFHEModel

from utils import *

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In [2]:
class Args:

    def __init__(self):

        self.input = "GFPGAN/inputs/whole_imgs"
        self.output = "results"
        self.version = "1.4"
        self.upscale = 5
        self.bg_upsampler = "realesrgan"
        self.bg_tile = 400
        self.suffix = None
        self.only_center_face = False
        self.aligned = False
        self.ext = "auto"
        self.weight = 0.5


args = Args()

In [3]:
img_list = sorted(glob.glob(f"{args.input}/*"))

os.makedirs(args.output, exist_ok=True)

assert len(img_list) >= 1

In [4]:
use_background_improvement = True

if args.bg_upsampler == "realesrgan":
    if use_background_improvement:

        half = True if torch.cuda.is_available() else False

        model = RRDBNet(
            num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2
        )
        # No linear modules in this model
        bg_upsampler = RealESRGANer(
            scale=2,  # Do not change this value
            model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
            model=model,
            tile=args.bg_tile,
            tile_pad=10,
            pre_pad=0,
            half=half,
        )  # need to set False in CPU mode

In [5]:
if args.version == "1.3":
    arch = "clean"
    channel_multiplier = 2
    model_name = "GFPGANv1.3"
    url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
    local_model_path = "GFPGANv1.3.pth"
elif args.version == "1.4":
    arch = "clean"
    channel_multiplier = 2
    model_name = "GFPGANv1.4"
    url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
    local_model_path = "GFPGANv1.4.pth"

# determine model paths
model_path = os.path.join("experiments/pretrained_models", model_name + ".pth")
if not os.path.isfile(model_path):
    model_path = os.path.join("gfpgan/weights", model_name + ".pth")
if not os.path.isfile(model_path):
    # download pre-trained models from url
    model_path = url

restorer = GFPGANer(
    model_path=model_path,
    upscale=args.upscale,
    arch=arch,
    channel_multiplier=channel_multiplier,
    bg_upsampler=bg_upsampler,
)

### Model 1

In [6]:
# Model 1: Face cropping and extraction
# using a FaceRestoreHelper with retinaface_resnet50
# No linear layers

face_helper_model = restorer.face_helper.face_det
face_helper_state_dict = restorer.face_helper.face_det.state_dict()

gfpgan = copy(restorer.gfpgan)

### Model 2

In [7]:
gfpgan_linear = extract_specific_module(gfpgan, dtype_layer=torch.nn.Linear, verbose=False)

gfpgan_conv2d = extract_specific_module(gfpgan, dtype_layer=torch.nn.Conv2d, verbose=False)

gfpgan_modulated_conv2d = extract_specific_module(gfpgan, dtype_layer=ModulatedConv2d, verbose=False)

print(len(gfpgan_linear), len(gfpgan_conv2d), len(gfpgan_modulated_conv2d))

gfpgan_linear = [
        name
        for name, _ in gfpgan_linear
        if restorer.gfpgan.input_is_latent and "style_mlp" not in name
    ]

print(len(gfpgan_linear), len(gfpgan_conv2d), len(gfpgan_modulated_conv2d))

32 79 23
24 79 23


In [8]:

input_size = (3, 512, 512)

compile_size = 2
inputs = torch.randn((compile_size, *input_size))

hybrid_model = HybridFHEModel(
    gfpgan,
    gfpgan_linear,
    verbose=2,
)
# Compile hybrid model
hybrid_model.compile_model(
    inputs,
    n_bits=8,
)

# summary(restorer.gfpgan, input_size=input_size)

In [9]:
output = hybrid_model(inputs)

In [10]:
input_size = (3, 256, 256)
inputs = torch.randn((compile_size, *input_size))
_ = hybrid_model(torch.randn((1, 3, 512, 512)))

In [11]:
1/0

ZeroDivisionError: division by zero

### Inference

In [None]:
for img_path in tqdm(img_list):
    # read image
    img_name = os.path.basename(img_path)
    print(f"Processing {img_name} ...")
    basename, ext = os.path.splitext(img_name)
    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    print(f"{input_img.shape=}")

    # restore faces and background if necessary
    cropped_faces, restored_faces, restored_img = restorer.enhance(
        input_img,
        has_aligned=args.aligned,
        only_center_face=args.only_center_face,
        paste_back=True,
        weight=args.weight,
    )

    # print(f"{len(cropped_faces)=} | {len(restored_faces)=} | {restored_img.shape=}")

    # save faces
    for idx, (cropped_face, restored_face) in tqdm(enumerate(zip(cropped_faces, restored_faces))):
        # save cropped face
        save_crop_path = os.path.join(args.output, "cropped_faces", f"{basename}_{idx:02d}.png")
        imwrite(cropped_face, save_crop_path)
        # save restored face
        if args.suffix is not None:
            save_face_name = f"{basename}_{idx:02d}_{args.suffix}.png"
        else:
            save_face_name = f"{basename}_{idx:02d}.png"
        save_restore_path = os.path.join(args.output, "restored_faces", save_face_name)
        imwrite(restored_face, save_restore_path)
        # save comparison image
        cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
        imwrite(cmp_img, os.path.join(args.output, "cmp", f"{basename}_{idx:02d}.png"))

    # save restored img
    if restored_img is not None:
        if args.ext == "auto":
            extension = ext[1:]
        else:
            extension = args.ext

        if args.suffix is not None:
            save_restore_path = os.path.join(
                args.output, "restored_imgs", f"{basename}_{args.suffix}.{extension}"
            )
        else:
            save_restore_path = os.path.join(
                args.output, "restored_imgs", f"{basename}.{extension}"
            )
        imwrite(restored_img, save_restore_path)
        break

print(f"Results are in the [{args.output}] folder.")