Skip to content

Commit

Permalink
Fix threading issue introduced in 4.7
Browse files Browse the repository at this point in the history
Computing the blinding factor and its inverse was done in a thread-unsafe
manner. Locking the computation & update of the blinding factors, and
passing these around in frame- and stack-bound data, solves this.

This fixes part of the issues reported in #173,
but there is more going on in that particular report.
  • Loading branch information
sybrenstuvel committed Feb 15, 2021
1 parent 3af4e65 commit 88418f0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 38 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Python-RSA changelog

## Version 4.7.1 - in development

- Fix threading issue introduced in 4.7 ([#173](https://github.com/sybrenstuvel/python-rsa/issues/173)

## Version 4.7 - released 2021-01-10

- Fix [#165](https://github.com/sybrenstuvel/python-rsa/issues/165):
Expand Down
74 changes: 44 additions & 30 deletions rsa/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""

import logging
import threading
import typing
import warnings

Expand All @@ -49,7 +50,7 @@
class AbstractKey:
"""Abstract superclass for private and public keys."""

__slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse')
__slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse', 'mutex')

def __init__(self, n: int, e: int) -> None:
self.n = n
Expand All @@ -58,6 +59,10 @@ def __init__(self, n: int, e: int) -> None:
# These will be computed properly on the first call to blind().
self.blindfac = self.blindfac_inverse = -1

# Used to protect updates to the blinding factor in multi-threaded
# environments.
self.mutex = threading.Lock()

@classmethod
def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey':
"""Loads a key in PKCS#1 PEM format, implement in a subclass.
Expand Down Expand Up @@ -148,36 +153,33 @@ def save_pkcs1(self, format: str = 'PEM') -> bytes:
method = self._assert_format_exists(format, methods)
return method()

def blind(self, message: int) -> int:
"""Performs blinding on the message using random number 'r'.
def blind(self, message: int) -> typing.Tuple[int, int]:
"""Performs blinding on the message.
:param message: the message, as integer, to blind.
:type message: int
:param r: the random number to blind with.
:type r: int
:return: the blinded message.
:rtype: int
:return: tuple (the blinded message, the inverse of the used blinding factor)
The blinding is such that message = unblind(decrypt(blind(encrypt(message))).
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
"""
self._update_blinding_factor()
return (message * pow(self.blindfac, self.e, self.n)) % self.n
blindfac, blindfac_inverse = self._update_blinding_factor()
blinded = (message * pow(blindfac, self.e, self.n)) % self.n
return blinded, blindfac_inverse

def unblind(self, blinded: int) -> int:
"""Performs blinding on the message using random number 'r'.
def unblind(self, blinded: int, blindfac_inverse: int) -> int:
"""Performs blinding on the message using random number 'blindfac_inverse'.
:param blinded: the blinded message, as integer, to unblind.
:param r: the random number to unblind with.
:param blindfac: the factor to unblind with.
:return: the original message.
The blinding is such that message = unblind(decrypt(blind(encrypt(message))).
See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
"""

return (self.blindfac_inverse * blinded) % self.n
return (blindfac_inverse * blinded) % self.n

def _initial_blinding_factor(self) -> int:
for _ in range(1000):
Expand All @@ -186,18 +188,29 @@ def _initial_blinding_factor(self) -> int:
return blind_r
raise RuntimeError('unable to find blinding factor')

def _update_blinding_factor(self):
if self.blindfac < 0:
# Compute initial blinding factor, which is rather slow to do.
self.blindfac = self._initial_blinding_factor()
self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n)
else:
# Reuse previous blinding factor as per section 9 of 'A Timing
# Attack against RSA with the Chinese Remainder Theorem' by Werner
# Schindler.
# See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf
self.blindfac = pow(self.blindfac, 2, self.n)
self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n)
def _update_blinding_factor(self) -> typing.Tuple[int, int]:
"""Update blinding factors.
Computing a blinding factor is expensive, so instead this function
does this once, then updates the blinding factor as per section 9
of 'A Timing Attack against RSA with the Chinese Remainder Theorem'
by Werner Schindler.
See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf
:return: the new blinding factor and its inverse.
"""

with self.mutex:
if self.blindfac < 0:
# Compute initial blinding factor, which is rather slow to do.
self.blindfac = self._initial_blinding_factor()
self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n)
else:
# Reuse previous blinding factor.
self.blindfac = pow(self.blindfac, 2, self.n)
self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n)

return self.blindfac, self.blindfac_inverse

class PublicKey(AbstractKey):
"""Represents a public RSA key.
Expand Down Expand Up @@ -446,9 +459,10 @@ def blinded_decrypt(self, encrypted: int) -> int:
:rtype: int
"""

blinded = self.blind(encrypted) # blind before decrypting
# Blinding and un-blinding should be using the same factor
blinded, blindfac_inverse = self.blind(encrypted)
decrypted = rsa.core.decrypt_int(blinded, self.d, self.n)
return self.unblind(decrypted)
return self.unblind(decrypted, blindfac_inverse)

def blinded_encrypt(self, message: int) -> int:
"""Encrypts the message using blinding to prevent side-channel attacks.
Expand All @@ -460,9 +474,9 @@ def blinded_encrypt(self, message: int) -> int:
:rtype: int
"""

blinded = self.blind(message) # blind before encrypting
blinded, blindfac_inverse = self.blind(message)
encrypted = rsa.core.encrypt_int(blinded, self.d, self.n)
return self.unblind(encrypted)
return self.unblind(encrypted, blindfac_inverse)

@classmethod
def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey':
Expand Down
15 changes: 7 additions & 8 deletions tests/test_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ def test_blinding(self):
message = 12345
encrypted = rsa.core.encrypt_int(message, pk.e, pk.n)

blinded_1 = pk.blind(encrypted) # blind before decrypting
blinded_1, unblind_1 = pk.blind(encrypted) # blind before decrypting
decrypted = rsa.core.decrypt_int(blinded_1, pk.d, pk.n)
unblinded_1 = pk.unblind(decrypted)
unblinded_1 = pk.unblind(decrypted, unblind_1)

self.assertEqual(unblinded_1, message)

# Re-blinding should use a different blinding factor.
blinded_2 = pk.blind(encrypted) # blind before decrypting
blinded_2, unblind_2 = pk.blind(encrypted) # blind before decrypting
self.assertNotEqual(blinded_1, blinded_2)

# The unblinding should still work, though.
decrypted = rsa.core.decrypt_int(blinded_2, pk.d, pk.n)
unblinded_2 = pk.unblind(decrypted)
unblinded_2 = pk.unblind(decrypted, unblind_2)
self.assertEqual(unblinded_2, message)


Expand Down Expand Up @@ -69,10 +69,9 @@ def getprime(_):
# This exponent will cause two other primes to be generated.
exponent = 136407

(p, q, e, d) = rsa.key.gen_keys(64,
accurate=False,
getprime_func=getprime,
exponent=exponent)
(p, q, e, d) = rsa.key.gen_keys(
64, accurate=False, getprime_func=getprime, exponent=exponent
)
self.assertEqual(39317, p)
self.assertEqual(33107, q)

Expand Down

0 comments on commit 88418f0

Please sign in to comment.