Skip to content

Commit

Permalink
Refactor the certificate selection process
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan Nikolchev committed Jul 1, 2020
1 parent 4521c5c commit 070cb8e
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 45 deletions.
8 changes: 4 additions & 4 deletions tests/tlstest.py
Expand Up @@ -297,8 +297,8 @@ def connect():
try:
connection.handshakeClientCert(settings=settings)
assert False
except TLSLocalAlert as e:
assert "certificate with curve" in str(e)
except TLSRemoteAlert as e:
assert "handshake_failure" in str(e)
connection.close()

test_no += 1
Expand Down Expand Up @@ -1665,8 +1665,8 @@ def connect():
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
assert False
except TLSRemoteAlert as e:
assert "handshake_failure" in str(e)
except TLSLocalAlert as e:
assert "curve in the public key is not supported by the client" in str(e)
connection.close()

test_no += 1
Expand Down
44 changes: 44 additions & 0 deletions tlslite/constants.py
Expand Up @@ -203,6 +203,50 @@ class SignatureAlgorithm(TLSEnum):
ed448 = 8 # RFC 8422


class AlgorithmOID(TLSEnum):
"""
Algorithm OIDs as defined in rfc5758(ecdsa),
rfc5754(rsa, sha), rfc3447(rss-pss).
The key is the DER encoded OID as a int and
the value is the algorithm id.
"""
oid = {}

#ecdsa_sha1
oid[111196837196800525313] = (2, 3)
#ecdsa_sha224
oid[28484837066454644032257] = (3, 3)
#ecdsa_sha256
oid[28484837066454644032258] = (4, 3)
#ecdsa_sha384
oid[28484837066454644032259] = (5, 3)
#ecdsa_sha512
oid[28484837066454644032260] = (6, 3)

#rsa_sha1
oid[7296840655416892695052549] = (2, 1)
#rsa_sha224
oid[7296840655416892695052558] = (3, 1)
#rsa_sha256
oid[7296840655416892695052555] = (4, 1)
#rsa_sha384
oid[7296840655416892695052556] = (5, 1)
#rsa_sha512
oid[7296840655416892695052557] = (6, 1)

#rsa_pss
oid[7296840655416892695052554] = 8

#sha224
oid[3806363433629502450256813752836] = 3
#sha256
oid[3806363433629502450256813752833] = 4
#sha384
oid[3806363433629502450256813752834] = 5
#sha512
oid[3806363433629502450256813752835] = 6


