Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests that leak threads (pytest 8) #3358

Merged
merged 9 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
h2==4.1.0
coverage==7.4.1
PySocks==1.7.1
pytest==7.4.4
pytest==8.0.2
pytest-timeout==2.1.0
pyOpenSSL==24.0.0
idna==3.4
Expand Down
2 changes: 2 additions & 0 deletions dummyserver/socketserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ def __init__(
socket_handler: typing.Callable[[socket.socket], None],
host: str = "localhost",
ready_event: threading.Event | None = None,
quit_event: threading.Event | None = None,
) -> None:
super().__init__()
self.daemon = True

self.socket_handler = socket_handler
self.host = host
self.ready_event = ready_event
self.quit_event = quit_event

def _start_server(self) -> None:
if self.USE_IPV6:
Expand Down
76 changes: 64 additions & 12 deletions dummyserver/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ssl
import threading
import typing
from test import LONG_TIMEOUT

import hypercorn
import pytest
Expand All @@ -19,11 +20,19 @@


def consume_socket(
sock: SSLTransport | socket.socket, chunks: int = 65536
sock: SSLTransport | socket.socket,
chunks: int = 65536,
quit_event: threading.Event | None = None,
) -> bytearray:
consumed = bytearray()
sock.settimeout(LONG_TIMEOUT)
while True:
b = sock.recv(chunks)
if quit_event and quit_event.is_set():
break
try:
b = sock.recv(chunks)
except (TimeoutError, socket.timeout):
continue
assert isinstance(b, bytes)
consumed += b
if b.endswith(b"\r\n\r\n"):
Expand Down Expand Up @@ -57,11 +66,16 @@ class SocketDummyServerTestCase:

