Skip to content

Commit

Permalink
implement deterministic implicit rejection for RSA decryption
Browse files Browse the repository at this point in the history
  • Loading branch information
tomato42 committed Dec 3, 2020
1 parent 771d033 commit 982cfe2
Show file tree
Hide file tree
Showing 2 changed files with 1,387 additions and 14 deletions.
167 changes: 153 additions & 14 deletions tlslite/utils/rsakey.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from . import tlshashlib as hashlib
from ..errors import MaskTooLongError, MessageTooLongError, EncodingError, \
InvalidSignature, UnknownRSAType
from .constanttime import ct_isnonzero_u32, ct_neq_u32, ct_lsb_prop_u8, \
ct_lsb_prop_u16, ct_lt_u32


class RSAKey(object):
Expand Down Expand Up @@ -45,6 +47,7 @@ def __init__(self, n=0, e=0, key_type="rsa"):
self.e = e
# pylint: enable=invalid-name
self.key_type = key_type
self._key_hash = None
raise NotImplementedError()

def __len__(self):
Expand Down Expand Up @@ -389,38 +392,174 @@ def encrypt(self, bytes):
paddedBytes = self._addPKCS1Padding(bytes, 2)
return self._raw_public_key_op_bytes(paddedBytes)

def _dec_prf(self, key, label, out_len):
"""PRF for deterministic implicit rejection in RSA decryption.
:param bytes key: key to use for derivation
:param bytes label: name of the keystream generated
:param int out_len: length of output, in bits
:rtype: bytes
:returns: a random bytestring
"""
out = bytearray()

if out_len % 8 != 0:
raise ValueError("only multiples of 8 supported as output size")

iterator = 0
while len(out) <= out_len // 8:
out += secureHMAC(
key,
numberToByteArray(iterator, 2) + label +
numberToByteArray(out_len, 2),
"sha256")
iterator += 1

