In [1]:
import argparse
import glob
import os
from copy import copy
from pprint import pprint

import cv2
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import imwrite
from gfpgan import GFPGANer
from gfpgan.archs.stylegan2_clean_arch import ModulatedConv2d
from realesrgan import RealESRGANer
from tqdm import tqdm
from utils import *

from concrete.ml.torch.hybrid_model import HybridFHEModel

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


## Purpose 

The purpose of this notebook is to compute, for an input image of size 3 * 256 * 256:
- the transmission latency 
- data size of input and output tensors for each intermediate layer of the model (Conv2d and Linear) 


### Overview of the GFP-GAN Architecture

The GFP-GAN pipeline is divided into 3 main components:

1. ***Face Cropping (restorer.face_helper):*** To detect and crop faces from input images.

   Composition:
   - 82 Standard Convolutional Layers with:
        + Kernel Sizes: (7, 7) or (3, 3) or (1, 1)
        + Strides: (2, 2) or (1, 1)
        + Padding: (3, 3) or (1, 1)

2. ***Face Restoration (restorer.gfpgan):*** To restore and enhance the quality of cropped facial images.

   Composition:
   - 32 Linear Layers
   - 79 Standard Convolutional Layers with:
        + Kernel Sizes: (3, 3) or (1, 1)
        + Strides: (1, 1)
        + Padding: (1, 1)
   - 23 Modulated Convolutional Layers (ModulatedConv2d), with: Kernel Sizes: 3 or 1


3. **Background Enhancement (restorer.upsampler):** To enhance the background details of the images after face restoration.

   Composition:
   - 351 Standard Convolutional Layers with fixed configurations:
        + Kernel Size: (3, 3)
        + Stride: (1, 1)
        + Padding: (1, 1)


**Absence of Grouped or Dilated Convolutions or Depthwise Convolutions**

+ Grouped Convolutions: Convolutions where the input is divided into parts/groups, we have a set of filters for each group, the result is concatenated.
(groups=1 by default).

+ Dilated Convolutions: Convolutions where the kernel is expanded by inserting zeros between its elements, increasing the receptive field without increasing the number of parameters.
(dilation=1 by default).



**Modulated**:

+ The convolutional weights are dynamically adjusted (modulated) for each input sample based on a style vector.
+ This modulation allows the network to adapt its convolutional filters per sample, enabling more control over generated features.
+ Modulate Weights Process: For each ModulatedConv2d layer:
        - The style vector is transformed (usually via another linear layer) to obtain modulation weights.
        - These weights modulate the convolutional filters.
        - Demodulation: After modulation, weights can vary in magnitude, leading to instability during training. Demodulation normalizes the weights to maintain a consistent signal magnitude across the layers.


```
Style Vector (w)
        |
Modulation Weights (s)
        |
Modulated Weights (s * k)
        |
(Optional) Demodulation
        |
Convolution Operation
        |
Output Feature Maps
```

Recall: H_out​ = ⌊​H_in​ + 2 × P_h​ − D_h ​* (K_h ​− 1) − 1 ⌋ / S_h ​+ 1


### Load the models

In [2]:
class Args:

    def __init__(self):

        self.input = "GFPGAN/inputs/whole_imgs"
        self.output = "results"
        self.version = "1.4"
        self.upscale = 5
        self.bg_upsampler = "realesrgan"
        self.bg_tile = 400
        self.suffix = None
        self.only_center_face = False
        self.aligned = False
        self.ext = "auto"
        self.weight = 0.5


args = Args()

In [3]:
use_background_improvement = True

if args.bg_upsampler == "realesrgan":
    if use_background_improvement:

        half = True if torch.cuda.is_available() else False

        model = RRDBNet(
            num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2
        )
        # No linear modules in this model
        bg_upsampler = RealESRGANer(
            scale=2,  # Do not change this value
            model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
            model=model,
            tile=args.bg_tile,
            tile_pad=10,
            pre_pad=0,
            half=half,
        )  # need to set False in CPU mode

In [4]:
if args.version == "1.3":
    arch = "clean"
    channel_multiplier = 2
    model_name = "GFPGANv1.3"
    url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
    local_model_path = "GFPGANv1.3.pth"
