All quotes below are excerpted from https://toadstyle.org/cryptopals/60.txt

```
60. Single-Coordinate Ladders and Insecure Twists

All our hard work is about to pay some dividends. Here's a list of
cool-kids jargon you'll be able to deploy after completing this
challenge:

* Montgomery curve
* single-coordinate ladder
* isomorphism
* birational equivalence
* quadratic twist
* trace of Frobenius

Not that you'll understand it all; you won't. But you'll at least be
able to silence crypto-dilettantes on Twitter.

Now, to the task at hand. In the last problem, we implemented ECDH
using a short Weierstrass curve form, like this:

    y^2 = x^3 + a*x + b

For a long time, this has been the most popular curve form. The NIST
P-curves standardized in the 90s look like this. It's what you'll see
first in most elliptic curve tutorials (including this one).

We can do a lot better. Meet the Montgomery curve:

    B*v^2 = u^3 + A*u^2 + u

Although it's almost as old as the Weierstrass form, it's been buried
in the literature until somewhat recently. The Montgomery curve has a
killer feature in the form of a simple and efficient algorithm to
compute scalar multiplication: the Montgomery ladder.

Here's the ladder:

    function ladder(u, k):
        u2, w2 := (1, 0)
        u3, w3 := (u, 1)
        for i in reverse(range(bitlen(p))):
            b := 1 & (k >> i)
            u2, u3 := cswap(u2, u3, b)
            w2, w3 := cswap(w2, w3, b)
            u3, w3 := ((u2*u3 - w2*w3)^2,
                       u * (u2*w3 - w2*u3)^2)
            u2, w2 := ((u2^2 - w2^2)^2,
                       4*u2*w2 * (u2^2 + A*u2*w2 + w2^2))
            u2, u3 := cswap(u2, u3, b)
            w2, w3 := cswap(w2, w3, b)
        return u2 * w2^(p-2)

You are not expected to understand this.
```

In [148]:
from dataclasses import dataclass
from itertools import combinations
from random import randrange
from math import log, ceil
from pprint import pprint

from challenge_31 import do_sha256, hmac
from challenge_39 import invmod
from challenge_57 import primegen, int_to_bytes, crt
from challenge_59 import tonelli_shanks, NoQuadraticResidueError

In [111]:
def cswap(a, b, i):
    # b,a if i else a,b
    return (b, a) if i else (a, b)  # absurdly, this is faster than the arithmetic implementation in python
    #return (b*i + a*(1-i), a*i + b*(1-i))


@dataclass
class MontyCurve:
    a: int
    b: int
    p: int
    
    def mul(self, u, k):
        a, p = self.a, self.p
        blp = int(ceil(log(p, 2)))  # bitlength of p
        u2, w2 = (1, 0)
        u3, w3 = (u, 1)
        for i in (range(blp)[::-1]):
            b = 1 & (k >> i)
            #u2, u3 = cswap(u2, u3, b)
            #w2, w3 = cswap(w2, w3, b)
            u2, u3 = (u3, u2) if b else (u2, u3)
            w2, w3 = (w3, w2) if b else (w2, w3)
            u3, w3 = ((u2*u3 - w2*w3)**2 % p,
                       u * (u2*w3 - w2*u3)**2 % p)
            u2, w2 = ((u2**2 - w2**2)**2 % p,
                       4*u2*w2 * (u2**2 + a*u2*w2 + w2**2) % p)
            #u2, u3 = cswap(u2, u3, b)
            #w2, w3 = cswap(w2, w3, b)
            u2, u3 = (u3, u2) if b else (u2, u3)
            w2, w3 = (w3, w2) if b else (w2, w3)
        return (u2 * pow(w2, p-2, p)) % p



# ==== Profiling ====
# %timeit get_extra_coeffs()
# for 3 different implementations of MontyCurve.mul:
# 2.87 s ± 66.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)  cswap function (branching)
# 2.68 s ± 64.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)  cswap inlined (branching)
# 3.14 s ± 33.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)  cswap inlined (arithmetic)

```
Go ahead and implement the ladder. Remember that all computations are
in GF(233970423115425145524320034830162017933).

Oh yeah, the curve parameters. You might be thinking that since we're
switching to a new curve format, we also need to pick out a whole new
curve. But you'd be totally wrong! It turns out that some short
Weierstrass curves can be converted into Montgomery curves.

You can perform this conversion algebraically. But it's kind of a
pain, so here you go:

    v^2 = u^3 + 534*u^2 + u

Through cunning and foresight, I have chosen this curve specifically
to have a really simple map between Weierstrass and Montgomery
forms. Here it is:

    u = x - 178
    v = y

Which makes our base point:

    (4, 85518893674295321206118380980485522083)

Or, you know. Just 4.

Anyway, implement the ladder. Verify ladder(4, n) = 0. Map some points
back and forth between your Weierstrass and Montgomery representations
and verify them.
```

