Skip to content

Commit

Permalink
Merge pull request #350 from tomato42/post-handshake-auth
Browse files Browse the repository at this point in the history
Post handshake auth
  • Loading branch information
tomato42 committed Nov 27, 2019
2 parents 1abaefe + 114cdd5 commit 3e7bddd
Show file tree
Hide file tree
Showing 6 changed files with 455 additions and 55 deletions.
48 changes: 46 additions & 2 deletions scripts/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def printUsage(s=None):
[-c CERT] [-k KEY] [-t TACK] [-v VERIFIERDB] [-d DIR] [-l LABEL] [-L LENGTH]
[--reqcert] [--param DHFILE] [--psk PSK] [--psk-ident IDENTITY]
[--psk-sha384] [--ssl3] [--max-ver VER] [--tickets COUNT] [--cipherlist]
[--request-pha] [--require-pha]
HOST:PORT
client
Expand All @@ -102,6 +103,9 @@ def printUsage(s=None):
finished
--cipherlist - comma separated ciphers to enable. For ex. aes128ccm,3des
You can specify this option multiple times.
--request-pha - ask client for post-handshake authentication
--require-pha - abort connection if client didn't provide certificate in
post-handshake authentication
CERT, KEY - the file with key and certificates that will be used by client or
server. The server can accept multiple pairs of `-c` and `-k` options
to configure different certificates (like RSA and ECDSA)
Expand Down Expand Up @@ -159,6 +163,8 @@ def handleArgs(argv, argString, flagsList=[]):
max_ver = None
tickets = None
ciphers = []
request_pha = False
require_pha = False

for opt, arg in opts:
if opt == "-k":
Expand Down Expand Up @@ -232,6 +238,10 @@ def handleArgs(argv, argString, flagsList=[]):
tickets = int(arg)
elif opt == "--cipherlist":
ciphers.append(arg)
elif opt == "--request-pha":
request_pha = True
elif opt == "--require-pha":
require_pha = True
else:
assert(False)

Expand Down Expand Up @@ -294,6 +304,10 @@ def handleArgs(argv, argString, flagsList=[]):
retList.append(tickets)
if "cipherlist=" in flagsList:
retList.append(ciphers)
if "request-pha" in flagsList:
retList.append(request_pha)
if "require-pha" in flagsList:
retList.append(require_pha)
return retList


Expand Down Expand Up @@ -494,11 +508,11 @@ def serverCmd(argv):
(address, privateKey, cert_chain, virtual_hosts, tacks, verifierDB,
directory, reqCert,
expLabel, expLength, dhparam, psk, psk_ident, psk_hash, ssl3,
max_ver, tickets, cipherlist) = \
max_ver, tickets, cipherlist, request_pha, require_pha) = \
handleArgs(argv, "kctbvdlL",
["reqcert", "param=", "psk=",
"psk-ident=", "psk-sha384", "ssl3", "max-ver=",
"tickets=", "cipherlist="])
"tickets=", "cipherlist=", "request-pha", "require-pha"])


if (cert_chain and not privateKey) or (not cert_chain and privateKey):
Expand Down Expand Up @@ -558,6 +572,28 @@ def do_GET(self):
else:
raise ValueError("Invalid return from "
"send_keyupdate_request")
if self.path.startswith('/secret'):
try:
for i in self.connection.request_post_handshake_auth():
pass
except ValueError:
self.wfile.write(b'HTTP/1.0 401 Certificate authentication'
b' required\r\n')
self.wfile.write(b'Connection: close\r\n')
self.wfile.write(b'Content-Length: 0\r\n\r\n')
return
self.connection.read(0, 0)
if self.connection.session.clientCertChain:
print(" Got client certificate in post-handshake auth: "
"{0}".format(self.connection.session
.clientCertChain.getFingerprint()))
else:
print(" No certificate from client received")
self.wfile.write(b'HTTP/1.0 401 Certificate authentication'
b' required\r\n')
self.wfile.write(b'Connection: close\r\n')
self.wfile.write(b'Content-Length: 0\r\n\r\n')
return
return super(MySimpleHTTPHandler, self).do_GET()

