Skip to content

Commit

Permalink
add proposal for TLSKeyStore
Browse files Browse the repository at this point in the history
  (RSA, MasterSecret, PreMasterSecret, ...)
  • Loading branch information
tintinweb committed May 16, 2016
1 parent 85d07c6 commit 3b7d120
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 70 deletions.
3 changes: 2 additions & 1 deletion examples/server_rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
tls_socket = TLSSocket(socket_, client=False)
tls_socket.bind(("", 8443))
tls_socket.listen(1)
tls_socket.tls_ctx.rsa_load_keys_from_file(os.path.join(basedir, "tests/integration/keys/key.pem"))
tls_socket.crypto.session.secret = TLSSecretRSA(open(os.path.join(basedir, "tests/integration/keys/key.pem")).read())
#tls_socket.tls_ctx.rsa_load_keys_from_file(os.path.join(basedir, "tests/integration/keys/key.pem"))
c_socket, _ = tls_socket.accept()

r = c_socket.recvall()
Expand Down
271 changes: 202 additions & 69 deletions scapy_ssl_tls/ssl_tls_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def x509_extract_pubkey_from_pem(public_key_string):

return x509_extract_pubkey_from_der(der)


def int_to_str(int_):
hex_ = "%x" % int_
return binascii.unhexlify("%s%s" % ("" if len(hex_) % 2 == 0 else "0", hex_))
Expand All @@ -115,7 +114,166 @@ def str_to_ec_point(ansi_str, ec_curve):
x, y = str_to_int(ansi_str[:len(ansi_str) // 2]), str_to_int(ansi_str[len(ansi_str) // 2:])
return ec.Point(ec_curve, x, y)

class TLSSecret(object):
def __init__(self, secret=None, **identification):
if secret:
self.load(secret)
self.identification = identification

def __eq__(self, other):
''' override __eq__ and __hash__ to autom. sort out dup. secrets in the set'''
return hash(self)==hash(other)

def __hash__(self, *args, **kwargs):
return hash(self.secret) ^ hash(frozenset(self.identification.items()))

def load(self, data):
raise NotImplementedError()

def matches(self, identification):
''' returns True if identification is in self.identification
'''
for k,v in identification.iteritems():
if not self.identification.get(k) == v:
return False
return True

class TLSSecretMasterSecret(TLSSecret):
def load(self, secret):
self.secret = secret

class TLSSecretPreMasterSecret(TLSSecretMasterSecret): pass

class TLSSecretRSA(TLSSecret):
def __hash__(self, *args, **kwargs):
# self.secret is an instance of RSA key
return hash(self.secret.exportKey("DER")) ^ hash(frozenset(self.identification.items()))
def load(self, secret):
try:
# try load DER
self.secret = RSA.importKey(secret) # try der
self.public = self.private.publickey()
self.identification['public'] = self.public.export('DER') # allows to match pubkey to privkey
return
except: pass
try:
# try load base64
self.secret = RSA.importKey(secret.decode("base64")) # try base64str
self.public = self.private.publickey()
self.identification['public'] = self.public.export('DER')
return
except: pass
try:
# try load pem
self.secret = RSA.importKey(self.pem_decode(secret)) # try pem
self.public = self.private.publickey()
self.identification['public'] = self.public.export('DER')
return
except: pass

def pem_decode(self, data):
pemo = pem_get_objects(data)
for key_pk in (k for k in pemo.keys() if "PRIVATE" in k.upper()):
try:
return pemo[key_pk].get("full")
except ValueError:
pass
return data

class TLSKeyStore(object):
''' Stores a set of TLS Secrets
'''
def __init__(self):
self.secrets = set([]) # set of distinct TLSSecrets

def add(self, secret):
self.secrets.add(secret)

def find(self, **identification):
''' return all secrets matching a specific connection/session/...
'''
return (secret for secret in self.secrets if secret.matches(identification))

def get(self, **identification):
''' return first matching TLSSecret or None'''
try:
return self.find(**identification).next()
except StopIteration:
return None

def save(self, path):
with open(path, 'w') as f:
for str_secret in self._save_extended_nsskeylog(self.secrets):
f.write("%s"%str_secret)

def load(self, data):
for secret in self._load_extended_nsskeylog(data):
self.add(secret)

def _save_extended_nsskeylog(self, secrets):
for s in secrets:
nss_prefix, nss_id = None, ''
if isinstance(s, TLSSecretMasterSecret):
if s.identification.get('client_random'):
nss_prefix = "CLIENT_RANDOM"
nss_id = s.identification.get('client_random')
elif s.identification.get('session_id'):
nss_prefix = "RSA SESSION-ID:"
nss_id = s.identification.get('session_id')
nss_id = nss_id.encode("hex")
nss_data = s.secret.encode("hex")
elif isinstance(s, TLSSecretPreMasterSecret):
if s.identification.get('client_random'):
nss_prefix = "PMS_CLIENT_RANDOM"
nss_id = s.identification.get('client_random')
elif s.identification.get('encrypted_pms'):
nss_prefix = "RSA "
nss_id = s.identification.get('encrypted_pms')
nss_id = nss_id.encode("hex")
nss_data = s.secret.encode(hex)
elif isinstance(s, TLSSecretRSA):
nss_prefix = "RSA-BASE64:"
nss_data = binascii.b2a_base64(s.secret.exportKey("DER"))
if nss_prefix:
yield '%s %s%s' % (nss_prefix.upper(),
'%s '%nss_id.upper() if nss_id else '',
nss_data.upper())

def _load_extended_nsskeylog(self, data):
# See this wireshark comment for a description of this file format:
# https://github.com/wireshark/wireshark/blob/d4dd4fd8481a2059713619a3e0d28ced7edbdf31/epan/dissectors/packet-ssl-utils.c#L4666-#L4691
# format extensions:
# RSA-BASE64: <base64 encoded rsa private key>
for line in data.split('\n'):
try:
line = line.upper().strip()
if line.startswith("CLIENT_RANDOM"):
_, client_random, master_secret = line.split()
client_random = client_random.strip().decode("hex")
master_secret = master_secret.strip().decode("hex")
yield TLSSecretMasterSecret(master_secret, client_random=client_random)
elif line.startswith("PMS_CLIENT_RANDOM"):
_, client_random, premaster_secret = line.split()
client_random = client_random.strip().decode("hex")
premaster_secret = premaster_secret.strip().decode("hex")
yield TLSSecretPreMasterSecret(premaster_secret, client_random=client_random)
elif line.startswith("RSA SESSION-ID:"):
_, session_id, master_secret = line.split()
session_id = session_id.strip().split(":")[1].decode("hex")
master_secret = master_secret.strip().split(":")[1].decode("hex")
yield TLSSecretMasterSecret(master_secret, session_id=session_id)
elif line.startswith("RSA "):
_, encrypted_pms, premaster_secret = line.split()
encrypted_pms = encrypted_pms.strip().decode("hex")
premaster_secret = premaster_secret.strip().decode("hex")
yield TLSSecretPreMasterSecret(premaster_secret, encrypted_pms=encrypted_pms)
elif line.startswith("RSA-BASE64:"): # proprietary addon scapy-ssl_tls
_, b64_private_key = line.split()
private_key = b64_private_key.strip().decode("base64")
yield TLSSecretRSA(private_key)
except ValueError:
pass

class TLSSessionCtx(object):

def __init__(self, client=True):
Expand Down Expand Up @@ -193,7 +351,6 @@ def __init__(self, client=True):
self.crypto.session = namedtuple('session', ["encrypted_premaster_secret",
'premaster_secret',
'master_secret',
'secret_maps'
"prf"])

self.crypto.session.encrypted_premaster_secret = None
Expand All @@ -204,14 +361,8 @@ def __init__(self, client=True):
self.crypto.session.randombytes.client = None
self.crypto.session.randombytes.server = None

self.crypto.session.secret_maps = namedtuple('secret_maps',['client_random_to_master',
'client_random_to_pms',
'session_id_to_master',
'encrypted_pms_to_pms'])
self.crypto.session.secret_maps.client_random_to_master = {}
self.crypto.session.secret_maps.client_random_to_pms = {}
self.crypto.session.secret_maps.session_id_to_master = {}
self.crypto.session.secret_maps.encrypted_pms_to_pms = {}
self.crypto.keystore = TLSKeyStore() # 'CookieJar' like secret store
self.crypto.session.secret = None # the session secret class TLSSecret

self.crypto.session.key = namedtuple('key',['client','server'])
self.crypto.session.key.server = namedtuple('server',['mac','encryption','iv', "seq_num"])
Expand Down Expand Up @@ -384,22 +535,20 @@ def _process(self,p):
# fetch randombytes for crypto stuff
if not self.crypto.session.randombytes.client:
self.crypto.session.randombytes.client = struct.pack("!I", p[tls.TLSClientHello].gmt_unix_time) + p[tls.TLSClientHello].random_bytes

# If the client random is related to a known (pre-)master secret, use it
if self.crypto.session.secret_maps.client_random_to_master.get(
self.crypto.session.randombytes.client) is not None:
self.crypto.session.master_secret = self.crypto.session.secret_maps.client_random_to_master.get(
self.crypto.session.randombytes.client)
elif self.crypto.session.secret_maps.client_random_to_pms.get(
self.crypto.session.randombytes.client) is not None:
self.crypto.session.premaster_secret = self.crypto.session.secret_maps.client_random_to_pms.get(
self.crypto.session.randombytes.client)

# If the client provided a session id matching a known master secret, use it
if self.crypto.session.secret_maps.session_id_to_master.get(
self.params.handshake.client.session_id) is not None:
self.crypto.session.master_secret = self.crypto.session.secret_maps.session_id_to_master.get(
self.params.handshake.client.session_id)

# check keystore for secret material
# Todo: refactor this code to only use crypto.session.secret in future.
if not self.crypto.session.secret:
for ks_secret in (self.crypto.keystore.get(client_random = self.crypto.session.randombytes.client),
self.crypto.keystore.get(session_id = self.params.handshake.client.session_id)):
if isinstance(ks_secret, TLSSecretMasterSecret):
self.crypto.session.secret = ks_secret
self.crypto.session.master_secret = ks_secret.secret
break
elif isinstance(ks_secret, TLSSecretPreMasterSecret):
self.crypto.session.secret = ks_secret
self.crypto.session.premaster_secret = ks_secret.secret
break

# Generate a random PMS. Overriden at decryption time if private key is provided
if self.crypto.session.premaster_secret is None:
Expand All @@ -414,10 +563,15 @@ def _process(self,p):
self.crypto.session.randombytes.server = struct.pack("!I", p[tls.TLSServerHello].gmt_unix_time) + p[tls.TLSServerHello].random_bytes

# If the session id is related to a known master secret, load it now
if self.crypto.session.secret_maps.session_id_to_master.get(
self.params.handshake.server.session_id) is not None:
self.crypto.session.master_secret = self.crypto.session.secret_maps.session_id_to_master.get(
self.params.handshake.server.session_id)
# check keystore for secret material
# Todo: refactor this code to only use crypto.session.secret in future.
if not self.crypto.session.secret:
for ks_secret in (self.crypto.keystore.get(session_id = self.params.handshake.client.session_id),):
if isinstance(ks_secret, TLSSecretMasterSecret):
self.crypto.session.secret = ks_secret
self.crypto.session.master_secret = ks_secret.secret
break

# negotiated params
if not self.params.negotiated.ciphersuite:
self.params.negotiated.ciphersuite = p[tls.TLSServerHello].cipher_suite
Expand Down Expand Up @@ -446,6 +600,13 @@ def _process(self,p):
# fetch server pubkey // PKCS1_v1_5
cert = p[tls.TLSCertificateList].certificates[0].data
self.crypto.server.rsa.pubkey = x509_extract_pubkey_from_der(str(cert))
# find matching private-key for server pubkey
if not self.crypto.session.secret:
for ks_secret in (self.crypto.keystore.get(public = self.crypto.server.rsa.pubkey.export('DER')),):
if isinstance(ks_secret, TLSSecretRSA):
self.crypto.session.secret = ks_secret
self.crypto.client.rsa.privkey, self.crypto.client.rsa.pubkey = ks_secret.secret, ks_secret.public
break
# TODO: In the future also handle kex = DH and extract static DH params from cert
elif self.params.negotiated.key_exchange is not None and self.params.negotiated.sig == DSA:
# TODO: Handle DSA sig key loading here to allow sig checks
Expand Down Expand Up @@ -489,11 +650,17 @@ def _process(self,p):
self.crypto.session.premaster_secret = PKCS1_v1_5.new(self.crypto.server.rsa.privkey).decrypt(
self.crypto.session.encrypted_premaster_secret, None)

# Or if the encrypted PMS maps to a known decrypted PMS, use that one.
# Use the first 8 encrypted bytes as the identifier (see 'load_secrets_from_file') for details
enc_identifier = self.crypto.session.encrypted_premaster_secret[:8]
if self.crypto.session.secret_maps.encrypted_pms_to_pms.get(enc_identifier) is not None:
self.crypto.session.premaster_secret = self.crypto.session.secret_maps.encrypted_pms_to_pms.get(enc_identifier)
if not self.crypto.session.secret:
# Or if the encrypted PMS maps to a known decrypted PMS, use that one.
# Use the first 8 encrypted bytes as the identifier (see 'load_secrets_from_file') for details
enc_identifier = self.crypto.session.encrypted_premaster_secret[:8]
# check keystore for secret material
# Todo: refactor this code to only use crypto.session.secret in future.
for ks_secret in (self.crypto.keystore.get(encrypted_pms = enc_identifier),):
if isinstance(ks_secret, TLSSecretPreMasterSecret):
self.crypto.session.secret = ks_secret
self.crypto.session.premaster_secret = ks_secret.secret
break

elif p.haslayer(tls.TLSClientDHParams):
self.crypto.client.dh.y_c = p[tls.TLSClientDHParams].data
Expand Down Expand Up @@ -560,40 +727,6 @@ def rsa_load_keys(self, priv_key, client=False):
else:
self.crypto.server.rsa.privkey, self.crypto.server.rsa.pubkey = self._rsa_load_keys(priv_key)

def _load_secret_line(self, line):
try:
line = line.upper()
if line.startswith("CLIENT_RANDOM"):
_, client_random, master_secret = line.split()
client_random = client_random.decode("hex")
master_secret = master_secret.decode("hex")
self.crypto.session.secret_maps.client_random_to_master[client_random] = master_secret
elif line.startswith("PMS_CLIENT_RANDOM"):
_, client_random, premaster_secret = line.split()
client_random = client_random.decode("hex")
premaster_secret = premaster_secret.decode("hex")
self.crypto.session.secret_maps.client_random_to_pms[client_random] = premaster_secret
elif line.startswith("RSA SESSION-ID:"):
_, session_id, master_secret = line.split()
session_id = session_id.split(":")[1].decode("hex")
master_secret = master_secret.split(":")[1].decode("hex")
self.crypto.session.secret_maps.session_id_to_master[session_id] = master_secret
elif line.startswith("RSA "):
_, encrypted_pms, premaster_secret = line.split()
encrypted_pms = encrypted_pms.decode("hex")
premaster_secret = premaster_secret.decode("hex")
self.crypto.session.secret_maps.encrypted_pms_to_pms[encrypted_pms] = premaster_secret
except ValueError:
return


def load_secrets_from_file(self, secret_file):
# See this wireshark comment for a description of this file format:
# https://github.com/wireshark/wireshark/blob/d4dd4fd8481a2059713619a3e0d28ced7edbdf31/epan/dissectors/packet-ssl-utils.c#L4666-#L4691
with open(secret_file,'r') as f:
for entry in f.readlines():
self._load_secret_line(entry)

def _generate_random_pms(self, version):
return "%s%s" % (struct.pack("!H", version), os.urandom(46))

Expand Down

0 comments on commit 3b7d120

Please sign in to comment.