In [112]:
curve = MontyCurve(a=534, b=1, p=233970423115425145524320034830162017933)
order = 233970423115425145498902418297807005944  # value copied from 59.txt

assert curve.mul(4, order) == 0

In [4]:
from challenge_59 import Curve as WeierCurve
w_curve = WeierCurve(a=-95051, b=11279326, p=233970423115425145524320034830162017933)
w_base = (182, 85518893674295321206118380980485522083)

In [5]:
def to_monty(pt):
    if pt is w_curve.zero:
        return 0
    x = pt[0]
    return x - 178


def to_weier(u):
    if u == 0:
        return w_curve.zero
    p, a = curve.p, curve.a
    v1, v2 = tonelli_shanks(
        (pow(u, 3, p) + a*pow(u, 2, p) + u),
        p
    )
    return (u + 178, v1), (u + 178, v2)  # two possibilities


def test_monty_weier_conversion():
    assert to_monty(w_base) == 4
    
    for _ in range(50):
        i = randrange(0, 10000)
        w = w_curve.mul(w_base, i)
        m = curve.mul(4, i)
        
        assert to_monty(w) == m
        assert w in to_weier(m)
        
    print("All good!", flush=True)

test_monty_weier_conversion()

All good!


---
```
One nice thing about the Montgomery ladder is its lack of special
cases. Specifically, no special handling of: P1 = O; P2 = O; P1 = P2;
or P1 = -P2. Contrast that with our Weierstrass addition function and
its battalion of ifs.

And there's a security benefit, too: by ignoring the v coordinate, we
take away a lot of leeway from the attacker. Recall that the ability
to choose arbitrary (x, y) pairs let them cherry-pick points from any
curve they can think of. The single-coordinate ladder robs the
attacker of that freedom.

But hang on a tick! Give this a whirl:

    ladder(76600469441198017145391791613091732004, 11)
```

In [6]:
u = 76600469441198017145391791613091732004
curve.mul(u, 11)

0

---
```
What the heck? What's going on here?

Let's do a quick sanity check. Here's the curve equation again:

    v^2 = u^3 + 534*u^2 + u

Plug in u and take the square root to recover v.
```

In [7]:
def u_to_v(u, curve=curve):
    p = curve.p
    v_sq = (u**3 + curve.a * u**2 + u) % p
    return tonelli_shanks(v_sq, p)

try:
    u_to_v(u)
except NoQuadraticResidueError:
    print("ERROR: Square root of", u, "does not exist!")

ERROR: Square root of 76600469441198017145391791613091732004 does not exist!


---
```
You should detect that something is quite wrong. This u does not
represent a point on our curve! Not every u does.

This means that even though we can only submit one coordinate, we
still have a little bit of leeway to find invalid
points. Specifically, an input u such that u^3 + 534*u^2 + u is not a
quadratic residue can never represent a point on our curve. So where
the heck are we?

The other curve we're on is a sister curve called a "quadratic twist",
or simply "the twist". There is actually a whole family of quadratic
twists to our curve, but they're all isomorphic to each
other. Remember that that means they have the same number of points,
the same subgroups, etc. So it doesn't really matter which particular
twist we use; in fact, we don't even need to pick one.

...

If Alice chose a curve with an insecure twist, i.e. one with a
partially smooth order, then some doors open back up for Eve. She can
choose low-order points on the twisted curve, send them to Alice, and
perform the invalid-curve attack as before.

The only caveat is that she won't be able to recover the full secret
using off-curve points, only a fraction of it. But we know how to
handle that.

So:

1. Calculate the order of the twist and find its small factors. This
   one should have a bunch under 2^24.
```

In [8]:
# The ordinary curve and its twist have 2*p + 2 points between them.
# The curve's order is known, so we can take the difference to find the twist's order:
twist_order = 2*curve.p + 2 - order