class MyHTTPServer(ThreadingMixIn, TLSSocketServerMixIn, HTTPServer):
Expand All @@ -576,6 +612,7 @@ def handshake(self, connection):
1)
connection.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 5))
connection.client_cert_required = require_pha
connection.handshakeServer(certChain=cert_chain,
privateKey=privateKey,
verifierDB=verifierDB,
Expand All @@ -589,6 +626,13 @@ def handshake(self, connection):
sni=sni)
# As an example (does not work here):
#nextProtos=[b"spdy/3", b"spdy/2", b"http/1.1"])
try:
if request_pha:
for i in connection.request_post_handshake_auth():
pass
except ValueError:
# if we can't do PHA, we can't do it
pass
stop = time_stamp()
except TLSRemoteAlert as a:
if a.description == AlertDescription.user_canceled:
Expand Down
105 changes: 97 additions & 8 deletions tests/tlstest.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,10 @@ def connect():
connection = connect()
try:
connection.handshakeClientCert(settings=settings)
assert(False)
assert False
except TLSLocalAlert as alert:
if alert.description != AlertDescription.illegal_parameter:
raise
raise
connection.close()
else:
test_no += 1
Expand Down Expand Up @@ -707,6 +707,43 @@ def connect():

test_no += 1

print("Test {0} - good mutual X.509, PHA, TLSv1.3".format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeClientCert(x509Chain, x509Key, settings=settings)
synchro.recv(1)
b = connection.read(0, 0)
assert b == b''
testConnClient(connection)
assert(isinstance(connection.session.serverCertChain, X509CertChain))
connection.close()

test_no += 1

print("Test {0} - mutual X.509, PHA, no client cert, TLSv1.3".format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeClientCert(X509CertChain(), x509Key, settings=settings)
synchro.recv(1)
b = connection.read(0, 0)
assert b == b''
try:
connection.read(0, 0)
assert False
except TLSRemoteAlert as e:
assert e.description == AlertDescription.certificate_required
assert "certificate_required" in str(e), str(e)

connection.close()

test_no += 1

print("Test {0} - good mutual X.509, TLSv1.1".format(test_no))
synchro.recv(1)
connection = connect()
Expand Down Expand Up @@ -784,7 +821,7 @@ def connect():
connection.handshakeClientSRP("test", "garbage",
serverName=address[0],
session=session, settings=settings)
assert(False)
assert False
except TLSRemoteAlert as alert:
if alert.description != AlertDescription.bad_record_mac:
raise
Expand Down Expand Up @@ -1014,7 +1051,7 @@ def connect():
settings.maxVersion = (3, 2)
try:
connection.handshakeClientCert(settings=settings)
assert()
assert False
except TLSRemoteAlert as alert:
if alert.description != AlertDescription.inappropriate_fallback:
raise
Expand Down Expand Up @@ -1110,6 +1147,7 @@ def connect():
try:
connection.handshakeClientCert(serverName=address[0], session=session,
settings=settings)
assert False
except TLSRemoteAlert as e:
assert(str(e) == "illegal_parameter")
else:
Expand Down Expand Up @@ -1344,7 +1382,8 @@ def heartbeat_response_check(message):

print("Test {0}: POP3 good".format(test_no))
except (socket.error, socket.timeout) as e:
print("Non-critical error: socket error trying to reach internet server: ", e)
print("Non-critical error: socket error trying to reach internet "
"server: ", e)

synchro.close()

Expand Down Expand Up @@ -1603,6 +1642,7 @@ def connect():
try:
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
assert False
except TLSRemoteAlert as e:
assert "handshake_failure" in str(e)
connection.close()
Expand Down Expand Up @@ -1634,6 +1674,7 @@ def connect():
try:
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
assert False
except TLSLocalAlert as e:
assert "No common signature algorithms" in str(e)
connection.close()
Expand Down Expand Up @@ -1740,7 +1781,7 @@ def connect():
try:
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
tacks=[tackUnrelated], settings=settings)
assert(False)
assert False
except TLSRemoteAlert as alert:
if alert.description != AlertDescription.illegal_parameter:
raise
Expand Down Expand Up @@ -1934,6 +1975,52 @@ def connect():

test_no += 1

print("Test {0} - good mutual X.509, PHA, TLSv1.3".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
settings=settings)
assert connection.session.clientCertChain is None
for result in connection.request_post_handshake_auth(settings):
assert result in (0, 1)
synchro.send(b'R')
testConnServer(connection)

assert connection.session.clientCertChain is not None
assert isinstance(connection.session.clientCertChain, X509CertChain)
connection.close()

test_no += 1

print("Test {0} - mutual X.509, PHA, no client cert, TLSv1.3".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
settings=settings)
connection.client_cert_required = True
assert connection.session.clientCertChain is None
for result in connection.request_post_handshake_auth(settings):
assert result in (0, 1)
synchro.send(b'R')
try:
testConnServer(connection)
assert False
except TLSLocalAlert as e:
assert "Client did not provide a certificate in post-handshake" in \
str(e)
assert e.description == AlertDescription.certificate_required

assert connection.session.clientCertChain is None
connection.close()

test_no += 1

print("Test {0} - good mutual X.509, TLSv1.1".format(test_no))
synchro.send(b'R')
connection = connect()
Expand Down Expand Up @@ -1995,13 +2082,14 @@ def connect():
synchro.send(b'R')
try:
connection.read(min=1, max=1)
assert() #Client is going to close the socket without a close_notify
assert False #Client is going to close the socket without a close_notify
except TLSAbruptCloseError as e:
pass
synchro.send(b'R')
connection = connect()
try:
connection.handshakeServer(verifierDB=verifierDB, sessionCache=sessionCache)
assert False
except TLSLocalAlert as alert:
if alert.description != AlertDescription.bad_record_mac:
raise
Expand Down Expand Up @@ -2210,7 +2298,7 @@ def server_bind(self):
try:
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
settings=settings)
assert()
assert False
except TLSLocalAlert as alert:
if alert.description != AlertDescription.inappropriate_fallback:
raise
Expand Down Expand Up @@ -2273,6 +2361,7 @@ def server_bind(self):
try:
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
sessionCache=sessionCache)
assert False
except TLSLocalAlert as e:
assert(str(e) == "illegal_parameter")
else:
Expand Down
2 changes: 2 additions & 0 deletions tlslite/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class ExtensionType(TLSEnum):
supported_versions = 43 # TLS 1.3
cookie = 44 # TLS 1.3
psk_key_exchange_modes = 45 # TLS 1.3
post_handshake_auth = 49 # TLS 1.3
signature_algorithms_cert = 50 # TLS 1.3
key_share = 51 # TLS 1.3
supports_npn = 13172
Expand Down Expand Up @@ -487,6 +488,7 @@ class AlertDescription(TLSEnum):
bad_certificate_status_response = 113 # RFC 6066
bad_certificate_hash_value = 114 # RFC 6066
unknown_psk_identity = 115
certificate_required = 116 # RFC 8446
no_application_protocol = 120 # RFC 7301


