Skip to content

Commit

Permalink
Merge pull request bitcoin#12 from real-or-random/bip-schnorr
Browse files Browse the repository at this point in the history
Improve readability of Python implementation
  • Loading branch information
sipa committed Nov 3, 2018
2 parents ce9fda8 + 7bc9f80 commit 7511d17
Showing 1 changed file with 40 additions and 33 deletions.
73 changes: 40 additions & 33 deletions bip-schnorr.mediawiki
Original file line number Diff line number Diff line change
Expand Up @@ -284,56 +284,63 @@ p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)

def point_add(p1, p2):
if (p1 is None):
return p2
if (p2 is None):
return p1
if (p1[0] == p2[0] and p1[1] != p2[1]):
def point_add(P1, P2):
if P1 is None:
return P2
if P2 is None:
return P1
if P1[0] == P2[0] and P1[1] != P2[1]:
return None
if (p1 == p2):
lam = (3 * p1[0] * p1[0] * pow(2 * p1[1], p - 2, p)) % p
if P1 == P2:
lam = (3 * P1[0] * P1[0] * pow(2 * P1[1], p - 2, p)) % p
else:
lam = ((p2[1] - p1[1]) * pow(p2[0] - p1[0], p - 2, p)) % p
x3 = (lam * lam - p1[0] - p2[0]) % p
return (x3, (lam * (p1[0] - x3) - p1[1]) % p)
lam = ((P2[1] - P1[1]) * pow(P2[0] - P1[0], p - 2, p)) % p
x3 = (lam * lam - P1[0] - P2[0]) % p
return (x3, (lam * (P1[0] - x3) - P1[1]) % p)

def point_mul(p, n):
r = None
def point_mul(P, n):
R = None
for i in range(256):
if ((n >> i) & 1):
r = point_add(r, p)
p = point_add(p, p)
return r
if (n >> i) & 1:
R = point_add(R, P)
P = point_add(P, P)
return R

def bytes_point(p):
return (b'\x03' if p[1] & 1 else b'\x02') + p[0].to_bytes(32, byteorder="big")
def bytes_from_int(x):
return x.to_bytes(32, byteorder="big")

def sha256(b):
return int.from_bytes(hashlib.sha256(b).digest(), byteorder="big")
def bytes_from_point(P):
return (b'\x03' if P[1] & 1 else b'\x02') + bytes_from_int(P[0])

def on_curve(point):
return (pow(point[1], 2, p) - pow(point[0], 3, p)) % p == 7
def int_from_bytes(b):
return int.from_bytes(b, byteorder="big")

def hash_sha256(b):
return hashlib.sha256(b).digest()

def on_curve(P):
return (pow(P[1], 2, p) - pow(P[0], 3, p)) % p == 7

def jacobi(x):
return pow(x, (p - 1) // 2, p)

def schnorr_sign(msg, seckey):
k = sha256(seckey.to_bytes(32, byteorder="big") + msg) % n
R = point_mul(G, k)
if jacobi(R[1]) != 1:
k = n - k
e = sha256(R[0].to_bytes(32, byteorder="big") + bytes_point(point_mul(G, seckey)) + msg) % n
return R[0].to_bytes(32, byteorder="big") + ((k + e * seckey) % n).to_bytes(32, byteorder="big")
k0 = int_from_bytes(hash_sha256(bytes_from_int(seckey) + msg)) % n
if k0 == 0:
raise RuntimeError('Failure. This happens only with negligible probability.')
R = point_mul(G, k0)
k = n - k0 if (jacobi(R[1]) != 1) else k0
e = int_from_bytes(hash_sha256(bytes_from_int(R[0]) + bytes_from_point(point_mul(G, seckey)) + msg)) % n
return bytes_from_int(R[0]) + bytes_from_int((k + e * seckey) % n)

def schnorr_verify(msg, pubkey, sig):
if (not on_curve(pubkey)):
if not on_curve(pubkey):
return False
r = int.from_bytes(sig[0:32], byteorder="big")
s = int.from_bytes(sig[32:64], byteorder="big")
r = int_from_bytes(sig[0:32])
s = int_from_bytes(sig[32:64])
if r >= p or s >= n:
return False
e = sha256(sig[0:32] + bytes_point(pubkey) + msg) % n
e = int_from_bytes(hash_sha256(sig[0:32] + bytes_from_point(pubkey) + msg)) % n
R = point_add(point_mul(G, s), point_mul(pubkey, n - e))
if R is None or jacobi(R[1]) != 1 or R[0] != r:
return False
Expand Down

0 comments on commit 7511d17

Please sign in to comment.