In [5]:
from collections import namedtuple
from math import inf
from Crypto.Util.number import inverse
# Create a simple Point class to represent the affine points.
Point = namedtuple("Point", "x y isinf", defaults=[False])
def pt_to_str(p: Point):
    return str(p.x)+","+str(p.y) if not p.isinf else "O"
Point.__str__ = pt_to_str
p = 9739
a = 497
b = 1768
def elliptic_curve_addition(p1: Point, p2: Point):
    assert not (p1.isinf and p2.isinf), "cannot add O to O"
    if p1.isinf: return p2
    if p2.isinf: return p1
    if (p1.x==p2.x and p1.y==-p2.y): return Point(0,0,isinf=True)
    tangent = (p2.y - p1.y)*inverse(p2.x - p1.x, p) if p1!=p2 else (3*(p1.x)**2 + a)*inverse(2*p1.y, p)
    p3x = tangent**2 - p1.x - p2.x
    p3y = tangent*(p1.x - p3x) - p1.y
    return Point(p3x%p, p3y%p)

def elliptic_curve_multiplication(p: Point, n: int):
    p1 = p
    p2 = Point(0,0,True)
    while (n>0):
        if n%2==1: p2 = elliptic_curve_addition(p1, p2)
        p1 = elliptic_curve_addition(p1, p1)
        n = n//2
    return p2

# p1 = Point(493,5564)
# p2 = Point(1539,4742)
# p3 = Point(4403,5202)
# p4 = Point(5323,5438)
p5x = 4726
from sympy import solve
from sympy.abc import x, y
from math import sqrt
g = Point(1804,5368)
ypow2 = (p5x**3 + a*p5x + b) % p
import random
import logging
import sys

_logger = logging.getLogger("tonellishanks")


def legendre_symbol(a: int, p: int, /) -> int:

    assert p % 2 != 0

    return pow(a, (p - 1) >> 1, p)


def _choose_b(p: int, /, *, det=True) -> int:

    assert p > 2
    assert p % 2 != 0

    b = 2
    _attempts = 1

    if det:
        while legendre_symbol(b, p) == 1:
            b += 1
            _attempts += 1
    else:
        while legendre_symbol(b, p) == 1:
            b = random.randrange(2, p)
            _attempts += 1

    assert b < p
    assert legendre_symbol(b, p) == p - 1

    _logger.info("Found b = %d after %d attempts", b, _attempts)

    return b