elif args.version == "1.4":
    arch = "clean"
    channel_multiplier = 2
    model_name = "GFPGANv1.4"
    url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
    local_model_path = "GFPGANv1.4.pth"

# determine model paths
model_path = os.path.join("experiments/pretrained_models", model_name + ".pth")
if not os.path.isfile(model_path):
    model_path = os.path.join("gfpgan/weights", model_name + ".pth")
if not os.path.isfile(model_path):
    # download pre-trained models from url
    model_path = url

restorer = GFPGANer(
    model_path=model_path,
    upscale=args.upscale,
    arch=arch,
    channel_multiplier=channel_multiplier,
    bg_upsampler=bg_upsampler,
)

In [5]:
BYTES_PER_VALUE = 2  # For 16-bit precision
EXPANSION_FACTOR = 5
DATA_TRANSMISSION = 2

In [6]:
def compute_total_data_bytes_for_selected_layer_class(layer_shapes, selected_class_name):

    total_data = 0

    for info in layer_shapes:
        class_name = info["class_name"]
        layer_name = info["layer_name"]

        if class_name == selected_class_name:
            input_shapes = info["input_shapes"]
            output_shapes = info["output_shapes"]

            try:
                C_in, H_in, W_in = extract_dimensions(input_shapes, layer_name)
                C_out, H_out, W_out = extract_dimensions(output_shapes, layer_name)

                input_size = C_in * H_in * W_in
                output_size = C_out * H_out * W_out

                layer_data = input_size + output_size
                total_data += layer_data

            except ValueError as e:
                print("Error:", e)

    print("\n\n===================================================================")

    print(
        f"Sum of input and output elements for all '{selected_class_name}' layers): {total_data} elements\n"
    )

    # Data transmission
    total_data *= DATA_TRANSMISSION
    print(
        f"Total data after accounting for data transmission (x{DATA_TRANSMISSION}): {total_data} elements\n"
    )

    total_data *= EXPANSION_FACTOR
    print(
        f"Total data after applying expansion factor (x{EXPANSION_FACTOR}): {total_data} elements\n"
    )

    total_data_bytes = total_data * BYTES_PER_VALUE
    print(
        f"Total data in bytes (assuming {BYTES_PER_VALUE}-byte precision per value): {total_data_bytes} bytes\n"
    )

    # Convert to MB and GB
    total_data_mb = total_data_bytes / (1024**2)
    total_data_gb = total_data_bytes / (1024**3)

    print(f"Total data size:")
    print(f"  -{total_data_bytes} bytes")
    print(f"  -{total_data_mb:.2f} MB")
    print(f"  -{total_data_gb:.2f} GB")

    return total_data_bytes

In [7]:
def compute_total_network_size(layer_shapes):

    total_data = 0
    total_params = 0

    for info in layer_shapes:
        class_name = info["class_name"]
        layer_name = info["module"].__class__.__name__
        module = info["module"]

        input_shapes = info["input_shapes"]
        output_shapes = info["output_shapes"]

        try:
            C_in, H_in, W_in = extract_dimensions(input_shapes, layer_name)
            input_size = C_in * H_in * W_in

            C_out, H_out, W_out = extract_dimensions(output_shapes, layer_name)
            output_size = C_out * H_out * W_out

            params_size = 0
            for param in module.parameters():
                params_size += param.numel()

            layer_data = input_size + output_size + params_size
            total_data += layer_data
            total_params += params_size

            # print(f"Layer: {layer_name} ({class_name})")
            # print(f"  Input size: {input_size}")
            # print(f"  Output size: {output_size}")
            # print(f"  Parameters size: {params_size}")
            # print(f"  Total layer size: {layer_data}\n")

        except ValueError as e:
            print(f"Error processing layer {layer_name}: {e}")

    total_network_size = total_data + total_params

    total_memory_bytes = total_network_size * BYTES_PER_VALUE

    # Convert to MB and GB
    total_memory_mb = total_memory_bytes / (1024**2)
    total_memory_gb = total_memory_bytes / (1024**3)

    print("\n\n===================================================================")

    print(f"Total network size in bytes: {total_memory_bytes} bytes")
    print(f"Total network size in megabytes: {total_memory_mb:.2f} MB")
    print(f"Total network size in megabytes: {total_memory_gb:.2f} GB")

    return total_network_size

