Skip to content

Commit

Permalink
Merge 567e3f7 into 8affa13
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardThiessen committed Oct 9, 2017
2 parents 8affa13 + 567e3f7 commit d91005d
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 67 deletions.
107 changes: 46 additions & 61 deletions rsa/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,81 +570,68 @@ def _save_pkcs1_pem(self):
return rsa.pem.save_pem(der, b'RSA PRIVATE KEY')


def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
def find_p_q(nbits, getprime_func=rsa.prime.getprimebyrange, accurate="unused parameter retained for compatibility"):
"""Returns a tuple of two different primes of nbits bits each.
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
will not be equal.
:param nbits: the number of bits in each of p and q.
:param getprime_func: the getprime function, defaults to
:py:func:`rsa.prime.getprime`.
:py:func:`rsa.prime.getprimebyrange`.
*Introduced in Python-RSA 3.1*
:param accurate: whether to enable accurate mode or not.
:returns: (p, q), where p > q
>>> (p, q) = find_p_q(128)
>>> from rsa import common
>>> common.bit_size(p * q)
256
When not in accurate mode, the number of bits can be slightly less
>>> (p, q) = find_p_q(128, accurate=False)
>>> from rsa import common
>>> common.bit_size(p * q) <= 256
True
>>> common.bit_size(p * q) > 240
True
"""
#constraints implemented are from FIPS 186-4 appendix B-3.1.2
#http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf#page=62

#primes are chosen to be between (2**nbits) and ceil(2**nbits/sqrt(2))
#this ensures that regardless of what primes are chosen the final modulus
#will be between (2**(2*nbits)) and (2**(2*nbits-1))

maximum=2**nbits#up to but not including
#multiply by fractional representation of 1/sqrt(2)(rounded up) for minimum
minimum=(maximum*0xb504f333f9de6484597d89b3754abea0)
minimum>>=128 #divide by fractional divisor
minimum+=1#round up

### code for generating the above constant
# divisor=2**128#fraction divisor
# target=divisor**2//2
# increment=divisor
# value=0
# while increment:
# value+=increment
# if value**2>target:
# value-=increment
# increment//=2
# value+=1#round up
# print(hex(value))

while 1:#loop allows for restarting keygen process if primes do not meet required conditions
# Choose the two primes
log.debug('find_p_q(%i): Finding p', nbits)
p = getprime_func(minimum,maximum)
log.debug('find_p_q(%i): Finding q', nbits)
q = getprime_func(minimum,maximum)

#check that the modulus has the correct bit size
assert rsa.common.bit_size(p * q)==nbits*2

#test to ensure they are far enough apart (FIPS 168-4 appendix B-3.1.2)
required_distance=2**max(0,nbits-100)
if abs(p-q)<required_distance:
log.debug('find_p_q(%i): p and q not far enough apart, restarting', nbits)
continue#try again
break

total_bits = nbits * 2

# Make sure that p and q aren't too close or the factoring programs can
# factor n.
shift = nbits // 16
pbits = nbits + shift
qbits = nbits - shift

# Choose the two initial primes
log.debug('find_p_q(%i): Finding p', nbits)
p = getprime_func(pbits)
log.debug('find_p_q(%i): Finding q', nbits)
q = getprime_func(qbits)

def is_acceptable(p, q):
"""Returns True iff p and q are acceptable:
- p and q differ
- (p * q) has the right nr of bits (when accurate=True)
"""

if p == q:
return False

if not accurate:
return True

# Make sure we have just the right amount of bits
found_size = rsa.common.bit_size(p * q)
return total_bits == found_size

# Keep choosing other primes until they match our requirements.
change_p = False
while not is_acceptable(p, q):
# Change p on one iteration and q on the other
if change_p:
p = getprime_func(pbits)
else:
q = getprime_func(qbits)

change_p = not change_p

# We want p > q as described on
# http://www.di-mgt.com.au/rsa_alg.html#crt
return max(p, q), min(p, q)


Expand Down Expand Up @@ -752,13 +739,11 @@ def newkeys(nbits, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT):
raise ValueError('Pool size (%i) should be >= 1' % poolsize)

# Determine which getprime function to use
getprime_func = rsa.prime.getprimebyrange
if poolsize > 1:
from rsa import parallel
import functools
getprime_func = parallel.exec_parallel_curry(getprime_func, poolsize)

getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
else:
getprime_func = rsa.prime.getprime

# Generate the key components
(p, q, e, d) = gen_keys(nbits, getprime_func, accurate=accurate, exponent=exponent)
Expand Down
39 changes: 38 additions & 1 deletion rsa/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,44 @@ def getprime(nbits, poolsize):
return result


__all__ = ['getprime']
def exec_parallel_curry(function,poolsize):
"""returns a multiprocess version of the supplied function for the given poolsize"""
def retfunc(*args,**kwargs):
return exec_parallel(poolsize,function,args,kwargs)
return retfunc

