diff --git a/README.md b/README.md index 41c2deb..c81a7e6 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,8 @@ Implements Coron's reformulation of Coppersmith's algorithm for finding small in Paper: http://www.jscoron.fr/publications/bivariate.pdf -Used in CSAW CTF Quals 2016 to solve Still Broken Box. (BTW, if you want an implementation of a crypto algorithm, write a crypto CTF challenge that needs it and read writeups.) - -~~**Warning: CTF Quality Code!**~~ Should be much more readable now. - -## Why? -Why not? ## Doesn't Sage provide this with `small_roots()`? -`small_roots()` only works with univariate polynomials. (which still would have saved me a lot of time in CSAW...) +`small_roots()` only works with univariate polynomials. + +## History +Used in CSAW CTF Quals 2016 to solve Still Broken Box. (BTW, if you want an implementation of a crypto algorithm, write a crypto CTF challenge that needs it and read writeups.) diff --git a/coppersmith.sage b/coppersmith.sage index 68afc1c..99644fe 100644 --- a/coppersmith.sage +++ b/coppersmith.sage @@ -1,6 +1,6 @@ -def coron(pol, X, Y, k=2, debug=False): +def coron(pol, X, Y, M=None, k=2, debug=False): """ - Returns all small roots of pol. + Returns all small roots of pol over the integers (modulo M if it is given). Applies Coron's reformulation of Coppersmith's algorithm for finding small integer roots of bivariate polynomials modulo an integer. @@ -9,6 +9,7 @@ def coron(pol, X, Y, k=2, debug=False): pol: The polynomial to find small integer roots of. X: Upper limit on x. Y: Upper limit on y. + M: Modulus. If M==None, then pol is considered over the integers. k: Determines size of lattice. Increase if the algorithm fails. debug: Turn on for debug print stuff. @@ -23,56 +24,55 @@ def coron(pol, X, Y, k=2, debug=False): raise ValueError("pol is not bivariate") P. = PolynomialRing(ZZ) - pol = pol(x,y) + pol = P( pol(x,y) ) + + # removing common factor of the coefficients + if M: + M //= gcd(M,pol.content()) + pol //= pol.content() + + if len(pol.factor()) > 1: + raise ValueError("pol is reducible") # Handle case where pol(0,0) == 0 xoffset = 0 - - while pol(xoffset,0) == 0: + while pol(xoffset,0) == 0 or (M and gcd(pol(xoffset,0),M) != 1): xoffset += 1 + if debug: + print("Offset:", xoffset) pol = pol(x+xoffset,y) + p00 = pol(0,0) # Handle case where gcd(pol(0,0),X*Y) != 1 - while gcd(pol(0,0), X) != 1: + while gcd(p00, X) != 1: X = next_prime(X, proof=False) - while gcd(pol(0,0), Y) != 1: + while gcd(p00, Y) != 1: Y = next_prime(Y, proof=False) - pol = P(pol/gcd(pol.coefficients())) # seems to be helpful - p00 = pol(0,0) - delta = max(pol.degree(x),pol.degree(y)) # maximum degree of any variable + delta = max(pol.degree(x),pol.degree(y)) # maximum degree of any variable + + if M: + u = M + else: + W = max(abs(i) for i in pol(x*X,y*Y).coefficients()) + u = W + ((1-W) % abs(p00)) - W = max(abs(i) for i in pol(x*X,y*Y).coefficients()) - u = W + ((1-W) % abs(p00)) - N = u*(X*Y)^k # modulus for polynomials + N = u*(X*Y)^k # modulus for polynomials # Construct polynomials p00inv = inverse_mod(p00,N) - polq = P(sum((i*p00inv % N)*j for i,j in zip(pol.coefficients(), - pol.monomials()))) - polynomials = [] - for i in range(delta+k+1): - for j in range(delta+k+1): - if 0 <= i <= k and 0 <= j <= k: - polynomials.append(polq * x^i * y^j * X^(k-i) * Y^(k-j)) - else: - polynomials.append(x^i * y^j * N) + polq = P( sum((i*p00inv % N)*j for i,j in pol) ) + + polynomials = [ polq * x^i * y^j * X^(k-i) * Y^(k-j) if (0 <= i <= k and 0 <= j <= k) + else x^i * y^j * N for i in range(delta+k+1) for j in range(delta+k+1) ] # Make list of monomials for matrix indices - monomials = [] - for i in polynomials: - for j in i.monomials(): - if j not in monomials: - monomials.append(j) - monomials.sort() + monomials = sorted( set( sum( (i.monomials() for i in polynomials), [] ) ) ) # Construct lattice spanned by polynomials with xX and yY - L = matrix(ZZ,len(monomials)) - for i in range(len(monomials)): - for j in range(len(monomials)): - L[i,j] = polynomials[i](X*x,Y*y).monomial_coefficient(monomials[j]) + L = matrix(ZZ, len(monomials), lambda i,j: polynomials[i](X*x,Y*y).monomial_coefficient(monomials[j]) ) # makes lattice upper triangular # probably not needed, but it makes debug output pretty @@ -92,31 +92,33 @@ def coron(pol, X, Y, k=2, debug=False): for i in range(L.nrows()): if debug: - print("Trying row {}".format(i)) + print(f"Trying row {i}") # i'th row converted to polynomial dividing out X and Y pol2 = P(sum(map(mul, zip(L[i],monomials)))(x/X,y/Y)) r = pol.resultant(pol2, y) - if r.is_constant(): # not independent + if r.is_constant(): # not independent continue for x0, _ in r.univariate_polynomial().roots(): - if x0-xoffset in [i[0] for i in roots]: + if x0+xoffset in [i[0] for i in roots]: continue if debug: - print("Potential x0:",x0) + print( "Potential x0:",x0 ) for y0, _ in pol(x0,y).univariate_polynomial().roots(): if debug: - print("Potential y0:",y0) - if (x0-xoffset,y0) not in roots and pol(x0,y0) == 0: - roots.append((x0-xoffset,y0)) + print( "Potential y0:",y0 ) + v = pol(x0,y0) + if M: + v %= M + if v == 0 and (x0+xoffset,y0) not in roots: + roots.append((x0+xoffset,y0)) return roots def main(): # Example 1: recover p,q prime given n=pq and the lower bits of p - print("---EXAMPLE 1---") nbits = 512 # bitlength of primes @@ -163,7 +165,6 @@ def main(): # Example 2: recover p,q prime given n=pq and the upper bits of p # This can be done with a univariate polynomial and Howgrave-Graham, # but this is another way to do it with a bivariate polynomial. - print("---EXAMPLE 2---") nbits = 512 # bitlength of primes