### Define an input of size : 3, 256, 256


In [8]:
input_tensor = torch.randn(1, 3, 256, 256)

### Model 1 - Face cropping and extraction:

In [9]:
face_helper_model = restorer.face_helper.face_det

restorer_face_helper_linear = extract_specific_module(
    face_helper_model, dtype_layer=torch.nn.Linear, verbose=False
)
restorer_face_helper_conv2d = extract_specific_module(
    face_helper_model, dtype_layer=torch.nn.Conv2d, verbose=False
)

print(
    f"{len(restorer_face_helper_linear)}-Linear Layers, "
    f"{len(restorer_face_helper_conv2d)}-Conv Layers"
)

0-Linear Layers, 82-Conv Layers


In [10]:
# Get the input and output shapes for each layer
layer_shapes = custom_torch_summary(face_helper_model, input_tensor, verbose=5)



Layer: body.conv1 - Conv2d
  Input shapes: [torch.Size([1, 3, 256, 256])]
  Output shapes: torch.Size([1, 64, 128, 128])

Layer: body.bn1 - BatchNorm2d
  Input shapes: [torch.Size([1, 64, 128, 128])]
  Output shapes: torch.Size([1, 64, 128, 128])

Layer: body.relu - ReLU
  Input shapes: [torch.Size([1, 64, 128, 128])]
  Output shapes: torch.Size([1, 64, 128, 128])

Layer: body.maxpool - MaxPool2d
  Input shapes: [torch.Size([1, 64, 128, 128])]
  Output shapes: torch.Size([1, 64, 64, 64])

Layer: body.layer1.0.conv1 - Conv2d
  Input shapes: [torch.Size([1, 64, 64, 64])]
  Output shapes: torch.Size([1, 64, 64, 64])

Layer: body.layer1.0.bn1 - BatchNorm2d
  Input shapes: [torch.Size([1, 64, 64, 64])]
  Output shapes: torch.Size([1, 64, 64, 64])



In [11]:
total_network_size = compute_total_network_size(layer_shapes)
total_network_size

Error processing layer IntermediateLayerGetter: Unexpected shape {1: torch.Size([1, 512, 32, 32]), 2: torch.Size([1, 1024, 16, 16]), 3: torch.Size([1, 2048, 8, 8])} in layer IntermediateLayerGetter
Error processing layer FPN: not enough values to unpack (expected 4, got 3)
Error processing layer BboxHead: Unexpected shape torch.Size([1, 2048, 4]) in layer BboxHead
Error processing layer BboxHead: Unexpected shape torch.Size([1, 512, 4]) in layer BboxHead
Error processing layer BboxHead: Unexpected shape torch.Size([1, 128, 4]) in layer BboxHead
Error processing layer ClassHead: Unexpected shape torch.Size([1, 2048, 2]) in layer ClassHead
Error processing layer ClassHead: Unexpected shape torch.Size([1, 512, 2]) in layer ClassHead
Error processing layer ClassHead: Unexpected shape torch.Size([1, 128, 2]) in layer ClassHead
Error processing layer LandmarkHead: Unexpected shape torch.Size([1, 2048, 10]) in layer LandmarkHead
Error processing layer LandmarkHead: Unexpected shape torch.Size

284124992

In [12]:
_ = compute_total_data_bytes_for_selected_layer_class(layer_shapes, selected_class_name="Conv2d")



Sum of input and output elements for all 'Conv2d' layers): 32897024 elements

Total data after accounting for data transmission (x2): 65794048 elements

Total data after applying expansion factor (x5): 328970240 elements

Total data in bytes (assuming 2-byte precision per value): 657940480 bytes

Total data size:
  -657940480 bytes
  -627.46 MB
  -0.61 GB


### Model 2 - Face Restoration (restorer.gfpgan):

In [13]:
restorer_gfpgan = restorer.gfpgan

restorer_gfpgan_linear = extract_specific_module(
    restorer_gfpgan, dtype_layer=torch.nn.Linear, verbose=False
)

restorer_gfpgan_conv2d = extract_specific_module(
    restorer_gfpgan, dtype_layer=torch.nn.Conv2d, verbose=False
)

