diff --git a/tlslite/messages.py b/tlslite/messages.py index 712e84fe..d724ad4b 100644 --- a/tlslite/messages.py +++ b/tlslite/messages.py @@ -699,6 +699,28 @@ def write(self): return self._write() +class HelloRequest(HandshakeMsg): + """ + Handling of Hello Request messages. + """ + + def __init__(self): + super(HelloRequest, self).__init__(HandshakeType.hello_request) + + def create(self): + return self + + def write(self): + return self.postWrite(Writer()) + + def parse(self, parser): + # verify that the message is empty (the buffer will just contain + # the length from header) + parser.startLengthCheck(3) + parser.stopLengthCheck() + return self + + class ServerHello(HelloMessage): """ Handling of Server Hello messages. diff --git a/unit_tests/test_tlslite_messages.py b/unit_tests/test_tlslite_messages.py index bfe34be0..b27e8d89 100644 --- a/unit_tests/test_tlslite_messages.py +++ b/unit_tests/test_tlslite_messages.py @@ -21,7 +21,7 @@ ClientMasterKey, ClientFinished, ServerFinished, CertificateStatus, \ Certificate, Finished, HelloMessage, ChangeCipherSpec, NextProtocol, \ ApplicationData, EncryptedExtensions, CertificateEntry, \ - NewSessionTicket, SessionTicketPayload, Heartbeat + NewSessionTicket, SessionTicketPayload, Heartbeat, HelloRequest from tlslite.utils.codec import Parser from tlslite.constants import CipherSuite, CertificateType, ContentType, \ AlertLevel, AlertDescription, ExtensionType, ClientCertificateType, \ @@ -3668,5 +3668,46 @@ def test_create_response(self): self.assertEqual(heartbeat_request.payload, heartbeat_response.payload) +class TestHelloRequest(unittest.TestCase): + def setUp(self): + self.msg = HelloRequest() + + def test___init__(self): + self.assertIsNotNone(self.msg) + self.assertEqual(self.msg.handshakeType, 0) + + def test_create(self): + msg = self.msg.create() + + self.assertIs(msg, self.msg) + + def test_write(self): + self.assertEqual(self.msg.write(), + bytearray(b'\x00' # handshake type + b'\x00\x00\x00' # overall length + )) + + def test_parse(self): + parser = Parser(bytearray(# b'\x00' # type + b'\x00\x00\x00')) # overall length + + msg = self.msg.parse(parser) + + self.assertIsInstance(msg, HelloRequest) + + def test_parse_with_truncated_length(self): + parser = Parser(bytearray(# b'\x00' # type + b'\x00\x00')) # overall length (truncated) + with self.assertRaises(SyntaxError): + self.msg.parse(parser) + + def test_parse_with_non_zero_payload(self): + parser = Parser(bytearray(# b'\x00' # type + b'\x00\x00\x01' # overall length + b'\xff')) # some garbage + with self.assertRaises(SyntaxError): + self.msg.parse(parser) + + if __name__ == '__main__': unittest.main()