def _do_work(pipe,function,args,kwargs):
result=function(*args,**kwargs)
pipe.send(result)

def exec_parallel(poolsize,function,args,kwargs):
"""
carries out a function in multiple processes. Returns the first return value
to be produced by any process
"""
(pipe_recv, pipe_send) = mp.Pipe(duplex=False)

# Create processes
try:
procs = [mp.Process(target=_do_work, args=(pipe_send,function,args,kwargs))
for _ in range(poolsize)]
# Start processes
for p in procs:
p.start()

result = pipe_recv.recv()
finally:
pipe_recv.close()
pipe_send.close()

# Terminate processes
for p in procs:
p.terminate()

return result


__all__ = ['getprime','exec_parallel']

if __name__ == '__main__':
print('Running doctests 1000x or until failure')
Expand Down
80 changes: 80 additions & 0 deletions rsa/prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,86 @@ def getprime(nbits):

# Retry if not prime

small_primes=(3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61,
67, 71, 73, 79, 83, 89, 97)
def prime_sieve(start,end):
"""for numbers in range(start,end) sieves out composites using trial
division returning potential primes
>>> sieve=prime_sieve(100,120)
>>> next(sieve)
101
>>> [x for x in sieve]
[103, 107, 109, 113]
"""

#handle small numbers
if start<=small_primes[-1]:
if start<=2:yield 2
for p in small_primes:
if p<=start:yield p
start|=1#make start odd
#We use an offset when doing the trial divisions. It is much smaller than
#the full number. This makes the modulo operations fast. When yielding a
#candidate we add the start and offset to get the candidate value.
residues=tuple((-start%p,p) for p in small_primes)
#start+offset=0 (mod p) <---condition to check for
#offset=-start (mod p)
#offset%p=(-start)%p <--that's the residue
offset=0
span=end-start
while offset<span:
for residue,p in residues:
if (offset%p)==residue:break
else:#all trial divisions were successful
yield start+offset
offset+=2

def getprimebyrange(start,end,initial=None):
"""Returns a prime number randomly chosen from range(start,end)
randomly chooses an initial point within the range
This can be overriden with the optional initial argument
starts at the initial point scanning range(initial,end) then trying
range(start,initial)
>>> p = getprimebyrange(100,200)
>>> 100<=p<200
True
>>> is_prime(p-1)
False
>>> is_prime(p)
True
>>> is_prime(p+1)
False
>>> getprimebyrange(10000,20000,initial=10000)
10007
>>> getprimebyrange(10000,20000,initial=10010)
10037
>>> #when no primes in range(initial,end), it tries range(start,initial)
>>> getprimebyrange(10000,10020,initial=10010)
10007
"""
#randomly choose the initial point in the range (unless specified)
if initial is None:
initial=rsa.randnum.randrange(start, end)
#check top part of range
for candidate in prime_sieve(initial, end):
# Test for primeness
if is_prime(candidate):
return candidate
#nothing in the top part of the given range
#check bottom part of range
for candidate in prime_sieve(start, initial):
#integer = rsa.randnum.read_random_odd_int(nbits)
# Test for primeness
if is_prime(candidate):
return candidate
#nothing the bottom half either
raise ValueError("no primes in range")


def are_relatively_prime(a, b):
"""Returns True if a and b are relatively prime, and False if they
Expand Down
12 changes: 12 additions & 0 deletions rsa/randnum.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,15 @@ def randint(maxvalue):
tries += 1

return value

def randrange(start,end):
"""Returns a random integer from range(start,end)
"""
assert end>start
span=end-start
#get an int with at least 64 extra bits
#because of the extra bits, value%span wraps around at least 2^64 times
#the non-uniformity of the resulting distribution is below 2**-64
bytes=(common.bit_size(span)+64+7)//8
value = transform.bytes2int(os.urandom(bytes))
return start+value%span
13 changes: 8 additions & 5 deletions tests/test_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,23 @@ def test_custom_getprime_func(self):
# List of primes to test with, in order [p, q, p, q, ....]
# By starting with two of the same primes, we test that this is
# properly rejected.
primes = [64123, 64123, 64123, 50957, 39317, 33107]
primes = [64123, 64123,#discarded because identical
61871, 61909,#rejected because too close
64123, 50957,#discarded because of custom exponent
61871, 46877]#should be returned

def getprime(_):
def getprime(a,b):
return primes.pop(0)

# This exponent will cause two other primes to be generated.
exponent = 136407

(p, q, e, d) = rsa.key.gen_keys(64,
(p, q, e, d) = rsa.key.gen_keys(32,
accurate=False,
getprime_func=getprime,
exponent=exponent)
self.assertEqual(39317, p)
self.assertEqual(33107, q)
self.assertEqual(61871, p)
self.assertEqual(46877, q)


class HashTest(unittest.TestCase):
Expand Down

0 comments on commit d91005d

Please sign in to comment.