## All challenge text is excerpted from https://toadstyle.org/cryptopals/59.txt

```
59. Elliptic Curve Diffie-Hellman and Invalid-Curve Attacks

I'm not going to show you any graphs - if you want to see one, you can
find them in, like, every other elliptic curve tutorial on the
internet. Personally, I've never been able to gain much insight from
them.

They're also really hard to draw in ASCII.

The key thing to understand about elliptic curves is that they're a
setting analogous in many ways to one we're more familiar with, the
multiplicative integers mod p. So if we learn how certain primitive
operations are defined, we can reason about them using a lot of tools
we already have in our utility belts.

Let's dig in. An elliptic curve E is just an equation like this:

    y^2 = x^3 + a*x + b

The choice of the a and b coefficients defines the curve.

We'll use the notation GF(p) to talk about a finite field of size
p. (The "GF" is for "Galois field", another name for a finite field.)
When we take a curve E over field GF(p) (written E(GF(p))), what we're
saying is that only points with both x and y in GF(p) are valid.

For example, (3, 6) might be a valid point in E(GF(7)), but it
wouldn't be a valid point in E(GF(5)); 6 is not a member of GF(5).

(3, 4.7) wouldn't be a valid point on either curve, since 4.7 is not
an integer and thus not a member of either field.

What about (3, -1)? This one is on the curve, but remember we're in
some GF(p). So in GF(7), -1 is actually 6. That means (3, -1) and (3,
6) are the same point. In GF(5), -1 is 4, so blah blah blah you get
what I'm saying.

Okay: if these points are going to form a group analogous to the
multiplicative integers mod p, we need to have an analogous set of
primitive functions to work with them.

1. In the multiplicative integers mod p, we combined two elements by
   multiplying them together and taking the remainder modulo p.

   We combine elliptic curve points by adding them. We'll talk about
   what that means in a hot second.

2. We used 1 as a multiplicative identity: y * 1 = y for all y.

   On an elliptic curve, we define the identity O as an abstract
   "point at infinity" that doesn't map to any actual (x, y)
   pair. This might feel like a bit of a hack, but it works.

   On the curve, we have the straightforward rule that P + O = P for
   all P.

   In your code, you can just write something like O := object(),
   since it only ever gets used in pointer comparisons. Or you can use
   some sentinel coordinate that doesn't satisfy the curve equation;
   (0, 1) is popular.

3. We had a modinv function to invert an integer mod p. This acted as
   a stand-in for division. Given y, it finds x such that y * x = 1.

   Inversion is way easier in elliptic curves. Just flip the sign on
   y, and remember that we're in GF(p):

       invert((x, y)) = (x, -y) = (x, p-y)

   Just like with multiplicative inverses, we have this rule on
   elliptic curves:

       P + (-P) = P + invert(P) = O

Incidentally, these primitives, along with a finite set of elements,
are all we need to define a finite cyclic group, which is all we need
to define the Diffie-Hellman function. Not important to understand the
abstract jargon, just FYI.

Let's talk about addition. Here it is:

    function add(P1, P2):
        if P1 = O:
            return P2

        if P2 = O:
            return P1

        if P1 = invert(P2):
            return O

        x1, y1 := P1
        x2, y2 := P2

        if P1 = P2:
            m := (3*x1^2 + a) / 2*y1
        else:
            m := (y2 - y1) / (x2 - x1)

        x3 := m^2 - x1 - x2
        y3 := m*(x1 - x3) - y1

        return (x3, y3)

The first three checks are simple - they pretty much just implement
the rules we have for the identity and inversion.

After that we, uh, use math. You can read more about that part
elsewhere, if you're interested. It's not too important to us, but it
(sort of) makes sense in the context of those graphs I'm not showing
you.

There's one more thing we need. In the multiplicative integers, we
expressed repeated multiplication as exponentiation, e.g.:

    y * y * y * y * y = y^5

We implemented this using a modexp function that walked the bits of
the exponent with a square-and-multiply inner loop.

On elliptic curves, we'll use scalar multiplication to express
repeated addition, e.g.:

    P + P + P + P + P = 5*P

Don't be confused by the shared notation: scalar multiplication is not
analogous to multiplication in the integers. It's analogous to
exponentiation.

Your scalarmult function will look pretty much exactly the same as
your modexp function, except with the primitives swapped out.

Actually, you wanna hear something great? You could define a generic
scale function parameterized over a group that works as a drop-in
implementation for both. Like this:

    function scale(x, k):
        result := identity
        while k > 0:
            if odd(k):
                result := combine(result, x)
            x := combine(x, x)
            k := k >> 1
        return result

The combine function would delegate to modular multiplication or
elliptic curve point depending on the group. It's kind of like the
definition of a group constitutes a kind of interface, and we have
these two different implementations we can swap out freely.

To extend this metaphor, here's a generic Diffie-Hellman:

    function generate_keypair():
        secret := random(1, baseorder)
        public := scale(base, secret)
        return (secret, public)

    function compute_secret(peer_public, self_secret):
        return scale(peer_public, self_secret)

Simplicity itself! The base and baseorder attributes map to g and q in
the multiplicative integer setting. It's pretty much the same on a
curve: we'll have a base point G and its order n such that:

    n*G = O

The fact that these two settings share so many similarities (and can
even share a naive implementation) is great news. It means we already
have a lot of the tools we need to reason about (and attack) elliptic
curves!

Let's put this newfound knowledge into action. Implement a set of
functions up to and including elliptic curve scalar
multiplication. (Remember that all computations are in GF(p), i.e. mod
p.) You can use this curve:

    y^2 = x^3 - 95051*x + 11279326

Over GF(233970423115425145524320034830162017933). Use this base point:

    (182, 85518893674295321206118380980485522083)

It has order 29246302889428143187362802287225875743.
```