@classmethod
def _start_server(
cls, socket_handler: typing.Callable[[socket.socket], None]
cls,
socket_handler: typing.Callable[[socket.socket], None],
quit_event: threading.Event | None = None,
) -> None:
ready_event = threading.Event()
cls.server_thread = SocketServerThread(
socket_handler=socket_handler, ready_event=ready_event, host=cls.host
socket_handler=socket_handler,
ready_event=ready_event,
host=cls.host,
quit_event=quit_event,
)
cls.server_thread.start()
ready_event.wait(5)
Expand All @@ -71,23 +85,41 @@ def _start_server(

@classmethod
def start_response_handler(
cls, response: bytes, num: int = 1, block_send: threading.Event | None = None
cls,
response: bytes,
num: int = 1,
block_send: threading.Event | None = None,
) -> threading.Event:
ready_event = threading.Event()
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
for _ in range(num):
ready_event.set()

sock = listener.accept()[0]
consume_socket(sock)
listener.settimeout(LONG_TIMEOUT)
while True:
if quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue
consume_socket(sock, quit_event=quit_event)
if quit_event.is_set():
sock.close()
return
if block_send:
block_send.wait()
while not block_send.wait(LONG_TIMEOUT):
if quit_event.is_set():
sock.close()
return
block_send.clear()
sock.send(response)
sock.close()

cls._start_server(socket_handler)
cls._start_server(socket_handler, quit_event=quit_event)
return ready_event

@classmethod
Expand All @@ -100,10 +132,25 @@ def start_basic_handler(
block_send,
)

@staticmethod
def quit_server_thread(server_thread: SocketServerThread) -> None:
if server_thread.quit_event:
server_thread.quit_event.set()
# in principle the maximum time that the thread can take to notice
# the quit_event is LONG_TIMEOUT and the thread should terminate
# shortly after that, we give 5 seconds leeway just in case
server_thread.join(LONG_TIMEOUT * 2 + 5.0)
if server_thread.is_alive():
raise Exception("server_thread did not exit")

@classmethod
def teardown_class(cls) -> None:
if hasattr(cls, "server_thread"):
cls.server_thread.join(0.1)
cls.quit_server_thread(cls.server_thread)

def teardown_method(self) -> None:
if hasattr(self, "server_thread"):
self.quit_server_thread(self.server_thread)

def assert_header_received(
self,
Expand All @@ -128,11 +175,16 @@ def assert_header_received(
class IPV4SocketDummyServerTestCase(SocketDummyServerTestCase):
@classmethod
def _start_server(
cls, socket_handler: typing.Callable[[socket.socket], None]
cls,
socket_handler: typing.Callable[[socket.socket], None],
quit_event: threading.Event | None = None,
) -> None:
ready_event = threading.Event()
cls.server_thread = SocketServerThread(
socket_handler=socket_handler, ready_event=ready_event, host=cls.host
socket_handler=socket_handler,
ready_event=ready_event,
host=cls.host,
quit_event=quit_event,
)
cls.server_thread.USE_IPV6 = False
cls.server_thread.start()
Expand Down
18 changes: 14 additions & 4 deletions test/test_ssltransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import select
import socket
import ssl
import threading
import typing
from unittest import mock

Expand Down Expand Up @@ -111,20 +112,29 @@ def setup_class(cls) -> None:
cls.server_context, cls.client_context = server_client_ssl_contexts()

def start_dummy_server(
self, handler: typing.Callable[[socket.socket], None] | None = None
self,
handler: typing.Callable[[socket.socket], None] | None = None,
validate: bool = True,
) -> None:
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
sock = listener.accept()[0]
try:
with self.server_context.wrap_socket(sock, server_side=True) as ssock:
request = consume_socket(ssock)
request = consume_socket(
ssock,
quit_event=quit_event,
)
if not validate:
return
validate_request(request)
ssock.send(sample_response())
except (ConnectionAbortedError, ConnectionResetError):
return

chosen_handler = handler if handler else socket_handler
self._start_server(chosen_handler)
self._start_server(chosen_handler, quit_event=quit_event)

@pytest.mark.timeout(PER_TEST_TIMEOUT)
def test_start_closed_socket(self) -> None:
Expand All @@ -138,7 +148,7 @@ def test_start_closed_socket(self) -> None:
@pytest.mark.timeout(PER_TEST_TIMEOUT)
def test_close_after_handshake(self) -> None:
"""Socket errors should be bubbled up"""
self.start_dummy_server()
self.start_dummy_server(validate=False)

sock = socket.create_connection((self.host, self.port))
with SSLTransport(
Expand Down
59 changes: 47 additions & 12 deletions test/with_dummyserver/test_socketlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import socket
import ssl
import tempfile
import threading
import typing
import zlib
from collections import OrderedDict
Expand Down Expand Up @@ -955,7 +956,11 @@ def socket_handler(listener: socket.socket) -> None:
assert response.connection is None

def test_socket_close_socket_then_file(self) -> None:
def consume_ssl_socket(listener: socket.socket) -> None:
quit_event = threading.Event()

def consume_ssl_socket(
listener: socket.socket,
) -> None:
try:
with listener.accept()[0] as sock, original_ssl_wrap_socket(
sock,
Expand All @@ -964,11 +969,11 @@ def consume_ssl_socket(listener: socket.socket) -> None:
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
) as ssl_sock:
consume_socket(ssl_sock)
consume_socket(ssl_sock, quit_event=quit_event)
except (ConnectionResetError, ConnectionAbortedError, OSError):
pass

self._start_server(consume_ssl_socket)
self._start_server(consume_ssl_socket, quit_event=quit_event)
with socket.create_connection(
(self.host, self.port)
) as sock, contextlib.closing(
Expand All @@ -983,6 +988,8 @@ def consume_ssl_socket(listener: socket.socket) -> None:
assert ssl_sock.fileno() == -1

def test_socket_close_stays_open_with_makefile_open(self) -> None:
quit_event = threading.Event()

def consume_ssl_socket(listener: socket.socket) -> None:
try:
with listener.accept()[0] as sock, original_ssl_wrap_socket(
Expand All @@ -992,11 +999,11 @@ def consume_ssl_socket(listener: socket.socket) -> None:
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
) as ssl_sock:
consume_socket(ssl_sock)
consume_socket(ssl_sock, quit_event=quit_event)
except (ConnectionResetError, ConnectionAbortedError, OSError):
pass

self._start_server(consume_ssl_socket)
self._start_server(consume_ssl_socket, quit_event=quit_event)
with socket.create_connection(
(self.host, self.port)
) as sock, contextlib.closing(
Expand Down Expand Up @@ -2232,11 +2239,28 @@ def socket_handler(listener: socket.socket) -> None:

class TestMultipartResponse(SocketDummyServerTestCase):
def test_multipart_assert_header_parsing_no_defects(self) -> None:
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
for _ in range(2):
sock = listener.accept()[0]
while not sock.recv(65536).endswith(b"\r\n\r\n"):
pass
listener.settimeout(LONG_TIMEOUT)

while True:
if quit_event and quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue

sock.settimeout(LONG_TIMEOUT)
while True:
if quit_event and quit_event.is_set():
sock.close()
return
if sock.recv(65536).endswith(b"\r\n\r\n"):
break

sock.sendall(
b"HTTP/1.1 404 Not Found\r\n"
Expand All @@ -2252,7 +2276,7 @@ def socket_handler(listener: socket.socket) -> None:
)
sock.close()

self._start_server(socket_handler)
self._start_server(socket_handler, quit_event=quit_event)
from urllib3.connectionpool import log

with mock.patch.object(log, "warning") as log_warning:
Expand Down Expand Up @@ -2308,15 +2332,26 @@ def socket_handler(listener: socket.socket) -> None:
def test_chunked_specified(
self, method: str, chunked: bool, body_type: str
) -> None:
quit_event = threading.Event()
buffer = bytearray()
expected_bytes = b"\r\n\r\na\r\nxxxxxxxxxx\r\n0\r\n\r\n"

def socket_handler(listener: socket.socket) -> None:
nonlocal buffer
sock = listener.accept()[0]
sock.settimeout(0)
listener.settimeout(LONG_TIMEOUT)
while True:
if quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue
sock.settimeout(LONG_TIMEOUT)

while expected_bytes not in buffer:
if quit_event.is_set():
return
with contextlib.suppress(BlockingIOError):
buffer += sock.recv(65536)

Expand All @@ -2327,7 +2362,7 @@ def socket_handler(listener: socket.socket) -> None:
)
sock.close()

self._start_server(socket_handler)
self._start_server(socket_handler, quit_event=quit_event)

body: typing.Any
if body_type == "generator":
Expand Down