diff --git a/saltyrtc/server/message.py b/saltyrtc/server/message.py index ff83094..72a55b7 100644 --- a/saltyrtc/server/message.py +++ b/saltyrtc/server/message.py @@ -220,6 +220,7 @@ def unpack(cls, client, data): # Decrypt if directed at us and keys have been exchanged # or just return a raw message to be sent to another client + expect_type = None if destination_type == AddressType.server: data = data[NONCE_LENGTH:] if not client.authenticated and client.type is None: @@ -231,6 +232,8 @@ def unpack(cls, client, data): cls._decrypt_payload(client, nonce, data)) except MessageError: pass + else: + expect_type = MessageType.client_auth # Try client-hello (unencrypted) if payload is None: @@ -238,9 +241,11 @@ def unpack(cls, client, data): payload = cls._unpack_payload(data) except MessageError: payload = None + else: + expect_type = MessageType.client_hello # Still no payload? - if payload is None: + if expect_type is None or payload is None: message = 'Expected either client-hello or client-auth, got neither' raise MessageError(message) else: @@ -266,6 +271,10 @@ def unpack(cls, client, data): except ValueError as exc: raise MessageError('Unknown message type: {}'.format(type_)) from exc + # Ensure type isn't violated + if expect_type is not None and type_ != expect_type: + raise MessageError('Expected type {}, got {}'.format(expect_type, type_)) + # Check and convert payload on appropriate message class try: message_class = cls._get_message_classes()[type_] diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7d64758..8eecdb1 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -354,6 +354,33 @@ def test_subprotocol_downgrade_2( assert not client.ws_client.open assert client.ws_client.close_code == CloseCode.protocol_error + @pytest.mark.asyncio + def test_initiator_handshake_unencrypted( + self, cookie_factory, pack_nonce, server, client_factory + ): + """ + Check that we cannot do a complete handshake for an initiator + when 'client-auth' is not encrypted. + """ + client = yield from client_factory() + + # server-hello, already checked in another test + message, _, sck, s, d, start_scsn = yield from client.recv() + + # client-auth + cck, ccsn = cookie_factory(), 2**32 - 1 + yield from client.send(pack_nonce(cck, 0x00, 0x00, ccsn), { + 'type': 'client-auth', + 'your_cookie': sck, + 'subprotocols': pytest.saltyrtc.subprotocols, + }) + ccsn += 1 + + # Expect protocol error + yield from server.wait_connections_closed() + assert not client.ws_client.open + assert client.ws_client.close_code == CloseCode.protocol_error + @pytest.mark.asyncio def test_initiator_handshake( self, cookie_factory, initiator_key, pack_nonce, server, client_factory, @@ -447,6 +474,39 @@ def test_responder_handshake( yield from client.close() yield from server.wait_connections_closed() + @pytest.mark.asyncio + def test_responder_handshake_unencrypted( + self, cookie_factory, responder_key, pack_nonce, client_factory, server + ): + """ + Check that we can do a complete handshake for a responder. + """ + client = yield from client_factory() + + # server-hello, already checked in another test + message, _, sck, s, d, start_scsn = yield from client.recv() + + # client-hello + cck, ccsn = cookie_factory(), 2**32 - 1 + yield from client.send(pack_nonce(cck, 0x00, 0x00, ccsn), { + 'type': 'client-hello', + 'key': responder_key.pk, + }) + ccsn += 1 + + # client-auth + yield from client.send(pack_nonce(cck, 0x00, 0x00, ccsn), { + 'type': 'client-auth', + 'your_cookie': sck, + 'subprotocols': pytest.saltyrtc.subprotocols, + }) + ccsn += 1 + + # Expect protocol error + yield from server.wait_connections_closed() + assert not client.ws_client.open + assert client.ws_client.close_code == CloseCode.protocol_error + @pytest.mark.asyncio def test_client_factory_handshake( self, server, client_factory, initiator_key, responder_key @@ -633,6 +693,37 @@ def test_invalid_destination_after_handshake( assert not responder.ws_client.open assert responder.ws_client.close_code == CloseCode.protocol_error + @pytest.mark.asyncio + def test_unencrypted_packet_after_initiator_handshake( + self, pack_nonce, server, client_factory + ): + """ + Check that the server closes with Protocol Error when an + unencrypted packet is being sent by an initiator. + """ + # Initiator handshake + initiator, i = yield from client_factory(initiator_handshake=True) + assert len(i['responders']) == 0 + + # Drop non-existing responder (encrypted) + yield from initiator.send(pack_nonce(i['cck'], 0x01, 0x00, i['ccsn']), { + 'type': 'drop-responder', + 'id': 0x02, + }) + i['ccsn'] += 1 + + # Drop non-existing responder (unencrypted) + yield from initiator.send(pack_nonce(i['cck'], 0x01, 0x00, i['ccsn']), { + 'type': 'drop-responder', + 'id': 0x02, + }, box=None) + i['ccsn'] += 1 + + # Expect protocol error + yield from server.wait_connections_closed() + assert not initiator.ws_client.open + assert initiator.ws_client.close_code == CloseCode.protocol_error + @pytest.mark.asyncio def test_new_initiator(self, server, client_factory): """