In [1]:
from functools import reduce
from itertools import count
from operator import mul
from random import randrange
from math import log
from dataclasses import dataclass

from challenge_31 import do_sha256, hmac
from challenge_39 import invmod
from from_notebook.challenge_57 import primegen, crt

In [2]:
@dataclass(frozen=True)
class Curve:
    """
    Generic class for elliptic curves represented in Weierstrass form.
    """
    
    a: int
    b: int
    p: int
    
    zero = object()  # share the zero singleton between curve instances

    def inv(self, pt):
        x, y = pt
        p = self.p
        return (x, p-y)  # = (x, -y) mod p

    def add(self, p1, p2):  # don't worry about how this works. it's ~magic~
        zero = self.zero
        if p1 is zero: return p2
        if p2 is zero: return p1
        if p1 == self.inv(p2): return zero  # p1 + (-p1) = 0

        a, p = self.a, self.p
        x1, y1 = p1
        x2, y2 = p2
        
        if p1 == p2:
            top = (3 * x1**2 + a) % p
            btm = (2 * y1) % p
        else:
            top = (y2 - y1) % p
            btm = (x2 - x1) % p
        m = (top * invmod(btm, p)) % p
        
        x3 = (m**2 - x1 - x2) % p
        y3 = (m*(x1 - x3) - y1) % p

        return x3, y3
    
    def mul(self, pt, k):
        result = self.zero
        add = self.add
        while k:
            if k & 1:
                result = add(result, pt)
            pt = add(pt, pt)
            k >>= 1
        return result

    def point(self, x, y):
        """
        Returns a new point on the curve.
        
        (Note that this class's other methods operate on raw coordinates, not Point instances)
        """
        return Point(x, y, self)


@dataclass(frozen=True)
class Point:
    x: int
    y: int
    curve: Curve
    
    def __add__(self, other):
        curve = self.curve
        zero = curve.zero
        assert isinstance(other, Point) or other is zero
        p1 = self.x, self.y
        p2 = other if other is zero else (other.x, other.y)
        result = curve.add(p1, p2)
        return result if result is zero else curve.point(*result)
    
    def __mul__(self, other):
        assert isinstance(other, int)
        curve = self.curve
        pt = self.x, self.y
        result = curve.mul(pt, other)
        return result if result is curve.zero else curve.point(*result)
    
    def __radd__(self, other):
        return self + other
    
    def __rmul__(self, other):
        return self * other
    
    # we could add more dunder methods here but it turns out we don't need them for this challenge

