Skip to content

Commit

Permalink
Support restarting workers after max requests
Browse files Browse the repository at this point in the history
This is useful as a "solution" to memory leaks in apps as it ensures
that after the max requests have been handled the worker will restart
hence freeing any memory leak.

The options match those used by Gunicorn.

This also ensures that the workers self-heal such that if a worker
crashes it will be restored.
  • Loading branch information
pgjones committed Jan 1, 2024
1 parent c0468e5 commit 7c39c68
Show file tree
Hide file tree
Showing 21 changed files with 163 additions and 68 deletions.
17 changes: 17 additions & 0 deletions src/hypercorn/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ def main(sys_args: Optional[List[str]] = None) -> int:
default=sentinel,
type=int,
)
parser.add_argument(
"--max-requests",
help="""Maximum number of requests a worker will process before restarting""",
default=sentinel,
type=int,
)
parser.add_argument(
"--max-requests-jitter",
help="This jitter causes the max-requests per worker to be "
"randomized by randint(0, max_requests_jitter)",
default=sentinel,
type=int,
)
parser.add_argument(
"-g", "--group", help="Group to own any unix sockets.", default=sentinel, type=int
)
Expand Down Expand Up @@ -252,6 +265,10 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode:
config.keyfile_password = args.keyfile_password
if args.log_config is not sentinel:
config.logconfig = args.log_config
if args.max_requests is not sentinel:
config.max_requests = args.max_requests
if args.max_requests_jitter is not sentinel:
config.max_requests_jitter = args.max_requests
if args.pid is not sentinel:
config.pid_path = args.pid
if args.root_path is not sentinel:
Expand Down
23 changes: 21 additions & 2 deletions src/hypercorn/asyncio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import platform
import signal
import ssl
import sys
from functools import partial
from multiprocessing.synchronize import Event as EventType
from os import getpid
from random import randint
from socket import socket
from typing import Any, Awaitable, Callable, Optional, Set

Expand All @@ -30,6 +32,14 @@
except ImportError:
from taskgroup import Runner # type: ignore

try:
from asyncio import TaskGroup
except ImportError:
from taskgroup import TaskGroup # type: ignore

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup


def _share_socket(sock: socket) -> socket:
# Windows requires the socket be explicitly shared across
Expand Down Expand Up @@ -84,7 +94,10 @@ def _signal_handler(*_: Any) -> None: # noqa: N803
ssl_context = config.create_ssl_context()
ssl_handshake_timeout = config.ssl_handshake_timeout

context = WorkerContext()
max_requests = None
if config.max_requests is not None:
max_requests = config.max_requests + randint(0, config.max_requests_jitter)
context = WorkerContext(max_requests)
server_tasks: Set[asyncio.Task] = set()

async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
Expand Down Expand Up @@ -136,7 +149,13 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
await config.log.info(f"Running on https://{bind} (QUIC) (CTRL + C to quit)")

try:
await raise_shutdown(shutdown_trigger)
async with TaskGroup() as task_group:
task_group.create_task(raise_shutdown(shutdown_trigger))
task_group.create_task(raise_shutdown(context.terminate.wait))
except BaseExceptionGroup as error:
_, other_errors = error.split((ShutdownError, KeyboardInterrupt))
if other_errors is not None:
raise other_errors
except (ShutdownError, KeyboardInterrupt):
pass
finally:
Expand Down
15 changes: 13 additions & 2 deletions src/hypercorn/asyncio/worker_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from typing import Type, Union
from typing import Optional, Type, Union

from ..typing import Event

Expand All @@ -26,9 +26,20 @@ def is_set(self) -> bool:
class WorkerContext:
event_class: Type[Event] = EventWrapper

def __init__(self) -> None:
def __init__(self, max_requests: Optional[int]) -> None:
self.max_requests = max_requests
self.requests = 0
self.terminate = self.event_class()
self.terminated = self.event_class()

async def mark_request(self) -> None:
if self.max_requests is None:
return

self.requests += 1
if self.requests > self.max_requests:
await self.terminate.set()

@staticmethod
async def sleep(wait: Union[float, int]) -> None:
return await asyncio.sleep(wait)
Expand Down
2 changes: 2 additions & 0 deletions src/hypercorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class Config:
logger_class = Logger
loglevel: str = "INFO"
max_app_queue_size: int = 10
max_requests: Optional[int] = None
max_requests_jitter: int = 0
pid_path: Optional[str] = None
server_names: List[str] = []
shutdown_timeout = 60 * SECONDS
Expand Down
1 change: 1 addition & 0 deletions src/hypercorn/protocol/h11.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def _create_stream(self, request: h11.Request) -> None:
)
)
self.keep_alive_requests += 1
await self.context.mark_request()

async def _send_h11_event(self, event: H11SendableEvent) -> None:
try:
Expand Down
1 change: 1 addition & 0 deletions src/hypercorn/protocol/h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None:
)
)
self.keep_alive_requests += 1
await self.context.mark_request()

async def _create_server_push(
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]
Expand Down
1 change: 1 addition & 0 deletions src/hypercorn/protocol/h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ async def _create_stream(self, request: HeadersReceived) -> None:
raw_path=raw_path,
)
)
await self.context.mark_request()

