From f4aee2a56531931b45123b51953ee6a8ac608232 Mon Sep 17 00:00:00 2001 From: Ivan Nikolchev Date: Thu, 26 Mar 2020 11:51:49 +0100 Subject: [PATCH 1/2] add new calcKey function --- tlslite/mathtls.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tlslite/mathtls.py b/tlslite/mathtls.py index f7920e594..6eb8164d1 100644 --- a/tlslite/mathtls.py +++ b/tlslite/mathtls.py @@ -595,6 +595,85 @@ def calcFinished(version, masterSecret, cipherSuite, handshakeHashes, return verifyData +def calcKey(version, secret, cipherSuite, label, handshakeHashes=None, + clientRandom=None, serverRandom=None, outputLength=None): + """ + Method for calculating different keys depending on input. + It can be used to calculate finished value, master secret, + extended master secret or key expansion. + + :param version: TLS protocol version tuple + :param secret: master secret or premasterSecret which will be + used in the PRF. + :param cipherSuite: Negotiated cipher suite of the connection. + :param label: label for the key you want to calculate + (ex. 'master secret', 'extended master secret', etc). + :param handshakeHashes: running hash of the handshake messages + needed for calculating extended master secret or finished value. + :param clientRandom: client random needed for calculating + master secret or key expansion. + :param serverRandom: server random needed for calculating + master secret or key expansion. + :outputLength: Number of bytes to output. + """ + + + # SSL3 calculations. + if version == (3, 0): + # Calculating Finished value, either for message sent + # by server or by client + if label == "client finished": + senderStr = b"\x43\x4C\x4E\x54" + return handshakeHashes.digestSSL(secret, senderStr) + elif label == "server finished": + senderStr = b"\x53\x52\x56\x52" + return handshakeHashes.digestSSL(secret, senderStr) + else: + assert label in ["key expansion", "master secret"] + func = PRF_SSL + + # TLS1.0 or TLS1.1 calculations. + elif version in ((3, 1), (3, 2)): + func = PRF + # Seed needed for calculating extended master secret + if label == "extended master secret": + seed = handshakeHashes.digest('md5') + \ + handshakeHashes.digest('sha1') + # Seed needed for calculating Finished value + elif label in ["server finished", "client finished"]: + seed = handshakeHashes.digest() + else: + assert label in ["key expansion", "master secret"] + + # TLS1.2 calculations. + else: + assert version == (3, 3) + if cipherSuite in CipherSuite.sha384PrfSuites: + func = PRF_1_2_SHA384 + # Seed needed for calculating Finished value or extended master + # secret + if label in ["extended master secret", "server finished", + "client finished"]: + seed = handshakeHashes.digest('sha384') + else: + assert label in ["key expansion", "master secret"] + else: + # Same as above, just using sha256 + func = PRF_1_2 + if label in ["extended master secret", "server finished", + "client finished"]: + seed = handshakeHashes.digest('sha256') + else: + assert label in ["key expansion", "master secret"] + + # Seed needed for calculating key expansion or master secret + if label == "key expansion": seed = serverRandom + clientRandom + if label == "master secret": seed = clientRandom + serverRandom + + if func == PRF_SSL: + return func(secret, seed, outputLength) + return func(secret, compatAscii2Bytes(label), seed, outputLength) + def makeX(salt, username, password): if len(username)>=256: raise ValueError("username too long") From 9a4d655ce6f11500f6e23d20a924c26a2e5ea1b3 Mon Sep 17 00:00:00 2001 From: Ivan Nikolchev Date: Tue, 31 Mar 2020 02:09:10 +0200 Subject: [PATCH 2/2] replace old functions with the new calcKey one --- tlslite/keyexchange.py | 12 +-- tlslite/recordlayer.py | 37 +------ tlslite/tlsconnection.py | 65 +++++++----- unit_tests/test_tlslite_mathtls.py | 119 ++++++++++++---------- unit_tests/test_tlslite_tlsrecordlayer.py | 36 ++++--- 5 files changed, 136 insertions(+), 133 deletions(-) diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 42ccc91f0..54ac7b0f9 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -6,7 +6,7 @@ import ecdsa from .mathtls import goodGroupParameters, makeK, makeU, makeX, \ - calcMasterSecret, paramStrength, RFC7919_GROUPS + paramStrength, RFC7919_GROUPS, calcKey from .errors import TLSInsufficientSecurity, TLSUnknownPSKIdentity, \ TLSIllegalParameterException, TLSDecryptionFailed, TLSInternalError, \ TLSDecodeError @@ -260,11 +260,11 @@ def calcVerifyBytes(version, handshakeHashes, signatureAlg, prf_name = None, peer_tag=b'client', key_type="rsa"): """Calculate signed bytes for Certificate Verify""" if version == (3, 0): - masterSecret = calcMasterSecret(version, - 0, - premasterSecret, - clientRandom, - serverRandom) + masterSecret = calcKey(version, premasterSecret, + 0, "master secret", + clientRandom=clientRandom, + serverRandom=serverRandom, + outputLength=48) verifyBytes = handshakeHashes.digestSSL(masterSecret, b"") elif version in ((3, 1), (3, 2)): if key_type != "ecdsa": diff --git a/tlslite/recordlayer.py b/tlslite/recordlayer.py index 9533585d3..708f66704 100644 --- a/tlslite/recordlayer.py +++ b/tlslite/recordlayer.py @@ -30,8 +30,7 @@ from .errors import TLSRecordOverflow, TLSIllegalParameterException,\ TLSAbruptCloseError, TLSDecryptionFailed, TLSBadRecordMAC, \ TLSUnexpectedMessage -from .mathtls import createMAC_SSL, createHMAC, PRF_SSL, PRF, PRF_1_2, \ - PRF_1_2_SHA384 +from .mathtls import createMAC_SSL, createHMAC, calcKey class RecordSocket(object): """ @@ -1097,34 +1096,6 @@ def _getHMACMethod(version): return createMACFunc - def _calcKeyBlock(self, cipherSuite, masterSecret, clientRandom, - serverRandom, outputLength): - """Calculate the overall key to slice up""" - if self.version == (3, 0): - keyBlock = PRF_SSL(masterSecret, - serverRandom + clientRandom, - outputLength) - elif self.version in ((3, 1), (3, 2)): - keyBlock = PRF(masterSecret, - b"key expansion", - serverRandom + clientRandom, - outputLength) - elif self.version == (3, 3): - if cipherSuite in CipherSuite.sha384PrfSuites: - keyBlock = PRF_1_2_SHA384(masterSecret, - b"key expansion", - serverRandom + clientRandom, - outputLength) - else: - keyBlock = PRF_1_2(masterSecret, - b"key expansion", - serverRandom + clientRandom, - outputLength) - else: - raise AssertionError() - - return keyBlock - def calcSSL2PendingStates(self, cipherSuite, masterSecret, clientRandom, serverRandom, implementations): """ @@ -1215,8 +1186,10 @@ def calcPendingStates(self, cipherSuite, masterSecret, clientRandom, outputLength = (macLength*2) + (keyLength*2) + (ivLength*2) #Calculate Keying Material from Master Secret - keyBlock = self._calcKeyBlock(cipherSuite, masterSecret, clientRandom, - serverRandom, outputLength) + keyBlock = calcKey(self.version, masterSecret, cipherSuite, + "key expansion", clientRandom=clientRandom, + serverRandom=serverRandom, + outputLength=outputLength) #Slice up Keying Material clientPendingState = ConnectionState() diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 4e831c8c2..662b94314 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -1788,16 +1788,16 @@ def _clientFinished(self, premasterSecret, clientRandom, serverRandom, # use the certificate authentication, the hashes are the same if not cvhh: cvhh = self._handshake_hash - masterSecret = calcExtendedMasterSecret(self.version, - cipherSuite, - premasterSecret, - cvhh) + masterSecret = calcKey(self.version, premasterSecret, + cipherSuite, "extended master secret", + handshakeHashes=cvhh, + outputLength=48) else: - masterSecret = calcMasterSecret(self.version, - cipherSuite, - premasterSecret, - clientRandom, - serverRandom) + masterSecret = calcKey(self.version, premasterSecret, + cipherSuite, "master secret", + clientRandom=clientRandom, + serverRandom=serverRandom, + outputLength=48) self._calcPendingStates(cipherSuite, masterSecret, clientRandom, serverRandom, cipherImplementations) @@ -4069,16 +4069,16 @@ def _serverFinished(self, premasterSecret, clientRandom, serverRandom, # to regular handshake hash if not cvhh: cvhh = self._handshake_hash - masterSecret = calcExtendedMasterSecret(self.version, - cipherSuite, - premasterSecret, - cvhh) + masterSecret = calcKey(self.version, premasterSecret, + cipherSuite, "extended master secret", + handshakeHashes=cvhh, + outputLength=48) else: - masterSecret = calcMasterSecret(self.version, - cipherSuite, - premasterSecret, - clientRandom, - serverRandom) + masterSecret = calcKey(self.version, premasterSecret, + cipherSuite, "master secret", + clientRandom=clientRandom, + serverRandom=serverRandom, + outputLength=48) #Calculate pending connection states self._calcPendingStates(cipherSuite, masterSecret, @@ -4125,12 +4125,16 @@ def _sendFinished(self, masterSecret, cipherSuite=None, nextProto=None, for result in self._sendMsg(nextProtoMsg): yield result + #Figure out the correct label to use + if self._client: + label = "client finished" + else: + label = "server finished" #Calculate verification data - verifyData = calcFinished(self.version, - masterSecret, - cipherSuite, - self._handshake_hash, - self._client) + verifyData = calcKey(self.version, masterSecret, + cipherSuite, label, + handshakeHashes=self._handshake_hash, + outputLength=12) if self.fault == Fault.badFinished: verifyData[0] = (verifyData[0]+1)%256 @@ -4175,12 +4179,17 @@ def _getFinished(self, masterSecret, cipherSuite=None, if nextProto: self.next_proto = nextProto + #Figure out which label to use. + if self._client: + label = "server finished" + else: + label = "client finished" + #Calculate verification data - verifyData = calcFinished(self.version, - masterSecret, - cipherSuite, - self._handshake_hash, - not self._client) + verifyData = calcKey(self.version, masterSecret, + cipherSuite, label, + handshakeHashes=self._handshake_hash, + outputLength=12) #Get and check Finished message under new state for result in self._getMsg(ContentType.handshake, diff --git a/unit_tests/test_tlslite_mathtls.py b/unit_tests/test_tlslite_mathtls.py index 721e0340e..e96e5d472 100644 --- a/unit_tests/test_tlslite_mathtls.py +++ b/unit_tests/test_tlslite_mathtls.py @@ -9,15 +9,17 @@ except ImportError: import unittest -from tlslite.mathtls import PRF_1_2, calcMasterSecret, calcFinished, \ - calcExtendedMasterSecret, paramStrength +from tlslite.mathtls import PRF_1_2, calcKey, paramStrength from tlslite.handshakehashes import HandshakeHashes from tlslite.constants import CipherSuite class TestCalcMasterSecret(unittest.TestCase): def test_with_empty_values(self): - ret = calcMasterSecret((3, 3), 0, bytearray(48), bytearray(32), - bytearray(32)) + ret = calcKey((3, 3), bytearray(48), 0, + "master secret", + clientRandom=bytearray(32), + serverRandom=bytearray(32), + outputLength=48) self.assertEqual(bytearray( b'I\xcf\xae\xe5[\x86\x92\xd3\xbbm\xd6\xeekSo/' + @@ -32,10 +34,10 @@ def setUp(self): self.handshakeHashes.update(bytearray(48)) def test_with_TLS_1_0(self): - ret = calcExtendedMasterSecret((3, 1), - 0, - bytearray(48), - self.handshakeHashes) + ret = calcKey((3, 1), bytearray(48), 0, + "extended master secret", + handshakeHashes=self.handshakeHashes, + outputLength=48) self.assertEqual(ret, bytearray( b'/\xe9\x86\xda\xda\xa9)\x1eyJ\xc9\x13E\xe4\xfc\xe7\x842m7(\xb4' b'\x98\xb7\xbc\xa5\xda\x1d\xf3\x15\xea\xdf:i\xeb\x9bA\x8f\xe7' @@ -43,10 +45,10 @@ def test_with_TLS_1_0(self): )) def test_with_TLS_1_2(self): - ret = calcExtendedMasterSecret((3, 3), - 0, - bytearray(48), - self.handshakeHashes) + ret = calcKey((3, 3), bytearray(48), 0, + "extended master secret", + handshakeHashes=self.handshakeHashes, + outputLength=48) self.assertEqual(ret, bytearray( b'\x03\xc93Yx\xcbjSEmz*\x0b\xc3\xc04G\xf3\xe3{\xee\x13\x8b\xac' b'\xd7\xb7\xe6\xbaY\x86\xd5\xf2o?\x8f\xc6\xf2\x19\x1d\x06\xe0N' @@ -54,11 +56,12 @@ def test_with_TLS_1_2(self): )) def test_with_TLS_1_2_and_SHA384_PRF(self): - ret = calcExtendedMasterSecret((3, 3), - CipherSuite. - TLS_RSA_WITH_AES_256_GCM_SHA384, - bytearray(48), - self.handshakeHashes) + ret = calcKey((3, 3), bytearray(48), + CipherSuite. + TLS_RSA_WITH_AES_256_GCM_SHA384, + "extended master secret", + handshakeHashes=self.handshakeHashes, + outputLength=48) self.assertEqual(ret, bytearray( b"\xd6\xed}K\xfbo\xb2\xdb\xa4\xee\xa1\x0f\x8f\x07*\x84w/\xbf_" b"\xbd\xc1U^\x93\xcf\xe8\xca\x82\xb7_B\xa3O\xd9V\x86\x12\xfd\x08" @@ -85,11 +88,9 @@ class TestCalcFinishedInSSL3(TestCalcFinished): def setUp(self): super(TestCalcFinishedInSSL3, self).setUp() - self.finished = calcFinished((3, 0), - bytearray(48), - 0, - self.hhashes, - True) + self.finished = calcKey((3, 0), bytearray(48), 0, "client finished", + handshakeHashes=self.hhashes, + outputLength=12) def test_client_value(self): self.assertEqual(bytearray( b'\x15\xa9\xd7\xf1\x8bV\xecY\xab\xee\xbaS\x9c}\xffW\xa0'+ @@ -97,7 +98,9 @@ def test_client_value(self): self.finished) def test_server_value(self): - ret = calcFinished((3, 0), bytearray(48), 0, self.hhashes, False) + ret = calcKey((3, 0), bytearray(48), 0, "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertEqual(bytearray( b'\xe3^aCb\x8a\xfc\x98\xbf\xd7\x08\xddX\xdc[\xeac\x02\xdb'+ @@ -105,12 +108,16 @@ def test_server_value(self): ret) def test_if_multiple_runs_are_the_same(self): - ret2 = calcFinished((3, 0), bytearray(48), 0, self.hhashes, True) + ret2 = calcKey((3, 0), bytearray(48), 0, "client finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertEqual(self.finished, ret2) def test_if_client_and_server_values_differ(self): - ret_srv = calcFinished((3, 0), bytearray(48), 0, self.hhashes, False) + ret_srv = calcKey((3, 0), bytearray(48), 0, "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertNotEqual(self.finished, ret_srv) @@ -118,11 +125,9 @@ class TestCalcFinishedInTLS1_0(TestCalcFinished): def setUp(self): super(TestCalcFinishedInTLS1_0, self).setUp() - self.finished = calcFinished((3, 1), - bytearray(48), - 0, - self.hhashes, - True) + self.finished = calcKey((3, 1), bytearray(48), 0, 'client finished', + handshakeHashes=self.hhashes, + outputLength=12) def test_client_value(self): self.assertEqual(12, len(self.finished)) @@ -131,7 +136,9 @@ def test_client_value(self): self.finished) def test_server_value(self): - ret_srv = calcFinished((3, 1), bytearray(48), 0, self.hhashes, False) + ret_srv = calcKey((3, 1), bytearray(48), 0, "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertEqual(12, len(ret_srv)) self.assertEqual(bytearray( @@ -139,12 +146,16 @@ def test_server_value(self): ret_srv) def test_if_client_and_server_values_differ(self): - ret_srv = calcFinished((3, 1), bytearray(48), 0, self.hhashes, False) + ret_srv = calcKey((3, 1), bytearray(48), 0, "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertNotEqual(self.finished, ret_srv) def test_if_values_for_TLS1_0_and_TLS1_0_are_same(self): - ret = calcFinished((3, 2), bytearray(48), 0, self.hhashes, True) + ret = calcKey((3, 2), bytearray(48), 0, "client finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertEqual(self.finished, ret) @@ -152,11 +163,9 @@ class TestCalcFinishedInTLS1_2WithSHA256(TestCalcFinished): def setUp(self): super(TestCalcFinishedInTLS1_2WithSHA256, self).setUp() - self.finished = calcFinished((3, 3), - bytearray(48), - 0, - self.hhashes, - True) + self.finished = calcKey((3, 3), bytearray(48), 0, "client finished", + handshakeHashes=self.hhashes, + outputLength=12) def test_client_value(self): self.assertEqual(12, len(self.finished)) @@ -165,7 +174,9 @@ def test_client_value(self): self.finished) def test_server_value(self): - ret_srv = calcFinished((3, 3), bytearray(48), 0, self.hhashes, False) + ret_srv = calcKey((3, 3), bytearray(48), 0, "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertEqual(12, len(self.finished)) self.assertEqual(bytearray( @@ -173,7 +184,9 @@ def test_server_value(self): ret_srv) def test_if_client_and_server_values_differ(self): - ret_srv = calcFinished((3, 3), bytearray(48), 0, self.hhashes, False) + ret_srv = calcKey((3, 3), bytearray(48), 0, "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertNotEqual(ret_srv, self.finished) @@ -181,11 +194,11 @@ class TestCalcFinishedInTLS1_2WithSHA384(TestCalcFinished): def setUp(self): super(TestCalcFinishedInTLS1_2WithSHA384, self).setUp() - self.finished = calcFinished((3, 3), - bytearray(48), - CipherSuite.TLS_RSA_WITH_AES_256_GCM_SHA384, - self.hhashes, - True) + self.finished = calcKey((3, 3), bytearray(48), + CipherSuite.TLS_RSA_WITH_AES_256_GCM_SHA384, + "client finished", + handshakeHashes=self.hhashes, + outputLength=12) def test_client_value(self): self.assertEqual(12, len(self.finished)) @@ -194,17 +207,21 @@ def test_client_value(self): self.finished) def test_server_value(self): - ret_srv = calcFinished((3, 3), bytearray(48), - CipherSuite.TLS_RSA_WITH_AES_256_GCM_SHA384, - self.hhashes, False) + ret_srv = calcKey((3, 3), bytearray(48), + CipherSuite.TLS_RSA_WITH_AES_256_GCM_SHA384, + "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertEqual(bytearray( b'\x02St\x13\xa8\xe6\xb6\xa2\x1c4\xff\xc5'), ret_srv) def test_if_client_and_server_values_differ(self): - ret_srv = calcFinished((3, 3), bytearray(48), - CipherSuite.TLS_RSA_WITH_AES_256_GCM_SHA384, - self.hhashes, False) + ret_srv = calcKey((3, 3), bytearray(48), + CipherSuite.TLS_RSA_WITH_AES_256_GCM_SHA384, + "server finished", + handshakeHashes=self.hhashes, + outputLength=12) self.assertNotEqual(self.finished, ret_srv) diff --git a/unit_tests/test_tlslite_tlsrecordlayer.py b/unit_tests/test_tlslite_tlsrecordlayer.py index 22bb8ede1..832857dd4 100644 --- a/unit_tests/test_tlslite_tlsrecordlayer.py +++ b/unit_tests/test_tlslite_tlsrecordlayer.py @@ -26,7 +26,7 @@ from tlslite.extensions import TLSExtension from tlslite.constants import ContentType, HandshakeType, CipherSuite, \ CertificateType -from tlslite.mathtls import calcMasterSecret, PRF_1_2 +from tlslite.mathtls import PRF_1_2, calcKey from tlslite.x509 import X509 from tlslite.x509certchain import X509CertChain from tlslite.utils.keyfactory import parsePEMKey @@ -724,11 +724,12 @@ def test_full_connection_with_RSA_kex(self): else: break - master_secret = calcMasterSecret((3,3), - CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, - premasterSecret, - client_hello.random, - server_hello.random) + master_secret = calcKey((3, 3), premasterSecret, + CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, + "master secret", + clientRandom=client_hello.random, + serverRandom=server_hello.random, + outputLength=48) record_layer._calcPendingStates( CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, @@ -774,11 +775,13 @@ def test_full_connection_with_RSA_kex(self): self.assertEqual(bytearray(b'\x03\x03' + b'\x00'*46), srv_premaster_secret) - srv_master_secret = calcMasterSecret(srv_record_layer.version, - CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, - srv_premaster_secret, - srv_client_hello.random, - srv_server_hello.random) + srv_master_secret = calcKey(srv_record_layer.version, + srv_premaster_secret, + CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, + "master secret", + clientRandom=srv_client_hello.random, + serverRandom=srv_server_hello.random, + outputLength=48) srv_record_layer._calcPendingStates(srv_cipher_suite, srv_master_secret, srv_client_hello.random, @@ -972,11 +975,12 @@ def test_full_connection_with_external_server(self): else: break - master_secret = calcMasterSecret((3,3), - CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, - premasterSecret, - client_hello.random, - server_hello.random) + master_secret = calcKey((3, 3), premasterSecret, + CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, + "master secret", + clientRandom=client_hello.random, + serverRandom=server_hello.random, + outputLength=48) record_layer._calcPendingStates( CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,