diff --git a/tests/tlstest.py b/tests/tlstest.py index 1df22ccf7..927af5da3 100755 --- a/tests/tlstest.py +++ b/tests/tlstest.py @@ -297,8 +297,8 @@ def connect(): try: connection.handshakeClientCert(settings=settings) assert False - except TLSLocalAlert as e: - assert "certificate with curve" in str(e) + except TLSRemoteAlert as e: + assert "handshake_failure" in str(e) connection.close() test_no += 1 @@ -1665,8 +1665,8 @@ def connect(): connection.handshakeServer(certChain=x509ecdsaChain, privateKey=x509ecdsaKey, settings=settings) assert False - except TLSRemoteAlert as e: - assert "handshake_failure" in str(e) + except TLSLocalAlert as e: + assert "curve in the public key is not supported by the client" in str(e) connection.close() test_no += 1 diff --git a/tlslite/constants.py b/tlslite/constants.py index f6cc35ef5..d250fe26e 100644 --- a/tlslite/constants.py +++ b/tlslite/constants.py @@ -203,6 +203,50 @@ class SignatureAlgorithm(TLSEnum): ed448 = 8 # RFC 8422 +class AlgorithmOID(TLSEnum): + """ + Algorithm OIDs as defined in rfc5758(ecdsa), + rfc5754(rsa, sha), rfc3447(rss-pss). + The key is the DER encoded OID as a int and + the value is the algorithm id. + """ + oid = {} + + #ecdsa_sha1 + oid[111196837196800525313] = (2, 3) + #ecdsa_sha224 + oid[28484837066454644032257] = (3, 3) + #ecdsa_sha256 + oid[28484837066454644032258] = (4, 3) + #ecdsa_sha384 + oid[28484837066454644032259] = (5, 3) + #ecdsa_sha512 + oid[28484837066454644032260] = (6, 3) + + #rsa_sha1 + oid[7296840655416892695052549] = (2, 1) + #rsa_sha224 + oid[7296840655416892695052558] = (3, 1) + #rsa_sha256 + oid[7296840655416892695052555] = (4, 1) + #rsa_sha384 + oid[7296840655416892695052556] = (5, 1) + #rsa_sha512 + oid[7296840655416892695052557] = (6, 1) + + #rsa_pss + oid[7296840655416892695052554] = 8 + + #sha224 + oid[3806363433629502450256813752836] = 3 + #sha256 + oid[3806363433629502450256813752833] = 4 + #sha384 + oid[3806363433629502450256813752834] = 5 + #sha512 + oid[3806363433629502450256813752835] = 6 + + class SignatureScheme(TLSEnum): """ Signature scheme used for signalling supported signature algorithms. diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 5c8dcc59c..1df18eaae 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -3255,19 +3255,6 @@ def _serverGetClientHello(self, settings, private_key, cert_chain, for result in self._sendMsg(alert): yield result - try: - sig_scheme, cert_chain, private_key = \ - self._pickServerKeyExchangeSig(settings, - clientHello, - cert_chain, - private_key, - version) - except TLSHandshakeFailure as alert: - for result in self._sendError( - AlertDescription.handshake_failure, - str(alert)): - yield result - #Check if there's intersection between supported curves by client and #server clientGroups = clientHello.getExtension(ExtensionType.supported_groups) @@ -3379,8 +3366,7 @@ def _serverGetClientHello(self, settings, private_key, cert_chain, cipherSuites = CipherSuite.filterForVersion(cipherSuites, minVersion=version, maxVersion=version) - cipherSuites = CipherSuite.filter_for_certificate(cipherSuites, - cert_chain) + #If resumption was requested and we have a session cache... if clientHello.session_id and sessionCache: session = None @@ -3552,29 +3538,26 @@ def _serverGetClientHello(self, settings, private_key, cert_chain, # #Given the current ciphersuite ordering, this means we prefer SRP #over non-SRP. - for cipherSuite in cipherSuites: - if cipherSuite in clientHello.cipher_suites: - break - else: - if clientGroups and \ - any(i in range(256, 512) for i in clientGroups) and \ - any(i in CipherSuite.dhAllSuites - for i in clientHello.cipher_suites): - for result in self._sendError( - AlertDescription.insufficient_security, - "FFDHE groups not acceptable and no other common " - "ciphers"): - yield result - else: - for result in self._sendError(\ - AlertDescription.handshake_failure, - "No mutual ciphersuite"): - yield result - if cipherSuite in CipherSuite.srpAllSuites and \ - not clientHello.srp_username: - for result in self._sendError(\ - AlertDescription.unknown_psk_identity, - "Client sent a hello, but without the SRP username"): + + try: + cipherSuite, sig_scheme, cert_chain, private_key = \ + self._server_select_certificate(settings, clientHello, + cipherSuites, cert_chain, + private_key, version) + except TLSHandshakeFailure as err: + for result in self._sendError( + AlertDescription.handshake_failure, + str(err)): + yield result + except TLSInsufficientSecurity as err: + for result in self._sendError( + AlertDescription.insufficient_security, + str(err)): + yield result + except TLSIllegalParameterException as err: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(err)): yield result #If an RSA suite is chosen, check for certificate type intersection @@ -3845,6 +3828,152 @@ def _serverSRPKeyExchange(self, clientHello, serverHello, verifierDB, yield premasterSecret, privateKey, serverCertChain + def _server_select_certificate(self, settings, client_hello, + cipher_suites, cert_chain, + private_key, version): + """ + This method makes the decision on which certificate/key pair, + signature algorithm and cipher to use based on the certificate. + """ + + last_cert = False + possible_certs = [] + sigalg_cert_ext = False + + # Get client groups + client_groups = client_hello. \ + getExtension(ExtensionType.supported_groups) + if client_groups is not None: + client_groups = client_groups.groups + + # If client did send signature_algorithms_cert use it, + # otherwise fallback to signature_algorithms. + # Client can also decide not to send sigalg extension + client_sigalgs = \ + client_hello. \ + getExtension(ExtensionType.signature_algorithms_cert) + if client_sigalgs is not None: + client_sigalgs = \ + client_hello. \ + getExtension(ExtensionType.signature_algorithms_cert). \ + sigalgs + sigalg_cert_ext = True + else: + client_sigalgs = \ + client_hello. \ + getExtension(ExtensionType.signature_algorithms) + if client_sigalgs is not None: + client_sigalgs = \ + client_hello. \ + getExtension(ExtensionType.signature_algorithms). \ + sigalgs + else: + client_sigalgs = [] + + # Get all the certificates we can offer + alt_certs = ((X509CertChain(i.certificates), i.key) for vh in + settings.virtual_hosts for i in vh.keys) + certs = [(cert, key) + for cert, key in chain([(cert_chain, private_key)], alt_certs)] + + for cert, key in certs: + + # Check if this is the last (cert, key) pair we have to check + if (cert, key) == certs[-1]: + last_cert = True + + # Mandatory checks. If any one of these checks fail, the certificate + # is not usuable. + try: + # Find a suitable ciphersuite based on the certificate + ciphers = CipherSuite.filter_for_certificate(cipher_suites, cert) + for cipher in ciphers: + if cipher in client_hello.cipher_suites: + break + else: + if client_groups and \ + any(i in range(256, 512) for i in client_groups) and \ + any(i in CipherSuite.dhAllSuites + for i in client_hello.cipher_suites): + raise TLSInsufficientSecurity( + "FFDHE groups not acceptable and no other common " + "ciphers") + raise TLSHandshakeFailure("No mutual ciphersuite") + + # Find a signature algorithm based on the certificate + try: + sig_scheme, _, _ = \ + self._pickServerKeyExchangeSig(settings, + client_hello, + cert, + key, + version, + False) + except TLSHandshakeFailure: + raise TLSHandshakeFailure( + "No common signature algorithms") + + # If the certificate is ECDSA, we must check curve compatibility + if cert and cert.x509List[0].certAlg == 'ecdsa' and \ + client_groups and client_sigalgs: + public_key = cert.getEndEntityPublicKey() + curve = public_key.curve_name + for name, aliases in CURVE_ALIASES.items(): + if curve in aliases: + curve = getattr(GroupName, name) + break + + if version <= (3, 3) and curve not in client_groups: + raise TLSHandshakeFailure( + "The curve in the public key is not " + "supported by the client: {0}" \ + .format(GroupName.toRepr(curve))) + + if version >= (3, 4): + if GroupName.toRepr(curve) not in \ + ('secp256r1', 'secp384r1', 'secp521r1'): + raise TLSIllegalParameterException( + "Curve in public key is not supported " + "in TLS1.3") + + # If all mandatory checks passed add + # this as possible certificate we can use. + possible_certs.append((cipher, sig_scheme, cert, key)) + + except: + if last_cert and not possible_certs: + raise + continue + + # Non-mandatory checks, if these fail the certificate is still usable + # but we should try to find one that passes all the checks + + # Check if the certificate is signed with a signature algorithm + # supported by the client. + # If the client did send signature_algorithm_cert extension, + # those check also apply to all the other certs in the chain + # with exception on self-signed certs. + if cert: + if sigalg_cert_ext: + cert_chain_ok = True + for i in range(len(cert.x509List)): + if cert.x509List[i].issuer == cert.x509List[i].subject: + break + if cert.x509List[i].sigalg not in client_sigalgs: + cert_chain_ok = False + break + if not cert_chain_ok: + break + else: + if cert.x509List[0].sigalg not in client_sigalgs: + break + + # If all mandatory and non-mandatory checks passed + # return the (cert, key) pair, cipher and sig_scheme + return cipher, sig_scheme, cert, key + return possible_certs[0] + + def _serverCertKeyExchange(self, clientHello, serverHello, sigHashAlg, serverCertChain, keyExchange, reqCert, reqCAs, cipherSuite, @@ -4231,7 +4360,7 @@ def _handshakeWrapperAsync(self, handshaker, checker): @staticmethod def _pickServerKeyExchangeSig(settings, clientHello, certList=None, private_key=None, - version=(3, 3)): + version=(3, 3), checkAlt=True): """Pick a hash that matches most closely the supported ones""" hashAndAlgsExt = clientHello.getExtension( ExtensionType.signature_algorithms) @@ -4247,8 +4376,12 @@ def _pickServerKeyExchangeSig(settings, clientHello, certList=None, # sha1 should be picked return "sha1", certList, private_key - alt_certs = ((X509CertChain(i.certificates), i.key) for vh in - settings.virtual_hosts for i in vh.keys) + if checkAlt: + alt_certs = ((X509CertChain(i.certificates), i.key) for vh in + settings.virtual_hosts for i in vh.keys) + else: + alt_certs = () + for certs, key in chain([(certList, private_key)], alt_certs): supported = TLSConnection._sigHashesToList(settings, diff --git a/tlslite/x509.py b/tlslite/x509.py index f860aac96..776ad2da2 100644 --- a/tlslite/x509.py +++ b/tlslite/x509.py @@ -13,6 +13,7 @@ from .utils.keyfactory import _createPublicRSAKey, _create_public_ecdsa_key from .utils.pem import * from .utils.compat import compatHMAC +from .constants import AlgorithmOID class X509(object): @@ -41,6 +42,8 @@ def __init__(self): self.publicKey = None self.subject = None self.certAlg = None + self.sigalg = None + self.issuer = None def __hash__(self): """Calculate hash of object.""" @@ -81,6 +84,28 @@ def parseBinary(self, bytes): self.bytes = bytearray(bytes) parser = ASN1Parser(self.bytes) + # Get the SignatureAlgorithm + signature_algorithm_identifier = parser.getChild(1) + self.sigalg = signature_algorithm_identifier.getChild(0) + + # Get the DER encoded OID as hex string + self.sigalg = bytearray([self.sigalg.type.tag_id]) + \ + bytearray([self.sigalg.length]) + \ + self.sigalg.value + self.sigalg = bytesToNumber(self.sigalg) + + # Finally get the (hash, signature) pair coresponding to it + # If it is rsa-pss we need to check the aditional parameters field + # to extract the hash algorithm + if AlgorithmOID.oid[self.sigalg] == 8: + sigalg_hash = signature_algorithm_identifier.getChild(1) + sigalg_hash = bytesToNumber(sigalg_hash.getChild(0).value) + + self.sigalg = (AlgorithmOID.oid[self.sigalg], + AlgorithmOID.oid[sigalg_hash]) + else: + self.sigalg = AlgorithmOID.oid[self.sigalg] + # Get the tbsCertificate tbs_certificate = parser.getChild(0) # Is the optional version field present? @@ -95,6 +120,10 @@ def parseBinary(self, bytes): # Get serial number self.serial_number = bytesToNumber(tbs_certificate.getChild(serial_number_index).value) + # Get the issuer + self.issuer = tbs_certificate.getChildBytes( + subject_public_key_info_index - 3) + # Get the subject self.subject = tbs_certificate.getChildBytes( subject_public_key_info_index - 1)