restorer_gfpgan_modulated_conv2d = extract_specific_module(
    restorer_gfpgan, dtype_layer=ModulatedConv2d, verbose=False
)

print(
    f"{len(restorer_gfpgan_linear)}-Linear Layers, "
    f"{len(restorer_gfpgan_conv2d)}-Conv Layers, "
    f"{len(restorer_gfpgan_modulated_conv2d)}-Conv Modulated Layers"
)

32-Linear Layers, 79-Conv Layers, 23-Conv Modulated Layers


In [14]:
bbox = face_helper_model(input_tensor)

type(bbox), len(bbox), [(b.shape) for b in bbox]

(tuple,
 3,
 [torch.Size([1, 2688, 4]),
  torch.Size([1, 2688, 2]),
  torch.Size([1, 2688, 10])])

In [15]:
# <!!!> restorer_gfpgan takes 3 * 512 * 512 input for each detected face

layer_shapes = custom_torch_summary(restorer_gfpgan, torch.randn(1, 3, 512, 512), verbose=5)

Error processing module ConstantInput(): Couldn't deal with layer 'stylegan_decoder.constant_input': unrecognized data type <class 'int'>
Input type: <class 'tuple'>, Output type: <class 'torch.Tensor'>
Error processing module StyleGAN2GeneratorCSFT(
  (style_mlp): Sequential(
    (0): NormStyleCode()
    (1): Linear(in_features=512, out_features=512, bias=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Linear(in_features=512, out_features=512, bias=True)
    (6): LeakyReLU(negative_slope=0.2, inplace=True)
    (7): Linear(in_features=512, out_features=512, bias=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Linear(in_features=512, out_features=512, bias=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Linear(in_features=512, out_features=512, bias=True)
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    

In [16]:
total_data = compute_total_data_bytes_for_selected_layer_class(
    layer_shapes, selected_class_name="Conv2d"
)
total_data



Sum of input and output elements for all 'Conv2d' layers): 319471552 elements

Total data after accounting for data transmission (x2): 638943104 elements

Total data after applying expansion factor (x5): 3194715520 elements

Total data in bytes (assuming 2-byte precision per value): 6389431040 bytes

Total data size:
  -6389431040 bytes
  -6093.44 MB
  -5.95 GB


6389431040

### Model 3 - Background Enhancement (restorer.upsampler):

In [17]:
upsampler_model = restorer.bg_upsampler.model

restorer_upsampler_linear = extract_specific_module(
    upsampler_model, dtype_layer=torch.nn.Linear, verbose=False
)
restorer_upsampler_conv2d = extract_specific_module(
    upsampler_model, dtype_layer=torch.nn.Conv2d, verbose=False
)

print(
    f"{len(restorer_upsampler_linear)}-Linear Layers, "
    f"{len(restorer_upsampler_conv2d)}-Conv Layers"
)

0-Linear Layers, 351-Conv Layers


In [18]:
# from torchsummary import summary

# summary(upsampler_model, input_tensor.squeeze().shape)

# # <!!!!> Saved in upsampler_sizes_v2.txt and upsampler_sizes.txt

In [19]:
with open("upsampler_sizes_v2.txt", "r") as f:
    data = f.readlines()

previous_output_shape = input_tensor.shape

parsed_data = [parse_line(line) for line in data][1:]

formatted_data = reformat_data(parsed_data, previous_output_shape)

In [20]:
total_data = compute_total_data_bytes_for_selected_layer_class(
    formatted_data, selected_class_name="Conv2d"
)

# Sanity check only
assert (
    total_data
    == filter_conv_layers(data, (3, 256, 256))[1]
    * DATA_TRANSMISSION
    * EXPANSION_FACTOR
    * BYTES_PER_VALUE
)



Sum of input and output elements for all 'Conv2d' layers): 514785280 elements

Total data after accounting for data transmission (x2): 1029570560 elements

Total data after applying expansion factor (x5): 5147852800 elements

Total data in bytes (assuming 2-byte precision per value): 10295705600 bytes

Total data size:
  -10295705600 bytes
  -9818.75 MB
  -9.59 GB


In [21]:
# Generalization

ref_w = ref_h = 256
new_w = new_h = 512

(new_h * new_w) / (ref_w * ref_h) * 9.59

38.36