In [1]:
import argparse
import glob
import os
from copy import copy
from pathlib import Path
from pprint import pprint

import concrete.fhe
import cv2
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import img2tensor, imwrite, tensor2img
from gfpgan.archs.stylegan2_clean_arch import ModulatedConv2d
from gfpgan_utils import *
from realesrgan import RealESRGANer
from torchvision.transforms.functional import normalize
from tqdm import tqdm

from concrete.ml.torch.hybrid_model import HybridFHEModel
from gfpgan import GFPGANer

# !pip install concrete-ml
# !pip install basicsr
# !pip install facexlib
# !pip install realesrgan
# !pip install lmdb
# !pip install opencv-python
# !pip install pyyaml
# !pip install tb-nightly
# !pip install yapf

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In [2]:
class Args:

    def __init__(self):

        self.input = "GFPGAN/inputs/whole_imgs"
        self.output = Path("results")
        self.version = "1.4"
        self.upscale = 5
        self.bg_upsampler = "realesrgan"
        self.bg_tile = 400
        self.suffix = None
        self.only_center_face = False
        self.aligned = False
        self.ext = "auto"
        self.weight = 0.5


args = Args()

In [3]:
img_list = sorted(glob.glob(f"{args.input}/*"))[:1]

os.makedirs(args.output, exist_ok=True)

assert len(img_list) >= 1

img_list

['GFPGAN/inputs/whole_imgs/00.jpg']

In [4]:
class GFPGANFHEHybridWrapper(GFPGANer):
    def __init__(self, **kwargs):

        super().__init__(
            model_path=get_gfpgan_path(),
            upscale=kwargs.get("upscale", 2),
            arch="clean",
            channel_multiplier=2,
            bg_upsampler=RealESRGANer(
                scale=2,
                model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
                model=RRDBNet(
                    num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2
                ),
                tile=kwargs.get("bg_tile", 0),
                tile_pad=10,
                pre_pad=0,
                half=torch.cuda.is_available(),
            ),
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        )

        # Model 1
        self.face_helper = self.face_helper

        # Model 2
        self.base_gfpgan = self.gfpgan
        self.hybrid_gfpgan = None

        # Model3
        self.upsampler = self.bg_upsampler.model

    def _prepare_faces(self, img, has_aligned, only_center_face):
        """Prepare faces by either resizing (aligned) or detecting landmarks (unaligned)."""
        if has_aligned:
            img = cv2.resize(img, (512, 512))
            self.face_helper.cropped_faces = [img]
        else:
            self.face_helper.read_image(img)
            # Get face landmarks
            self.face_helper.get_face_landmarks_5(
                only_center_face=only_center_face, eye_dist_threshold=5
            )
            # Align and warp each face
            self.face_helper.align_warp_face()

    def _prepare_data(self, cropped_face):
        # Prepare data
        cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
        normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
        cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
        return cropped_face_t

    def _finalyze(self, img, has_aligned, paste_back):

        if not has_aligned and paste_back:
            bg_img = None
            if self.bg_upsampler is not None:
                bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]

            self.face_helper.get_inverse_affine(None)
            restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
        else:
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, None

    def forward(self, img, fhe_mode="disable", **kwargs):
        """Forward pass with FHE mode, handling face preparation and restoration."""

        assert self.hybrid_gfpgan is not None, "Hybrid model is not initialized. Please run `use_hybrid_model` before."

        has_aligned = kwargs.get("has_aligned", False)
        only_center_face = kwargs.get("only_center_face", False)
        paste_back = kwargs.get("paste_back", True)
        weight = kwargs.get("weight", 0.5)
        return_rgb = kwargs.get("return_rgb", False)

        self.face_helper.clean_all()

        # Prepare aligned or unaligned faces
        self._prepare_faces(img, has_aligned, only_center_face)

        # Face restoration
        for cropped_face in self.face_helper.cropped_faces:

            cropped_face_t = self._prepare_data(cropped_face)

            output = self.hybrid_gfpgan(
                cropped_face_t, fhe_mode=fhe_mode, **{"return_rgb": return_rgb, "weight": weight}
            )[0]

            # Convert to image
            restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)).astype(
                "uint8"
            )

            self.face_helper.add_restored_face(restored_face)

        # Upsample and paste back faces to the original image
        return self._finalyze(img, has_aligned, paste_back)
    
    def __call__(self, img, fhe_mode, **kwargs):
        return self.forward(img, fhe_mode, **kwargs)

    def compile(
        self,
        inputs,
        n_bits=5,
    ):
        """Compile the hybrid model."""

        if not self.hybrid_gfpgan:
            raise ValueError(
                "Hybrid model is not initialized. Please run `use_hybrid_model` before compiling."
            )

        # Compile hybrid model
        self.hybrid_gfpgan.compile_model(
            inputs,
            n_bits=n_bits,
        )

    def use_hybrid_model(self, remote_layers):
        """Initialize the hybrid model with specified remote layers."""
        if remote_layers == []:
            print("No remote layers specified; using the base model.")
        self.hybrid_gfpgan = HybridFHEModel(self.base_gfpgan, remote_layers)

In [5]:
wrapped_gfpgan = GFPGANFHEHybridWrapper()

In [6]:
linear_layer_names = extract_specific_module(
    wrapped_gfpgan.base_gfpgan, dtype_layer=torch.nn.Linear
)

linear_layer_names = [
    name
    for name in linear_layer_names
    if wrapped_gfpgan.base_gfpgan.input_is_latent and "style_mlp" not in name
]

conv_layer_names = extract_specific_module(wrapped_gfpgan.base_gfpgan, dtype_layer=torch.nn.Conv2d)

print(f"{len(linear_layer_names)=} - {len(conv_layer_names)=}")

len(linear_layer_names)=24 - len(conv_layer_names)=79


In [7]:
wrapped_gfpgan.use_hybrid_model(linear_layer_names)

input_size = (3, 512, 512)
inputs = torch.randn((1, *input_size))

wrapped_gfpgan.compile(inputs, n_bits=4)

In [8]:
fhe_mode = "execute"
extra = "conv_&_linear"

for img_path in tqdm(img_list):

    input_img, basename, ext = read_img(img_path)

    cropped_faces, restored_faces, restored_img = wrapped_gfpgan(input_img, fhe_mode=fhe_mode)

    save_restored_faces(
        cropped_faces,
        restored_faces,
        restored_img,
        Path(args.output) / f"{extra}_{fhe_mode}",
        basename,
        args.suffix,
        ext,
    )

  0%|          | 0/1 [00:00<?, ?it/s]

Hybrid model with fhe_mode='execute'