Expand Down
36 changes: 2 additions & 34 deletions tlslite/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,34 +59,6 @@ class TLSAlert(TLSError):

pass

_descriptionStr = {\
AlertDescription.close_notify: "close_notify",\
AlertDescription.unexpected_message: "unexpected_message",\
AlertDescription.bad_record_mac: "bad_record_mac",\
AlertDescription.decryption_failed: "decryption_failed",\
AlertDescription.record_overflow: "record_overflow",\
AlertDescription.decompression_failure: "decompression_failure",\
AlertDescription.handshake_failure: "handshake_failure",\
AlertDescription.no_certificate: "no certificate",\
AlertDescription.bad_certificate: "bad_certificate",\
AlertDescription.unsupported_certificate: "unsupported_certificate",\
AlertDescription.certificate_revoked: "certificate_revoked",\
AlertDescription.certificate_expired: "certificate_expired",\
AlertDescription.certificate_unknown: "certificate_unknown",\
AlertDescription.illegal_parameter: "illegal_parameter",\
AlertDescription.unknown_ca: "unknown_ca",\
AlertDescription.access_denied: "access_denied",\
AlertDescription.decode_error: "decode_error",\
AlertDescription.decrypt_error: "decrypt_error",\
AlertDescription.export_restriction: "export_restriction",\
AlertDescription.protocol_version: "protocol_version",\
AlertDescription.insufficient_security: "insufficient_security",\
AlertDescription.internal_error: "internal_error",\
AlertDescription.inappropriate_fallback: "inappropriate_fallback",\
AlertDescription.user_canceled: "user_canceled",\
AlertDescription.no_renegotiation: "no_renegotiation",\
AlertDescription.unknown_psk_identity: "unknown_psk_identity"}


class TLSLocalAlert(TLSAlert):
"""A TLS alert has been signalled by the local implementation.
Expand All @@ -109,9 +81,7 @@ def __init__(self, alert, message=None):
self.message = message

def __str__(self):
alertStr = TLSAlert._descriptionStr.get(self.description)
if alertStr == None:
alertStr = str(self.description)
alertStr = AlertDescription.toStr(self.description)
if self.message:
return alertStr + ": " + self.message
else:
Expand All @@ -136,9 +106,7 @@ def __init__(self, alert):
self.level = alert.level

def __str__(self):
alertStr = TLSAlert._descriptionStr.get(self.description)
if alertStr == None:
alertStr = str(self.description)
alertStr = AlertDescription.toStr(self.description)
return alertStr


Expand Down
Loading

0 comments on commit 3e7bddd

Please sign in to comment.