Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests that leak threads (pytest 8) #3358

Merged
merged 9 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
h2==4.1.0
coverage==7.4.1
PySocks==1.7.1
pytest==7.4.4
pytest==8.0.2
pytest-timeout==2.1.0
pyOpenSSL==24.0.0
idna==3.4
Expand Down
2 changes: 2 additions & 0 deletions dummyserver/socketserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ def __init__(
socket_handler: typing.Callable[[socket.socket], None],
host: str = "localhost",
ready_event: threading.Event | None = None,
quit_event: threading.Event | None = None,
) -> None:
super().__init__()
self.daemon = True

self.socket_handler = socket_handler
self.host = host
self.ready_event = ready_event
self.quit_event = quit_event

def _start_server(self) -> None:
if self.USE_IPV6:
Expand Down
73 changes: 61 additions & 12 deletions dummyserver/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ssl
import threading
import typing
from test import LONG_TIMEOUT

import hypercorn
import pytest
Expand All @@ -19,11 +20,19 @@


def consume_socket(
sock: SSLTransport | socket.socket, chunks: int = 65536
sock: SSLTransport | socket.socket,
chunks: int = 65536,
quit_event: threading.Event | None = None,
) -> bytearray:
consumed = bytearray()
sock.settimeout(LONG_TIMEOUT)
while True:
b = sock.recv(chunks)
if quit_event and quit_event.is_set():
break
try:
b = sock.recv(chunks)
except (TimeoutError, socket.timeout):
continue
assert isinstance(b, bytes)
consumed += b
if b.endswith(b"\r\n\r\n"):
Expand Down Expand Up @@ -57,11 +66,16 @@ class SocketDummyServerTestCase:

@classmethod
def _start_server(
cls, socket_handler: typing.Callable[[socket.socket], None]
cls,
socket_handler: typing.Callable[[socket.socket], None],
quit_event: threading.Event | None = None,
) -> None:
ready_event = threading.Event()
cls.server_thread = SocketServerThread(
socket_handler=socket_handler, ready_event=ready_event, host=cls.host
socket_handler=socket_handler,
ready_event=ready_event,
host=cls.host,
quit_event=quit_event,
)
cls.server_thread.start()
ready_event.wait(5)
Expand All @@ -71,23 +85,41 @@ def _start_server(

@classmethod
def start_response_handler(
cls, response: bytes, num: int = 1, block_send: threading.Event | None = None
cls,
response: bytes,
num: int = 1,
block_send: threading.Event | None = None,
) -> threading.Event:
ready_event = threading.Event()
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
for _ in range(num):
ready_event.set()

sock = listener.accept()[0]
consume_socket(sock)
listener.settimeout(LONG_TIMEOUT)
while True:
if quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue
consume_socket(sock, quit_event=quit_event)
if quit_event.is_set():
sock.close()
return
if block_send:
block_send.wait()
while not block_send.wait(LONG_TIMEOUT):
if quit_event.is_set():
sock.close()
return
block_send.clear()
sock.send(response)
sock.close()

cls._start_server(socket_handler)
cls._start_server(socket_handler, quit_event=quit_event)
return ready_event

@classmethod
Expand All @@ -103,7 +135,19 @@ def start_basic_handler(
@classmethod
def teardown_class(cls) -> None:
if hasattr(cls, "server_thread"):
cls.server_thread.join(0.1)
if cls.server_thread.quit_event:
cls.server_thread.quit_event.set()
cls.server_thread.join(LONG_TIMEOUT * 2 + 5.0)
if cls.server_thread.is_alive():
raise Exception("server_thread did not exit")

def teardown_method(self) -> None:
if hasattr(self, "server_thread"):
if self.server_thread.quit_event:
self.server_thread.quit_event.set()
self.server_thread.join(LONG_TIMEOUT * 2 + 5.0)
if self.server_thread.is_alive():
raise Exception("server_thread did not exit")

def assert_header_received(
self,
Expand All @@ -128,11 +172,16 @@ def assert_header_received(
class IPV4SocketDummyServerTestCase(SocketDummyServerTestCase):
@classmethod
def _start_server(
cls, socket_handler: typing.Callable[[socket.socket], None]
cls,
socket_handler: typing.Callable[[socket.socket], None],
quit_event: threading.Event | None = None,
) -> None:
ready_event = threading.Event()
cls.server_thread = SocketServerThread(
socket_handler=socket_handler, ready_event=ready_event, host=cls.host
socket_handler=socket_handler,
ready_event=ready_event,
host=cls.host,
quit_event=quit_event,
)
cls.server_thread.USE_IPV6 = False
cls.server_thread.start()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ markers = [
"limit_memory: Limit memory with memray",
"requires_network: This test needs access to the Internet",
"integration: Slow integrations tests not run by default",
"server_threads",
ecerulm marked this conversation as resolved.
Show resolved Hide resolved
]
log_level = "DEBUG"
filterwarnings = [
Expand Down
1 change: 1 addition & 0 deletions test/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_memory_usage(

assert len(get_func(buffer)) == 10 * 2**20

@pytest.mark.server_threads
@pytest.mark.limit_memory("10.01 MB")
def test_get_all_memory_usage_single_chunk(self) -> None:
buffer = BytesQueueBuffer()
Expand Down
19 changes: 15 additions & 4 deletions test/test_ssltransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import select
import socket
import ssl
import threading
import typing
from unittest import mock

Expand Down Expand Up @@ -111,20 +112,29 @@ def setup_class(cls) -> None:
cls.server_context, cls.client_context = server_client_ssl_contexts()

def start_dummy_server(
self, handler: typing.Callable[[socket.socket], None] | None = None
self,
handler: typing.Callable[[socket.socket], None] | None = None,
validate: bool = True,
) -> None:
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
sock = listener.accept()[0]
try:
with self.server_context.wrap_socket(sock, server_side=True) as ssock:
request = consume_socket(ssock)
request = consume_socket(
ssock,
quit_event=quit_event,
)
if not validate:
return
validate_request(request)
ssock.send(sample_response())
except (ConnectionAbortedError, ConnectionResetError):
return

chosen_handler = handler if handler else socket_handler
self._start_server(chosen_handler)
self._start_server(chosen_handler, quit_event=quit_event)

@pytest.mark.timeout(PER_TEST_TIMEOUT)
def test_start_closed_socket(self) -> None:
Expand All @@ -135,10 +145,11 @@ def test_start_closed_socket(self) -> None:
with pytest.raises(OSError):
SSLTransport(sock, context)

@pytest.mark.server_threads
@pytest.mark.timeout(PER_TEST_TIMEOUT)
def test_close_after_handshake(self) -> None:
"""Socket errors should be bubbled up"""
self.start_dummy_server()
self.start_dummy_server(validate=False)

sock = socket.create_connection((self.host, self.port))
with SSLTransport(
Expand Down
1 change: 1 addition & 0 deletions test/with_dummyserver/test_connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_total_applies_connect(self) -> None:
finally:
conn.close()

@pytest.mark.server_threads
def test_total_timeout(self) -> None:
block_event = Event()
ready_event = self.start_basic_handler(block_send=block_event, num=2)
Expand Down
63 changes: 51 additions & 12 deletions test/with_dummyserver/test_socketlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import socket
import ssl
import tempfile
import threading
import typing
import zlib
from collections import OrderedDict
Expand Down Expand Up @@ -954,8 +955,13 @@ def socket_handler(listener: socket.socket) -> None:
assert pool.pool.qsize() == 1
assert response.connection is None

@pytest.mark.server_threads
def test_socket_close_socket_then_file(self) -> None:
def consume_ssl_socket(listener: socket.socket) -> None:
quit_event = threading.Event()

def consume_ssl_socket(
listener: socket.socket,
) -> None:
try:
with listener.accept()[0] as sock, original_ssl_wrap_socket(
sock,
Expand All @@ -964,11 +970,11 @@ def consume_ssl_socket(listener: socket.socket) -> None:
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
) as ssl_sock:
consume_socket(ssl_sock)
consume_socket(ssl_sock, quit_event=quit_event)
except (ConnectionResetError, ConnectionAbortedError, OSError):
pass

self._start_server(consume_ssl_socket)
self._start_server(consume_ssl_socket, quit_event=quit_event)
with socket.create_connection(
(self.host, self.port)
) as sock, contextlib.closing(
Expand All @@ -982,7 +988,10 @@ def consume_ssl_socket(listener: socket.socket) -> None:
ssl_sock.sendall(b"hello")
assert ssl_sock.fileno() == -1

@pytest.mark.server_threads
def test_socket_close_stays_open_with_makefile_open(self) -> None:
quit_event = threading.Event()

def consume_ssl_socket(listener: socket.socket) -> None:
try:
with listener.accept()[0] as sock, original_ssl_wrap_socket(
Expand All @@ -992,11 +1001,11 @@ def consume_ssl_socket(listener: socket.socket) -> None:
certfile=DEFAULT_CERTS["certfile"],
ca_certs=DEFAULT_CA,
) as ssl_sock:
consume_socket(ssl_sock)
consume_socket(ssl_sock, quit_event=quit_event)
except (ConnectionResetError, ConnectionAbortedError, OSError):
pass

self._start_server(consume_ssl_socket)
self._start_server(consume_ssl_socket, quit_event=quit_event)
with socket.create_connection(
(self.host, self.port)
) as sock, contextlib.closing(
Expand Down Expand Up @@ -2231,12 +2240,30 @@ def socket_handler(listener: socket.socket) -> None:


class TestMultipartResponse(SocketDummyServerTestCase):
@pytest.mark.server_threads
def test_multipart_assert_header_parsing_no_defects(self) -> None:
quit_event = threading.Event()

def socket_handler(listener: socket.socket) -> None:
for _ in range(2):
sock = listener.accept()[0]
while not sock.recv(65536).endswith(b"\r\n\r\n"):
pass
listener.settimeout(LONG_TIMEOUT)

while True:
if quit_event and quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue

sock.settimeout(LONG_TIMEOUT)
while True:
if quit_event and quit_event.is_set():
sock.close()
return
if sock.recv(65536).endswith(b"\r\n\r\n"):
break

sock.sendall(
b"HTTP/1.1 404 Not Found\r\n"
Expand All @@ -2252,7 +2279,7 @@ def socket_handler(listener: socket.socket) -> None:
)
sock.close()

self._start_server(socket_handler)
self._start_server(socket_handler, quit_event=quit_event)
from urllib3.connectionpool import log

with mock.patch.object(log, "warning") as log_warning:
Expand Down Expand Up @@ -2302,21 +2329,33 @@ def socket_handler(listener: socket.socket) -> None:
assert b"Content-Length: 0\r\n" in sent_bytes
assert b"transfer-encoding" not in sent_bytes.lower()

@pytest.mark.server_threads
@pytest.mark.parametrize("chunked", [True, False])
@pytest.mark.parametrize("method", ["POST", "PUT", "PATCH"])
@pytest.mark.parametrize("body_type", ["file", "generator", "bytes"])
def test_chunked_specified(
self, method: str, chunked: bool, body_type: str
) -> None:
quit_event = threading.Event()
buffer = bytearray()
expected_bytes = b"\r\n\r\na\r\nxxxxxxxxxx\r\n0\r\n\r\n"

def socket_handler(listener: socket.socket) -> None:
nonlocal buffer
sock = listener.accept()[0]
sock.settimeout(0)
listener.settimeout(LONG_TIMEOUT)
while True:
if quit_event.is_set():
return
try:
sock = listener.accept()[0]
break
except (TimeoutError, socket.timeout):
continue
sock.settimeout(LONG_TIMEOUT)

while expected_bytes not in buffer:
if quit_event.is_set():
return
with contextlib.suppress(BlockingIOError):
buffer += sock.recv(65536)

Expand All @@ -2327,7 +2366,7 @@ def socket_handler(listener: socket.socket) -> None:
)
sock.close()

self._start_server(socket_handler)
self._start_server(socket_handler, quit_event=quit_event)

body: typing.Any
if body_type == "generator":
Expand Down