In [70]:
import random as rn
import binascii
import hashlib
import math

In [49]:
p = 2**255 - 19
A = 486662

In [50]:
print(hex(p))

0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed


In [73]:
c1 = (p + 3) // 8       # Integer arithmetic
c2 = pow(2, c1, p)
c3 = pow(2, (p-1) // 4, p)
c4 = (p - 5) // 8       # Integer arithmetic

print(hex(c1))
print(hex(c2))
print(hex(c3))
print(hex(c4))
print(c2 == c3 + 1)


0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe
0x2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b1
0x2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0
0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd
True


In [103]:
w = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd
print(math.log2(w))
print(bin(w))

252.0
0b111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111101


In [102]:
qp = pow(0x63b5154e8c80d61c717393708d628efe4959c7129c17298ec7a5777eb4e64c27, 2**250-1, p)
qp = qp * qp % p
qp = qp * 0x63b5154e8c80d61c717393708d628efe4959c7129c17298ec7a5777eb4e64c27 % p
print(hex(qp))


0x7b5b8c3d947f0533bc21452c15a6f45cad761ec4f7a475cff5c36ef0410cbd6


In [69]:
print(hex(p >> 3))

0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd


In [52]:
def is_on_curve25519(x, y):
    return pow(y, 2, p) == (pow(x, 3, p) + A * pow(x, 2, p) + x) % p

In [53]:
def int2bytes(i) -> str:
    hex_string = '%x' % i
    n = len(hex_string)
    return binascii.unhexlify(hex_string.zfill(n + (n & 1)))

In [62]:
def print_bytes(s: str, b: bytes):
    print(s, "".join(format(x, "02x") for x in b))

In [83]:
def cmov(a, b, c):
    print("cmov: ", not c)
    if not c:
        return a
    return b
    
def sgn0(x):
    return x % 2

def inv0(x):
    return pow(x, p-2, p)

In [56]:
def sha512(s):
    return hashlib.sha512(s).digest()

In [57]:
def expand_message(MSG: bytes, DST: bytes) -> bytes:
    EXP_TAG = int2bytes(0x8000000000000000000000000000000000000000000000000000000000545301)
    MSG = EXP_TAG + MSG + int2bytes(0x20) + DST + int2bytes(0x1E)
    print_bytes("to expand: ", MSG)
    print(len(MSG))
    return sha512(MSG)

In [58]:
def hash_to_field(MSG: str, DST: str) -> int:
    expanded = expand_message(MSG, DST)
    print_bytes("expanded: ", expanded)
    return int.from_bytes(expanded, 'big') % p

In [98]:
def map_to_curve_elligator2_curve25519(u):
    tv1 = pow(u, 2, p)
    tv1 = 2 * tv1 % p
    xd = tv1 + 1 % p        # Nonzero: -1 is square (mod p), tv1 is not
    x1n = -A  % p             # x1 = x1n / xd = -A / (1 + 2 * u^2)
    tv2 = pow(xd, 2, p)
    gxd = tv2 * xd  % p       # gxd = xd^3
    gx1 = A * tv1  % p        # x1n + A * xd
    gx1 = gx1 * x1n  % p      # x1n^2 + A * x1n * xd
    gx1 = gx1 + tv2  % p      # x1n^2 + A * x1n * xd + xd^2
    gx1 = gx1 * x1n  % p      # x1n^3 + A * x1n^2 * xd + x1n * xd^2
    tv3 = pow(gxd, 2, p)
    tv2 = pow(tv3, 2, p)           # gxd^4
    tv3 = tv3 * gxd % p       # gxd^3
    tv3 = tv3 * gx1 % p       # gx1 * gxd^3
    tv2 = tv2 * tv3 % p       # gx1 * gxd^7
    print("tv2: ", hex(tv2))
    y11 = pow(tv2, c4, p)          # (gx1 * gxd^7)^((p - 5) / 8)
    print("y11: ", hex(y11))
    y11 = y11 * tv3  % p       # gx1 * gxd^3 * (gx1 * gxd^7)^((p - 5) / 8)
    y12 = y11 * c3  % p
    tv2 = pow(y11, 2, p)
    tv2 = tv2 * gxd  % p
    e1 = tv2 == gx1
    y1 = cmov(y12, y11, e1)  # If g(x1) is square, this is its sqrt
    x2n = x1n * tv1  % p          # x2 = x2n / xd = 2 * u^2 * x1n / xd
    y21 = y11 * u  % p
    y21 = y21 * c2 % p
    y22 = y21 * c3 % p
    gx2 = gx1 * tv1  % p          # g(x2) = gx2 / gxd = 2 * u^2 * g(x1)
    tv2 = pow(y21, 2, p)
    tv2 = tv2 * gxd % p
    e2 = tv2 == gx2
    y2 = cmov(y22, y21, e2)  # If g(x2) is square, this is its sqrt
    tv2 = pow(y1, 2, p)
    tv2 = tv2 * gxd % p
    e3 = tv2 == gx1
    xn = cmov(x2n, x1n, e3)  # If e3, x = x1, else x = x2
    y = cmov(y2, y1, e3)    # If e3, y = y1, else y = y2
    e4 = sgn0(y) == 1        # Fix sign of y
    y = cmov(y, -y, e3 ^ e4)
    return (xn, xd, y, 1)

In [64]:
def point_generate(DST: str, rng: int):
    m = int2bytes(rng)
    print("rng: ", hex(rng))
    print_bytes("m: ", m)

    u = hash_to_field(m, DST)

    print("u: ", hex(u))

    xn, xd, yn, yd = map_to_curve_elligator2_curve25519(u)

    x = xn * inv0(xd) % p
    y = yn * inv0(yd) % p

    return x, y

In [66]:
rng = rn.randint(0, 2**256 - 1)

In [104]:
for i in range(4):
    rng = rn.randint(0, 2**256 - 1)
    DST = int2bytes(0x54535F53504543545F4453540000000000000000000000000000000000D8)
    x, y = point_generate(DST, rng)

    print("x: ", hex(x))
    print("y: ", hex(y))
    print(is_on_curve25519(x, y))
    print("======================================================")

rng:  0xd68328b42d3f338b23303da1fd12a73a0d3fb1a1348f4ae2e51aca3584fd2712
m:  d68328b42d3f338b23303da1fd12a73a0d3fb1a1348f4ae2e51aca3584fd2712
to expand:  8000000000000000000000000000000000000000000000000000000000545301d68328b42d3f338b23303da1fd12a73a0d3fb1a1348f4ae2e51aca3584fd27122054535f53504543545f4453540000000000000000000000000000000000d81e
96
expanded:  4d30f65289b0b974e4e7a3240543db0c818d848a405d74734c110c23ad625958a098725adbfde2aaad316df6875efaf789ef32b3a005f7fc956307488cfa44f0
u:  0x15dd029b4c396a04a793a54f4f717ed2c4f0df392de54119dfead494499389c8
xd:  0x2d5a8b57273a2d059f924d49a908977b74157c0d80f230547bbcf977fff55fd9
gx1:  0x6cd57255ebafd0132559cbeec31349e842b12ed776907cd52eedf08118a9350c
tv2:  0x4ca689b7bde5abc5f88eb1e2c37845a08b6f25855a9e09c1b73c32f5c06e875d
y11:  0x7abd5f27b414938ab5af6d3c9c2bcde9d1bfc322f20b3a6352e6adadf0e3126f
cmov:  True
cmov:  True
cmov:  True
cmov:  True
cmov:  True
x:  0x671ad6d51a5c679fbbc15453d4284acea8241a06ec287a639449bf814e457e95
y:  0x30d06c46219

In [52]:
u = rn.randint(0, p-1)
print(hex(u))

0x1b898a0b2ff3e1766310aafe636344854d567706074eaa5e36e41a0e7ad206b5


In [56]:
xn, xd, yn, yd = map_to_curve_elligator2_curve25519(u)

x = xn * inv0(xd) % p
y = yn * inv0(yd) % p

print(hex(x))
print(hex(y))

0x1707caa60b5e50ec85d273fcc8cd805c5dc8daa55e7a02b8b7b7298a36102ed2
0x48d3bb8f816b184a1b6a9641e826435049b50fcdb54da0ebd975c996ecaa3795


In [57]:
pow(y, 2, p) == (pow(x, 3, p) + A * pow(x, 2, p) + x) % p

True

In [None]:
for i in range(1000):
    u = rn.randint(0, p-1)
    xn, xd, yn, yd = map_to_curve_elligator2_curve25519(u)
    x = xn * inv0(xd) % p
    y = yn * inv0(yd) % p
    if not is_on_curve25519(x, y):
        print("Failed on ", hex(u))
        break
print("END")