#!/usr/bin/env python3
"""
Websocket server example - TLS (WSS)
openssl req -x509 -newkey rsa:2048 -nodes -keyout localhost.key -out localhost.crt -days 365 -subj "/CN=localhost"
"""

import os
import ssl  # Import the ssl module
from asyncio import TaskGroup, sleep
from collections.abc import AsyncGenerator
from ssl import SSLContext

from aiohttp import WebSocketError, WSMessageTypeError, web
from uvloop import new_event_loop

TOTAL_GB = 10
CHUNK_SIZE = 65536


async def binary_data_generator(total_gb: float, chunk_size: int) -> AsyncGenerator[bytes, None]:
    """
    An asynchronous generator that yields chunks of random binary data.

    Args:
        total_gb: The total gigabytes of data to generate.
        chunk_size: The size of each data chunk in bytes.
    """
    bytes_to_send: int = int(total_gb * 1024**3)
    bytes_sent = 0

    while bytes_sent < bytes_to_send:
        chunk: bytes = os.urandom(min(chunk_size, bytes_to_send - bytes_sent))
        yield chunk
        bytes_sent += len(chunk)
        await sleep(0)


async def recv(ws: web.WebSocketResponse) -> None:
    """Just receive the data"""
    async for msg in ws:
        if msg.type == web.WSMsgType.BINARY:
            continue

        if msg.type == web.WSMsgType.ERROR:
            break


async def send(ws: web.WebSocketResponse) -> None:
    """Send random bytes"""
    try:
        async for binary_data in binary_data_generator(TOTAL_GB, CHUNK_SIZE):
            await ws.send_bytes(binary_data)
    except ConnectionError as exc:
        print(exc)


async def ws_handler(request: web.Request) -> web.WebSocketResponse:
    """Example server that just sends binary messages continuously"""
    ws: web.WebSocketResponse = web.WebSocketResponse()
    _ = await ws.prepare(request)
    print("Secure client connected.")

    try:
        async with TaskGroup() as tg:
            _ = tg.create_task(recv(ws))
            _ = tg.create_task(send(ws))

    except* (WebSocketError, WSMessageTypeError, ConnectionError) as exc:
        print(f"Connection closed with exception group: {exc}")

    print("Client disconnected.")
    return ws


def main() -> None:
    """Entrypoint"""
    ssl_context: SSLContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    ssl_context.load_cert_chain("localhost.crt", "localhost.key")

    app: web.Application = web.Application()
    _ = app.add_routes(routes=[web.get("/ws", ws_handler)])

    print("Starting server on wss://127.0.0.1:4443/ws")
    web.run_app(
        app,
        host="127.0.0.1",
        port=4443,
        loop=new_event_loop(),
        ssl_context=ssl_context,
    )


if __name__ == "__main__":
    main()
