Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post handshake auth #350

Merged
merged 8 commits into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions scripts/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def printUsage(s=None):
[-c CERT] [-k KEY] [-t TACK] [-v VERIFIERDB] [-d DIR] [-l LABEL] [-L LENGTH]
[--reqcert] [--param DHFILE] [--psk PSK] [--psk-ident IDENTITY]
[--psk-sha384] [--ssl3] [--max-ver VER] [--tickets COUNT] [--cipherlist]
[--request-pha] [--require-pha]
HOST:PORT

client
Expand All @@ -102,6 +103,9 @@ def printUsage(s=None):
finished
--cipherlist - comma separated ciphers to enable. For ex. aes128ccm,3des
You can specify this option multiple times.
--request-pha - ask client for post-handshake authentication
--require-pha - abort connection if client didn't provide certificate in
post-handshake authentication
CERT, KEY - the file with key and certificates that will be used by client or
server. The server can accept multiple pairs of `-c` and `-k` options
to configure different certificates (like RSA and ECDSA)
Expand Down Expand Up @@ -159,6 +163,8 @@ def handleArgs(argv, argString, flagsList=[]):
max_ver = None
tickets = None
ciphers = []
request_pha = False
require_pha = False

for opt, arg in opts:
if opt == "-k":
Expand Down Expand Up @@ -232,6 +238,10 @@ def handleArgs(argv, argString, flagsList=[]):
tickets = int(arg)
elif opt == "--cipherlist":
ciphers.append(arg)
elif opt == "--request-pha":
request_pha = True
elif opt == "--require-pha":
require_pha = True
else:
assert(False)

Expand Down Expand Up @@ -294,6 +304,10 @@ def handleArgs(argv, argString, flagsList=[]):
retList.append(tickets)
if "cipherlist=" in flagsList:
retList.append(ciphers)
if "request-pha" in flagsList:
retList.append(request_pha)
if "require-pha" in flagsList:
retList.append(require_pha)
return retList


Expand Down Expand Up @@ -494,11 +508,11 @@ def serverCmd(argv):
(address, privateKey, cert_chain, virtual_hosts, tacks, verifierDB,
directory, reqCert,
expLabel, expLength, dhparam, psk, psk_ident, psk_hash, ssl3,
max_ver, tickets, cipherlist) = \
max_ver, tickets, cipherlist, request_pha, require_pha) = \
handleArgs(argv, "kctbvdlL",
["reqcert", "param=", "psk=",
"psk-ident=", "psk-sha384", "ssl3", "max-ver=",
"tickets=", "cipherlist="])
"tickets=", "cipherlist=", "request-pha", "require-pha"])


if (cert_chain and not privateKey) or (not cert_chain and privateKey):
Expand Down Expand Up @@ -558,6 +572,28 @@ def do_GET(self):
else:
raise ValueError("Invalid return from "
"send_keyupdate_request")
if self.path.startswith('/secret'):
try:
for i in self.connection.request_post_handshake_auth():
pass
except ValueError:
self.wfile.write(b'HTTP/1.0 401 Certificate authentication'
b' required\r\n')
self.wfile.write(b'Connection: close\r\n')
self.wfile.write(b'Content-Length: 0\r\n\r\n')
return
self.connection.read(0, 0)
if self.connection.session.clientCertChain:
print(" Got client certificate in post-handshake auth: "
"{0}".format(self.connection.session
.clientCertChain.getFingerprint()))
else:
print(" No certificate from client received")
self.wfile.write(b'HTTP/1.0 401 Certificate authentication'
b' required\r\n')
self.wfile.write(b'Connection: close\r\n')
self.wfile.write(b'Content-Length: 0\r\n\r\n')
return
return super(MySimpleHTTPHandler, self).do_GET()

class MyHTTPServer(ThreadingMixIn, TLSSocketServerMixIn, HTTPServer):
Expand All @@ -576,6 +612,7 @@ def handshake(self, connection):
1)
connection.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 5))
connection.client_cert_required = require_pha
connection.handshakeServer(certChain=cert_chain,
privateKey=privateKey,
verifierDB=verifierDB,
Expand All @@ -589,6 +626,13 @@ def handshake(self, connection):
sni=sni)
# As an example (does not work here):
#nextProtos=[b"spdy/3", b"spdy/2", b"http/1.1"])
try:
if request_pha:
for i in connection.request_post_handshake_auth():
pass
except ValueError:
# if we can't do PHA, we can't do it
pass
stop = time_stamp()
except TLSRemoteAlert as a:
if a.description == AlertDescription.user_canceled:
Expand Down
105 changes: 97 additions & 8 deletions tests/tlstest.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,10 @@ def connect():
connection = connect()
try:
connection.handshakeClientCert(settings=settings)
assert(False)
assert False
except TLSLocalAlert as alert:
if alert.description != AlertDescription.illegal_parameter:
raise
raise
connection.close()
else:
test_no += 1
Expand Down Expand Up @@ -707,6 +707,43 @@ def connect():

