In [1]:
import numpy as np
from PIL import Image
import concrete.fhe as fhe
from concrete.fhe import Compiler, Configuration

In [2]:
# Core Functions
def encode_watermark(watermark: str, img_size: int):
    """
    Generates watermark masks using 8x8 blocks.
    
    Each character is encoded in one 8x8 block. For each bit of the character’s 
    ASCII code (8 bits), if the bit is 0, a corresponding diagonal element is set to 0.
    """
    mask = np.ones((img_size, img_size), dtype=np.uint64)
    blocks_count = img_size // 8
    total_blocks = blocks_count * blocks_count

    if len(watermark) > total_blocks:
        raise ValueError(f"Not enough blocks (need {len(watermark)}, total {total_blocks}).")

    for i, char in enumerate(watermark):
        ascii_code = ord(char)
        # Convert to a list of 8 bits (as integers)
        bits = [int(b) for b in bin(ascii_code)[2:].zfill(8)]
        row_block = (i // blocks_count) * 8
        col_block = (i % blocks_count) * 8

        # For each bit, if it is 0, set the corresponding diagonal position to 0
        for k in range(8):
            if bits[k] == 0:
                mask[row_block + k, col_block + (7 - k)] = 0

    mask_inverted = 1 - mask
    # Convert masks to int64 (Concrete accepts int64/int32)
    return mask.astype(np.int64), mask_inverted.astype(np.int64)

def identity_function(x: np.ndarray) -> np.ndarray:
    # Refresh the encrypted data for correctness using FHE
    return fhe.refresh(x)

def embed_function(x: np.ndarray, mask: np.ndarray) -> np.ndarray:
    # Embed watermark by multiplying the image with the watermark mask
    return fhe.refresh(x * mask)

def check_function(x: np.ndarray, mask_inverted: np.ndarray) -> np.int64:
    # Multiply the image with the inverted mask and sum all elements to check watermark
    product = fhe.refresh(x * mask_inverted)
    return np.sum(product)


In [3]:
# FHE Circuit Compilation Helpers
def compile_circuit(function, inputset, parameter_encryption_statuses, config):
    compiler = Compiler(
        function=function,
        parameter_encryption_statuses=parameter_encryption_statuses
    )
    return compiler.compile(inputset, configuration=config)

def compile_identity(inputset, config):
    return compile_circuit(identity_function, inputset, {"x": "encrypted"}, config)

def compile_embed(inputset, config):
    return compile_circuit(embed_function, inputset, {"x": "encrypted", "mask": "clear"}, config)

def compile_check(inputset, config):
    return compile_circuit(check_function, inputset, {"x": "encrypted", "mask_inverted": "clear"}, config)

In [None]:
# Setup, Input Data, and Masks

config = Configuration(
    composable=True,
    enable_unsafe_features=True,
    use_insecure_key_cache=True,
    insecure_key_cache_location="~/.cml_keycache",
    global_p_error=1e-7,
)

img_size = 32
watermark_str = "H"

# Generate watermark masks
mask, mask_inverted = encode_watermark(watermark_str, img_size)

# Load images and select the green channel (channel index 1)
original_image = np.asarray(Image.open("images/home_icon.png"))[:, :, 1].astype(np.int64)
second_image = np.asarray(Image.open("images/book_icon.png"))[:, :, 1].astype(np.int64)

# Prepare input sets for FHE circuit compilation
identity_inputset = [
    np.random.randint(0, 256, size=(img_size, img_size), dtype=np.int64)
    for _ in range(3)
]

embed_inputset = [
    (np.random.randint(0, 256, size=(img_size, img_size), dtype=np.int64), mask)
    for _ in range(3)
]

check_inputset = [
    (np.random.randint(0, 72, size=(img_size, img_size), dtype=np.int64), mask_inverted)
    for _ in range(3)
]

In [None]:
# Compile and Run FHE Circuits

# Compile the FHE circuits using the provided input sets and configuration
circuit_id    = compile_identity(identity_inputset, config)
circuit_embed = compile_embed(embed_inputset, config)
circuit_check = compile_check(check_inputset, config)

# (A) On the client: Encrypt original images and process via identity_function
enc_original = circuit_id.encrypt(original_image)
out_id = circuit_id.run(enc_original)

enc_second = circuit_id.encrypt(second_image)
out_id_false = circuit_id.run(enc_second)

# (B) On the server: Embed watermark using the embed_function
out_embed = circuit_embed.run(out_id, mask)

# (C) On a third machine: Check watermark using the check_function
out_check = circuit_check.run(out_embed, mask_inverted)
out_check_false = circuit_check.run(out_id_false, mask_inverted)

# Decrypt the sum and determine watermark presence (sum == 0 indicates presence)
decrypted_sum = circuit_check.decrypt(out_check)
print("Watermark present:", decrypted_sum == 0)

decrypted_sum_false = circuit_check.decrypt(out_check_false)
print("Watermark present (second image):", decrypted_sum_false == 0)