return out[:out_len//8]

def decrypt(self, encBytes):
"""Decrypt the passed-in bytes.
This requires the key to have a private component. It performs
PKCS1 decryption of the passed-in data.
PKCS#1 v1.5 decryption operation of the passed-in data.
Note: as a workaround against Bleichenbacher-like attacks, it will
return a deterministically selected random message in case the padding
checks failed. It returns an error (None) only in case the ciphertext
is of incorrect length or encodes an integer bigger than the modulus
of the key (i.e. it's publically invalid).
:type encBytes: bytes-like object
:param encBytes: The value which will be decrypted.
:rtype: bytearray or None
:returns: A PKCS1 decryption of the passed-in data or None if
the data is not properly formatted.
:returns: A PKCS#1 v1.5 decryption of the passed-in data or None if
the provided data is not properly formatted.
"""
if not self.hasPrivateKey():
raise AssertionError()
if self.key_type != "rsa":
raise ValueError("Decryption requires RSA key, \"{0}\" present"
.format(self.key_type))
try:
decBytes = self._raw_private_key_op_bytes(encBytes)
dec_bytes = self._raw_private_key_op_bytes(encBytes)
except ValueError:
# _raw_private_key_op_bytes fails only when encBytes >= self.n,
# or when len(encBytes) != numBytes(self.n) and that's public
# information, so we don't have to handle it
# in sidechannel secure way
return None
#Check first two bytes
if decBytes[0] != 0 or decBytes[1] != 2:
return None
#Scan through for zero separator
for x in range(1, len(decBytes)-1):
if decBytes[x]== 0:
break
else:
return None
return decBytes[x+1:] #Return everything after the separator

###################
# here be dragons #
###################
# While the code is written as-if it was side-channel secure, in
# practice, because of cPython implementation details IT IS NOT
# see:
# https://securitypitfalls.wordpress.com/2018/08/03/constant-time-compare-in-python/

n = self.n

# maximum length we can return is reduced by the mandatory prefix:
# (0x00 0x02), 8 bytes of padding, so this is the position of the
# null separator byte, as counted from the last position
max_sep_offset = numBytes(n) - 10

# the private exponent (d) doesn't change so `_key_hash` doesn't
# change, calculate it only once
if not hasattr(self, '_key_hash') or not self._key_hash:
self._key_hash = secureHash(numberToByteArray(self.d, numBytes(n)),
"sha256")

kdk = secureHMAC(self._key_hash, encBytes, "sha256")

# we need 128 2-byte numbers, encoded as the number of bits
length_randoms = self._dec_prf(kdk, b"length", 128 * 2 * 8)

message_random = self._dec_prf(kdk, b"message", numBytes(n) * 8)

# select the last length that's not too large to return
synth_length = 0
length_rand_iter = iter(length_randoms)
length_mask = (1 << numBits(max_sep_offset)) - 1
for high, low in zip(length_rand_iter, length_rand_iter):
# interpret the two bytes from the PRF output as 16-bit big-endian
# integer
len_candidate = (high << 8) + low
len_candidate &= length_mask
# equivalent to:
# if len_candidate < max_sep_offset:
# synth_length = len_candidate
mask = ct_lt_u32(len_candidate, max_sep_offset)
mask = ct_lsb_prop_u16(mask)
synth_length = synth_length & (0xffff^mask) | len_candidate & mask

synth_msg_start = numBytes(n) - synth_length

error_detected = 0

# enumerate over all decrypted bytes
em_bytes = enumerate(dec_bytes)
# first check if first two bytes specify PKCS#1 v1.5 encryption padding
_, val = next(em_bytes)
error_detected |= ct_isnonzero_u32(val)
_, val = next(em_bytes)
error_detected |= ct_neq_u32(val, 0x02)
# then look for for the null separator byte among the padding bytes
# but inspect all decrypted bytes, even if we already find the
# separator earlier
msg_start = 0
for pos, val in em_bytes:
# padding must be at least 8 bytes long, fail if any of the first
# 8 bytes of it are zero
# equivalent to:
# if pos < 10 and not val:
# error_detected = 0x01
error_detected |= ct_lt_u32(pos, 10) & (1 ^ ct_isnonzero_u32(val))

# update the msg_start only once; when it's 0
# (pos+1) because we want to skip the null separator
# equivalent to:
# if pos >= 10 and not msg_start and not val:
# msg_start = pos+1
mask = (1 ^ ct_lt_u32(pos, 10)) & (1 ^ ct_isnonzero_u32(val)) \
& (1 ^ ct_isnonzero_u32(msg_start))
mask = ct_lsb_prop_u16(mask)
msg_start = msg_start & (0xffff ^ mask) | (pos+1) & mask

# if separator wasn't found, it's an error
# equivalent to:
# if not msg_start:
# error_detected = 0x01
error_detected |= 1 ^ ct_isnonzero_u32(msg_start)

# equivalent to:
# if error_detected:
# ret_msg_start = synth_msg_start
# else:
# ret_msg_start = msg_start
mask = ct_lsb_prop_u16(error_detected)
ret_msg_start = msg_start & (0xffff ^ mask) | synth_msg_start & mask

# as at this point the length doesn't leak the information if the
# padding was correct or not, we don't have to worry about the
# length of the returned value (and thus the size of the buffer we
# pass to the caller); but we still need to read both buffers
# to ensure that the memory access patern is preserved (that both
# buffers are accessed, not just the one we return)

# equivalent to:
# if error_detected:
# return message_random[ret_msg_start:]
# else:
# return dec_bytes[ret_msg_start:]
mask = ct_lsb_prop_u8(error_detected)
not_mask = 0xff ^ mask
ret = bytearray(
x & not_mask | y & mask for x, y in
zip(dec_bytes[ret_msg_start:], message_random[ret_msg_start:]))

return ret

def _rawPrivateKeyOp(self, message):
raise NotImplementedError()
Expand Down

0 comments on commit 982cfe2

Please sign in to comment.