print("Twist's order:", twist_order)
print("Factoring...")
%time factors = [p for p in primegen(up_to=2**24) if twist_order % p == 0 and (twist_order // p) % p != 0]
print("Small, non-repeated factors:", factors)

Twist's order: 233970423115425145549737651362517029924
Factoring...
Small, non-repeated factors: [11, 107, 197, 1621, 105143, 405373, 2323367]


----
```
2. Find points with those orders. This is simple:

   a. Choose a random u mod p and verify that u^3 + A*u^2 + u is a
      nonsquare in GF(p).

   b. Call the order of the twist n. To find an element of order q,
      calculate ladder(u, n/q).
```

In [9]:
# encapsulate this search in a function to avoid cluttering the top level namespace with temp variables
def get_twist_point(fac):
    p = curve.p
    ladder = curve.mul

    while True:
        u = randrange(0, p)
        expr = (pow(u, 3, curve.p) + curve.a*pow(u, 2, p) + u) % curve.p
        try:
            tonelli_shanks(expr, p)
        except NoQuadraticResidueError: pass
        else: continue
        elem = ladder(u, twist_order // fac)
        if elem != 0:
            break

    assert twist_order % fac == 0
    assert ladder(elem, fac) == 0

    return elem

twist_points = {fac: get_twist_point(fac) for fac in factors}
pprint(twist_points)

print("\nHere's the order-11 twist subgroup we found:")
for i in range(13):
    print(curve.mul(twist_points[11], i))
print("...\n")

# note how all the i'th and 11-i'th elements are equal
# this is the cause of the 'combinatorial explosion' alluded to at the end of 60.txt
# we can probably get some "easy" wins by taking additional residues mod eg 11*107, 11*197, ...
# this would reduce our final search's work factor by a factor of 2, at the cost of some precomputation
# we'll see diminishing returns as the precomputation starts to involve larger subgroups and gets more expensive
# thus, this becomes an optimization problem - but i'm getting ahead of myself here
# we'll have to work through some more preliminaries before we get to work on solving that problem

{11: 105888069003703096891937904030103459645,
 107: 232784264231402442109051618948138095864,
 197: 4809650968340552704772177512072571923,
 1621: 70782733606194749784487348130521439533,
 105143: 18205162903737310491909882733161523543,
 405373: 96985955565507758684441919626600641905,
 2323367: 152304473698859030858339554113356746138}

Here's the order-11 twist subgroup we found:
0
105888069003703096891937904030103459645
1430388126279164727092494211327512206
76600469441198017145391791613091732004
4612483201341222105440076661179035958
173527332646559565669040569905840307859
173527332646559565669040569905840307859
4612483201341222105440076661179035958
76600469441198017145391791613091732004
1430388126279164727092494211327512206
105888069003703096891937904030103459645
0
105888069003703096891937904030103459645
...



---
```
3. Send these points to Alice to recover portions of her secret.
```

In [10]:
# implementation of Alice here is modeled on challenge 58's Bob
# cf that block's comments

def alice_coro(message):
    p, mul = curve.p, curve.mul
    priv = randrange(0, p)
    print("DEBUG: Private key =", priv)
    pub = mul(4, priv)  # remember - 4 is our Montgomery curve's generator

    h = (yield pub)
    while True:
        secret = mul(h, priv)
        K = do_sha256(int_to_bytes(secret))
        t = hmac(K, message)
        h = (yield (message, t))

alice = alice_coro(b"no alarms and no surprises")
alice_pubkey = next(alice)
print("Alice initialized. Pubkey:", alice_pubkey)

DEBUG: Private key = 217127450270066099617232436535682564139
Alice initialized. Pubkey: 150460403881749324846754032855571228541


In [11]:
def recover_coefficient(g, order, message, t):
    ladder = curve.mul
    for i in range(order):
        guess = ladder(g, i)
        K = do_sha256(int_to_bytes(guess))
        tag = hmac(K, message)
        if tag == t:
            return (i, order-i)
    print("coefficient not found (?!)")
    raise Exception("this should never happen")

In [128]:
coeffs = {}

print("This step may take a while.")
for small_order, small_pt in twist_points.items():
    print("\nRecovering possible residues mod", small_order, "...", flush=True)
    message, t = alice.send(small_pt)
    %time i1, i2 = recover_coefficient(small_pt, small_order, message, t)
    coeffs[small_order] = i1, i2

print("Done!\n")
pprint(coeffs)

This step may take a while.

Recovering residue mod 11 ...
CPU times: user 909 µs, sys: 0 ns, total: 909 µs
Wall time: 943 µs

Recovering residue mod 107 ...
CPU times: user 9.08 ms, sys: 0 ns, total: 9.08 ms
Wall time: 10.3 ms

Recovering residue mod 197 ...
CPU times: user 3.69 ms, sys: 0 ns, total: 3.69 ms
Wall time: 3.75 ms

Recovering residue mod 1621 ...
CPU times: user 367 ms, sys: 1.86 ms, total: 369 ms
Wall time: 385 ms

Recovering residue mod 105143 ...
CPU times: user 22.9 s, sys: 1.85 ms, total: 22.9 s
Wall time: 23.2 s

Recovering residue mod 405373 ...
CPU times: user 13 s, sys: 0 ns, total: 13 s
Wall time: 13.2 s

Recovering residue mod 2323367 ...
CPU times: user 8min 28s, sys: 51.5 ms, total: 8min 29s
Wall time: 8min 36s
Done!

{11: (1, 10),
 107: (17, 90),
 197: (7, 190),
 1621: (651, 970),
 105143: (46247, 58896),
 405373: (26839, 378534),
 2323367: (956882, 1366485)}


In [120]:
# Let's make like a tree and generate some more residues :)


def get_extra_coeffs():
    extra_coeffs = {}
    #for p1, p2 in combinations(factors[:4], 2):
    p1 = factors[0]
    for p2 in factors[1:5]:
        print(f"Recovering possible residues mod {p1}*{p2}", flush=True)
        o = p1*p2
        while True:
            pt = get_twist_point(o)
            # make sure pt has order p1*p2, not order p1 or p2
            if curve.mul(pt, p1) != 0 and curve.mul(pt, p2) != 0:
                break
        i1, i2 = recover_coefficient(pt, o, *alice.send(pt))
        extra_coeffs[p1, p2] = i1, i2
    print("Done!\n")
    return extra_coeffs

# we've just "cheated" a bit by bringing in some data that we only know after evaluating the above cells
# this data is the index `5` in `for p2 in factors[1:5]`
# (which I chose only after inspecting the list of factors to determine an appropriate cutoff point)

# so yes - this attack isn't 100% automated
# it could be automated at the cost of a little work and some added complexity
# but you're not paying me enough for that - yet :)

%time extra_coeffs = get_extra_coeffs()
pprint(extra_coeffs)

Recovering possible residues mod 11*107
Recovering possible residues mod 11*197
Recovering possible residues mod 11*1621
Recovering possible residues mod 11*105143
Done!

CPU times: user 2min 21s, sys: 11.3 ms, total: 2min 21s
Wall time: 2min 23s
{(11, 107): (197, 980),
 (11, 197): (978, 1189),
 (11, 1621): (4212, 13619),
 (11, 105143): (269182, 887391)}


In [141]:
residue_pairs = []

for t1, t2 in extra_coeffs.items():
    p1, p2 = t1
    r1, r2 = t2
    pairs = (r1%p1, r1%p2), (r2%p1, r2%p2)
    residue_pairs.append(sorted(pairs))

res_seq_1 = (coeffs[11][0],) + tuple(t[0][1] for t in residue_pairs)
res_seq_2 = (coeffs[11][1],) + tuple(t[1][1] for t in residue_pairs)

pprint(residue_pairs)
print()
print(res_seq_1)
print(res_seq_2)

[[(1, 17), (10, 90)],
 [(1, 7), (10, 190)],
 [(1, 651), (10, 970)],
 [(1, 58896), (10, 46247)]]

(1, 17, 7, 651, 58896)
(10, 90, 190, 970, 46247)


In [147]:
def possible_residues():
    # yields lists of possible (i.e. internally consistent) combinations of residue values
    # this is another place where we "cheat" a little by using results from earlier in the notebook
    # it's not really necessary here; it just helps keep the code terse
    for r1 in coeffs[405373]:
        for r2 in coeffs[2323367]:
            t = (r1, r2)
            yield res_seq_1 + t
            yield res_seq_2 + t

print(len(tuple(possible_residues())), "possible combinations of residues (down from 128)")

8 possible combinations of residues (down from 128)


In [149]:
# ok it's go time

for residues in possible_residues():
    residue, modulus = crt(residues, factors)
    print("Guess:", residue, "mod", modulus)

Guess: 27371474982420709302617654 mod 37220200115549684379403037
Guess: 30195693567363723420087351 mod 37220200115549684379403037
Guess: 18519464853909645531783472 mod 37220200115549684379403037
Guess: 21343683438852659649253169 mod 37220200115549684379403037
Guess: 15876516676697024730149868 mod 37220200115549684379403037
Guess: 18700735261640038847619565 mod 37220200115549684379403037
Guess: 7024506548185960959315686 mod 37220200115549684379403037
Guess: 9848725133128975076785383 mod 37220200115549684379403037


In [150]:
217127450270066099617232436535682564139 % 37220200115549684379403037

30195693567363723420087351