```
Let's put this newfound knowledge into action. Implement a set of
functions up to and including elliptic curve scalar
multiplication. (Remember that all computations are in GF(p), i.e. mod
p.) You can use this curve:

    y^2 = x^3 - 95051*x + 11279326

Over GF(233970423115425145524320034830162017933). Use this base point:

    (182, 85518893674295321206118380980485522083)

It has order 29246302889428143187362802287225875743.

Oh yeah, order. Finding the order of an elliptic curve group turns out
to be a bit tricky, so just trust me when I tell you this one has
order 233970423115425145498902418297807005944. That factors to 2^3 *
29246302889428143187362802287225875743.

If your implementation works correctly, it should be easy to verify:
remember that multiplying the base point by its order should yield the
group identity.
```

In [3]:
curve = Curve(a=-95051, b=11279326, p=233970423115425145524320034830162017933)

base = curve.point(182, 85518893674295321206118380980485522083)
order = 29246302889428143187362802287225875743

assert base * order is curve.zero  # simple correctness check

```
Implement ECDH and verify that you can do a handshake correctly. In
this case, Alice and Bob's secrets will be scalars modulo the base
point order and their public elements will be points. If you
implemented the primitives correctly, everything should "just work".
```

In [4]:
class ECDHKeypair:
    _priv = None
    pub = None
    
    def __init__(self, curve):
        self.curve = curve
        self.keygen()
        
    def keygen(self):
        priv = randrange(0, order)
        pub = base * priv
        self._priv, self.pub = priv, pub
    
    def handshake(self, other_pub):
        return other_pub * self._priv

In [5]:
# Let's run through a test handshake to make sure our ECDH implementation is sound.
# We'll encapsulate this test in a function to avoid polluting the top-level namespace.

def test_handshake():
    alice = ECDHKeypair(curve)
    bob = ECDHKeypair(curve)

    alice_secret = alice.handshake(bob.pub)
    bob_secret = bob.handshake(alice.pub)

    print("Alice's version of shared secret:", alice_secret)
    print("Bob's version of shared secret:  ", bob_secret)
    print()
    assert alice_secret == bob_secret
    print("ECDH handshake successful!")
    
test_handshake()

Alice's version of shared secret: Point(x=59865467645989007467777152194241321196, y=58931224118967431601178956576373450454, curve=Curve(a=-95051, b=11279326, p=233970423115425145524320034830162017933))
Bob's version of shared secret:   Point(x=59865467645989007467777152194241321196, y=58931224118967431601178956576373450454, curve=Curve(a=-95051, b=11279326, p=233970423115425145524320034830162017933))

ECDH handshake successful!


```
Next, reconfigure your protocol from #57 to use ECDH.
```

In [6]:
assert log(curve.p, 2) < 128


def point_to_bytes(pt):
    return repr(pt).encode('ascii')


def bob_coro(message, curve):
    keypair = ECDHKeypair(curve)
    print("Bob: Private key =", keypair._priv)

    # announce our public key on coroutine initialization (before generating first response)
    output = keypair.pub
    
    while True:
        remote_pub = (yield output)
        secret = keypair.handshake(remote_pub)
        mac_key = do_sha256(point_to_bytes(secret))
        mac = hmac(mac_key, message)
        output = (message, mac)

In [7]:
# quick test: make sure Bob gives us correct MACs and doesn't throw any errors

def test_bob(n=10):
    bob = bob_coro(message=b'a pile driver provider for liars', curve=curve)
    bob_pub = next(bob)

    for _ in range(n):
        keypair = ECDHKeypair(curve)
        message, mac = bob.send(keypair.pub)

        mac_key = do_sha256(point_to_bytes(keypair.handshake(bob_pub)))
        assert hmac(mac_key, message) == mac

    print("Bob appears to be working!")

test_bob()

Bob: Private key = 16919048753646620131254890879527295666
Bob appears to be working!


