In [None]:
import os
import sys
import time
import zipfile
import cv2
import numpy as np
import onnxruntime as ort
import torch
import requests
import tempfile
import subprocess
import json
from tqdm import tqdm
from typing import Tuple, List, Optional
import random

ort.set_default_logger_severity(3)

# Model
MODEL_URL = 'https://zkevm-4.s3.us-east-2.amazonaws.com/model'
MODEL_NAME = 'age.onnx'
DOWNLOAD_MODEL_PATH = './competition/model/'

# Test image datasets
IMAGE_DATASETS_URL = 'https://storage.omron.ai/age.zip'
DOWNLOAD_IMAGE_DATASETS = './competition/test_data/'

# Circuit
CIRCUIT_BASE_URL = 'https://zkevm-4.s3.us-east-2.amazonaws.com/'
CIRCUIT_NAMES = ['circuit1','circuit2','circuit3']
# CIRCUIT_NAMES = ['circuit1','circuit2']
DOWNLOAD_CIRCUIT_PATH = './competition/circuit/'

LOCAL_EZKL_PATH = '../../target/release/ezkl'
TEMP_FOLDER = './competition/tmp'

TEST_COUNT = 10

def download_model(url: str, save_directory: str, model_name) -> None:
    save_path = os.path.join(save_directory, model_name)

    if os.path.exists(save_path):
        print(f"{save_path} already exists. Skipping download.")
        return

    os.makedirs(save_directory, exist_ok=True)

    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))

        with open(save_path, 'wb') as file, tqdm(
            desc=model_name,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for data in response.iter_content(chunk_size=1024):
                file.write(data)
                bar.update(len(data))

        print(f"Downloaded {model_name} to {save_path}")

    except requests.exceptions.RequestException as e:
        print(f"Error downloading {model_name}: {e}")

def download_circuit_files(urls, base_directory)-> list:
    download_circuits_path = []
    for url in urls:
        circuit_name = url.split('/')[-2]
        circuit_folder = os.path.join(base_directory, circuit_name)

        if not os.path.exists(circuit_folder):
            os.makedirs(circuit_folder)

        all_files_exist = True
        for file_name in ['kzg.srs', 'model.compiled', 'pk.key', 'settings.json', 'vk.key']:
            save_path = os.path.join(circuit_folder, file_name)
            if os.path.exists(save_path):
                print(f"{save_path} already exists. Skipping download.")
            else:
                all_files_exist = False

        if all_files_exist:
            print(f"All files for {circuit_name} already exist. Skipping download.")
            print(f"download_circuits_path:{download_circuits_path}")
            download_circuits_path.append(circuit_folder)
            continue

        for file_name in ['kzg.srs', 'model.compiled', 'pk.key', 'settings.json', 'vk.key']:
            file_url = f"{url}{file_name}"
            save_path = os.path.join(circuit_folder, file_name)

            response = requests.get(file_url, stream=True)
            total_size = int(response.headers.get('content-length', 0))

            with open(save_path, 'wb') as file, tqdm(
                desc=file_name,
                total=total_size,
                unit='iB',
                unit_scale=True,
                unit_divisor=1024,
            ) as bar:
                for data in response.iter_content(chunk_size=1024):
                    file.write(data)
                    bar.update(len(data))

            print(f"Downloaded {file_name} to {save_path}")
        download_circuits_path.append(circuit_folder)
    print(f"download_circuits_path:{download_circuits_path}")
    print(f"All files are downloaded succefully!")

    return download_circuits_path

def download_and_process_images(download_image_dataset, image_datasets_url)-> str:
    if os.path.exists(download_image_dataset):
        print(f"Test datasets dir {download_image_dataset} already exists!")
    else:
        print(f"Test datasets dir {download_image_dataset} does not exist, creating it.")
        os.makedirs(download_image_dataset)

    zip_path = os.path.join(download_image_dataset, "age.zip")
    extracted_path = os.path.join(download_image_dataset, "extracted")
    processed_path = os.path.join(download_image_dataset, "processed_64x64")

    os.makedirs(extracted_path, exist_ok=True)
    os.makedirs(processed_path, exist_ok=True)

    if os.path.exists(processed_path) and os.listdir(processed_path):
        print(f"{processed_path} already exists and is not empty. Skipping processing.")
        return processed_path

    print("Downloading dataset...")
    response = requests.get(image_datasets_url, stream=True)
    total_size = int(response.headers.get("content-length", 0))

    with open(zip_path, "wb") as f, tqdm(
        desc="Downloading", total=total_size, unit="iB", unit_scale=True
    ) as pbar:
        for data in response.iter_content(chunk_size=1024):
            size = f.write(data)
            pbar.update(size)

    print("Extracting zip...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    print("Processing images to 64x64...")
    for root, _, files in tqdm(os.walk(extracted_path)):
        for img_name in files:
            if img_name.lower().endswith((".png", ".jpg", ".jpeg")):
                img_path = os.path.join(root, img_name)
                try:
                    img = cv2.imread(img_path)
                    if img is not None:
                        img = cv2.resize(img, (64, 64))
                        cv2.imwrite(os.path.join(processed_path, img_name), img)
                except Exception as e:
                    print(f"Failed to process {img_name}: {e}")

    print(f"Images processed and saved to {processed_path}")
    return processed_path

class ImageProcessor:
    @staticmethod
    def normalize(
        img: torch.Tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    ) -> torch.Tensor:
        mean = torch.tensor(mean).view(-1, 1, 1)
        std = torch.tensor(std).view(-1, 1, 1)
        return (img - mean) / std

    @staticmethod
    def to_tensor(img: np.ndarray) -> torch.Tensor:
        img = img.transpose((2, 0, 1))  # HWC to CHW
        img = torch.from_numpy(img).float()
        return img / 255.0

def preprocess_image(img_path) -> Optional[torch.Tensor]:
    try:
        img = cv2.imread(img_path)
        if img is None:
            print(f"    Error cat not fild: {img_path}")
            return None

        image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Convert to tensor and normalize
        tensor = ImageProcessor.to_tensor(image)
        tensor = ImageProcessor.normalize(tensor)

        # Add batch dimension
        tensor = tensor.unsqueeze(0)

        return tensor
    except Exception as e:
        print(f"    Error processing image data: {e}")
        return None

def inference_onnx_model(model_path: str, input_tensor: torch.Tensor) -> torch.Tensor:
    try:
        session = ort.InferenceSession(model_path)
        input_name = session.get_inputs()[0].name

        # Convert input tensor to numpy array
        input_data = input_tensor.detach().cpu().numpy() if input_tensor.requires_grad else input_tensor.cpu().numpy()
        # Prepare input feed
        options = ort.RunOptions()
        options.log_severity_level = 3

        output_names = [output.name for output in session.get_outputs()]
        outputs = session.run(output_names, {input_name: input_data}, options)

        return outputs

    except Exception as e:
        print(f"Error during ONNX inference: {e}")
        return None

def get_temp_folder() -> str:
    if not os.path.exists(TEMP_FOLDER):
        os.makedirs(TEMP_FOLDER, exist_ok=True)
    return TEMP_FOLDER

def generate_proof(
    circuit_dir: str, test_inputs: torch.Tensor
) -> Tuple[str, dict] | None:
    try:
        input_data = {
            "input_data": [[float(x) for x in test_inputs.flatten().tolist()]]
        }

        with tempfile.NamedTemporaryFile(
            mode="w+", suffix=".json", dir=get_temp_folder(), delete=False
        ) as temp_input:
            json.dump(input_data, temp_input, indent=2)
            temp_input_path = temp_input.name

        with tempfile.NamedTemporaryFile(
            mode="w+", suffix=".json", dir=get_temp_folder(), delete=False
        ) as temp_witness:
            witness_path = temp_witness.name

        with tempfile.NamedTemporaryFile(
            mode="w+", suffix=".json", dir=get_temp_folder(), delete=False
        ) as temp_proof:
            temp_proof_path = temp_proof.name

        model_path = os.path.join(circuit_dir, "model.compiled")
        if not os.path.exists(model_path):
            print(f"model.compiled not found at {model_path}")
            return None

        # print(f"Input data: {json.dumps(input_data, indent=2)}")
        witness_result = subprocess.run(
            [
                LOCAL_EZKL_PATH,
                "gen-witness",
                "--data",
                temp_input_path,
                "--compiled-circuit",
                model_path,
                "--output",
                witness_path,
            ],
            capture_output=True,
            text=True,
            timeout=300,
        )

        if witness_result.returncode != 0:
            print(
                f"Witness generation failed with code {witness_result.returncode}"
            )
            print(f"STDOUT: {witness_result.stdout}")
            print(f"STDERR: {witness_result.stderr}")
            return None

        print("Witness generation successful, starting proof generation")
        proof_start = time.perf_counter()
        prove_result = subprocess.run(
            [
                LOCAL_EZKL_PATH,
                "prove",
                "--compiled-circuit",
                model_path,
                "--witness",
                witness_path,
                "--pk-path",
                os.path.join(circuit_dir, "pk.key"),
                "--proof-path",
                temp_proof_path,
            ],
            capture_output=True,
            text=True,
            timeout=300,
        )
        proof_time = time.perf_counter() - proof_start

        os.unlink(temp_input_path)
        os.unlink(witness_path)

        if prove_result.returncode != 0:
            print(
                f"Proof generation failed with code {prove_result.returncode}"
            )
            print(f"STDOUT: {prove_result.stdout}")
            print(f"STDERR: {prove_result.stderr}")
            return None

        with open(temp_proof_path) as f:
            proof_data = json.load(f)
            print(f"Proof timing - Proof: {proof_time:.3f}s")
            return temp_proof_path, proof_data, proof_time
    except Exception as e:
        print(f"Error generating proof: {e}")
        return None

def verify_proof(circuit_dir: str, proof_path: str) -> bool:
    try:
        verify_result = subprocess.run(
            [
                LOCAL_EZKL_PATH,
                "verify",
                "--proof-path",
                proof_path,
                "--settings-path",
                os.path.join(circuit_dir, "settings.json"),
                "--vk-path",
                os.path.join(circuit_dir, "vk.key"),
            ],
            capture_output=True,
            text=True,
            timeout=300,
        )
        return verify_result.returncode == 0
    except Exception as e:
        print(f"Error verifying proof: {e}")
        return False
    finally:
        if os.path.exists(proof_path):
            os.unlink(proof_path)

def compare_outputs(expected: list, actual: list) -> float:
    try:
        expected_tensor = torch.tensor(expected)
        actual_tensor = torch.tensor(actual)

        expected_flat = expected_tensor.flatten()
        actual_flat = actual_tensor.flatten()

        mae = torch.nn.functional.l1_loss(actual_flat, expected_flat)
        raw_accuracy = torch.exp(-mae).item()
        return raw_accuracy
    except Exception as e:
        return 0.0

def benchmark(onnx_model_path, processed_path, circuit_dir, test_count)-> Tuple[float, float, bool, dict]:
    image_files = [f for f in os.listdir(processed_path) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
    print(f"Total images found: {len(image_files)}")

    selected_images = random.sample(image_files, min(test_count, len(image_files)))
    print(f"Selected images: {selected_images}")

    raw_accuracy_scores, proof_sizes, response_times, verification_results = (
        [],
        [],
        [],
        [],
    )
    for img_name in selected_images:
        image_path = os.path.join(processed_path, img_name)
        input_tensor = preprocess_image(image_path)

        output_tensor1 = inference_onnx_model(onnx_model_path, input_tensor)

        flattened = []
        for out in output_tensor1:
            flattened.extend(out.flatten())
        baseline_output = np.array(flattened)

        proof_result = generate_proof(circuit_dir, input_tensor)
        if not proof_result:
            print("Proof generation failed")

            raw_accuracy_scores.append(0.0)
            verification_results.append(False)
            proof_sizes.append(float("inf"))
            response_times.append(float("inf"))
            continue

        proof_path, proof_data, response_time = proof_result
        # print(
        #     f"Generated proof with size: {len(proof_data['proof'])}"
        # )
        response_times.append(response_time)

        proof = proof_data.get("proof", [])
        public_signals = [
            float(x)
            for sublist in proof_data.get("pretty_public_inputs", {}).get(
                "rescaled_outputs", []
            )
            for x in sublist
        ]
        proof_sizes.append(len(proof))
        
        verify_result = verify_proof(circuit_dir, proof_path)
        # print(f"Proof verification result: {verify_result}")
        verification_results.append(verify_result)

        if verify_result:
            raw_accuracy = compare_outputs(
                baseline_output, public_signals
            )
            # print(f"Raw accuracy: {raw_accuracy}")
            raw_accuracy_scores.append(raw_accuracy)
        else:
            print("Proof verification failed")
            raw_accuracy_scores.append(0.0)

        if not all(verification_results):
            print(
                "One or more verifications failed - setting all scores to 0"
            )
            return 0.0, float("inf"), float("inf"), False, {}

        avg_raw_accuracy = (
            sum(raw_accuracy_scores) / len(raw_accuracy_scores)
            if raw_accuracy_scores
            else 0
        )
        avg_proof_size = (
            sum(proof_sizes) / len(proof_sizes) if proof_sizes else float("inf")
        )
        avg_response_time = (
            sum(response_times) / len(response_times)
            if response_times
            else float("inf")
        )
        return (
            avg_proof_size,
            avg_response_time,
            True,
            avg_raw_accuracy,
        )

def main():
    print("** Age Verification Optimize Test **")

    print("** Downloading model... **")
    model_path = DOWNLOAD_MODEL_PATH + MODEL_NAME
    download_model(MODEL_URL, DOWNLOAD_MODEL_PATH, MODEL_NAME)

    print("** Downloading test image datasets... **")
    processed_path = download_and_process_images(DOWNLOAD_IMAGE_DATASETS, IMAGE_DATASETS_URL)

    print("** Downloading circuit files... **")
    circuit_urls = [f"{CIRCUIT_BASE_URL}{name}/" for name in CIRCUIT_NAMES]
    download_circuits_path = download_circuit_files(circuit_urls, DOWNLOAD_CIRCUIT_PATH)

    print("** Start to benchmark... **")
    for circuit_path in download_circuits_path:
        print(f"** Running: {circuit_path}... **")
        avg_proof_size, avg_response_time, verification_results, avg_raw_accuracy = benchmark(model_path, processed_path, circuit_path, TEST_COUNT)
        print(
            f"** {circuit_path} Result: \n"
            f"      - avg_proof_size: {avg_proof_size}\n"
            f"      - avg_response_time: {avg_response_time}\n"
            f"      - verification_results: {verification_results}\n"
            f"      - avg_raw_accuracy: {avg_raw_accuracy}"
        )
    print("** Benchmark Completed... **")

if __name__ == "__main__":
    main()
