In [29]:
import random as rn

In [37]:
p = 2**255 - 19
A = 486662

In [23]:
print(hex(p))

0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed


In [55]:
c1 = (p + 3) // 8       # Integer arithmetic
c2 = pow(2, c1, p)
c3 = pow(2, (p-1) // 4, p)
c4 = (p - 5) // 8       # Integer arithmetic

print(hex(c1))
print(hex(c2))
print(hex(c3))
print(hex(c4))


0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe
0x2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b1
0x2b8324804fc1df0b2b4d00993dfbd7a72f431806ad2fe478c4ee1b274a0ea0b0
0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd


In [None]:
def is_on_curve25519(x, y):
    return pow(y, 2, p) == (pow(x, 3, p) + A * pow(x, 2, p) + x) % p

In [40]:
def cmov(a, b, c):
    if not c:
        return a
    return b
    
def sgn0(x):
    return x % 2

def inv0(x):
    return pow(x, p-2, p)

In [48]:
def map_to_curve_elligator2_curve25519(u):
    tv1 = pow(u, 2, p)
    tv1 = 2 * tv1 % p
    xd = tv1 + 1 % p        # Nonzero: -1 is square (mod p), tv1 is not
    x1n = -A  % p             # x1 = x1n / xd = -A / (1 + 2 * u^2)
    tv2 = pow(xd, 2, p)
    gxd = tv2 * xd  % p       # gxd = xd^3
    gx1 = A * tv1  % p        # x1n + A * xd
    gx1 = gx1 * x1n  % p      # x1n^2 + A * x1n * xd
    gx1 = gx1 + tv2  % p      # x1n^2 + A * x1n * xd + xd^2
    gx1 = gx1 * x1n  % p      # x1n^3 + A * x1n^2 * xd + x1n * xd^2
    tv3 = pow(gxd, 2, p)
    tv2 = pow(tv3, 2, p)           # gxd^4
    tv3 = tv3 * gxd % p       # gxd^3
    tv3 = tv3 * gx1 % p       # gx1 * gxd^3
    tv2 = tv2 * tv3 % p       # gx1 * gxd^7
    y11 = pow(tv2, c4, p)          # (gx1 * gxd^7)^((p - 5) / 8)
    y11 = y11 * tv3  % p       # gx1 * gxd^3 * (gx1 * gxd^7)^((p - 5) / 8)
    y12 = y11 * c3  % p
    tv2 = pow(y11, 2, p)
    tv2 = tv2 * gxd  % p
    e1 = tv2 == gx1
    y1 = cmov(y12, y11, e1)  # If g(x1) is square, this is its sqrt
    x2n = x1n * tv1  % p          # x2 = x2n / xd = 2 * u^2 * x1n / xd
    y21 = y11 * u  % p
    y21 = y21 * c2 % p
    y22 = y21 * c3 % p
    gx2 = gx1 * tv1  % p          # g(x2) = gx2 / gxd = 2 * u^2 * g(x1)
    tv2 = pow(y21, 2, p)
    tv2 = tv2 * gxd % p
    e2 = tv2 == gx2
    y2 = cmov(y22, y21, e2)  # If g(x2) is square, this is its sqrt
    tv2 = pow(y1, 2, p)
    tv2 = tv2 * gxd % p
    e3 = tv2 == gx1
    xn = cmov(x2n, x1n, e3)  # If e3, x = x1, else x = x2
    y = cmov(y2, y1, e3)    # If e3, y = y1, else y = y2
    e4 = sgn0(y) == 1        # Fix sign of y
    y = cmov(y, -y, e3 ^ e4)
    return (xn, xd, y, 1)

In [52]:
u = rn.randint(0, p-1)
print(hex(u))

0x1b898a0b2ff3e1766310aafe636344854d567706074eaa5e36e41a0e7ad206b5


In [56]:
xn, xd, yn, yd = map_to_curve_elligator2_curve25519(u)

x = xn * inv0(xd) % p
y = yn * inv0(yd) % p

print(hex(x))
print(hex(y))

0x1707caa60b5e50ec85d273fcc8cd805c5dc8daa55e7a02b8b7b7298a36102ed2
0x48d3bb8f816b184a1b6a9641e826435049b50fcdb54da0ebd975c996ecaa3795


In [57]:
pow(y, 2, p) == (pow(x, 3, p) + A * pow(x, 2, p) + x) % p

True

In [None]:
for i in range(1000):
    u = rn.randint(0, p-1)
    xn, xd, yn, yd = map_to_curve_elligator2_curve25519(u)
    x = xn * inv0(xd) % p
    y = yn * inv0(yd) % p
    if not is_on_curve25519(x, y):
        print("Failed on ", hex(u))
        break
print("END")