Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consume connections better in socket-level tests #1958

Merged
merged 4 commits into from
Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 63 additions & 0 deletions dummyserver/testcase.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import threading
from contextlib import contextmanager

import pytest
from tornado import ioloop, web

from urllib3.connection import HTTPConnection

from dummyserver.server import (
SocketServerThread,
run_tornado_app,
Expand Down Expand Up @@ -217,3 +220,63 @@ class IPv6HTTPDummyProxyTestCase(HTTPDummyProxyTestCase):

proxy_host = "::1"
proxy_host_alt = "127.0.0.1"


class ConnectionMarker(object):
"""
Marks an HTTP(S)Connection's socket after a request was made.

Helps a test server understand when a client finished a request,
without implementing a complete HTTP server.
"""

MARK_FORMAT = b"$#MARK%04x*!"

@classmethod
@contextmanager
def mark(cls, monkeypatch):
"""
Mark connections under in that context.
"""

orig_request = HTTPConnection.request
orig_request_chunked = HTTPConnection.request_chunked

def call_and_mark(target):
def part(self, *args, **kwargs):
result = target(self, *args, **kwargs)
self.sock.sendall(cls._get_socket_mark(self.sock, False))
return result

return part

with monkeypatch.context() as m:
m.setattr(HTTPConnection, "request", call_and_mark(orig_request))
m.setattr(
HTTPConnection, "request_chunked", call_and_mark(orig_request_chunked)
)
yield

@classmethod
def consume_request(cls, sock, chunks=65536):
"""
Consume a socket until after the HTTP request is sent.
"""
consumed = bytearray()
mark = cls._get_socket_mark(sock, True)
while True:
b = sock.recv(chunks)
if not b:
break
consumed += b
if consumed.endswith(mark):
break
return consumed

@classmethod
def _get_socket_mark(cls, sock, server):
if server:
port = sock.getpeername()[1]
else:
port = sock.getsockname()[1]
return cls.MARK_FORMAT % (port,)
49 changes: 26 additions & 23 deletions test/with_dummyserver/test_chunked_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from urllib3 import HTTPConnectionPool
from urllib3.util.retry import Retry
from urllib3.util import SUPPRESS_USER_AGENT
from dummyserver.testcase import SocketDummyServerTestCase, consume_socket
from test import notWindows
from dummyserver.testcase import (
SocketDummyServerTestCase,
consume_socket,
ConnectionMarker,
)

# Retry failed tests
pytestmark = pytest.mark.flaky
Expand Down Expand Up @@ -156,56 +159,56 @@ def socket_handler(listener):
sock.close()
assert self.chunked_requests == 2

@notWindows
def test_preserve_chunked_on_redirect(self):
def test_preserve_chunked_on_redirect(self, monkeypatch):
self.chunked_requests = 0

def socket_handler(listener):
for i in range(2):
sock = listener.accept()[0]
request = consume_socket(sock)
request = ConnectionMarker.consume_request(sock)
if b"Transfer-Encoding: chunked" in request.split(b"\r\n"):
self.chunked_requests += 1

if i == 0:
sock.send(
sock.sendall(
b"HTTP/1.1 301 Moved Permanently\r\n"
b"Location: /redirect\r\n\r\n"
)
else:
sock.send(b"HTTP/1.1 200 OK\r\n\r\n")
sock.sendall(b"HTTP/1.1 200 OK\r\n\r\n")
sock.close()

self._start_server(socket_handler)
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(redirect=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
with ConnectionMarker.mark(monkeypatch):
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(redirect=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
assert self.chunked_requests == 2

@notWindows
def test_preserve_chunked_on_broken_connection(self):
def test_preserve_chunked_on_broken_connection(self, monkeypatch):
self.chunked_requests = 0

def socket_handler(listener):
for i in range(2):
sock = listener.accept()[0]
request = consume_socket(sock)
request = ConnectionMarker.consume_request(sock)
if b"Transfer-Encoding: chunked" in request.split(b"\r\n"):
self.chunked_requests += 1

if i == 0:
# Bad HTTP version will trigger a connection close
sock.send(b"HTTP/0.5 200 OK\r\n\r\n")
sock.sendall(b"HTTP/0.5 200 OK\r\n\r\n")
else:
sock.send(b"HTTP/1.1 200 OK\r\n\r\n")
sock.sendall(b"HTTP/1.1 200 OK\r\n\r\n")
sock.close()

self._start_server(socket_handler)
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(read=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
assert self.chunked_requests == 2
with ConnectionMarker.mark(monkeypatch):
with HTTPConnectionPool(self.host, self.port) as pool:
retries = Retry(read=1)
pool.urlopen(
"GET", "/", chunked=True, preload_content=False, retries=retries
)
assert self.chunked_requests == 2