```
Can we apply the subgroup-confinement attacks from #57 in this
setting? At first blush, it seems like it will be pretty difficult,
since the cofactor is so small. We can recover, like, three bits by
sending a point with order 8, but that's about it. There just aren't
enough small-order points on the curve.

How about not on the curve?

Wait, what? Yeah, points *not* on the curve. Look closer at our
combine function. Notice anything missing? The b parameter of the
curve is not accounted for anywhere. This is because we have four
inputs to the calculation: the curve parameters (a, b) and the point
coordinates (x, y). Given any three, you can calculate the fourth. In
other words, we don't need b because b is already baked into every
valid (x, y) pair.

There's a dangerous assumption there: namely, that the peer will
submit a valid (x, y) pair. If Eve can submit an invalid pair, that
really opens up her play: now she can pick points from any curve that
differs only in its b parameter. All she has to do is find some curves
with small subgroups and cherry-pick a few points of small
order. Alice will unwittingly compute the shared secret on the wrong
curve and leak a few bits of her private key in the process.

How do we find suitable curves? Well, remember that I mentioned
counting points on elliptic curves is tricky. If you're very brave,
you can implement Schoof-Elkies-Atkins. Or you can use a computer
algebra system like SageMath. Or you can just use these curves I
generated for you:

y^2 = x^3 - 95051*x + 210

y^2 = x^3 - 95051*x + 504

y^2 = x^3 - 95051*x + 727

They have orders:

233970423115425145550826547352470124412

233970423115425145544350131142039591210

233970423115425145545378039958152057148

They should have a fair few small factors between them. So: find some
points of small order and send them to Alice. You can use the same
trick from before to find points of some prime order r. Suppose the
group has order q. Pick some random point and multiply by q/r. If you
land on the identity, start over.

It might not be immediately obvious how to choose random points, but
you can just pick an x and calculate y. This will require you to
implement a modular square root algorithm; use Tonelli-Shanks, it's
pretty straightforward.
```

In [8]:
class NoQuadraticResidueError(Exception):  pass