class SignatureScheme(TLSEnum):
"""
Signature scheme used for signalling supported signature algorithms.
Expand Down
215 changes: 174 additions & 41 deletions tlslite/tlsconnection.py
Expand Up @@ -3255,19 +3255,6 @@ def _serverGetClientHello(self, settings, private_key, cert_chain,
for result in self._sendMsg(alert):
yield result

try:
sig_scheme, cert_chain, private_key = \
self._pickServerKeyExchangeSig(settings,
clientHello,
cert_chain,
private_key,
version)
except TLSHandshakeFailure as alert:
for result in self._sendError(
AlertDescription.handshake_failure,
str(alert)):
yield result

#Check if there's intersection between supported curves by client and
#server
clientGroups = clientHello.getExtension(ExtensionType.supported_groups)
Expand Down Expand Up @@ -3379,8 +3366,7 @@ def _serverGetClientHello(self, settings, private_key, cert_chain,
cipherSuites = CipherSuite.filterForVersion(cipherSuites,
minVersion=version,
maxVersion=version)
cipherSuites = CipherSuite.filter_for_certificate(cipherSuites,
cert_chain)

#If resumption was requested and we have a session cache...
if clientHello.session_id and sessionCache:
session = None
Expand Down Expand Up @@ -3552,29 +3538,26 @@ def _serverGetClientHello(self, settings, private_key, cert_chain,
#
#Given the current ciphersuite ordering, this means we prefer SRP
#over non-SRP.
for cipherSuite in cipherSuites:
if cipherSuite in clientHello.cipher_suites:
break
else:
if clientGroups and \
any(i in range(256, 512) for i in clientGroups) and \
any(i in CipherSuite.dhAllSuites
for i in clientHello.cipher_suites):
for result in self._sendError(
AlertDescription.insufficient_security,
"FFDHE groups not acceptable and no other common "
"ciphers"):
yield result
else:
for result in self._sendError(\
AlertDescription.handshake_failure,
"No mutual ciphersuite"):
yield result
if cipherSuite in CipherSuite.srpAllSuites and \
not clientHello.srp_username:
for result in self._sendError(\
AlertDescription.unknown_psk_identity,
"Client sent a hello, but without the SRP username"):

try:
cipherSuite, sig_scheme, cert_chain, private_key = \
self._server_select_certificate(settings, clientHello,
cipherSuites, cert_chain,
private_key, version)
except TLSHandshakeFailure as err:
for result in self._sendError(
AlertDescription.handshake_failure,
str(err)):
yield result
except TLSInsufficientSecurity as err:
for result in self._sendError(
AlertDescription.insufficient_security,
str(err)):
yield result
except TLSIllegalParameterException as err:
for result in self._sendError(
AlertDescription.illegal_parameter,
str(err)):
yield result

#If an RSA suite is chosen, check for certificate type intersection
Expand Down Expand Up @@ -3845,6 +3828,152 @@ def _serverSRPKeyExchange(self, clientHello, serverHello, verifierDB,

yield premasterSecret, privateKey, serverCertChain

def _server_select_certificate(self, settings, client_hello,
cipher_suites, cert_chain,
private_key, version):
"""
This method makes the decision on which certificate/key pair,
signature algorithm and cipher to use based on the certificate.
"""

last_cert = False
possible_certs = []
sigalg_cert_ext = False

# Get client groups
client_groups = client_hello. \
getExtension(ExtensionType.supported_groups)
if client_groups is not None:
client_groups = client_groups.groups

# If client did send signature_algorithms_cert use it,
# otherwise fallback to signature_algorithms.
# Client can also decide not to send sigalg extension
client_sigalgs = \
client_hello. \
getExtension(ExtensionType.signature_algorithms_cert)
if client_sigalgs is not None:
client_sigalgs = \
client_hello. \
getExtension(ExtensionType.signature_algorithms_cert). \
sigalgs
sigalg_cert_ext = True
else:
client_sigalgs = \
client_hello. \
getExtension(ExtensionType.signature_algorithms)
if client_sigalgs is not None:
client_sigalgs = \
client_hello. \
getExtension(ExtensionType.signature_algorithms). \
sigalgs
else:
client_sigalgs = []

# Get all the certificates we can offer
alt_certs = ((X509CertChain(i.certificates), i.key) for vh in
settings.virtual_hosts for i in vh.keys)
certs = [(cert, key)
for cert, key in chain([(cert_chain, private_key)], alt_certs)]

for cert, key in certs:

# Check if this is the last (cert, key) pair we have to check
if (cert, key) == certs[-1]:
last_cert = True

# Mandatory checks. If any one of these checks fail, the certificate
# is not usuable.
try:
# Find a suitable ciphersuite based on the certificate
ciphers = CipherSuite.filter_for_certificate(cipher_suites, cert)
for cipher in ciphers:
if cipher in client_hello.cipher_suites:
break
else:
if client_groups and \
any(i in range(256, 512) for i in client_groups) and \
any(i in CipherSuite.dhAllSuites
for i in client_hello.cipher_suites):
raise TLSInsufficientSecurity(
"FFDHE groups not acceptable and no other common "
"ciphers")
raise TLSHandshakeFailure("No mutual ciphersuite")

# Find a signature algorithm based on the certificate
try:
sig_scheme, _, _ = \
self._pickServerKeyExchangeSig(settings,
client_hello,
cert,
key,
version,
False)
except TLSHandshakeFailure:
raise TLSHandshakeFailure(
"No common signature algorithms")

# If the certificate is ECDSA, we must check curve compatibility
if cert and cert.x509List[0].certAlg == 'ecdsa' and \
client_groups and client_sigalgs:
public_key = cert.getEndEntityPublicKey()
curve = public_key.curve_name
for name, aliases in CURVE_ALIASES.items():
if curve in aliases:
curve = getattr(GroupName, name)
break

if version <= (3, 3) and curve not in client_groups:
raise TLSHandshakeFailure(
"The curve in the public key is not "
"supported by the client: {0}" \
.format(GroupName.toRepr(curve)))

if version >= (3, 4):
if GroupName.toRepr(curve) not in \
('secp256r1', 'secp384r1', 'secp521r1'):
raise TLSIllegalParameterException(
"Curve in public key is not supported "
"in TLS1.3")

# If all mandatory checks passed add
# this as possible certificate we can use.
possible_certs.append((cipher, sig_scheme, cert, key))

except:
if last_cert and not possible_certs:
raise
continue

# Non-mandatory checks, if these fail the certificate is still usable
# but we should try to find one that passes all the checks

# Check if the certificate is signed with a signature algorithm
# supported by the client.
# If the client did send signature_algorithm_cert extension,
# those check also apply to all the other certs in the chain
# with exception on self-signed certs.
if cert:
if sigalg_cert_ext:
cert_chain_ok = True
for i in range(len(cert.x509List)):
if cert.x509List[i].issuer == cert.x509List[i].subject:
break
if cert.x509List[i].sigalg not in client_sigalgs:
cert_chain_ok = False
break
if not cert_chain_ok:
break
else:
if cert.x509List[0].sigalg not in client_sigalgs:
break

# If all mandatory and non-mandatory checks passed
# return the (cert, key) pair, cipher and sig_scheme
return cipher, sig_scheme, cert, key
return possible_certs[0]


def _serverCertKeyExchange(self, clientHello, serverHello, sigHashAlg,
serverCertChain, keyExchange,
reqCert, reqCAs, cipherSuite,
Expand Down Expand Up @@ -4231,7 +4360,7 @@ def _handshakeWrapperAsync(self, handshaker, checker):
@staticmethod
def _pickServerKeyExchangeSig(settings, clientHello, certList=None,
private_key=None,
version=(3, 3)):
version=(3, 3), checkAlt=True):
"""Pick a hash that matches most closely the supported ones"""
hashAndAlgsExt = clientHello.getExtension(
ExtensionType.signature_algorithms)
Expand All @@ -4247,8 +4376,12 @@ def _pickServerKeyExchangeSig(settings, clientHello, certList=None,
# sha1 should be picked
return "sha1", certList, private_key

alt_certs = ((X509CertChain(i.certificates), i.key) for vh in
settings.virtual_hosts for i in vh.keys)
if checkAlt:
alt_certs = ((X509CertChain(i.certificates), i.key) for vh in
settings.virtual_hosts for i in vh.keys)
else:
alt_certs = ()


for certs, key in chain([(certList, private_key)], alt_certs):
supported = TLSConnection._sigHashesToList(settings,
Expand Down
29 changes: 29 additions & 0 deletions tlslite/x509.py
Expand Up @@ -13,6 +13,7 @@
from .utils.keyfactory import _createPublicRSAKey, _create_public_ecdsa_key
from .utils.pem import *
from .utils.compat import compatHMAC
from .constants import AlgorithmOID


class X509(object):
Expand Down Expand Up @@ -41,6 +42,8 @@ def __init__(self):
self.publicKey = None
self.subject = None
self.certAlg = None
self.sigalg = None
self.issuer = None

def __hash__(self):
"""Calculate hash of object."""
Expand Down Expand Up @@ -81,6 +84,28 @@ def parseBinary(self, bytes):
self.bytes = bytearray(bytes)
parser = ASN1Parser(self.bytes)

# Get the SignatureAlgorithm
signature_algorithm_identifier = parser.getChild(1)
self.sigalg = signature_algorithm_identifier.getChild(0)

# Get the DER encoded OID as hex string
self.sigalg = bytearray([self.sigalg.type.tag_id]) + \
bytearray([self.sigalg.length]) + \
self.sigalg.value
self.sigalg = bytesToNumber(self.sigalg)

# Finally get the (hash, signature) pair coresponding to it
# If it is rsa-pss we need to check the aditional parameters field
# to extract the hash algorithm
if AlgorithmOID.oid[self.sigalg] == 8:
sigalg_hash = signature_algorithm_identifier.getChild(1)
sigalg_hash = bytesToNumber(sigalg_hash.getChild(0).value)

self.sigalg = (AlgorithmOID.oid[self.sigalg],
AlgorithmOID.oid[sigalg_hash])
else:
self.sigalg = AlgorithmOID.oid[self.sigalg]

# Get the tbsCertificate
tbs_certificate = parser.getChild(0)
# Is the optional version field present?
Expand All @@ -95,6 +120,10 @@ def parseBinary(self, bytes):
# Get serial number
self.serial_number = bytesToNumber(tbs_certificate.getChild(serial_number_index).value)

# Get the issuer
self.issuer = tbs_certificate.getChildBytes(
subject_public_key_info_index - 3)

# Get the subject
self.subject = tbs_certificate.getChildBytes(
subject_public_key_info_index - 1)
Expand Down

0 comments on commit 070cb8e

Please sign in to comment.