# Image watermarking using Concrete ML by Horaizon27 team

### Imports

In [None]:
import cv2
import numpy as np
from PIL import Image
from concrete import fhe

from utils import (
    create_watermark_mask, 
    create_verification_mask,
    create_zero_diagonal_mask,
    create_inputset_for_circuit,
    generate_random_text
)

### Constants

In [None]:
IMAGE_SIZE = (128, 128)
DCT_BLOCK_SIZE = 8
ITERATIONS = 10

IMAGES = ["black128", "red128", "white128", "coala128", "philosopher128", "robot128", "salmon128", "space128"]
IMAGE_NAME_PATTERN = "images/{}.png"
WATERMARKED_IMAGES_NAME_PATTERN = "watermarked_images/{}_{}.png"

## Image watermarking

### Load image

In [None]:
def load_image(image_path, image_size):
    image = Image.open(image_path)
    if image.mode == "RGBA":
        image = image.convert("RGB")
    image = image.resize(image_size)
    return np.asarray(image)

### Image preprocessing

#### DCT encoding

In [None]:
def image_dct_encode(img_array):
    img_height = img_array.shape[0]
    img_width = img_array.shape[1]
    img_channels = img_array.shape[2]

    img_array_encoded = np.zeros((img_height, img_width, img_channels), np.float32)
    
    for channel in range(img_channels):
        vis0 = np.zeros((img_height, img_width), np.float32)
        vis0[:img_height, :img_width] = img_array[:, :, channel]

        for row in range(int(img_height / DCT_BLOCK_SIZE)):
            for col in range(int(img_width / DCT_BLOCK_SIZE)):
                block = vis0[
                    row * DCT_BLOCK_SIZE : (row + 1) * DCT_BLOCK_SIZE,
                    col * DCT_BLOCK_SIZE : (col + 1) * DCT_BLOCK_SIZE,
                ]
                encoded_block = cv2.dct(block)

                img_array_encoded[
                    row * DCT_BLOCK_SIZE : (row + 1) * DCT_BLOCK_SIZE,
                    col * DCT_BLOCK_SIZE : (col + 1) * DCT_BLOCK_SIZE,
                    channel,
                ] = encoded_block

    return np.around(img_array_encoded)

### Watermark embedding

In [None]:
def watermark_embeding(array: np.ndarray, message_mask: np.ndarray):
    zero_diagonal_mask = create_zero_diagonal_mask(array.shape)
    result = array * zero_diagonal_mask 
    result += message_mask
    return result

def create_watermark_circuit(img_size):
    fhe_compiler = fhe.Compiler(
        function=watermark_embeding,
        parameter_encryption_statuses={
            "array": "encrypted",
            "message_mask": "clear"
        }
    )

    inputset_for_compiler = create_inputset_for_circuit(img_size)
    circuit = fhe_compiler.compile(inputset_for_compiler)
    return circuit

def get_watermarked_array(
    watermark_circuit, img_array, watermark_text
):

    array = np.copy(img_array[:, :, 1]).astype(np.int16)
    watermark_mask = create_watermark_mask(watermark_text, array.shape)


    watermarked_array = watermark_circuit.encrypt_run_decrypt(
        array, watermark_mask
    )

    img_array[:, :, 1] = watermarked_array.astype(np.float32)

    return img_array

### Image postprocessing

#### DCT decode

In [None]:
def image_dct_decode(img_array, block_size=8):
    img_array_encoded = np.copy(img_array)
    img_height = img_array_encoded.shape[0]
    img_width = img_array_encoded.shape[1]
    img_channels = img_array_encoded.shape[2]

    img_array_decoded = np.zeros((img_height, img_width, img_channels), np.float32)

    for channel in range(img_channels):
        vis0 = np.zeros((img_height, img_width), np.float32)
        vis0[:img_height, :img_width] = img_array_encoded[:, :, channel]

        for row in range(int(img_height / block_size)):
            for col in range(int(img_width / block_size)):
                block = vis0[
                    row * block_size : (row + 1) * block_size,
                    col * block_size : (col + 1) * block_size,
                ]
                decoded_block = cv2.idct(block)                

                img_array_decoded[
                    row * block_size : (row + 1) * block_size,
                    col * block_size : (col + 1) * block_size,
                    channel,
                ] = decoded_block

    return img_array_decoded

#### Convert to RGB array

In [None]:
def round_array_data(img_array):
    img_array[img_array > 255] = 255
    img_array[img_array < 0] = 0
    return np.around(img_array).astype(np.uint8)