def eulers_criterion(n, p):
    # tests whether n is a quadratic residue mod p
    # (i.e. whether there exists x such that pow(x, 2, p) == n)
    return pow(n, (p-1)//2, p) == 1


def modsqrt(n, p):
    if not eulers_criterion(n, p):
        raise NoQuadraticResidueError

    # reference: https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm

    # 1. find Q, S such that Q is odd and Q * 2**S = p-1
    Q, S = p-1, 0
    while Q & 1 == 0:  # equivalent to Q % 2 == 0
        Q >>= 1  # equivalent to Q //= 2
        S += 1

    # 2. find some int z such that z is not a quadratic residue mod p
    z = 2
    while eulers_criterion(z, p):
        z += 1
        assert z < p

    # 3. initialize main loop's state variables
    M = S
    c = pow(z, Q, p)
    t = pow(n, Q, p)
    R = pow(n, (Q+1) // 2, p)

    # 4. main loop
    while t > 1:
        # find i's value using repeated squaring
        t_sq = t
        for i in count(1):
            t_sq = pow(t_sq, 2, p)
            if t_sq == 1:
                break
            assert i < M  # cheap correctness check; if i >= M then the residue doesn't exist
                          # (shouldn't ever happen, since we made sure the residue should
                          # exist by checking Euler's criteron at the top of this function)

        # update state variables and loop
        exponent = M - i - 1
        if exponent < 0:
            b = pow(c, 2**(-exponent), p)
            b = (b * invmod(c, p)) % p
        else:
            b = pow(c, 2**exponent, p)

        M = i
        c = pow(b, 2, p)
        t = (t * c) % p
        R = (R * b) % p

    if t == 0:
        return 0

    res1 = R
    res2 = (-R) % p
    return res1, res2

In [9]:
def test_tonelli_shanks():
    p = 17
    for i in range(1, p):
        sq = pow(i, 2, p)
        roots = modsqrt(sq, p)
        assert i in roots
    print("Tonelli-Shanks appears to be working!")
test_tonelli_shanks()

Tonelli-Shanks appears to be working!


```
Implement the key-recovery attack from #57 using small-order points
from invalid curves.
```

In [10]:
def find_point_of_order_r(r, curve, curve_order):
    a, b, p, zero = curve.a, curve.b, curve.p, curve.zero
    
    while True:
        # generate a random point by picking an x-coordinate and trying to solve for y
        x = randrange(0, p)
        y_sq = (pow(x, 3, p) + a*x + b) % p
        
        # try to go from y^2 to y
        try:
            y = modsqrt(y_sq, p)[0]  # arbitrarily take the 1st residue returned by tonelli-shanks
        except NoQuadraticResidueError:
            continue
        pt = curve.point(x, y)
        
        # try to go from our random point to a point with order r; if successful, return
        pt2 = pt * (curve_order // r)
        if pt2 is not zero:
            assert pt2 * r is zero
            return pt2

In [11]:
# Parameters for our new curves:
new_curves = [Curve(curve.a, b, curve.p) for b in (210, 504, 727)]
new_orders = [233970423115425145550826547352470124412,
              233970423115425145544350131142039591210,
              233970423115425145545378039958152057148]

In [12]:
bob = bob_coro(message=b'a pile driver provider for liars', curve=curve)
bob_pub = next(bob)
bob_pub  # outputs bob's public key

Bob: Private key = 19990612835856002233477428662883116341


Point(x=109902732143885232361967428973062751214, y=195899022950719055491266162866661997777, curve=Curve(a=-95051, b=11279326, p=233970423115425145524320034830162017933))

In [13]:
moduli = []
residues = []

for new_curve, new_order in zip(new_curves, new_orders):
    print("\nNow using b =", new_curve.b)
    print("Partially factoring curve's order...")

    small_non_repeated_factors = [p for p in primegen(up_to=2**16)
                                  if new_order % p == 0]

    divisors = [d for d in small_non_repeated_factors
                if d not in moduli and d != 2]  # d=2 gives us points with y-coord 0 - more trouble than it's worth

    moduli += divisors
    
    if divisors:
        print("New moduli:", divisors)
        print("Gathering residues...")
        for d in divisors:
            base_pt = find_point_of_order_r(d, new_curve, new_order)
            message, mac = bob.send(base_pt)
            
            # run exhaustive search on range(d) to determine bob._priv % d
            pt = curve.zero
            for i in range(d):
                mac_key = do_sha256(point_to_bytes(pt))
                if hmac(mac_key, message) == mac:
                    break
                pt = base_pt + pt
            else:
                raise Exception("couldn't find mac key")
            
            residues.append(i)  # i = bob._priv % d
        print("Done.")

assert reduce(mul, moduli, 1) > curve.p  # make sure we have enough moduli for the CRT to work
print("\nWe have enough data to use the CRT!")


Now using b = 210
Partially factoring curve's order...
New moduli: [3, 11, 23, 31, 89, 4999, 28411, 45361]
Gathering residues...
Done.

Now using b = 504
Partially factoring curve's order...
New moduli: [5, 7, 61, 12157, 34693]
Gathering residues...
Done.

Now using b = 727
Partially factoring curve's order...
New moduli: [37, 67, 607, 1979, 13327, 13799]
Gathering residues...
Done.

We have enough data to use the CRT!


In [14]:
print("Residues:", residues)
print("Moduli:", moduli)

res = crt(residues, moduli)  # this is our guess for Bob's private key
assert res[1] > curve.p
bob_priv = res[0]
bob_pub_derived = base * bob_priv

print()
print("Bob's private key (derived):", bob_priv)
#print("Bob's public key (derived): ", bob_pub_derived)
#print("Bob's public key (actual):  ", bob_pub)
print()
assert bob_pub == bob_pub_derived
print("It worked!")

Residues: [1, 4, 7, 14, 76, 3476, 16084, 20794, 1, 3, 20, 4842, 8292, 16, 34, 357, 834, 12886, 8622]
Moduli: [3, 11, 23, 31, 89, 4999, 28411, 45361, 5, 7, 61, 12157, 34693, 37, 67, 607, 1979, 13327, 13799]

Bob's private key (derived): 19990612835856002233477428662883116341

It worked!
