From c8bb193967f1ca6ecaf52f5eef64ae40da71a1fe Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 11:05:32 -0700 Subject: [PATCH 01/17] Don't bother checking Python version, just force it to bytes. --- wsproto/extensions.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/wsproto/extensions.py b/wsproto/extensions.py index aaff2f1..6b540fa 100644 --- a/wsproto/extensions.py +++ b/wsproto/extensions.py @@ -8,7 +8,6 @@ import zlib -from .compat import PY2 from .frame_protocol import CloseReason, Opcode, RsvBits @@ -163,11 +162,8 @@ def frame_inbound_payload_data(self, proto, data): if not self._inbound_compressed or not self._inbound_is_compressible: return data - if PY2: - data = str(data) - try: - return self._decompressor.decompress(data) + return self._decompressor.decompress(bytes(data)) except zlib.error: return CloseReason.INVALID_FRAME_PAYLOAD_DATA @@ -213,9 +209,7 @@ def frame_outbound(self, proto, opcode, rsv, data, fin): self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -bits) - if PY2: - data = str(data) - data = self._compressor.compress(data) + data = self._compressor.compress(bytes(data)) if fin: data += self._compressor.flush(zlib.Z_SYNC_FLUSH) From ea11eb96b82258f3c3c20641b7a640ebe41c720e Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 11:05:57 -0700 Subject: [PATCH 02/17] Use a slightly better default parameter pattern. --- wsproto/extensions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/wsproto/extensions.py b/wsproto/extensions.py index 6b540fa..9d1f720 100644 --- a/wsproto/extensions.py +++ b/wsproto/extensions.py @@ -39,12 +39,19 @@ def frame_outbound(self, proto, opcode, rsv, data, fin): class PerMessageDeflate(Extension): name = 'permessage-deflate' + DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 + DEFAULT_SERVER_MAX_WINDOW_BITS = 15 + def __init__(self, client_no_context_takeover=False, - client_max_window_bits=15, server_no_context_takeover=False, - server_max_window_bits=15): + client_max_window_bits=None, server_no_context_takeover=False, + server_max_window_bits=None): self.client_no_context_takeover = client_no_context_takeover + if client_max_window_bits is None: + client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS self.client_max_window_bits = client_max_window_bits self.server_no_context_takeover = server_no_context_takeover + if server_max_window_bits is None: + server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS self.server_max_window_bits = server_max_window_bits self._compressor = None From 4b2db763b2a88876f185b61d8f830a188f9b189d Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 11:06:16 -0700 Subject: [PATCH 03/17] Move a conditional. --- wsproto/extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wsproto/extensions.py b/wsproto/extensions.py index 9d1f720..d9ce8f5 100644 --- a/wsproto/extensions.py +++ b/wsproto/extensions.py @@ -177,10 +177,10 @@ def frame_inbound_payload_data(self, proto, data): def frame_inbound_complete(self, proto, fin): if not fin: return - elif not self._inbound_compressed: - return elif not self._inbound_is_compressible: return + elif not self._inbound_compressed: + return try: data = self._decompressor.decompress(b'\x00\x00\xff\xff') From 86388ee22a052ade77c80a719af8240ffe9337b1 Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 11:06:33 -0700 Subject: [PATCH 04/17] Get wsproto.extensions up to 100% test coverage. --- test/test_extensions.py | 36 +++ test/test_permessage_deflate.py | 448 ++++++++++++++++++++++++++++++++ 2 files changed, 484 insertions(+) create mode 100644 test/test_extensions.py create mode 100644 test/test_permessage_deflate.py diff --git a/test/test_extensions.py b/test/test_extensions.py new file mode 100644 index 0000000..0e65774 --- /dev/null +++ b/test/test_extensions.py @@ -0,0 +1,36 @@ +import wsproto.extensions as wpext +import wsproto.frame_protocol as fp + + +class TestExtension(object): + def test_enabled(self): + ext = wpext.Extension() + assert not ext.enabled() + + def test_offer(self): + ext = wpext.Extension() + assert ext.offer(None) is None + + def test_accept(self): + ext = wpext.Extension() + assert ext.accept(None, None) is None + + def test_frame_inbound_header(self): + ext = wpext.Extension() + result = ext.frame_inbound_header(None, None, None, None) + assert result == fp.RsvBits(False, False, False) + + def test_frame_inbound_payload_data(self): + ext = wpext.Extension() + data = object() + assert ext.frame_inbound_payload_data(None, data) == data + + def test_frame_inbound_complete(self): + ext = wpext.Extension() + assert ext.frame_inbound_complete(None, None) is None + + def test_frame_outbound(self): + ext = wpext.Extension() + rsv = fp.RsvBits(True, True, True) + data = object() + assert ext.frame_outbound(None, None, rsv, data, None) == (rsv, data) diff --git a/test/test_permessage_deflate.py b/test/test_permessage_deflate.py new file mode 100644 index 0000000..e1054d1 --- /dev/null +++ b/test/test_permessage_deflate.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +import zlib + +import pytest + +import wsproto.extensions as wpext +import wsproto.frame_protocol as fp + + +class TestPerMessageDeflate(object): + parameter_sets = [ + { + 'client_no_context_takeover': False, + 'client_max_window_bits': 15, + 'server_no_context_takeover': False, + 'server_max_window_bits': 15, + }, + { + 'client_no_context_takeover': True, + 'client_max_window_bits': 9, + 'server_no_context_takeover': False, + 'server_max_window_bits': 15, + }, + { + 'client_no_context_takeover': False, + 'client_max_window_bits': 15, + 'server_no_context_takeover': True, + 'server_max_window_bits': 9, + }, + { + 'client_no_context_takeover': True, + 'client_max_window_bits': 8, + 'server_no_context_takeover': True, + 'server_max_window_bits': 9, + }, + { + 'client_no_context_takeover': True, + 'server_max_window_bits': 9, + }, + { + 'server_no_context_takeover': True, + 'client_max_window_bits': 8, + }, + { + 'client_max_window_bits': None, + 'server_max_window_bits': None, + }, + {}, + ] + + def make_offer_string(self, params): + offer = ['permessage-deflate'] + + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + offer.append('client_max_window_bits') + else: + offer.append('client_max_window_bits=%d' % + params['client_max_window_bits']) + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + offer.append('server_max_window_bits') + else: + offer.append('server_max_window_bits=%d' % + params['server_max_window_bits']) + if params.get('client_no_context_takeover', False): + offer.append('client_no_context_takeover') + if params.get('server_no_context_takeover', False): + offer.append('server_no_context_takeover') + + return '; '.join(offer) + + def compare_params_to_string(self, params, ext, param_string): + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + bits = ext.client_max_window_bits + else: + bits = params['client_max_window_bits'] + assert 'client_max_window_bits=%d' % bits in param_string + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + bits = ext.server_max_window_bits + else: + bits = params['server_max_window_bits'] + assert 'server_max_window_bits=%d' % bits in param_string + if params.get('client_no_context_takeover', False): + assert 'client_no_context_takeover' in param_string + if params.get('server_no_context_takeover', False): + assert 'server_no_context_takeover' in param_string + + @pytest.mark.parametrize('params', parameter_sets) + def test_offer(self, params): + ext = wpext.PerMessageDeflate(**params) + offer = ext.offer(None) + + self.compare_params_to_string(params, ext, offer) + + @pytest.mark.parametrize('params', parameter_sets) + def test_finalize(self, params): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + params = dict(params) + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + del params['client_max_window_bits'] + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + del params['server_max_window_bits'] + offer = self.make_offer_string(params) + ext.finalize(None, offer) + + if params.get('client_max_window_bits', None): + assert ext.client_max_window_bits == \ + params['client_max_window_bits'] + if params.get('server_max_window_bits', None): + assert ext.server_max_window_bits == \ + params['server_max_window_bits'] + assert ext.client_no_context_takeover is \ + params.get('client_no_context_takeover', False) + assert ext.server_no_context_takeover is \ + params.get('server_no_context_takeover', False) + + assert ext.enabled() + + def test_finalize_ignores_rubbish(self): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + ext.finalize(None, 'i am the lizard queen; worship me') + + assert ext.enabled() + + @pytest.mark.parametrize('params', parameter_sets) + def test_accept(self, params): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + offer = self.make_offer_string(params) + print(repr(offer)) + + response = ext.accept(None, offer) + print(repr(response)) + + if ext.client_no_context_takeover: + assert 'client_no_context_takeover' in response + if ext.server_no_context_takeover: + assert 'server_no_context_takeover' in response + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + bits = ext.client_max_window_bits + else: + bits = params['client_max_window_bits'] + assert ext.client_max_window_bits == bits + assert 'client_max_window_bits=%d' % bits in response + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + bits = ext.server_max_window_bits + else: + bits = params['server_max_window_bits'] + assert ext.server_max_window_bits == bits + assert 'server_max_window_bits=%d' % bits in response + + def test_accept_ignores_rubbish(self): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + ext.accept(None, 'i am the lizard queen; worship me') + + assert ext.enabled() + + def test_inbound_uncompressed_control_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.PING, + fp.RsvBits(False, False, False), + len(payload)) + assert result.rsv1 + + data = ext.frame_inbound_payload_data(proto, payload) + assert data == payload + + assert ext.frame_inbound_complete(proto, True) is None + + def test_inbound_compressed_control_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.PING, + fp.RsvBits(True, False, False), + len(payload)) + assert result == fp.CloseReason.PROTOCOL_ERROR + + def test_inbound_compressed_continuation_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.CONTINUATION, + fp.RsvBits(True, False, False), + len(payload)) + assert result == fp.CloseReason.PROTOCOL_ERROR + + def test_inbound_uncompressed_data_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(False, False, False), + len(payload)) + assert result.rsv1 + + data = ext.frame_inbound_payload_data(proto, payload) + assert data == payload + + assert ext.frame_inbound_complete(proto, True) is None + + @pytest.mark.parametrize('client', [True, False]) + def test_client_inbound_compressed_single_data_frame(self, client): + payload = b'x' * 23 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + len(compressed_payload)) + assert result.rsv1 + + data = ext.frame_inbound_payload_data(proto, compressed_payload) + data += ext.frame_inbound_complete(proto, True) + assert data == payload + + @pytest.mark.parametrize('client', [True, False]) + def test_client_inbound_compressed_multiple_data_frames(self, client): + payload = b'x' * 23 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + split = 3 + data = b'' + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + split) + assert result.rsv1 + result = ext.frame_inbound_payload_data(proto, + compressed_payload[:split]) + assert not isinstance(result, fp.CloseReason) + data += result + assert ext.frame_inbound_complete(proto, False) is None + + result = ext.frame_inbound_header(proto, fp.Opcode.CONTINUATION, + fp.RsvBits(False, False, False), + len(compressed_payload) - split) + assert result.rsv1 + result = ext.frame_inbound_payload_data(proto, + compressed_payload[split:]) + assert not isinstance(result, fp.CloseReason) + data += result + + result = ext.frame_inbound_complete(proto, True) + assert not isinstance(result, fp.CloseReason) + data += result + + assert data == payload + + def test_inbound_bad_zlib_payload(self): + compressed_payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + len(compressed_payload)) + assert result.rsv1 + result = ext.frame_inbound_payload_data(proto, compressed_payload) + assert result is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA + + def test_inbound_bad_zlib_decoder_end_state(self, monkeypatch): + compressed_payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + len(compressed_payload)) + assert result.rsv1 + + class FailDecompressor(object): + def decompress(self, data): + return b'' + + def flush(self): + raise zlib.error() + + monkeypatch.setattr(ext, '_decompressor', FailDecompressor()) + + result = ext.frame_inbound_complete(proto, True) + assert result is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA + + @pytest.mark.parametrize('client,no_context_takeover', [ + (True, True), + (True, False), + (False, True), + (False, False), + ]) + def test_decompressor_reset(self, client, no_context_takeover): + if client: + args = {'server_no_context_takeover': no_context_takeover} + else: + args = {'client_no_context_takeover': no_context_takeover} + ext = wpext.PerMessageDeflate(**args) + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), 0) + assert result.rsv1 + + assert ext._decompressor is not None + + result = ext.frame_inbound_complete(proto, True) + assert not isinstance(result, fp.CloseReason) + + if no_context_takeover: + assert ext._decompressor is None + else: + assert ext._decompressor is not None + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), 0) + assert result.rsv1 + + assert ext._decompressor is not None + + def test_outbound_uncompressible_opcode(self): + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + rsv = fp.RsvBits(False, False, False) + payload = b'x' * 23 + + rsv, data = ext.frame_outbound(proto, fp.Opcode.PING, rsv, payload, + True) + + assert rsv.rsv1 is False + assert data == payload + + @pytest.mark.parametrize('client', [True, False]) + def test_outbound_compress_single_frame(self, client): + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + rsv = fp.RsvBits(False, False, False) + payload = b'x' * 23 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, payload, + True) + + assert rsv.rsv1 is True + assert data == compressed_payload + + @pytest.mark.parametrize('client', [True, False]) + def test_outbound_compress_multiple_frames(self, client): + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + rsv = fp.RsvBits(False, False, False) + payload = b'x' * 23 + split = 12 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, + payload[:split], False) + assert rsv.rsv1 is True + + rsv = fp.RsvBits(False, False, False) + rsv, more_data = ext.frame_outbound(proto, fp.Opcode.CONTINUATION, rsv, + payload[split:], True) + assert rsv.rsv1 is False + assert data + more_data == compressed_payload + + @pytest.mark.parametrize('client,no_context_takeover', [ + (True, True), + (True, False), + (False, True), + (False, False), + ]) + def test_compressor_reset(self, client, no_context_takeover): + if client: + args = {'client_no_context_takeover': no_context_takeover} + else: + args = {'server_no_context_takeover': no_context_takeover} + ext = wpext.PerMessageDeflate(**args) + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + rsv = fp.RsvBits(False, False, False) + + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, b'', + False) + assert rsv.rsv1 is True + assert ext._compressor is not None + + rsv = fp.RsvBits(False, False, False) + rsv, data = ext.frame_outbound(proto, fp.Opcode.CONTINUATION, rsv, b'', + True) + assert rsv.rsv1 is False + if no_context_takeover: + assert ext._compressor is None + else: + assert ext._compressor is not None + + rsv = fp.RsvBits(False, False, False) + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, b'', + False) + assert rsv.rsv1 is True + assert ext._compressor is not None + + @pytest.mark.parametrize('params', parameter_sets) + def test_repr(self, params): + ext = wpext.PerMessageDeflate(**params) + self.compare_params_to_string(params, ext, repr(ext)) From 25e8924bdcd96c40787211ddf3013876e0b52fed Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 14:43:52 -0700 Subject: [PATCH 05/17] Make finalize a formal part of the Extension API. --- test/test_extensions.py | 4 ++++ wsproto/extensions.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_extensions.py b/test/test_extensions.py index 0e65774..5f41806 100644 --- a/test/test_extensions.py +++ b/test/test_extensions.py @@ -15,6 +15,10 @@ def test_accept(self): ext = wpext.Extension() assert ext.accept(None, None) is None + def test_finalize(self): + ext = wpext.Extension() + assert ext.finalize(None, None) is None + def test_frame_inbound_header(self): ext = wpext.Extension() result = ext.frame_inbound_header(None, None, None, None) diff --git a/wsproto/extensions.py b/wsproto/extensions.py index d9ce8f5..ad7daa1 100644 --- a/wsproto/extensions.py +++ b/wsproto/extensions.py @@ -23,6 +23,9 @@ def offer(self, connection): def accept(self, connection, offer): return None + def finalize(self, connection, offer): + return None + def frame_inbound_header(self, proto, opcode, rsv, payload_length): return RsvBits(False, False, False) @@ -30,7 +33,7 @@ def frame_inbound_payload_data(self, proto, data): return data def frame_inbound_complete(self, proto, fin): - pass + return None def frame_outbound(self, proto, opcode, rsv, data, fin): return (rsv, data) From 62b272940ad1297473f86eca4d47ebd190a4d08d Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 14:44:38 -0700 Subject: [PATCH 06/17] Standardise on the wsproto.compat.PY3 constant. --- test/test_upgrade.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/test_upgrade.py b/test/test_upgrade.py index 8190b47..0bfb7ab 100644 --- a/test/test_upgrade.py +++ b/test/test_upgrade.py @@ -6,19 +6,16 @@ import base64 import email import random -import sys +from wsproto.compat import PY3 from wsproto.connection import WSConnection, CLIENT, SERVER from wsproto.events import ( ConnectionEstablished, ConnectionFailed, ConnectionRequested ) -IS_PYTHON3 = sys.version_info >= (3, 0) - - def parse_headers(headers): - if IS_PYTHON3: + if PY3: headers = email.message_from_bytes(headers) else: headers = email.message_from_string(headers) From f16248d9ef496b3b7da2d087ea3be7435a850ffc Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 14:45:32 -0700 Subject: [PATCH 07/17] Make sure we're decoding subprotocol and extension headers properly. --- wsproto/connection.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/wsproto/connection.py b/wsproto/connection.py index 9c4aaec..c8a99a6 100644 --- a/wsproto/connection.py +++ b/wsproto/connection.py @@ -66,7 +66,7 @@ def _normed_header_dict(h11_headers): # wrong, because those can contain quoted strings, which can in turn contain # commas. XX FIXME def _split_comma_header(value): - return [piece.strip() for piece in value.split(b',')] + return [piece.decode('ascii').strip() for piece in value.split(b',')] class WSConnection(object): @@ -320,6 +320,7 @@ def _establish_client_connection(self, event): subprotocol = headers.get(b'sec-websocket-protocol', None) if subprotocol is not None: + subprotocol = subprotocol.decode('ascii') if subprotocol not in self.subprotocols: return ConnectionFailed(CloseReason.PROTOCOL_ERROR, "unrecognized subprotocol {!r}" @@ -330,7 +331,6 @@ def _establish_client_connection(self, event): accepts = _split_comma_header(extensions) for accept in accepts: - accept = accept.decode('ascii') name = accept.split(';', 1)[0].strip() for extension in self.extensions: if extension.name == name: @@ -400,7 +400,6 @@ def accept(self, event, subprotocol=None): offers = _split_comma_header(extensions) for offer in offers: - offer = offer.decode('ascii') name = offer.split(';', 1)[0].strip() for extension in self.extensions: if extension.name == name: From 14b55a89bd8f2aba6216c905a8881f1fa9b129a9 Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 14:46:01 -0700 Subject: [PATCH 08/17] Fix up some corner cases in upgrade message handling. --- wsproto/connection.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/wsproto/connection.py b/wsproto/connection.py index c8a99a6..6062bf9 100644 --- a/wsproto/connection.py +++ b/wsproto/connection.py @@ -239,11 +239,15 @@ def _process_upgrade(self, data): event = self._upgrade_connection.next_event() if event is h11.NEED_DATA: break - elif self.client and isinstance(event, h11.InformationalResponse): + elif self.client and isinstance(event, (h11.InformationalResponse, + h11.Response)): data = self._upgrade_connection.trailing_data[0] return self._establish_client_connection(event), data elif not self.client and isinstance(event, h11.Request): return self._process_connection_request(event), None + else: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message") self._incoming = b'' return None, None From 512bada892f837cf016209f7cf09039e9d9cbd2a Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 14:46:23 -0700 Subject: [PATCH 09/17] Whitespace. --- wsproto/connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wsproto/connection.py b/wsproto/connection.py index 6062bf9..3309ce1 100644 --- a/wsproto/connection.py +++ b/wsproto/connection.py @@ -268,7 +268,6 @@ def events(self): try: for frame in self._proto.received_frames(): - if frame.opcode is Opcode.PING: assert frame.frame_finished and frame.message_finished self._outgoing += self._proto.pong(frame.payload) From 497b48e6bd8919b613f697f811c9b23243f71c51 Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 16:14:00 -0700 Subject: [PATCH 10/17] Be more careful in the http upgrade path. --- wsproto/connection.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/wsproto/connection.py b/wsproto/connection.py index 3309ce1..c9f709d 100644 --- a/wsproto/connection.py +++ b/wsproto/connection.py @@ -236,7 +236,11 @@ def receive_bytes(self, data): def _process_upgrade(self, data): self._upgrade_connection.receive_data(data) while True: - event = self._upgrade_connection.next_event() + try: + event = self._upgrade_connection.next_event() + except h11.RemoteProtocolError: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message"), b'' if event is h11.NEED_DATA: break elif self.client and isinstance(event, (h11.InformationalResponse, @@ -247,7 +251,7 @@ def _process_upgrade(self, data): return self._process_connection_request(event), None else: return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad HTTP message") + "Bad HTTP message"), b'' self._incoming = b'' return None, None From 31bc6f51bb56a93171b23f2df71893ee704488a6 Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 14:46:32 -0700 Subject: [PATCH 11/17] Fully test the upgrade path, both server and client. --- test/test_upgrade.py | 554 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 554 insertions(+) diff --git a/test/test_upgrade.py b/test/test_upgrade.py index 0bfb7ab..163bdfe 100644 --- a/test/test_upgrade.py +++ b/test/test_upgrade.py @@ -7,11 +7,14 @@ import email import random +import pytest + from wsproto.compat import PY3 from wsproto.connection import WSConnection, CLIENT, SERVER from wsproto.events import ( ConnectionEstablished, ConnectionFailed, ConnectionRequested ) +from wsproto.extensions import Extension def parse_headers(headers): @@ -23,6 +26,26 @@ def parse_headers(headers): return dict(headers.items()) +class FakeExtension(Extension): + name = 'fake' + + def __init__(self, offer_response=None, accept_response=None): + self.offer_response = offer_response + self.accepted_offer = None + self.offered = None + self.accept_response = accept_response + + def offer(self, proto): + return self.offer_response + + def finalize(self, proto, offer): + self.accepted_offer = offer + + def accept(self, proto, offer): + self.offered = offer + return self.accept_response + + class TestClientUpgrade(object): def initiate(self, host, path, **kwargs): ws = WSConnection(CLIENT, host, path, **kwargs) @@ -130,6 +153,206 @@ def test_bad_upgrade_header(self): ws.receive_bytes(response) assert isinstance(next(ws.events()), ConnectionFailed) + def test_simple_extension_offer(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + assert _ext.name == headers['sec-websocket-extensions'] + + def test_simple_extension_non_offer(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=False) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + assert 'sec-websocket-extensions' not in headers + + def test_extension_offer_with_params(self): + ext_parameters = 'parameter1=value1; parameter2=value2' + _ext = FakeExtension(offer_response=ext_parameters) + + _host = 'frob.nitz' + _path = '/fnord' + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + assert headers['sec-websocket-extensions'] == \ + '%s; %s' % (_ext.name, ext_parameters) + + def test_simple_extension_accept(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Extensions: " + \ + _ext.name.encode('ascii') + b"\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionEstablished) + assert _ext.name in _ext.accepted_offer + + def test_extension_accept_with_parameters(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + ext_parameters = 'parameter1=value1; parameter2=value2' + extensions = _ext.name + '; ' + ext_parameters + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Extensions: " + \ + extensions.encode('ascii') + b"\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionEstablished) + assert _ext.accepted_offer == extensions + + def test_accept_an_extension_we_do_not_recognise(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Extensions: pretend\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionFailed) + + def test_wrong_status_code_in_response(self): + _host = 'frob.nitz' + _path = '/fnord' + + ws, method, path, version, headers = self.initiate(_host, _path) + + response = b"HTTP/1.1 200 OK\r\n" + response += b"Server: SimpleHTTP/0.6 Python/3.6.1\r\n" + response += b"Date: Fri, 02 Jun 2017 20:40:39 GMT\r\n" + response += b"Content-type: application/octet-stream\r\n" + response += b"Content-Length: 0\r\n" + response += b"Last-Modified: Fri, 02 Jun 2017 20:40:00 GMT\r\n" + response += b"Connection: close\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionFailed) + + def test_response_takes_a_few_goes(self): + _host = 'frob.nitz' + _path = '/fnord' + + ws, method, path, version, headers = self.initiate(_host, _path) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"\r\n" + + split = len(response) // 2 + + ws.receive_bytes(response[:split]) + with pytest.raises(StopIteration): + next(ws.events()) + + ws.receive_bytes(response[split:]) + assert isinstance(next(ws.events()), ConnectionEstablished) + + def test_subprotocol_offer(self): + _host = 'frob.nitz' + _path = '/fnord' + subprotocols = ['one', 'two'] + + ws, method, path, version, headers = \ + self.initiate(_host, _path, subprotocols=subprotocols) + + for subprotocol in subprotocols: + assert subprotocol in headers['sec-websocket-protocol'] + + def test_subprotocol_accept(self): + _host = 'frob.nitz' + _path = '/fnord' + subprotocols = ['one', 'two'] + + ws, method, path, version, headers = \ + self.initiate(_host, _path, subprotocols=subprotocols) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Protocol: " + \ + subprotocols[0].encode('ascii') + b"\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + event = next(ws.events()) + assert isinstance(event, ConnectionEstablished) + assert event.subprotocol == subprotocols[0] + + def test_subprotocol_accept_unoffered(self): + _host = 'frob.nitz' + _path = '/fnord' + subprotocols = ['one', 'two'] + + ws, method, path, version, headers = \ + self.initiate(_host, _path, subprotocols=subprotocols) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Protocol: three\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionFailed) + class TestServerUpgrade(object): def test_correct_request(self): @@ -165,3 +388,334 @@ def test_correct_request(self): assert headers['connection'].lower() == 'upgrade' assert headers['upgrade'].lower() == 'websocket' assert headers['sec-websocket-accept'] == accept_token.decode('ascii') + + def test_wrong_method(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'POST ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_bad_connection(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Zoinks\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_bad_upgrade(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebPocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_missing_version(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_missing_key(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_subprotocol_offers(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Protocol: one, two\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + assert event.proposed_subprotocols == ['one', 'two'] + + def test_accept_subprotocol(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Protocol: one, two\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + assert event.proposed_subprotocols == ['one', 'two'] + + ws.accept(event, 'two') + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert int(code) == 101 + assert headers['sec-websocket-protocol'] == 'two' + + def test_accept_wrong_subprotocol(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Protocol: one, two\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + assert event.proposed_subprotocols == ['one', 'two'] + + with pytest.raises(ValueError): + ws.accept(event, 'three') + + def test_simple_extension_negotiation(self): + test_host = 'frob.nitz' + test_path = '/fnord' + ext = FakeExtension(accept_response=True) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: ' + \ + ext.name.encode('ascii') + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert ext.offered == ext.name + assert headers['sec-websocket-extensions'] == ext.name + + def test_extension_negotiation_with_our_parameters(self): + test_host = 'frob.nitz' + test_path = '/fnord' + offered_params = 'parameter1=value3; parameter2=value4' + ext_params = 'parameter1=value1; parameter2=value2' + ext = FakeExtension(accept_response=ext_params) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: ' + \ + ext.name.encode('ascii') + b'; ' + \ + offered_params.encode('ascii') + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert ext.offered == '%s; %s' % (ext.name, offered_params) + assert headers['sec-websocket-extensions'] == \ + '%s; %s' % (ext.name, ext_params) + + def test_disinterested_extension_negotiation(self): + test_host = 'frob.nitz' + test_path = '/fnord' + ext = FakeExtension(accept_response=False) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: ' + \ + ext.name.encode('ascii') + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert ext.offered == ext.name + assert 'sec-websocket-extensions' not in headers + + def test_unwanted_extension_negotiation(self): + test_host = 'frob.nitz' + test_path = '/fnord' + ext = FakeExtension(accept_response=False) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: pretend\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert 'sec-websocket-extensions' not in headers + + def test_not_an_http_request_at_all(self): + ws = WSConnection(SERVER) + + request = b'Good god, what is this?\r\n\r\n' + + ws.receive_bytes(request) + assert isinstance(next(ws.events()), ConnectionFailed) + + def test_h11_somehow_loses_its_mind(self): + ws = WSConnection(SERVER) + ws._upgrade_connection.next_event = lambda: object() + + ws.receive_bytes(b'') + assert isinstance(next(ws.events()), ConnectionFailed) From 416d65e9a580493a2f0c5846d25a0bd79d0c6db4 Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 16:14:44 -0700 Subject: [PATCH 12/17] Test the rest of the connection mechanics. --- test/test_connection.py | 242 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 test/test_connection.py diff --git a/test/test_connection.py b/test/test_connection.py new file mode 100644 index 0000000..ffbd058 --- /dev/null +++ b/test/test_connection.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- + +import pytest + +from wsproto.connection import WSConnection, CLIENT, SERVER, ConnectionState +from wsproto.events import (ConnectionClosed, TextReceived, BytesReceived) +from wsproto.frame_protocol import CloseReason, FrameProtocol + + +class FakeProtocol(object): + def __init__(self): + self.send_data_response = None + self.close_response = None + self.received_frames_response = [] + + self.send_data_payload = None + self.send_data_final = None + self.close_code = None + self.close_reason = None + self.receive_bytes_bytes = None + + def send_data(self, payload, final): + self.send_data_payload = payload + self.send_data_final = final + return self.send_data_response + + def close(self, code, reason): + self.close_code = code + self.close_reason = reason + return self.close_response + + def receive_bytes(self, data): + self.receive_bytes_bytes = data + + def received_frames(self): + return self.received_frames_response + + +class TestConnection(object): + @pytest.mark.parametrize('final', [True, False]) + def test_send_data(self, final): + data = b'x' * 23 + payload = b'y' * 23 + + proto = FakeProtocol() + proto.send_data_response = payload + + connection = WSConnection(SERVER) + connection._proto = proto + connection.send_data(data, final) + + assert proto.send_data_payload == data + assert proto.send_data_final is final + assert connection.bytes_to_send() == payload + + @pytest.mark.parametrize('code,reason', [ + (CloseReason.NORMAL_CLOSURE, u'bye'), + (CloseReason.GOING_AWAY, u'๐Ÿ‘‹๐Ÿ‘‹'), + ]) + def test_close(self, code, reason): + payload = b'y' * 23 + + proto = FakeProtocol() + proto.close_response = payload + + connection = WSConnection(SERVER) + connection._proto = proto + connection.close(code, reason) + + assert proto.close_code is code + assert proto.close_reason == reason + assert connection.bytes_to_send() == payload + + def test_normal_closure(self): + payload = b'y' * 23 + + proto = FakeProtocol() + proto.close_response = payload + + connection = WSConnection(SERVER) + connection._proto = proto + connection.close() + + connection.bytes_to_send() + connection.receive_bytes(None) + with pytest.raises(StopIteration): + next(connection.events()) + assert connection.closed + + def test_abnormal_closure(self): + payload = b'y' * 23 + + proto = FakeProtocol() + proto.close_response = payload + + connection = WSConnection(SERVER) + connection._proto = proto + connection._state = ConnectionState.OPEN + + connection.receive_bytes(None) + assert isinstance(next(connection.events()), ConnectionClosed) + assert connection.closed + + def test_bytes_send_all(self): + connection = WSConnection(SERVER) + connection._outgoing = b'fnord fnord' + assert connection.bytes_to_send() == b'fnord fnord' + assert connection.bytes_to_send() == b'' + + def test_bytes_send_some(self): + connection = WSConnection(SERVER) + connection._outgoing = b'fnord fnord' + assert connection.bytes_to_send(5) == b'fnord' + assert connection.bytes_to_send() == b' fnord' + + def test_receive_bytes(self): + payload = b'y' * 23 + + proto = FakeProtocol() + + connection = WSConnection(SERVER) + connection._proto = proto + connection._state = ConnectionState.OPEN + + connection.receive_bytes(payload) + assert proto.receive_bytes_bytes == payload + + def test_events_ping(self): + payload = b'x' * 23 + frame = b'\x89' + bytearray([len(payload)]) + payload + + connection = WSConnection(CLIENT, host='localhost', resource='foo') + connection._proto = FrameProtocol(True, []) + connection._state = ConnectionState.OPEN + connection.bytes_to_send() + + connection.receive_bytes(frame) + with pytest.raises(StopIteration): + next(connection.events()) + output = connection.bytes_to_send() + assert output[:2] == b'\x8a' + bytearray([len(payload) | 0x80]) + + def test_events_close(self): + payload = b'\x03\xe8' + b'x' * 23 + frame = b'\x88' + bytearray([len(payload)]) + payload + + connection = WSConnection(CLIENT, host='localhost', resource='foo') + connection._proto = FrameProtocol(True, []) + connection._state = ConnectionState.OPEN + connection.bytes_to_send() + + connection.receive_bytes(frame) + event = next(connection.events()) + assert isinstance(event, ConnectionClosed) + assert event.code == CloseReason.NORMAL_CLOSURE + assert event.reason == payload[2:].decode('utf8') + + output = connection.bytes_to_send() + assert output[:2] == b'\x88' + bytearray([len(payload) | 0x80]) + + @pytest.mark.parametrize('text,payload,full_message,full_frame', [ + (True, u'ฦ’รฑรถยฎโˆ‚๐Ÿ˜Ž', True, True), + (True, u'ฦ’รฑรถยฎโˆ‚๐Ÿ˜Ž', False, True), + (True, u'ฦ’รฑรถยฎโˆ‚๐Ÿ˜Ž', False, False), + (False, b'x' * 23, True, True), + (False, b'x' * 23, False, True), + (False, b'x' * 23, False, False), + ]) + def test_data_events(self, text, payload, full_message, full_frame): + if text: + opcode = 0x01 + encoded_payload = payload.encode('utf8') + else: + opcode = 0x02 + encoded_payload = payload + + if full_message: + opcode = bytearray([opcode | 0x80]) + else: + opcode = bytearray([opcode]) + + if full_frame: + length = bytearray([len(encoded_payload)]) + else: + length = bytearray([len(encoded_payload) + 100]) + + frame = opcode + length + encoded_payload + + connection = WSConnection(CLIENT, host='localhost', resource='foo') + connection._proto = FrameProtocol(True, []) + connection._state = ConnectionState.OPEN + connection.bytes_to_send() + + connection.receive_bytes(frame) + event = next(connection.events()) + if text: + assert isinstance(event, TextReceived) + else: + assert isinstance(event, BytesReceived) + assert event.data == payload + assert event.frame_finished is full_frame + assert event.message_finished is full_message + + assert not connection.bytes_to_send() + + def test_frame_protocol_somehow_loses_its_mind(self): + class FailFrame(object): + opcode = object() + + class DoomProtocol(object): + def receive_bytes(self, data): + return None + + def received_frames(self): + return [FailFrame()] + + connection = WSConnection(CLIENT, host='localhost', resource='foo') + connection._proto = DoomProtocol() + connection._state = ConnectionState.OPEN + connection.bytes_to_send() + + connection.receive_bytes(b'') + with pytest.raises(StopIteration): + next(connection.events()) + assert not connection.bytes_to_send() + + def test_frame_protocol_gets_fed_garbage(self): + payload = b'x' * 23 + frame = b'\x09' + bytearray([len(payload)]) + payload + + connection = WSConnection(CLIENT, host='localhost', resource='foo') + connection._proto = FrameProtocol(True, []) + connection._state = ConnectionState.OPEN + connection.bytes_to_send() + + connection.receive_bytes(frame) + event = next(connection.events()) + assert isinstance(event, ConnectionClosed) + assert event.code == CloseReason.PROTOCOL_ERROR + + output = connection.bytes_to_send() + assert output[:1] == b'\x88' From 4628adc38789f52e8e2ef0cdf600b9fbed7b30ab Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Fri, 2 Jun 2017 16:27:24 -0700 Subject: [PATCH 13/17] Test events (really event __repr__) --- test/test_events.py | 80 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 test/test_events.py diff --git a/test/test_events.py b/test/test_events.py new file mode 100644 index 0000000..6b280ba --- /dev/null +++ b/test/test_events.py @@ -0,0 +1,80 @@ +import pytest + +from h11 import Request + +from wsproto.events import ( + ConnectionClosed, + ConnectionEstablished, + ConnectionRequested, +) +from wsproto.frame_protocol import CloseReason + + +def test_connection_requested_repr_no_subprotocol(): + method = b'GET' + target = b'/foo' + headers = { + b'host': b'localhost', + b'sec-websocket-version': b'13', + } + http_version = b'1.1' + + req = Request(method=method, target=target, headers=list(headers.items()), + http_version=http_version) + + event = ConnectionRequested([], req) + r = repr(event) + + assert 'ConnectionRequested' in r + assert target.decode('ascii') in r + + +def test_connection_requested_repr_with_subprotocol(): + method = b'GET' + target = b'/foo' + headers = { + b'host': b'localhost', + b'sec-websocket-version': b'13', + b'sec-websocket-protocol': b'fnord', + } + http_version = b'1.1' + + req = Request(method=method, target=target, headers=list(headers.items()), + http_version=http_version) + + event = ConnectionRequested([], req) + r = repr(event) + + assert 'ConnectionRequested' in r + assert target.decode('ascii') in r + assert headers[b'sec-websocket-protocol'].decode('ascii') in r + + +@pytest.mark.parametrize('subprotocol,extensions', [ + ('sproto', None), + (None, ['fake']), + ('sprout', ['pretend']), +]) +def test_connection_established_repr(subprotocol, extensions): + event = ConnectionEstablished(subprotocol, extensions) + r = repr(event) + + if subprotocol: + assert subprotocol in r + if extensions: + for extension in extensions: + assert extension in r + + +@pytest.mark.parametrize('code,reason', [ + (CloseReason.NORMAL_CLOSURE, None), + (CloseReason.NORMAL_CLOSURE, 'because i felt like it'), + (CloseReason.INVALID_FRAME_PAYLOAD_DATA, 'GOOD GOD WHAT DID YOU DO'), +]) +def test_connection_closed_repr(code, reason): + event = ConnectionClosed(code, reason) + r = repr(event) + + assert repr(code) in r + if reason: + assert reason in r From b082068903874b56f6e5a2ffa50a2491612deccc Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Wed, 7 Jun 2017 14:58:13 -0400 Subject: [PATCH 14/17] Add ping methods. --- wsproto/connection.py | 7 +++++++ wsproto/frame_protocol.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/wsproto/connection.py b/wsproto/connection.py index c9f709d..0a36a39 100644 --- a/wsproto/connection.py +++ b/wsproto/connection.py @@ -433,3 +433,10 @@ def accept(self, event, subprotocol=None): self._outgoing += self._upgrade_connection.send(response) self._proto = FrameProtocol(self.client, self.extensions) self._state = ConnectionState.OPEN + + def ping(self, payload=None): + if payload is not None: + payload = bytes(payload) + else: + payload = b'' + self._outgoing += self._proto.ping(payload) diff --git a/wsproto/frame_protocol.py b/wsproto/frame_protocol.py index 9426826..e94c68f 100644 --- a/wsproto/frame_protocol.py +++ b/wsproto/frame_protocol.py @@ -492,6 +492,9 @@ def close(self, code=None, reason=None): return self._serialize_frame(Opcode.CLOSE, payload) + def ping(self, payload=b''): + return self._serialize_frame(Opcode.PING, payload) + def pong(self, payload=b''): return self._serialize_frame(Opcode.PONG, payload) From c31fb97a18e57b1685775d69c86c01d93b4a11db Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Wed, 7 Jun 2017 14:58:27 -0400 Subject: [PATCH 15/17] Test new ping methods. --- test/test_frame_protocol.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_frame_protocol.py b/test/test_frame_protocol.py index 0c25c07..d02397e 100644 --- a/test/test_frame_protocol.py +++ b/test/test_frame_protocol.py @@ -1117,6 +1117,17 @@ def test_local_only_close_reason(self): data = proto.close(code=fp.CloseReason.NO_STATUS_RCVD) assert data == b'\x88\x02\x03\xe8' + def test_ping_without_payload(self): + proto = fp.FrameProtocol(client=False, extensions=[]) + data = proto.ping() + assert data == b'\x89\x00' + + def test_ping_with_payload(self): + proto = fp.FrameProtocol(client=False, extensions=[]) + payload = u'ยฏ\_(ใƒ„)_/ยฏ'.encode('utf8') + data = proto.ping(payload) + assert data == b'\x89' + bytearray([len(payload)]) + payload + def test_pong_without_payload(self): proto = fp.FrameProtocol(client=False, extensions=[]) data = proto.pong() From 365e120dbeafb08014614121c544cb416bddfc13 Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Wed, 7 Jun 2017 14:58:47 -0400 Subject: [PATCH 16/17] Rewrite connection tests so we're not dealing with implementation details. --- test/test_connection.py | 236 +++++++++++++++++++--------------------- 1 file changed, 114 insertions(+), 122 deletions(-) diff --git a/test/test_connection.py b/test/test_connection.py index ffbd058..4844ed4 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,105 +1,105 @@ # -*- coding: utf-8 -*- +import itertools + import pytest from wsproto.connection import WSConnection, CLIENT, SERVER, ConnectionState -from wsproto.events import (ConnectionClosed, TextReceived, BytesReceived) +from wsproto.events import ( + ConnectionClosed, + ConnectionEstablished, + ConnectionRequested, + TextReceived, + BytesReceived, +) from wsproto.frame_protocol import CloseReason, FrameProtocol -class FakeProtocol(object): - def __init__(self): - self.send_data_response = None - self.close_response = None - self.received_frames_response = [] - - self.send_data_payload = None - self.send_data_final = None - self.close_code = None - self.close_reason = None - self.receive_bytes_bytes = None +class TestConnection(object): + def create_connection(self): + server = WSConnection(SERVER) + client = WSConnection(CLIENT, host='localhost', resource='foo') - def send_data(self, payload, final): - self.send_data_payload = payload - self.send_data_final = final - return self.send_data_response + server.receive_bytes(client.bytes_to_send()) + event = next(server.events()) + assert isinstance(event, ConnectionRequested) - def close(self, code, reason): - self.close_code = code - self.close_reason = reason - return self.close_response + server.accept(event) + client.receive_bytes(server.bytes_to_send()) + assert isinstance(next(client.events()), ConnectionEstablished) - def receive_bytes(self, data): - self.receive_bytes_bytes = data + return client, server - def received_frames(self): - return self.received_frames_response + def test_negotiation(self): + self.create_connection() + @pytest.mark.parametrize('as_client,final', [ + (True, True), + (True, False), + (False, True), + (False, False) + ]) + def test_send_and_receive(self, as_client, final): + client, server = self.create_connection() + if as_client: + me = client + them = server + else: + me = server + them = client -class TestConnection(object): - @pytest.mark.parametrize('final', [True, False]) - def test_send_data(self, final): data = b'x' * 23 - payload = b'y' * 23 - - proto = FakeProtocol() - proto.send_data_response = payload - connection = WSConnection(SERVER) - connection._proto = proto - connection.send_data(data, final) + me.send_data(data, final) + them.receive_bytes(me.bytes_to_send()) - assert proto.send_data_payload == data - assert proto.send_data_final is final - assert connection.bytes_to_send() == payload + event = next(them.events()) + assert isinstance(event, BytesReceived) + assert event.data == data + assert event.message_finished is final - @pytest.mark.parametrize('code,reason', [ - (CloseReason.NORMAL_CLOSURE, u'bye'), - (CloseReason.GOING_AWAY, u'๐Ÿ‘‹๐Ÿ‘‹'), + @pytest.mark.parametrize('as_client,code,reason', [ + (True, CloseReason.NORMAL_CLOSURE, u'bye'), + (True, CloseReason.GOING_AWAY, u'๐Ÿ‘‹๐Ÿ‘‹'), + (False, CloseReason.NORMAL_CLOSURE, u'bye'), + (False, CloseReason.GOING_AWAY, u'๐Ÿ‘‹๐Ÿ‘‹'), ]) - def test_close(self, code, reason): - payload = b'y' * 23 - - proto = FakeProtocol() - proto.close_response = payload + def test_close(self, as_client, code, reason): + client, server = self.create_connection() + if as_client: + me = client + them = server + else: + me = server + them = client - connection = WSConnection(SERVER) - connection._proto = proto - connection.close(code, reason) + me.close(code, reason) + them.receive_bytes(me.bytes_to_send()) - assert proto.close_code is code - assert proto.close_reason == reason - assert connection.bytes_to_send() == payload + event = next(them.events()) + assert isinstance(event, ConnectionClosed) + assert event.code is code + assert event.reason == reason def test_normal_closure(self): - payload = b'y' * 23 - - proto = FakeProtocol() - proto.close_response = payload - - connection = WSConnection(SERVER) - connection._proto = proto - connection.close() + client, server = self.create_connection() - connection.bytes_to_send() - connection.receive_bytes(None) - with pytest.raises(StopIteration): - next(connection.events()) - assert connection.closed + for conn in (client, server): + conn.close() + conn.receive_bytes(None) + with pytest.raises(StopIteration): + print(repr(next(conn.events()))) + assert conn.closed def test_abnormal_closure(self): - payload = b'y' * 23 - - proto = FakeProtocol() - proto.close_response = payload - - connection = WSConnection(SERVER) - connection._proto = proto - connection._state = ConnectionState.OPEN + client, server = self.create_connection() - connection.receive_bytes(None) - assert isinstance(next(connection.events()), ConnectionClosed) - assert connection.closed + for conn in (client, server): + conn.receive_bytes(None) + event = next(conn.events()) + assert isinstance(event, ConnectionClosed) + assert event.code is CloseReason.ABNORMAL_CLOSURE + assert conn.closed def test_bytes_send_all(self): connection = WSConnection(SERVER) @@ -113,50 +113,45 @@ def test_bytes_send_some(self): assert connection.bytes_to_send(5) == b'fnord' assert connection.bytes_to_send() == b' fnord' - def test_receive_bytes(self): - payload = b'y' * 23 - - proto = FakeProtocol() - - connection = WSConnection(SERVER) - connection._proto = proto - connection._state = ConnectionState.OPEN - - connection.receive_bytes(payload) - assert proto.receive_bytes_bytes == payload + @pytest.mark.parametrize('as_client', [True, False]) + def test_ping_pong(self, as_client): + client, server = self.create_connection() + if as_client: + me = client + them = server + else: + me = server + them = client - def test_events_ping(self): payload = b'x' * 23 - frame = b'\x89' + bytearray([len(payload)]) + payload - connection = WSConnection(CLIENT, host='localhost', resource='foo') - connection._proto = FrameProtocol(True, []) - connection._state = ConnectionState.OPEN - connection.bytes_to_send() + me.ping(payload) + wire_data = me.bytes_to_send() + assert wire_data[0] == 0x89 + masked = bool(wire_data[1] & 0x80) + assert wire_data[1] & ~0x80 == len(payload) + if masked: + maskbytes = itertools.cycle(bytearray(wire_data[2:6])) + data = bytearray(b ^ next(maskbytes) + for b in bytearray(wire_data[6:])) + else: + data = wire_data[2:] + assert data == payload - connection.receive_bytes(frame) + them.receive_bytes(wire_data) with pytest.raises(StopIteration): - next(connection.events()) - output = connection.bytes_to_send() - assert output[:2] == b'\x8a' + bytearray([len(payload) | 0x80]) - - def test_events_close(self): - payload = b'\x03\xe8' + b'x' * 23 - frame = b'\x88' + bytearray([len(payload)]) + payload - - connection = WSConnection(CLIENT, host='localhost', resource='foo') - connection._proto = FrameProtocol(True, []) - connection._state = ConnectionState.OPEN - connection.bytes_to_send() - - connection.receive_bytes(frame) - event = next(connection.events()) - assert isinstance(event, ConnectionClosed) - assert event.code == CloseReason.NORMAL_CLOSURE - assert event.reason == payload[2:].decode('utf8') - - output = connection.bytes_to_send() - assert output[:2] == b'\x88' + bytearray([len(payload) | 0x80]) + print(repr(next(them.events()))) + wire_data = them.bytes_to_send() + assert wire_data[0] == 0x8a + masked = bool(wire_data[1] & 0x80) + assert wire_data[1] & ~0x80 == len(payload) + if masked: + maskbytes = itertools.cycle(bytearray(wire_data[2:6])) + data = bytearray(b ^ next(maskbytes) + for b in bytearray(wire_data[6:])) + else: + data = wire_data[2:] + assert data == payload @pytest.mark.parametrize('text,payload,full_message,full_frame', [ (True, u'ฦ’รฑรถยฎโˆ‚๐Ÿ˜Ž', True, True), @@ -225,18 +220,15 @@ def received_frames(self): assert not connection.bytes_to_send() def test_frame_protocol_gets_fed_garbage(self): + client, server = self.create_connection() + payload = b'x' * 23 frame = b'\x09' + bytearray([len(payload)]) + payload - connection = WSConnection(CLIENT, host='localhost', resource='foo') - connection._proto = FrameProtocol(True, []) - connection._state = ConnectionState.OPEN - connection.bytes_to_send() - - connection.receive_bytes(frame) - event = next(connection.events()) + client.receive_bytes(frame) + event = next(client.events()) assert isinstance(event, ConnectionClosed) assert event.code == CloseReason.PROTOCOL_ERROR - output = connection.bytes_to_send() + output = client.bytes_to_send() assert output[:1] == b'\x88' From 3e933440fcdda1f6fc25b3d95f1872c4f13d1888 Mon Sep 17 00:00:00 2001 From: Benno Rice Date: Wed, 7 Jun 2017 15:54:13 -0400 Subject: [PATCH 17/17] Rewrite the connection level ping method to fix coverage. --- wsproto/connection.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/wsproto/connection.py b/wsproto/connection.py index 0a36a39..0c11e33 100644 --- a/wsproto/connection.py +++ b/wsproto/connection.py @@ -435,8 +435,5 @@ def accept(self, event, subprotocol=None): self._state = ConnectionState.OPEN def ping(self, payload=None): - if payload is not None: - payload = bytes(payload) - else: - payload = b'' + payload = bytes(payload or b'') self._outgoing += self._proto.ping(payload)