## QR code generation using Concrete ML by Horaizon27 team

### Imports

In [None]:
import time

import numpy as np
import pandas as pd
from pathlib import Path
from concrete import fhe
from concrete.ml.torch.compile import compile_brevitas_qat_model

import training_utils
import qrcode_utils

### Constants

In [None]:
DATASET_SIZE = 4000
QRCODE_VERSION = 1
QRCODE_IMAGE_SIZE = 17 + QRCODE_VERSION * 4
STYLE_NAME = "green_orange"
TRAINING_QRCODES_DIR = Path(f"{STYLE_NAME}-train_data")
DEFAULT_QRCODES_DIR = TRAINING_QRCODES_DIR / "default"
STYLED_QRCODES_DIR = TRAINING_QRCODES_DIR / "styled"
INFERENCE_RESULT_ROOT = Path("inference_result")
DEFAULT_DEVICE = "cpu"

### Load test data

In [None]:
_, default_qrcode_test, _ = \
    training_utils.create_qrcodes_datasets(DEFAULT_QRCODES_DIR, DATASET_SIZE)
print (default_qrcode_test.shape)

_, styled_qrcode_test, _ = \
    training_utils.create_qrcodes_datasets(STYLED_QRCODES_DIR, DATASET_SIZE)
print (styled_qrcode_test.shape,)

## Models

We are going to load and compile QuantAE and QuantAEPruned models
- Run **model_training.ipynb** to train models
- Make sure that model files are located in **ae_{QRCODE_IMAGE_SIZE}** and **ae_{QRCODE_IMAGE_SIZE}_pruned** directories

Load pre-trained models

In [None]:
ae_model = training_utils.load_model(f"ae_{QRCODE_IMAGE_SIZE}", QRCODE_IMAGE_SIZE)
ae_model.to(device=DEFAULT_DEVICE).eval()

pruned_ae_model = training_utils.load_model(
    f"ae_{QRCODE_IMAGE_SIZE}_pruned", QRCODE_IMAGE_SIZE, pruned=True
)
pruned_ae_model.to(device=DEFAULT_DEVICE).eval()

Compile models

In [None]:
def generate_inputset(image_size):
    inputset = np.ones([1, 1, image_size, image_size])
    inputset[0][0][:image_size // 2][image_size // 2:] = 0
    return inputset

def get_compiled_model(model, inputset, n_bits=8):
    compile_cfg = fhe.compilation.configuration.Configuration(
        use_gpu=False, enable_unsafe_features=True, 
        parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI
    ) 
    return compile_brevitas_qat_model(
        torch_model=model,
        torch_inputset=inputset,
        n_bits=n_bits,
        rounding_threshold_bits={"n_bits": n_bits, "method": "approximate"},
        configuration=compile_cfg,
        p_error=0.5
    )

In [None]:
inputset = generate_inputset(QRCODE_IMAGE_SIZE)

fhe_ae_model = get_compiled_model(ae_model, inputset)
fhe_pruned_ae_model = get_compiled_model(pruned_ae_model, inputset)

## Style Transfer comparison

For comparison, we will try five options:
1) QuantAE model (non FHE)
2) Compiled QuantAE model in "simulate mode"
3) Compiled QuantAE model in "execute mode"
4) Compiled QuantAEPruned model in "simulate" mode
5) Compiled QuantAEPruned model in "execute mode"

We gonna measure **inference time**, **readability** and **reference diff**

More information about models you can find in **README.md**

In [None]:
def get_style_transfer_stats(
        model, default_qrcode_np_arrays, styled_qrcode_np_arrays, mode, debug=False
):
    inference_time = []
    inference_image_diff = []
    readable_qrcodes = 0
    test_dataset_size = len(default_qrcode_np_arrays)

    if debug:
        results_dir = Path("inference_result")
        results_dir.mkdir(exist_ok=True, parents=True)

    for i in range(1, test_dataset_size + 1):
        default_qrcode_array = np.float32(default_qrcode_np_arrays[i - 1])

        if mode == "non-fhe":
            test_image_tensor = qrcode_utils.np_qrcode_array_to_tensor(
                default_qrcode_array
            )
            start_time = time.time()
            st_image_tensor = model(test_image_tensor)
            inference_time.append(time.time() - start_time)
            st_image_np_array = qrcode_utils.tensor_to_np_qrcode_array(st_image_tensor)
        else:
            test_image_ts_array = qrcode_utils.np_to_ts_qrcode_array(
                default_qrcode_array
            )
            start_time = time.time()
            st_image_ts_array = model.forward(test_image_ts_array, fhe=mode)
            inference_time.append(time.time() - start_time)
            st_image_np_array = qrcode_utils.ts_to_np_qrcode_array(st_image_ts_array)


        inference_image_diff.append(
            qrcode_utils.get_diff_between_image_arrays(st_image_np_array, styled_qrcode_np_arrays[i - 1])
        )

        corrected_st_image_array = qrcode_utils.get_corrected_qrcode_image(
			st_image_np_array, default_qrcode_array * 255
		)
        if qrcode_utils.is_valid_qrcode(corrected_st_image_array):
            readable_qrcodes += 1
        
        if debug:
            qrcode_utils.qrcode_array_to_image(st_image_np_array).save(f"{results_dir}/styled_{i}.jpg")
            qrcode_utils.qrcode_array_to_image(corrected_st_image_array).save(f"{results_dir}/corrected_{i}.jpg")

    return {
        "avg_inference_time": sum(inference_time) / len(inference_time),
        "qrcode_readability": readable_qrcodes / test_dataset_size,
        "reference_diff": np.mean(inference_image_diff)
    }

### Non-FHE 

In [None]:
ae_model.eval()
ae_results = get_style_transfer_stats(
    ae_model, default_qrcode_test, styled_qrcode_test, mode="non-fhe", debug=False
)

### Simulate mode

In [None]:
fhe_ae_simulate_results = get_style_transfer_stats(
    fhe_ae_model, default_qrcode_test, styled_qrcode_test, 
    mode="simulate", debug=False
)

fhe_ae_pruned_simulate_results = get_style_transfer_stats(
    fhe_pruned_ae_model, default_qrcode_test, styled_qrcode_test, 
    mode="simulate", debug=False
)

### Execute mode

In [None]:
fhe_ae_execute_results = get_style_transfer_stats(
    fhe_ae_model, default_qrcode_test, styled_qrcode_test, 
    mode="execute", debug=False
)

fhe_ae_pruned_execute_results = get_style_transfer_stats(
    fhe_pruned_ae_model, default_qrcode_test, styled_qrcode_test, 
    mode="execute", debug=False
)

## Results

In [None]:
results = pd.DataFrame(
    [
        [
            ae_results[fileid],
            fhe_ae_simulate_results[fileid],
            fhe_ae_pruned_simulate_results[fileid],
            fhe_ae_execute_results[fileid],
            fhe_ae_pruned_execute_results[fileid]
        ]
        for fileid in ["avg_inference_time", "qrcode_readability", "reference_diff"]
    ], 
    index=["Inference time", "Readability", "Reference diff"], 
    columns=["QuantAE (Non-FHE)", "QuantAE (Sim)", "QuantAEPruned (Sim)", "QuantAE (FHE)", "QuantAEPruned (FHE)"]
)
results