In [1]:
import os
from glob import glob
from random import choice
from os.path import join, sep

In [2]:
import keras
import tensorflow as tf
from rich.progress import track
from dotenv import find_dotenv, load_dotenv

In [3]:
from src.addons.watermark.models import create_watermark, create_extract_mark
from src.addons.watermark.models import WatermarkModel
from src.addons.data.pipeline import test_pipeline
from src.addons.visualize.table import print_tables
from src.addons.data.augment import attacks

---

In [4]:
_ = load_dotenv(find_dotenv())

In [5]:
images_path = glob(join(os.environ.get("RAW_PATH"), "tests") + sep + "*.jpg")
test_ds = test_pipeline(images_path, 32)

2023-12-25 23:29:42.504270: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2023-12-25 23:29:42.504296: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 32.00 GB
2023-12-25 23:29:42.504300: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 10.67 GB
2023-12-25 23:29:42.504333: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-12-25 23:29:42.504352: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [6]:
embedding = create_watermark((128, 128, 3), (8, 8, 1), 1.)
embedding.load_weights(join(os.environ.get("MODELS_PATH"), "storage", "embedding.25_12_2023_20_03_40.weights.h5"))

In [7]:
extractor = create_extract_mark((128, 128, 3))
extractor.load_weights(join(os.environ.get("MODELS_PATH"), "storage", "extractor.25_12_2023_20_03_40.weights.h5"))

---

In [8]:
models = WatermarkModel(embedding=embedding, extractor=extractor)

In [9]:
results = {}
for attack in track(attacks):
    psnr, ber = models.evaluate(test_ds, attack)
    results[attack] = [float(psnr), float(ber)]

In [10]:
headers = ["Attack", "PSNR", "BER"]
content = [[attack] + results[attack] for attack in attacks]
print_tables("Result", headers, content)