diff --git a/test/test_connection.py b/test/test_connection.py new file mode 100644 index 0000000..4844ed4 --- /dev/null +++ b/test/test_connection.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- + +import itertools + +import pytest + +from wsproto.connection import WSConnection, CLIENT, SERVER, ConnectionState +from wsproto.events import ( + ConnectionClosed, + ConnectionEstablished, + ConnectionRequested, + TextReceived, + BytesReceived, +) +from wsproto.frame_protocol import CloseReason, FrameProtocol + + +class TestConnection(object): + def create_connection(self): + server = WSConnection(SERVER) + client = WSConnection(CLIENT, host='localhost', resource='foo') + + server.receive_bytes(client.bytes_to_send()) + event = next(server.events()) + assert isinstance(event, ConnectionRequested) + + server.accept(event) + client.receive_bytes(server.bytes_to_send()) + assert isinstance(next(client.events()), ConnectionEstablished) + + return client, server + + def test_negotiation(self): + self.create_connection() + + @pytest.mark.parametrize('as_client,final', [ + (True, True), + (True, False), + (False, True), + (False, False) + ]) + def test_send_and_receive(self, as_client, final): + client, server = self.create_connection() + if as_client: + me = client + them = server + else: + me = server + them = client + + data = b'x' * 23 + + me.send_data(data, final) + them.receive_bytes(me.bytes_to_send()) + + event = next(them.events()) + assert isinstance(event, BytesReceived) + assert event.data == data + assert event.message_finished is final + + @pytest.mark.parametrize('as_client,code,reason', [ + (True, CloseReason.NORMAL_CLOSURE, u'bye'), + (True, CloseReason.GOING_AWAY, u'๐Ÿ‘‹๐Ÿ‘‹'), + (False, CloseReason.NORMAL_CLOSURE, u'bye'), + (False, CloseReason.GOING_AWAY, u'๐Ÿ‘‹๐Ÿ‘‹'), + ]) + def test_close(self, as_client, code, reason): + client, server = self.create_connection() + if as_client: + me = client + them = server + else: + me = server + them = client + + me.close(code, reason) + them.receive_bytes(me.bytes_to_send()) + + event = next(them.events()) + assert isinstance(event, ConnectionClosed) + assert event.code is code + assert event.reason == reason + + def test_normal_closure(self): + client, server = self.create_connection() + + for conn in (client, server): + conn.close() + conn.receive_bytes(None) + with pytest.raises(StopIteration): + print(repr(next(conn.events()))) + assert conn.closed + + def test_abnormal_closure(self): + client, server = self.create_connection() + + for conn in (client, server): + conn.receive_bytes(None) + event = next(conn.events()) + assert isinstance(event, ConnectionClosed) + assert event.code is CloseReason.ABNORMAL_CLOSURE + assert conn.closed + + def test_bytes_send_all(self): + connection = WSConnection(SERVER) + connection._outgoing = b'fnord fnord' + assert connection.bytes_to_send() == b'fnord fnord' + assert connection.bytes_to_send() == b'' + + def test_bytes_send_some(self): + connection = WSConnection(SERVER) + connection._outgoing = b'fnord fnord' + assert connection.bytes_to_send(5) == b'fnord' + assert connection.bytes_to_send() == b' fnord' + + @pytest.mark.parametrize('as_client', [True, False]) + def test_ping_pong(self, as_client): + client, server = self.create_connection() + if as_client: + me = client + them = server + else: + me = server + them = client + + payload = b'x' * 23 + + me.ping(payload) + wire_data = me.bytes_to_send() + assert wire_data[0] == 0x89 + masked = bool(wire_data[1] & 0x80) + assert wire_data[1] & ~0x80 == len(payload) + if masked: + maskbytes = itertools.cycle(bytearray(wire_data[2:6])) + data = bytearray(b ^ next(maskbytes) + for b in bytearray(wire_data[6:])) + else: + data = wire_data[2:] + assert data == payload + + them.receive_bytes(wire_data) + with pytest.raises(StopIteration): + print(repr(next(them.events()))) + wire_data = them.bytes_to_send() + assert wire_data[0] == 0x8a + masked = bool(wire_data[1] & 0x80) + assert wire_data[1] & ~0x80 == len(payload) + if masked: + maskbytes = itertools.cycle(bytearray(wire_data[2:6])) + data = bytearray(b ^ next(maskbytes) + for b in bytearray(wire_data[6:])) + else: + data = wire_data[2:] + assert data == payload + + @pytest.mark.parametrize('text,payload,full_message,full_frame', [ + (True, u'ฦ’รฑรถยฎโˆ‚๐Ÿ˜Ž', True, True), + (True, u'ฦ’รฑรถยฎโˆ‚๐Ÿ˜Ž', False, True), + (True, u'ฦ’รฑรถยฎโˆ‚๐Ÿ˜Ž', False, False), + (False, b'x' * 23, True, True), + (False, b'x' * 23, False, True), + (False, b'x' * 23, False, False), + ]) + def test_data_events(self, text, payload, full_message, full_frame): + if text: + opcode = 0x01 + encoded_payload = payload.encode('utf8') + else: + opcode = 0x02 + encoded_payload = payload + + if full_message: + opcode = bytearray([opcode | 0x80]) + else: + opcode = bytearray([opcode]) + + if full_frame: + length = bytearray([len(encoded_payload)]) + else: + length = bytearray([len(encoded_payload) + 100]) + + frame = opcode + length + encoded_payload + + connection = WSConnection(CLIENT, host='localhost', resource='foo') + connection._proto = FrameProtocol(True, []) + connection._state = ConnectionState.OPEN + connection.bytes_to_send() + + connection.receive_bytes(frame) + event = next(connection.events()) + if text: + assert isinstance(event, TextReceived) + else: + assert isinstance(event, BytesReceived) + assert event.data == payload + assert event.frame_finished is full_frame + assert event.message_finished is full_message + + assert not connection.bytes_to_send() + + def test_frame_protocol_somehow_loses_its_mind(self): + class FailFrame(object): + opcode = object() + + class DoomProtocol(object): + def receive_bytes(self, data): + return None + + def received_frames(self): + return [FailFrame()] + + connection = WSConnection(CLIENT, host='localhost', resource='foo') + connection._proto = DoomProtocol() + connection._state = ConnectionState.OPEN + connection.bytes_to_send() + + connection.receive_bytes(b'') + with pytest.raises(StopIteration): + next(connection.events()) + assert not connection.bytes_to_send() + + def test_frame_protocol_gets_fed_garbage(self): + client, server = self.create_connection() + + payload = b'x' * 23 + frame = b'\x09' + bytearray([len(payload)]) + payload + + client.receive_bytes(frame) + event = next(client.events()) + assert isinstance(event, ConnectionClosed) + assert event.code == CloseReason.PROTOCOL_ERROR + + output = client.bytes_to_send() + assert output[:1] == b'\x88' diff --git a/test/test_events.py b/test/test_events.py new file mode 100644 index 0000000..6b280ba --- /dev/null +++ b/test/test_events.py @@ -0,0 +1,80 @@ +import pytest + +from h11 import Request + +from wsproto.events import ( + ConnectionClosed, + ConnectionEstablished, + ConnectionRequested, +) +from wsproto.frame_protocol import CloseReason + + +def test_connection_requested_repr_no_subprotocol(): + method = b'GET' + target = b'/foo' + headers = { + b'host': b'localhost', + b'sec-websocket-version': b'13', + } + http_version = b'1.1' + + req = Request(method=method, target=target, headers=list(headers.items()), + http_version=http_version) + + event = ConnectionRequested([], req) + r = repr(event) + + assert 'ConnectionRequested' in r + assert target.decode('ascii') in r + + +def test_connection_requested_repr_with_subprotocol(): + method = b'GET' + target = b'/foo' + headers = { + b'host': b'localhost', + b'sec-websocket-version': b'13', + b'sec-websocket-protocol': b'fnord', + } + http_version = b'1.1' + + req = Request(method=method, target=target, headers=list(headers.items()), + http_version=http_version) + + event = ConnectionRequested([], req) + r = repr(event) + + assert 'ConnectionRequested' in r + assert target.decode('ascii') in r + assert headers[b'sec-websocket-protocol'].decode('ascii') in r + + +@pytest.mark.parametrize('subprotocol,extensions', [ + ('sproto', None), + (None, ['fake']), + ('sprout', ['pretend']), +]) +def test_connection_established_repr(subprotocol, extensions): + event = ConnectionEstablished(subprotocol, extensions) + r = repr(event) + + if subprotocol: + assert subprotocol in r + if extensions: + for extension in extensions: + assert extension in r + + +@pytest.mark.parametrize('code,reason', [ + (CloseReason.NORMAL_CLOSURE, None), + (CloseReason.NORMAL_CLOSURE, 'because i felt like it'), + (CloseReason.INVALID_FRAME_PAYLOAD_DATA, 'GOOD GOD WHAT DID YOU DO'), +]) +def test_connection_closed_repr(code, reason): + event = ConnectionClosed(code, reason) + r = repr(event) + + assert repr(code) in r + if reason: + assert reason in r diff --git a/test/test_extensions.py b/test/test_extensions.py new file mode 100644 index 0000000..5f41806 --- /dev/null +++ b/test/test_extensions.py @@ -0,0 +1,40 @@ +import wsproto.extensions as wpext +import wsproto.frame_protocol as fp + + +class TestExtension(object): + def test_enabled(self): + ext = wpext.Extension() + assert not ext.enabled() + + def test_offer(self): + ext = wpext.Extension() + assert ext.offer(None) is None + + def test_accept(self): + ext = wpext.Extension() + assert ext.accept(None, None) is None + + def test_finalize(self): + ext = wpext.Extension() + assert ext.finalize(None, None) is None + + def test_frame_inbound_header(self): + ext = wpext.Extension() + result = ext.frame_inbound_header(None, None, None, None) + assert result == fp.RsvBits(False, False, False) + + def test_frame_inbound_payload_data(self): + ext = wpext.Extension() + data = object() + assert ext.frame_inbound_payload_data(None, data) == data + + def test_frame_inbound_complete(self): + ext = wpext.Extension() + assert ext.frame_inbound_complete(None, None) is None + + def test_frame_outbound(self): + ext = wpext.Extension() + rsv = fp.RsvBits(True, True, True) + data = object() + assert ext.frame_outbound(None, None, rsv, data, None) == (rsv, data) diff --git a/test/test_frame_protocol.py b/test/test_frame_protocol.py index 0c25c07..d02397e 100644 --- a/test/test_frame_protocol.py +++ b/test/test_frame_protocol.py @@ -1117,6 +1117,17 @@ def test_local_only_close_reason(self): data = proto.close(code=fp.CloseReason.NO_STATUS_RCVD) assert data == b'\x88\x02\x03\xe8' + def test_ping_without_payload(self): + proto = fp.FrameProtocol(client=False, extensions=[]) + data = proto.ping() + assert data == b'\x89\x00' + + def test_ping_with_payload(self): + proto = fp.FrameProtocol(client=False, extensions=[]) + payload = u'ยฏ\_(ใƒ„)_/ยฏ'.encode('utf8') + data = proto.ping(payload) + assert data == b'\x89' + bytearray([len(payload)]) + payload + def test_pong_without_payload(self): proto = fp.FrameProtocol(client=False, extensions=[]) data = proto.pong() diff --git a/test/test_permessage_deflate.py b/test/test_permessage_deflate.py new file mode 100644 index 0000000..e1054d1 --- /dev/null +++ b/test/test_permessage_deflate.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +import zlib + +import pytest + +import wsproto.extensions as wpext +import wsproto.frame_protocol as fp + + +class TestPerMessageDeflate(object): + parameter_sets = [ + { + 'client_no_context_takeover': False, + 'client_max_window_bits': 15, + 'server_no_context_takeover': False, + 'server_max_window_bits': 15, + }, + { + 'client_no_context_takeover': True, + 'client_max_window_bits': 9, + 'server_no_context_takeover': False, + 'server_max_window_bits': 15, + }, + { + 'client_no_context_takeover': False, + 'client_max_window_bits': 15, + 'server_no_context_takeover': True, + 'server_max_window_bits': 9, + }, + { + 'client_no_context_takeover': True, + 'client_max_window_bits': 8, + 'server_no_context_takeover': True, + 'server_max_window_bits': 9, + }, + { + 'client_no_context_takeover': True, + 'server_max_window_bits': 9, + }, + { + 'server_no_context_takeover': True, + 'client_max_window_bits': 8, + }, + { + 'client_max_window_bits': None, + 'server_max_window_bits': None, + }, + {}, + ] + + def make_offer_string(self, params): + offer = ['permessage-deflate'] + + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + offer.append('client_max_window_bits') + else: + offer.append('client_max_window_bits=%d' % + params['client_max_window_bits']) + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + offer.append('server_max_window_bits') + else: + offer.append('server_max_window_bits=%d' % + params['server_max_window_bits']) + if params.get('client_no_context_takeover', False): + offer.append('client_no_context_takeover') + if params.get('server_no_context_takeover', False): + offer.append('server_no_context_takeover') + + return '; '.join(offer) + + def compare_params_to_string(self, params, ext, param_string): + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + bits = ext.client_max_window_bits + else: + bits = params['client_max_window_bits'] + assert 'client_max_window_bits=%d' % bits in param_string + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + bits = ext.server_max_window_bits + else: + bits = params['server_max_window_bits'] + assert 'server_max_window_bits=%d' % bits in param_string + if params.get('client_no_context_takeover', False): + assert 'client_no_context_takeover' in param_string + if params.get('server_no_context_takeover', False): + assert 'server_no_context_takeover' in param_string + + @pytest.mark.parametrize('params', parameter_sets) + def test_offer(self, params): + ext = wpext.PerMessageDeflate(**params) + offer = ext.offer(None) + + self.compare_params_to_string(params, ext, offer) + + @pytest.mark.parametrize('params', parameter_sets) + def test_finalize(self, params): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + params = dict(params) + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + del params['client_max_window_bits'] + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + del params['server_max_window_bits'] + offer = self.make_offer_string(params) + ext.finalize(None, offer) + + if params.get('client_max_window_bits', None): + assert ext.client_max_window_bits == \ + params['client_max_window_bits'] + if params.get('server_max_window_bits', None): + assert ext.server_max_window_bits == \ + params['server_max_window_bits'] + assert ext.client_no_context_takeover is \ + params.get('client_no_context_takeover', False) + assert ext.server_no_context_takeover is \ + params.get('server_no_context_takeover', False) + + assert ext.enabled() + + def test_finalize_ignores_rubbish(self): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + ext.finalize(None, 'i am the lizard queen; worship me') + + assert ext.enabled() + + @pytest.mark.parametrize('params', parameter_sets) + def test_accept(self, params): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + offer = self.make_offer_string(params) + print(repr(offer)) + + response = ext.accept(None, offer) + print(repr(response)) + + if ext.client_no_context_takeover: + assert 'client_no_context_takeover' in response + if ext.server_no_context_takeover: + assert 'server_no_context_takeover' in response + if 'client_max_window_bits' in params: + if params['client_max_window_bits'] is None: + bits = ext.client_max_window_bits + else: + bits = params['client_max_window_bits'] + assert ext.client_max_window_bits == bits + assert 'client_max_window_bits=%d' % bits in response + if 'server_max_window_bits' in params: + if params['server_max_window_bits'] is None: + bits = ext.server_max_window_bits + else: + bits = params['server_max_window_bits'] + assert ext.server_max_window_bits == bits + assert 'server_max_window_bits=%d' % bits in response + + def test_accept_ignores_rubbish(self): + ext = wpext.PerMessageDeflate() + assert not ext.enabled() + + ext.accept(None, 'i am the lizard queen; worship me') + + assert ext.enabled() + + def test_inbound_uncompressed_control_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.PING, + fp.RsvBits(False, False, False), + len(payload)) + assert result.rsv1 + + data = ext.frame_inbound_payload_data(proto, payload) + assert data == payload + + assert ext.frame_inbound_complete(proto, True) is None + + def test_inbound_compressed_control_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.PING, + fp.RsvBits(True, False, False), + len(payload)) + assert result == fp.CloseReason.PROTOCOL_ERROR + + def test_inbound_compressed_continuation_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.CONTINUATION, + fp.RsvBits(True, False, False), + len(payload)) + assert result == fp.CloseReason.PROTOCOL_ERROR + + def test_inbound_uncompressed_data_frame(self): + payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(False, False, False), + len(payload)) + assert result.rsv1 + + data = ext.frame_inbound_payload_data(proto, payload) + assert data == payload + + assert ext.frame_inbound_complete(proto, True) is None + + @pytest.mark.parametrize('client', [True, False]) + def test_client_inbound_compressed_single_data_frame(self, client): + payload = b'x' * 23 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + len(compressed_payload)) + assert result.rsv1 + + data = ext.frame_inbound_payload_data(proto, compressed_payload) + data += ext.frame_inbound_complete(proto, True) + assert data == payload + + @pytest.mark.parametrize('client', [True, False]) + def test_client_inbound_compressed_multiple_data_frames(self, client): + payload = b'x' * 23 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + split = 3 + data = b'' + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + split) + assert result.rsv1 + result = ext.frame_inbound_payload_data(proto, + compressed_payload[:split]) + assert not isinstance(result, fp.CloseReason) + data += result + assert ext.frame_inbound_complete(proto, False) is None + + result = ext.frame_inbound_header(proto, fp.Opcode.CONTINUATION, + fp.RsvBits(False, False, False), + len(compressed_payload) - split) + assert result.rsv1 + result = ext.frame_inbound_payload_data(proto, + compressed_payload[split:]) + assert not isinstance(result, fp.CloseReason) + data += result + + result = ext.frame_inbound_complete(proto, True) + assert not isinstance(result, fp.CloseReason) + data += result + + assert data == payload + + def test_inbound_bad_zlib_payload(self): + compressed_payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + len(compressed_payload)) + assert result.rsv1 + result = ext.frame_inbound_payload_data(proto, compressed_payload) + assert result is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA + + def test_inbound_bad_zlib_decoder_end_state(self, monkeypatch): + compressed_payload = b'x' * 23 + + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), + len(compressed_payload)) + assert result.rsv1 + + class FailDecompressor(object): + def decompress(self, data): + return b'' + + def flush(self): + raise zlib.error() + + monkeypatch.setattr(ext, '_decompressor', FailDecompressor()) + + result = ext.frame_inbound_complete(proto, True) + assert result is fp.CloseReason.INVALID_FRAME_PAYLOAD_DATA + + @pytest.mark.parametrize('client,no_context_takeover', [ + (True, True), + (True, False), + (False, True), + (False, False), + ]) + def test_decompressor_reset(self, client, no_context_takeover): + if client: + args = {'server_no_context_takeover': no_context_takeover} + else: + args = {'client_no_context_takeover': no_context_takeover} + ext = wpext.PerMessageDeflate(**args) + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), 0) + assert result.rsv1 + + assert ext._decompressor is not None + + result = ext.frame_inbound_complete(proto, True) + assert not isinstance(result, fp.CloseReason) + + if no_context_takeover: + assert ext._decompressor is None + else: + assert ext._decompressor is not None + + result = ext.frame_inbound_header(proto, fp.Opcode.BINARY, + fp.RsvBits(True, False, False), 0) + assert result.rsv1 + + assert ext._decompressor is not None + + def test_outbound_uncompressible_opcode(self): + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=True, extensions=[ext]) + + rsv = fp.RsvBits(False, False, False) + payload = b'x' * 23 + + rsv, data = ext.frame_outbound(proto, fp.Opcode.PING, rsv, payload, + True) + + assert rsv.rsv1 is False + assert data == payload + + @pytest.mark.parametrize('client', [True, False]) + def test_outbound_compress_single_frame(self, client): + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + rsv = fp.RsvBits(False, False, False) + payload = b'x' * 23 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, payload, + True) + + assert rsv.rsv1 is True + assert data == compressed_payload + + @pytest.mark.parametrize('client', [True, False]) + def test_outbound_compress_multiple_frames(self, client): + ext = wpext.PerMessageDeflate() + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + + rsv = fp.RsvBits(False, False, False) + payload = b'x' * 23 + split = 12 + compressed_payload = b'\xaa\xa8\xc0\n\x00\x00' + + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, + payload[:split], False) + assert rsv.rsv1 is True + + rsv = fp.RsvBits(False, False, False) + rsv, more_data = ext.frame_outbound(proto, fp.Opcode.CONTINUATION, rsv, + payload[split:], True) + assert rsv.rsv1 is False + assert data + more_data == compressed_payload + + @pytest.mark.parametrize('client,no_context_takeover', [ + (True, True), + (True, False), + (False, True), + (False, False), + ]) + def test_compressor_reset(self, client, no_context_takeover): + if client: + args = {'client_no_context_takeover': no_context_takeover} + else: + args = {'server_no_context_takeover': no_context_takeover} + ext = wpext.PerMessageDeflate(**args) + ext._enabled = True + proto = fp.FrameProtocol(client=client, extensions=[ext]) + rsv = fp.RsvBits(False, False, False) + + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, b'', + False) + assert rsv.rsv1 is True + assert ext._compressor is not None + + rsv = fp.RsvBits(False, False, False) + rsv, data = ext.frame_outbound(proto, fp.Opcode.CONTINUATION, rsv, b'', + True) + assert rsv.rsv1 is False + if no_context_takeover: + assert ext._compressor is None + else: + assert ext._compressor is not None + + rsv = fp.RsvBits(False, False, False) + rsv, data = ext.frame_outbound(proto, fp.Opcode.BINARY, rsv, b'', + False) + assert rsv.rsv1 is True + assert ext._compressor is not None + + @pytest.mark.parametrize('params', parameter_sets) + def test_repr(self, params): + ext = wpext.PerMessageDeflate(**params) + self.compare_params_to_string(params, ext, repr(ext)) diff --git a/test/test_upgrade.py b/test/test_upgrade.py index 8190b47..163bdfe 100644 --- a/test/test_upgrade.py +++ b/test/test_upgrade.py @@ -6,19 +6,19 @@ import base64 import email import random -import sys +import pytest + +from wsproto.compat import PY3 from wsproto.connection import WSConnection, CLIENT, SERVER from wsproto.events import ( ConnectionEstablished, ConnectionFailed, ConnectionRequested ) - - -IS_PYTHON3 = sys.version_info >= (3, 0) +from wsproto.extensions import Extension def parse_headers(headers): - if IS_PYTHON3: + if PY3: headers = email.message_from_bytes(headers) else: headers = email.message_from_string(headers) @@ -26,6 +26,26 @@ def parse_headers(headers): return dict(headers.items()) +class FakeExtension(Extension): + name = 'fake' + + def __init__(self, offer_response=None, accept_response=None): + self.offer_response = offer_response + self.accepted_offer = None + self.offered = None + self.accept_response = accept_response + + def offer(self, proto): + return self.offer_response + + def finalize(self, proto, offer): + self.accepted_offer = offer + + def accept(self, proto, offer): + self.offered = offer + return self.accept_response + + class TestClientUpgrade(object): def initiate(self, host, path, **kwargs): ws = WSConnection(CLIENT, host, path, **kwargs) @@ -133,6 +153,206 @@ def test_bad_upgrade_header(self): ws.receive_bytes(response) assert isinstance(next(ws.events()), ConnectionFailed) + def test_simple_extension_offer(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + assert _ext.name == headers['sec-websocket-extensions'] + + def test_simple_extension_non_offer(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=False) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + assert 'sec-websocket-extensions' not in headers + + def test_extension_offer_with_params(self): + ext_parameters = 'parameter1=value1; parameter2=value2' + _ext = FakeExtension(offer_response=ext_parameters) + + _host = 'frob.nitz' + _path = '/fnord' + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + assert headers['sec-websocket-extensions'] == \ + '%s; %s' % (_ext.name, ext_parameters) + + def test_simple_extension_accept(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Extensions: " + \ + _ext.name.encode('ascii') + b"\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionEstablished) + assert _ext.name in _ext.accepted_offer + + def test_extension_accept_with_parameters(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + ext_parameters = 'parameter1=value1; parameter2=value2' + extensions = _ext.name + '; ' + ext_parameters + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Extensions: " + \ + extensions.encode('ascii') + b"\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionEstablished) + assert _ext.accepted_offer == extensions + + def test_accept_an_extension_we_do_not_recognise(self): + _host = 'frob.nitz' + _path = '/fnord' + _ext = FakeExtension(offer_response=True) + + ws, method, path, version, headers = \ + self.initiate(_host, _path, extensions=[_ext]) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Extensions: pretend\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionFailed) + + def test_wrong_status_code_in_response(self): + _host = 'frob.nitz' + _path = '/fnord' + + ws, method, path, version, headers = self.initiate(_host, _path) + + response = b"HTTP/1.1 200 OK\r\n" + response += b"Server: SimpleHTTP/0.6 Python/3.6.1\r\n" + response += b"Date: Fri, 02 Jun 2017 20:40:39 GMT\r\n" + response += b"Content-type: application/octet-stream\r\n" + response += b"Content-Length: 0\r\n" + response += b"Last-Modified: Fri, 02 Jun 2017 20:40:00 GMT\r\n" + response += b"Connection: close\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionFailed) + + def test_response_takes_a_few_goes(self): + _host = 'frob.nitz' + _path = '/fnord' + + ws, method, path, version, headers = self.initiate(_host, _path) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"\r\n" + + split = len(response) // 2 + + ws.receive_bytes(response[:split]) + with pytest.raises(StopIteration): + next(ws.events()) + + ws.receive_bytes(response[split:]) + assert isinstance(next(ws.events()), ConnectionEstablished) + + def test_subprotocol_offer(self): + _host = 'frob.nitz' + _path = '/fnord' + subprotocols = ['one', 'two'] + + ws, method, path, version, headers = \ + self.initiate(_host, _path, subprotocols=subprotocols) + + for subprotocol in subprotocols: + assert subprotocol in headers['sec-websocket-protocol'] + + def test_subprotocol_accept(self): + _host = 'frob.nitz' + _path = '/fnord' + subprotocols = ['one', 'two'] + + ws, method, path, version, headers = \ + self.initiate(_host, _path, subprotocols=subprotocols) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Protocol: " + \ + subprotocols[0].encode('ascii') + b"\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + event = next(ws.events()) + assert isinstance(event, ConnectionEstablished) + assert event.subprotocol == subprotocols[0] + + def test_subprotocol_accept_unoffered(self): + _host = 'frob.nitz' + _path = '/fnord' + subprotocols = ['one', 'two'] + + ws, method, path, version, headers = \ + self.initiate(_host, _path, subprotocols=subprotocols) + + key = headers['sec-websocket-key'].encode('ascii') + accept_token = ws._generate_accept_token(key) + + response = b"HTTP/1.1 101 Switching Protocols\r\n" + response += b"Connection: Upgrade\r\n" + response += b"Upgrade: WebSocket\r\n" + response += b"Sec-WebSocket-Accept: " + accept_token + b"\r\n" + response += b"Sec-WebSocket-Protocol: three\r\n" + response += b"\r\n" + + ws.receive_bytes(response) + assert isinstance(next(ws.events()), ConnectionFailed) + class TestServerUpgrade(object): def test_correct_request(self): @@ -168,3 +388,334 @@ def test_correct_request(self): assert headers['connection'].lower() == 'upgrade' assert headers['upgrade'].lower() == 'websocket' assert headers['sec-websocket-accept'] == accept_token.decode('ascii') + + def test_wrong_method(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'POST ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_bad_connection(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Zoinks\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_bad_upgrade(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebPocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_missing_version(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_missing_key(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionFailed) + + def test_subprotocol_offers(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Protocol: one, two\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + assert event.proposed_subprotocols == ['one', 'two'] + + def test_accept_subprotocol(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Protocol: one, two\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + assert event.proposed_subprotocols == ['one', 'two'] + + ws.accept(event, 'two') + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert int(code) == 101 + assert headers['sec-websocket-protocol'] == 'two' + + def test_accept_wrong_subprotocol(self): + test_host = 'frob.nitz' + test_path = '/fnord' + + ws = WSConnection(SERVER) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b'GET ' + test_path.encode('ascii') + b' HTTP/1.1\r\n' + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Protocol: one, two\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + assert event.proposed_subprotocols == ['one', 'two'] + + with pytest.raises(ValueError): + ws.accept(event, 'three') + + def test_simple_extension_negotiation(self): + test_host = 'frob.nitz' + test_path = '/fnord' + ext = FakeExtension(accept_response=True) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: ' + \ + ext.name.encode('ascii') + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert ext.offered == ext.name + assert headers['sec-websocket-extensions'] == ext.name + + def test_extension_negotiation_with_our_parameters(self): + test_host = 'frob.nitz' + test_path = '/fnord' + offered_params = 'parameter1=value3; parameter2=value4' + ext_params = 'parameter1=value1; parameter2=value2' + ext = FakeExtension(accept_response=ext_params) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: ' + \ + ext.name.encode('ascii') + b'; ' + \ + offered_params.encode('ascii') + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert ext.offered == '%s; %s' % (ext.name, offered_params) + assert headers['sec-websocket-extensions'] == \ + '%s; %s' % (ext.name, ext_params) + + def test_disinterested_extension_negotiation(self): + test_host = 'frob.nitz' + test_path = '/fnord' + ext = FakeExtension(accept_response=False) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: ' + \ + ext.name.encode('ascii') + b'\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert ext.offered == ext.name + assert 'sec-websocket-extensions' not in headers + + def test_unwanted_extension_negotiation(self): + test_host = 'frob.nitz' + test_path = '/fnord' + ext = FakeExtension(accept_response=False) + + ws = WSConnection(SERVER, extensions=[ext]) + + nonce = bytes(random.getrandbits(8) for x in range(0, 16)) + nonce = base64.b64encode(nonce) + + request = b"GET " + test_path.encode('ascii') + b" HTTP/1.1\r\n" + request += b'Host: ' + test_host.encode('ascii') + b'\r\n' + request += b'Connection: Upgrade\r\n' + request += b'Upgrade: WebSocket\r\n' + request += b'Sec-WebSocket-Version: 13\r\n' + request += b'Sec-WebSocket-Key: ' + nonce + b'\r\n' + request += b'Sec-WebSocket-Extensions: pretend\r\n' + request += b'\r\n' + + ws.receive_bytes(request) + event = next(ws.events()) + assert isinstance(event, ConnectionRequested) + ws.accept(event) + + data = ws.bytes_to_send() + response, headers = data.split(b'\r\n', 1) + version, code, reason = response.split(b' ') + headers = parse_headers(headers) + + assert 'sec-websocket-extensions' not in headers + + def test_not_an_http_request_at_all(self): + ws = WSConnection(SERVER) + + request = b'Good god, what is this?\r\n\r\n' + + ws.receive_bytes(request) + assert isinstance(next(ws.events()), ConnectionFailed) + + def test_h11_somehow_loses_its_mind(self): + ws = WSConnection(SERVER) + ws._upgrade_connection.next_event = lambda: object() + + ws.receive_bytes(b'') + assert isinstance(next(ws.events()), ConnectionFailed) diff --git a/wsproto/connection.py b/wsproto/connection.py index 9c4aaec..0c11e33 100644 --- a/wsproto/connection.py +++ b/wsproto/connection.py @@ -66,7 +66,7 @@ def _normed_header_dict(h11_headers): # wrong, because those can contain quoted strings, which can in turn contain # commas. XX FIXME def _split_comma_header(value): - return [piece.strip() for piece in value.split(b',')] + return [piece.decode('ascii').strip() for piece in value.split(b',')] class WSConnection(object): @@ -236,14 +236,22 @@ def receive_bytes(self, data): def _process_upgrade(self, data): self._upgrade_connection.receive_data(data) while True: - event = self._upgrade_connection.next_event() + try: + event = self._upgrade_connection.next_event() + except h11.RemoteProtocolError: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message"), b'' if event is h11.NEED_DATA: break - elif self.client and isinstance(event, h11.InformationalResponse): + elif self.client and isinstance(event, (h11.InformationalResponse, + h11.Response)): data = self._upgrade_connection.trailing_data[0] return self._establish_client_connection(event), data elif not self.client and isinstance(event, h11.Request): return self._process_connection_request(event), None + else: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message"), b'' self._incoming = b'' return None, None @@ -264,7 +272,6 @@ def events(self): try: for frame in self._proto.received_frames(): - if frame.opcode is Opcode.PING: assert frame.frame_finished and frame.message_finished self._outgoing += self._proto.pong(frame.payload) @@ -320,6 +327,7 @@ def _establish_client_connection(self, event): subprotocol = headers.get(b'sec-websocket-protocol', None) if subprotocol is not None: + subprotocol = subprotocol.decode('ascii') if subprotocol not in self.subprotocols: return ConnectionFailed(CloseReason.PROTOCOL_ERROR, "unrecognized subprotocol {!r}" @@ -330,7 +338,6 @@ def _establish_client_connection(self, event): accepts = _split_comma_header(extensions) for accept in accepts: - accept = accept.decode('ascii') name = accept.split(';', 1)[0].strip() for extension in self.extensions: if extension.name == name: @@ -400,7 +407,6 @@ def accept(self, event, subprotocol=None): offers = _split_comma_header(extensions) for offer in offers: - offer = offer.decode('ascii') name = offer.split(';', 1)[0].strip() for extension in self.extensions: if extension.name == name: @@ -427,3 +433,7 @@ def accept(self, event, subprotocol=None): self._outgoing += self._upgrade_connection.send(response) self._proto = FrameProtocol(self.client, self.extensions) self._state = ConnectionState.OPEN + + def ping(self, payload=None): + payload = bytes(payload or b'') + self._outgoing += self._proto.ping(payload) diff --git a/wsproto/extensions.py b/wsproto/extensions.py index aaff2f1..ad7daa1 100644 --- a/wsproto/extensions.py +++ b/wsproto/extensions.py @@ -8,7 +8,6 @@ import zlib -from .compat import PY2 from .frame_protocol import CloseReason, Opcode, RsvBits @@ -24,6 +23,9 @@ def offer(self, connection): def accept(self, connection, offer): return None + def finalize(self, connection, offer): + return None + def frame_inbound_header(self, proto, opcode, rsv, payload_length): return RsvBits(False, False, False) @@ -31,7 +33,7 @@ def frame_inbound_payload_data(self, proto, data): return data def frame_inbound_complete(self, proto, fin): - pass + return None def frame_outbound(self, proto, opcode, rsv, data, fin): return (rsv, data) @@ -40,12 +42,19 @@ def frame_outbound(self, proto, opcode, rsv, data, fin): class PerMessageDeflate(Extension): name = 'permessage-deflate' + DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 + DEFAULT_SERVER_MAX_WINDOW_BITS = 15 + def __init__(self, client_no_context_takeover=False, - client_max_window_bits=15, server_no_context_takeover=False, - server_max_window_bits=15): + client_max_window_bits=None, server_no_context_takeover=False, + server_max_window_bits=None): self.client_no_context_takeover = client_no_context_takeover + if client_max_window_bits is None: + client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS self.client_max_window_bits = client_max_window_bits self.server_no_context_takeover = server_no_context_takeover + if server_max_window_bits is None: + server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS self.server_max_window_bits = server_max_window_bits self._compressor = None @@ -163,21 +172,18 @@ def frame_inbound_payload_data(self, proto, data): if not self._inbound_compressed or not self._inbound_is_compressible: return data - if PY2: - data = str(data) - try: - return self._decompressor.decompress(data) + return self._decompressor.decompress(bytes(data)) except zlib.error: return CloseReason.INVALID_FRAME_PAYLOAD_DATA def frame_inbound_complete(self, proto, fin): if not fin: return - elif not self._inbound_compressed: - return elif not self._inbound_is_compressible: return + elif not self._inbound_compressed: + return try: data = self._decompressor.decompress(b'\x00\x00\xff\xff') @@ -213,9 +219,7 @@ def frame_outbound(self, proto, opcode, rsv, data, fin): self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -bits) - if PY2: - data = str(data) - data = self._compressor.compress(data) + data = self._compressor.compress(bytes(data)) if fin: data += self._compressor.flush(zlib.Z_SYNC_FLUSH) diff --git a/wsproto/frame_protocol.py b/wsproto/frame_protocol.py index 9426826..e94c68f 100644 --- a/wsproto/frame_protocol.py +++ b/wsproto/frame_protocol.py @@ -492,6 +492,9 @@ def close(self, code=None, reason=None): return self._serialize_frame(Opcode.CLOSE, payload) + def ping(self, payload=b''): + return self._serialize_frame(Opcode.PING, payload) + def pong(self, payload=b''): return self._serialize_frame(Opcode.PONG, payload)