test_no += 1

print("Test {0} - good mutual X.509, PHA, TLSv1.3".format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeClientCert(x509Chain, x509Key, settings=settings)
synchro.recv(1)
b = connection.read(0, 0)
assert b == b''
testConnClient(connection)
assert(isinstance(connection.session.serverCertChain, X509CertChain))
connection.close()

test_no += 1

print("Test {0} - mutual X.509, PHA, no client cert, TLSv1.3".format(test_no))
synchro.recv(1)
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeClientCert(X509CertChain(), x509Key, settings=settings)
synchro.recv(1)
b = connection.read(0, 0)
assert b == b''
try:
connection.read(0, 0)
assert False
except TLSRemoteAlert as e:
assert e.description == AlertDescription.certificate_required
assert "certificate_required" in str(e), str(e)

connection.close()

test_no += 1

print("Test {0} - good mutual X.509, TLSv1.1".format(test_no))
synchro.recv(1)
connection = connect()
Expand Down Expand Up @@ -784,7 +821,7 @@ def connect():
connection.handshakeClientSRP("test", "garbage",
serverName=address[0],
session=session, settings=settings)
assert(False)
assert False
except TLSRemoteAlert as alert:
if alert.description != AlertDescription.bad_record_mac:
raise
Expand Down Expand Up @@ -1014,7 +1051,7 @@ def connect():
settings.maxVersion = (3, 2)
try:
connection.handshakeClientCert(settings=settings)
assert()
assert False
except TLSRemoteAlert as alert:
if alert.description != AlertDescription.inappropriate_fallback:
raise
Expand Down Expand Up @@ -1110,6 +1147,7 @@ def connect():
try:
connection.handshakeClientCert(serverName=address[0], session=session,
settings=settings)
assert False
except TLSRemoteAlert as e:
assert(str(e) == "illegal_parameter")
else:
Expand Down Expand Up @@ -1344,7 +1382,8 @@ def heartbeat_response_check(message):

print("Test {0}: POP3 good".format(test_no))
except (socket.error, socket.timeout) as e:
print("Non-critical error: socket error trying to reach internet server: ", e)
print("Non-critical error: socket error trying to reach internet "
"server: ", e)

synchro.close()

Expand Down Expand Up @@ -1603,6 +1642,7 @@ def connect():
try:
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
assert False
except TLSRemoteAlert as e:
assert "handshake_failure" in str(e)
connection.close()
Expand Down Expand Up @@ -1634,6 +1674,7 @@ def connect():
try:
connection.handshakeServer(certChain=x509ecdsaChain,
privateKey=x509ecdsaKey, settings=settings)
assert False
except TLSLocalAlert as e:
assert "No common signature algorithms" in str(e)
connection.close()
Expand Down Expand Up @@ -1740,7 +1781,7 @@ def connect():
try:
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
tacks=[tackUnrelated], settings=settings)
assert(False)
assert False
except TLSRemoteAlert as alert:
if alert.description != AlertDescription.illegal_parameter:
raise
Expand Down Expand Up @@ -1934,6 +1975,52 @@ def connect():

test_no += 1

print("Test {0} - good mutual X.509, PHA, TLSv1.3".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
settings=settings)
assert connection.session.clientCertChain is None
for result in connection.request_post_handshake_auth(settings):
assert result in (0, 1)
synchro.send(b'R')
testConnServer(connection)

assert connection.session.clientCertChain is not None
assert isinstance(connection.session.clientCertChain, X509CertChain)
connection.close()

test_no += 1

print("Test {0} - mutual X.509, PHA, no client cert, TLSv1.3".format(test_no))
synchro.send(b'R')
connection = connect()
settings = HandshakeSettings()
settings.minVersion = (3, 4)
settings.maxVersion = (3, 4)
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
settings=settings)
connection.client_cert_required = True
assert connection.session.clientCertChain is None
for result in connection.request_post_handshake_auth(settings):
assert result in (0, 1)
synchro.send(b'R')
try:
testConnServer(connection)
assert False
except TLSLocalAlert as e:
assert "Client did not provide a certificate in post-handshake" in \
str(e)
assert e.description == AlertDescription.certificate_required

