#!/usr/bin/env python3
"""
WebSocket benchmark with SHA256 verification.

1. Generate self-signed certificate (once):
   openssl req -x509 -newkey rsa:2048 -nodes -keyout localhost.key -out localhost.crt -days 365 -subj "/CN=localhost"

2. Generate the test data file (once per test size):
   python benchmark.py generate

3. Run the server in one terminal:
   python benchmark.py server

4. Run the client tests in another terminal:
   python benchmark.py client --test download
   python benchmark.py client --test upload
   # (Comparison) Test aiohttp upload speed
   python benchmark.py client --test upload_aiohttp
"""
import argparse
import asyncio
import hashlib
import mmap
import os
import ssl
import sys
import time
from collections.abc import AsyncGenerator
from ssl import SSLContext

import uvloop
from aiohttp import web
from curl_cffi import AsyncSession, CurlOpt

# --- Config ---
TOTAL_GB: float = 10
CHUNK_SIZE: int = 65536
DATA_FILENAME = "testdata.bin"
HASH_FILENAME = "testdata.hash"
TOTAL_BYTES = int(TOTAL_GB * 1024**3)


def generate_files():
    print(f"--- Generating {TOTAL_GB:.2f} GB test file: '{DATA_FILENAME}' ---")
    if os.path.exists(DATA_FILENAME) and os.path.getsize(DATA_FILENAME) == TOTAL_BYTES:
        print("File already exists. Skipping generation.")
        return
    start_time, hasher, megabyte_chunk = (
        time.perf_counter(),
        hashlib.sha256(),
        os.urandom(1024 * 1024),
    )
    try:
        with open(DATA_FILENAME, "wb") as f:
            for i in range(TOTAL_BYTES // len(megabyte_chunk)):
                f.write(megabyte_chunk)
                hasher.update(megabyte_chunk)
                if (i + 1) % 100 == 0:
                    print(
                        f"\rGenerating... {(i * len(megabyte_chunk)) / TOTAL_BYTES:.0%}",
                        end="",
                        flush=True,
                    )
        print("\rGeneration complete.     ")
        with open(HASH_FILENAME, "w") as f:
            f.write(hasher.hexdigest())
        print(
            f"Generated files in {time.perf_counter() - start_time:.2f}s. Hash: {hasher.hexdigest()}"
        )
    except Exception as e:
        print(f"\nError: {e}", file=sys.stderr)
        sys.exit(1)


def load_test_data() -> bytes:
    print(f"--- Loading '{DATA_FILENAME}' into memory... ---")
    try:
        with open(DATA_FILENAME, "rb") as f, mmap.mmap(
            f.fileno(), 0, access=mmap.ACCESS_READ
        ) as mm:
            return bytes(mm)
    except FileNotFoundError:
        print(
            f"Error: '{DATA_FILENAME}' not found. Run 'python benchmark.py generate' first.",
            file=sys.stderr,
        )
        sys.exit(1)


def load_source_hash() -> str:
    try:
        with open(HASH_FILENAME, "r") as f:
            return f.read().strip()
    except FileNotFoundError:
        print(
            f"Error: '{HASH_FILENAME}' not found. Run 'python benchmark.py generate' first.",
            file=sys.stderr,
        )
        sys.exit(1)


async def server_ws_send_handler(ws: web.WebSocketResponse, test_data: bytes):
    print("Client connected for DOWNLOAD. Sending data...")
    start_time = time.perf_counter()
    try:
        data_view = memoryview(test_data)
        for i in range(0, len(data_view), CHUNK_SIZE):
            await ws.send_bytes(data_view[i : i + CHUNK_SIZE])
    finally:
        await ws.close()
    print(
        f"Server sent {TOTAL_GB:.2f} GB in {time.perf_counter() - start_time:.2f}s. Client disconnected."
    )


async def server_ws_recv_handler(ws: web.WebSocketResponse, source_hash: str):
    print("Client connected for UPLOAD. Awaiting data...")
    hasher, bytes_received, start_time = hashlib.sha256(), 0, time.perf_counter()
    async for msg in ws:
        if msg.type == web.WSMsgType.BINARY:
            hasher.update(msg.data)
            bytes_received += len(msg.data)
    duration = time.perf_counter() - start_time
    print("\n--- Upload Complete (Server-Side) ---")
    print(f"Received {bytes_received / (1024**3):.2f} GB in {duration:.2f}s.")
    received_hash = hasher.hexdigest()
    print(f"Source Hash:   {source_hash}")
    print(f"Received Hash: {received_hash}")
    if source_hash == received_hash:
        print("✅ Server-Side Hash Verification SUCCESSFUL")
    else:
        print("❌ Server-Side Hash Verification FAILED")
    print("Client disconnected.")


def run_server_blocking():
    test_data, source_hash = load_test_data(), load_source_hash()
    ssl_context = SSLContext(ssl.PROTOCOL_TLS_SERVER)
    try:
        ssl_context.load_cert_chain("localhost.crt", "localhost.key")
    except FileNotFoundError:
        print(
            "Error: Certificate files not found. Run `openssl` command.",
            file=sys.stderr,
        )
        sys.exit(1)

    async def main_handler(request: web.Request) -> web.WebSocketResponse:
        ws = web.WebSocketResponse()
        await ws.prepare(request)
        if request.query.get("test") == "upload":
            await server_ws_recv_handler(ws, source_hash)
        else:
            await server_ws_send_handler(ws, test_data)
        return ws

    app = web.Application()
    app.add_routes([web.get("/ws", main_handler)])
    print("Starting smart server on wss://127.0.0.1:4443/ws")
    web.run_app(
        app, host="127.0.0.1", port=4443, ssl_context=ssl_context, print=lambda _: None
    )


async def data_chunk_iterator(data: bytes) -> AsyncGenerator[bytes, None]:
    data_view = memoryview(data)
    for i in range(0, len(data_view), CHUNK_SIZE):
        yield data_view[i : i + CHUNK_SIZE]


async def run_client_download(source_hash: str):
    print("--- Starting curl-cffi Benchmark (DOWNLOAD) ---")
    print(f"Expecting Hash: {source_hash}")
    received_chunks: list[bytes] = []
    try:
        async with AsyncSession(impersonate="chrome", verify=False) as session:
            ws = await session.ws_connect("wss://127.0.0.1:4443/ws?test=download")
            print("Receiving data into memory...")
            start_time = time.perf_counter()
            async for msg in ws:
                received_chunks.append(msg)
            dl_duration = time.perf_counter() - start_time
        print("\nVerifying data...")
        start_time = time.perf_counter()
        full_data = b"".join(received_chunks)
        received_hash = hashlib.sha256(full_data).hexdigest()
        del full_data
        verify_duration = time.perf_counter() - start_time
        bytes_received = sum(len(c) for c in received_chunks)
        rate = (bytes_received / (1024**3) * 8) / dl_duration if dl_duration > 0 else 0
        print("\n--- Results ---")
        print(f"Received {bytes_received/(1024**3):.2f} GB.")
        print(
            f"  - Download Time: {dl_duration:.2f}s | Verification Time: {verify_duration:.2f}s"
        )
        print(f"Average Download Throughput: {rate:.2f} Gbps.")
        print(f"\nSource: {source_hash}\nActual: {received_hash}")
        if source_hash == received_hash:
            print("✅ Hash Verification SUCCESSFUL")
        else:
            print("❌ Hash Verification FAILED")
    finally:
        print("--- curl-cffi Benchmark Complete ---")


async def run_client_upload() -> None:
    print("--- Starting curl-cffi Benchmark (UPLOAD) ---")
    test_data = load_test_data()
    try:
        async with AsyncSession(impersonate="chrome", verify=False) as session:
            ws = await session.ws_connect("wss://127.0.0.1:4443/ws?test=upload")
            ws._curl.setopt(CurlOpt.TCP_NODELAY, 1)
            print("Sending data from memory...")
            start_time = time.perf_counter()
            async for chunk in data_chunk_iterator(test_data):
                await ws.send(chunk)
            await ws.flush()
            await ws.close()
            duration = time.perf_counter() - start_time
        rate = (TOTAL_BYTES / (1024**3) * 8) / duration if duration > 0 else 0
        print("\n--- Results ---")
        print(f"Sent {TOTAL_GB:.2f} GB in {duration:.2f} seconds.")
        print(f"Average Upload Throughput: {rate:.2f} Gbps.")
    finally:
        del test_data
        print("--- curl-cffi Benchmark Complete ---")


async def run_client_upload_aiohttp():
    import aiohttp

    print("--- Starting aiohttp Benchmark (UPLOAD) ---")
    test_data = load_test_data()
    try:
        async with aiohttp.ClientSession() as session:
            async with session.ws_connect(
                "wss://127.0.0.1:4443/ws?test=upload", ssl=False
            ) as ws:
                print("Sending data from memory...")
                start_time = time.perf_counter()
                async for chunk in data_chunk_iterator(test_data):
                    await ws.send_bytes(chunk)
                await ws.close()
                duration = time.perf_counter() - start_time
        rate = (TOTAL_BYTES / (1024**3) * 8) / duration if duration > 0 else 0
        print("\n--- Results ---")
        print(f"Sent {TOTAL_GB:.2f} GB in {duration:.2f} seconds.")
        print(f"Average Upload Throughput: {rate:.2f} Gbps.")
    finally:
        del test_data
        print("--- aiohttp Benchmark Complete ---")


def main():
    parser = argparse.ArgumentParser(description="WebSocket Unidirectional Benchmark")
    parser.add_argument(
        "mode", choices=["generate", "server", "client"], help="Operation"
    )
    parser.add_argument(
        "--test",
        choices=["download", "upload", "upload_aiohttp"],
        default="download",
    )
    args = parser.parse_args()

    if args.mode == "generate":
        generate_files()
    elif args.mode == "server":
        uvloop.install()
        run_server_blocking()
    elif args.mode == "client":
        uvloop.install()
        if args.test == "download":
            source_hash = load_source_hash()
            asyncio.run(run_client_download(source_hash))
        elif args.test == "upload":
            asyncio.run(run_client_upload())
        elif args.test == "upload_aiohttp":
            asyncio.run(run_client_upload_aiohttp())


if __name__ == "__main__":
    main()