async def _create_server_push(
self, stream_id: int, path: bytes, headers: List[Tuple[bytes, bytes]]
Expand Down
81 changes: 53 additions & 28 deletions src/hypercorn/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import signal
import time
from multiprocessing import get_context
from multiprocessing.connection import wait
from multiprocessing.context import BaseContext
from multiprocessing.process import BaseProcess
from multiprocessing.synchronize import Event as EventType
Expand All @@ -12,12 +13,10 @@

from .config import Config, Sockets
from .typing import WorkerFunc
from .utils import load_application, wait_for_changes, write_pid_file
from .utils import check_for_updates, files_to_watch, load_application, write_pid_file


def run(config: Config) -> int:
exit_code = 0

if config.pid_path is not None:
write_pid_file(config.pid_path)

Expand All @@ -42,67 +41,82 @@ def run(config: Config) -> int:
if config.use_reloader and config.workers == 0:
raise RuntimeError("Cannot reload without workers")

if config.use_reloader or config.workers == 0:
# Load the application so that the correct paths are checked for
# changes, but only when the reloader is being used.
load_application(config.application_path, config.wsgi_max_body_size)

exitcode = 0
if config.workers == 0:
worker_func(config, sockets)
else:
if config.use_reloader:
# Load the application so that the correct paths are checked for
# changes, but only when the reloader is being used.
load_application(config.application_path, config.wsgi_max_body_size)

ctx = get_context("spawn")

active = True
shutdown_event = ctx.Event()

def shutdown(*args: Any) -> None:
nonlocal active, shutdown_event
shutdown_event.set()
active = False

processes: List[BaseProcess] = []
while active:
# Ignore SIGINT before creating the processes, so that they
# inherit the signal handling. This means that the shutdown
# function controls the shutdown.
signal.signal(signal.SIGINT, signal.SIG_IGN)

shutdown_event = ctx.Event()
processes = start_processes(config, worker_func, sockets, shutdown_event, ctx)

def shutdown(*args: Any) -> None:
nonlocal active, shutdown_event
shutdown_event.set()
active = False
_populate(processes, config, worker_func, sockets, shutdown_event, ctx)

for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}:
if hasattr(signal, signal_name):
signal.signal(getattr(signal, signal_name), shutdown)

if config.use_reloader:
wait_for_changes(shutdown_event)
shutdown_event.set()
files = files_to_watch()
while True:
finished = wait((process.sentinel for process in processes), timeout=1)
updated = check_for_updates(files)
if updated:
shutdown_event.set()
for process in processes:
process.join()
shutdown_event.clear()
break
if len(finished) > 0:
break
else:
active = False
wait(process.sentinel for process in processes)

for process in processes:
process.join()
if process.exitcode != 0:
exit_code = process.exitcode
exitcode = _join_exited(processes)
if exitcode != 0:
shutdown_event.set()
active = False

for process in processes:
process.terminate()

exitcode = _join_exited(processes) if exitcode != 0 else exitcode

for sock in sockets.secure_sockets:
sock.close()

for sock in sockets.insecure_sockets:
sock.close()

return exit_code
return exitcode


def start_processes(
def _populate(
processes: List[BaseProcess],
config: Config,
worker_func: WorkerFunc,
sockets: Sockets,
shutdown_event: EventType,
ctx: BaseContext,
) -> List[BaseProcess]:
processes = []
for _ in range(config.workers):
) -> None:
for _ in range(config.workers - len(processes)):
process = ctx.Process( # type: ignore
target=worker_func,
kwargs={"config": config, "shutdown_event": shutdown_event, "sockets": sockets},
Expand All @@ -117,4 +131,15 @@ def start_processes(
processes.append(process)
if platform.system() == "Windows":
time.sleep(0.1)
return processes


def _join_exited(processes: List[BaseProcess]) -> int:
exitcode = 0
for index in reversed(range(len(processes))):
worker = processes[index]
if worker.exitcode is not None:
worker.join()
exitcode = worker.exitcode if exitcode == 0 else exitcode
del processes[index]

return exitcode
7 changes: 6 additions & 1 deletion src/hypercorn/trio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from functools import partial
from multiprocessing.synchronize import Event as EventType
from random import randint
from typing import Awaitable, Callable, Optional

import trio
Expand Down Expand Up @@ -37,7 +38,10 @@ async def worker_serve(
config.set_statsd_logger_class(StatsdLogger)

lifespan = Lifespan(app, config)
context = WorkerContext()
max_requests = None
if config.max_requests is not None:
max_requests = config.max_requests + randint(0, config.max_requests_jitter)
context = WorkerContext(max_requests)

async with trio.open_nursery() as lifespan_nursery:
await lifespan_nursery.start(lifespan.handle_lifespan)
Expand Down Expand Up @@ -82,6 +86,7 @@ async def worker_serve(
async with trio.open_nursery(strict_exception_groups=True) as nursery:
if shutdown_trigger is not None:
nursery.start_soon(raise_shutdown, shutdown_trigger)
nursery.start_soon(raise_shutdown, context.terminate.wait)

nursery.start_soon(
partial(
Expand Down
15 changes: 13 additions & 2 deletions src/hypercorn/trio/worker_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Type, Union
from typing import Optional, Type, Union

import trio

Expand All @@ -27,9 +27,20 @@ def is_set(self) -> bool:
class WorkerContext:
event_class: Type[Event] = EventWrapper

def __init__(self) -> None:
def __init__(self, max_requests: Optional[int]) -> None:
self.max_requests = max_requests
self.requests = 0
self.terminate = self.event_class()
self.terminated = self.event_class()

async def mark_request(self) -> None:
if self.max_requests is None:
return

self.requests += 1
if self.requests > self.max_requests:
await self.terminate.set()

@staticmethod
async def sleep(wait: Union[float, int]) -> None:
return await trio.sleep(wait)
Expand Down
4 changes: 4 additions & 0 deletions src/hypercorn/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,12 @@ def is_set(self) -> bool:

class WorkerContext(Protocol):
event_class: Type[Event]
terminate: Event
terminated: Event

async def mark_request(self) -> None:
...

@staticmethod
async def sleep(wait: Union[float, int]) -> None:
...
Expand Down
Loading

0 comments on commit 7c39c68

Please sign in to comment.