def _tonelli_shanks_recursive(a: int, k: int, p: int, b: int, b_inverse: int, /):
    """
    Computes a square root of a modulo prime p
    :param a: the number to take the square root of
    :param k: positive integer, such that a^m = 1 (mod p) where m = (p-1)/(2^k)
    :param p: odd prime p modulo which we are working
    :param b: an arbitrary non-square modulo p
    :param b_inverse: the inverse of b modulo p, i.e., b * b_inverse = 1 (mod p)
    :return: one of the square roots of a modulo p (the other can be obtained via negation modulo p)
    """

    assert p > 2
    assert 0 < a < p
    assert k > 0

    m = (p - 1) >> k

    # assumption
    assert pow(a, m, p) == 1

    a_m = 1

    # check that b is indeed a non-square modulo p
    assert legendre_symbol(b, p) == p - 1

    _logger.info("-------- [New round] --------")
    _logger.info("a = %d, m = %d, a^m = 1", a, m)

    while m % 2 == 0 and a_m == 1:

        m >>= 1
        k += 1

        assert m == (p - 1) >> k

        a_m = pow(a, m, p)

        _logger.info(
            "m is even and a^m = 1 => we divide m by 2 and get: m = %d, a^m = %s",
            m,
            "1" if a_m == 1 else "-1"
        )

        # since Z/pZ is a field, there cannot be any roots for 1 apart from 1 and -1
        assert a_m == 1 or a_m == p - 1

    assert a_m == 1 or a_m == p - 1

    if a_m == p - 1:
        # a^m = -1 (mod p)
        _logger.info("m = %d, a^m = -1 => we multiply a^m with a legendre symbol of a non-square b modulo p", m)
        assert k >= 2
        b_power = 1 << (k - 1)
        b_power_half = 1 << (k - 2)
        assert pow(a, m, p) == p - 1
        assert b_power * m == (p - 1) >> 1
        a_next = (a * pow(b, b_power, p)) % p
        _logger.info("(a * b^%d)^m = (a * b^%d)^%d = %d^%d = 1", b_power, b_power, m, a_next, m)
        _logger.info(
            "It follows that a_next := a * b^%d = %d * %d = %d is a square whose root yields a root of a",
            b_power,
            a,
            pow(b, b_power, p),
            a_next
        )
        assert pow(a_next, m, p) == 1
        a_next_root = _tonelli_shanks_recursive(a_next, k, p, b, b_inverse)
        _logger.info("The root of a_next = %d is %d", a_next, a_next_root)
        a_root = a_next_root * pow(b_inverse, b_power_half, p)
        _logger.info("sqrt(a_next)^2 = %d^2 = a_next = a * b^%d = sqrt(a)^2 * b^%d", a_next_root, b_power, b_power)
        _logger.info(
            "=> sqrt(a = %d) = sqrt(a_next) * b^(-%d) = %d * %d = %d",
            a,
            b_power_half,
            a_next_root,
            pow(b_inverse, b_power_half, p),
            a_root
        )
        _logger.info("-------- [Round complete] --------")
        return a_root % p
    assert a_m == 1
    assert m % 2 == 1
    _logger.info("-------- [Round complete] --------")
    # we now handle the case when m is odd
    # this case is easy, a^((m+1)/2) is a square root of a
    return pow(a, (m + 1) >> 1, p)


def tonelli_shanks(a: int, p: int, /, *, deterministic=True) -> int | None:
    """
    Computes a square root of a modulo prime p
    :param a: the number to take the square root of
    :param p: odd prime p modulo which we are working
    :param deterministic: whether to search for the non-square b deterministically
    :return: one of the square roots of a modulo p (the other can be obtained via negation modulo p)
    """
    assert p > 2
    assert 0 < a < p
    # quick Fermat primality test
    assert pow(a, p - 1, p) == 1
    if legendre_symbol(a, p) != 1:
        # a is not not a square modulo p
        return None
    _logger.info("======== [Starting algorithm with a = %d, p = %d] ========", a, p)
    b = _choose_b(p, det=deterministic)
    b_inverse = inverse(b, p)
    assert b * b_inverse % p == 1
    return _tonelli_shanks_recursive(a, 1, p, b, b_inverse)
y = tonelli_shanks(ypow2, p)
secret = elliptic_curve_multiplication(Point(p5x, y), 6534)
# from Crypto.Hash import SHA1
# SHA1.new(str(flag.x).encode()).hexdigest()

In [6]:
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
import hashlib


def is_pkcs7_padded(message):
    padding = message[-message[-1]:]
    return all(padding[i] == len(padding) for i in range(0, len(padding)))


def decrypt_flag(shared_secret: int, iv: str, ciphertext: str):
    # Derive AES key from shared secret
    sha1 = hashlib.sha1()
    sha1.update(str(shared_secret).encode('ascii'))
    key = sha1.digest()[:16]
    # Decrypt flag
    ciphertext = bytes.fromhex(ciphertext)
    iv = bytes.fromhex(iv)
    cipher = AES.new(key, AES.MODE_CBC, iv)
    plaintext = cipher.decrypt(ciphertext)

    if is_pkcs7_padded(plaintext):
        return unpad(plaintext, 16).decode('ascii')
    else:
        return plaintext.decode('ascii')

print(decrypt_flag(secret.x, 'cd9da9f1c60925922377ea952afc212c', 'febcbe3a3414a730b125931dccf912d2239f3e969c4334d95ed0ec86f6449ad8'))

crypto{3ff1c1ent_k3y_3xch4ng3}
