From 7bc9f80abbb580cea11a3730f2a514634f344062 Mon Sep 17 00:00:00 2001 From: Tim Ruffing Date: Thu, 2 Aug 2018 17:11:01 +0200 Subject: [PATCH] Improve readability of Python implementation * 1-to-1 correspondence to functions in pseudocode * naming of functions closer to pseudocode * lowercase / uppercase names more consistent * more consistent parentheses * don't reassign to k in signing algorithm (#7) * fail signing if private nonce is 0 (#13) --- bip-schnorr.mediawiki | 73 ++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/bip-schnorr.mediawiki b/bip-schnorr.mediawiki index 1ea31e5d9b..fd59c05f0e 100644 --- a/bip-schnorr.mediawiki +++ b/bip-schnorr.mediawiki @@ -284,56 +284,63 @@ p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) -def point_add(p1, p2): - if (p1 is None): - return p2 - if (p2 is None): - return p1 - if (p1[0] == p2[0] and p1[1] != p2[1]): +def point_add(P1, P2): + if P1 is None: + return P2 + if P2 is None: + return P1 + if P1[0] == P2[0] and P1[1] != P2[1]: return None - if (p1 == p2): - lam = (3 * p1[0] * p1[0] * pow(2 * p1[1], p - 2, p)) % p + if P1 == P2: + lam = (3 * P1[0] * P1[0] * pow(2 * P1[1], p - 2, p)) % p else: - lam = ((p2[1] - p1[1]) * pow(p2[0] - p1[0], p - 2, p)) % p - x3 = (lam * lam - p1[0] - p2[0]) % p - return (x3, (lam * (p1[0] - x3) - p1[1]) % p) + lam = ((P2[1] - P1[1]) * pow(P2[0] - P1[0], p - 2, p)) % p + x3 = (lam * lam - P1[0] - P2[0]) % p + return (x3, (lam * (P1[0] - x3) - P1[1]) % p) -def point_mul(p, n): - r = None +def point_mul(P, n): + R = None for i in range(256): - if ((n >> i) & 1): - r = point_add(r, p) - p = point_add(p, p) - return r + if (n >> i) & 1: + R = point_add(R, P) + P = point_add(P, P) + return R -def bytes_point(p): - return (b'\x03' if p[1] & 1 else b'\x02') + p[0].to_bytes(32, byteorder="big") +def bytes_from_int(x): + return x.to_bytes(32, byteorder="big") -def sha256(b): - return int.from_bytes(hashlib.sha256(b).digest(), byteorder="big") +def bytes_from_point(P): + return (b'\x03' if P[1] & 1 else b'\x02') + bytes_from_int(P[0]) -def on_curve(point): - return (pow(point[1], 2, p) - pow(point[0], 3, p)) % p == 7 +def int_from_bytes(b): + return int.from_bytes(b, byteorder="big") + +def hash_sha256(b): + return hashlib.sha256(b).digest() + +def on_curve(P): + return (pow(P[1], 2, p) - pow(P[0], 3, p)) % p == 7 def jacobi(x): return pow(x, (p - 1) // 2, p) def schnorr_sign(msg, seckey): - k = sha256(seckey.to_bytes(32, byteorder="big") + msg) % n - R = point_mul(G, k) - if jacobi(R[1]) != 1: - k = n - k - e = sha256(R[0].to_bytes(32, byteorder="big") + bytes_point(point_mul(G, seckey)) + msg) % n - return R[0].to_bytes(32, byteorder="big") + ((k + e * seckey) % n).to_bytes(32, byteorder="big") + k0 = int_from_bytes(hash_sha256(bytes_from_int(seckey) + msg)) % n + if k0 == 0: + raise RuntimeError('Failure. This happens only with negligible probability.') + R = point_mul(G, k0) + k = n - k0 if (jacobi(R[1]) != 1) else k0 + e = int_from_bytes(hash_sha256(bytes_from_int(R[0]) + bytes_from_point(point_mul(G, seckey)) + msg)) % n + return bytes_from_int(R[0]) + bytes_from_int((k + e * seckey) % n) def schnorr_verify(msg, pubkey, sig): - if (not on_curve(pubkey)): + if not on_curve(pubkey): return False - r = int.from_bytes(sig[0:32], byteorder="big") - s = int.from_bytes(sig[32:64], byteorder="big") + r = int_from_bytes(sig[0:32]) + s = int_from_bytes(sig[32:64]) if r >= p or s >= n: return False - e = sha256(sig[0:32] + bytes_point(pubkey) + msg) % n + e = int_from_bytes(hash_sha256(sig[0:32] + bytes_from_point(pubkey) + msg)) % n R = point_add(point_mul(G, s), point_mul(pubkey, n - e)) if R is None or jacobi(R[1]) != 1 or R[0] != r: return False