### Save watermarked image

In [None]:
def save_watermark_image(img_array, image_path):
    img = Image.fromarray(img_array)
    img.save(image_path)

### Full pipeline

In [None]:
def embed_watermark_to_image(image_path, image_size, watermark_message, watermarked_image_path):
    img_array = load_image(image_path, image_size)
    img_array_dct_encoded = image_dct_encode(img_array)


    watermark_circuit = create_watermark_circuit(image_size)

    img_array_dct_encoded_with_watermark = get_watermarked_array(
        watermark_circuit, img_array_dct_encoded, watermark_message
    )

    img_array_with_watermark = image_dct_decode(
        img_array_dct_encoded_with_watermark
    )

    img_array_with_watermark = round_array_data(img_array_with_watermark)

    save_watermark_image(img_array_with_watermark, watermarked_image_path)

### Invisibility check

In [None]:
def calculate_psnr(img_path, watermarked_img_path, img_size, max_value=255):
    img_array = load_image(img_path, img_size)
    watermarked_img_array = load_image(watermarked_img_path, img_size)

    mse = np.mean(
        (img_array - watermarked_img_array) ** 2
    )
    if mse == 0:
        return 100
    return 20 * np.log10(max_value / (np.sqrt(mse)))

In [None]:
watermarking_data = {}

for image_name in IMAGES:
    image_path = IMAGE_NAME_PATTERN.format(image_name)
    image_data = []
    psnr_data = []
    for i in range(ITERATIONS):
        message = generate_random_text(15)
        watermarked_image_path = WATERMARKED_IMAGES_NAME_PATTERN.format(image_name, message)
        embed_watermark_to_image(image_path, IMAGE_SIZE, message, watermarked_image_path)
        psnr = calculate_psnr(image_path, watermarked_image_path, IMAGE_SIZE)

        image_data.append(
            {
                "message": message,
                "watermarked_image_path": watermarked_image_path,
            }
        )
        psnr_data.append(psnr)

    watermarking_data[image_name] = image_data
    print (f'Average PNSR for image "{image_name}" is {sum(psnr_data) / ITERATIONS}')


## Watermark verification

### FHE functions

In [None]:
def check_watermark(array: np.ndarray, verification_mask: np.ndarray):
    errors_array = array * verification_mask
    errors = np.sum(errors_array)

    return errors

def create_verification_circuit(img_size):
    fhe_compiler = fhe.Compiler(
        function=check_watermark,
        parameter_encryption_statuses={
            "array": "encrypted",
            "verification_mask": "clear"
        }
    )

    inputset_for_compiler = create_inputset_for_circuit(img_size)
    circuit = fhe_compiler.compile(inputset_for_compiler)
    return circuit

#### Full pipeline

In [None]:
def is_watermarked_image_with_text(img_path, img_size, watermark_text):
    verification_circuit = create_verification_circuit(img_size)
    verification_mask = create_verification_mask(watermark_text, img_size)

    watermarked_img_array = load_image(img_path, img_size)
    watermarked_img_array_dct_encoded = image_dct_encode(watermarked_img_array)
    array = watermarked_img_array_dct_encoded[:, :, 1].astype(np.int16)

    score = verification_circuit.encrypt_run_decrypt(array, verification_mask)
    return score < 50

### Check watermark correctness

#### False negative check

Texts from watermarking stage should pass verification

In [None]:
errors = []

for image_name in IMAGES:
    for image_data in watermarking_data[image_name]:
        print (f'Checking file "{image_data["watermarked_image_path"]}"')
        if not is_watermarked_image_with_text(
            image_data["watermarked_image_path"], 
            IMAGE_SIZE,
            image_data["message"] 
        ):
            errors.append(image_data)

if not errors:
    print ("No watermarking errors")
else:
    print (f"Found {len(errors)} errors")
    print (errors)

#### False positive check

Random text should fail ferification check

In [None]:
errors = []

for image_name in IMAGES:
    for image_data in watermarking_data[image_name]:
        print (f'Checking file "{image_data["watermarked_image_path"]}"')
        for i in range(ITERATIONS):
            random_text = generate_random_text(15)

            if is_watermarked_image_with_text(
                image_data["watermarked_image_path"], 
                IMAGE_SIZE,
                random_text
            ):
                errors.append(image_data)

if not errors:
    print ("No watermarking errors")
else:
    print (f"Found {len(errors)} errors")
    print (errors)