assert connection.session.clientCertChain is None
connection.close()

test_no += 1

print("Test {0} - good mutual X.509, TLSv1.1".format(test_no))
synchro.send(b'R')
connection = connect()
Expand Down Expand Up @@ -1995,13 +2082,14 @@ def connect():
synchro.send(b'R')
try:
connection.read(min=1, max=1)
assert() #Client is going to close the socket without a close_notify
assert False #Client is going to close the socket without a close_notify
except TLSAbruptCloseError as e:
pass
synchro.send(b'R')
connection = connect()
try:
connection.handshakeServer(verifierDB=verifierDB, sessionCache=sessionCache)
assert False
except TLSLocalAlert as alert:
if alert.description != AlertDescription.bad_record_mac:
raise
Expand Down Expand Up @@ -2210,7 +2298,7 @@ def server_bind(self):
try:
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
settings=settings)
assert()
assert False
except TLSLocalAlert as alert:
if alert.description != AlertDescription.inappropriate_fallback:
raise
Expand Down Expand Up @@ -2273,6 +2361,7 @@ def server_bind(self):
try:
connection.handshakeServer(certChain=x509Chain, privateKey=x509Key,
sessionCache=sessionCache)
assert False
except TLSLocalAlert as e:
assert(str(e) == "illegal_parameter")
else:
Expand Down
2 changes: 2 additions & 0 deletions tlslite/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class ExtensionType(TLSEnum):
supported_versions = 43 # TLS 1.3
cookie = 44 # TLS 1.3
psk_key_exchange_modes = 45 # TLS 1.3
post_handshake_auth = 49 # TLS 1.3
signature_algorithms_cert = 50 # TLS 1.3
key_share = 51 # TLS 1.3
supports_npn = 13172
Expand Down Expand Up @@ -487,6 +488,7 @@ class AlertDescription(TLSEnum):
bad_certificate_status_response = 113 # RFC 6066
bad_certificate_hash_value = 114 # RFC 6066
unknown_psk_identity = 115
certificate_required = 116 # RFC 8446
no_application_protocol = 120 # RFC 7301


Expand Down
36 changes: 2 additions & 34 deletions tlslite/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,34 +59,6 @@ class TLSAlert(TLSError):

pass

_descriptionStr = {\
AlertDescription.close_notify: "close_notify",\
AlertDescription.unexpected_message: "unexpected_message",\
AlertDescription.bad_record_mac: "bad_record_mac",\
AlertDescription.decryption_failed: "decryption_failed",\
AlertDescription.record_overflow: "record_overflow",\
AlertDescription.decompression_failure: "decompression_failure",\
AlertDescription.handshake_failure: "handshake_failure",\
AlertDescription.no_certificate: "no certificate",\
AlertDescription.bad_certificate: "bad_certificate",\
AlertDescription.unsupported_certificate: "unsupported_certificate",\
AlertDescription.certificate_revoked: "certificate_revoked",\
AlertDescription.certificate_expired: "certificate_expired",\
AlertDescription.certificate_unknown: "certificate_unknown",\
AlertDescription.illegal_parameter: "illegal_parameter",\
AlertDescription.unknown_ca: "unknown_ca",\
AlertDescription.access_denied: "access_denied",\
AlertDescription.decode_error: "decode_error",\
AlertDescription.decrypt_error: "decrypt_error",\
AlertDescription.export_restriction: "export_restriction",\
AlertDescription.protocol_version: "protocol_version",\
AlertDescription.insufficient_security: "insufficient_security",\
AlertDescription.internal_error: "internal_error",\
AlertDescription.inappropriate_fallback: "inappropriate_fallback",\
AlertDescription.user_canceled: "user_canceled",\
AlertDescription.no_renegotiation: "no_renegotiation",\
AlertDescription.unknown_psk_identity: "unknown_psk_identity"}


class TLSLocalAlert(TLSAlert):
"""A TLS alert has been signalled by the local implementation.
Expand All @@ -109,9 +81,7 @@ def __init__(self, alert, message=None):
self.message = message

def __str__(self):
alertStr = TLSAlert._descriptionStr.get(self.description)
if alertStr == None:
alertStr = str(self.description)
alertStr = AlertDescription.toStr(self.description)
if self.message:
return alertStr + ": " + self.message
else:
Expand All @@ -136,9 +106,7 @@ def __init__(self, alert):
self.level = alert.level

def __str__(self):
alertStr = TLSAlert._descriptionStr.get(self.description)
if alertStr == None:
alertStr = str(self.description)
alertStr = AlertDescription.toStr(self.description)
return alertStr


Expand Down
Loading