In [2]:
with open("modified.pem", "r") as f:
    modified = f.read()
modified = modified.strip().split("\n")[1:-1]
modified = [line.ljust(64, "?") for line in modified]
modified = "".join(modified)

In [3]:
import base64

unkchar = "?"

pad = ""
key_b64 = modified
key_bytes = base64.b64decode(
    key_b64.replace(unkchar, "A") + pad
)
key_mask = base64.b64decode(
    "".join([
        "/" if i != unkchar else "A"
        for i in key_b64
    ]) + pad
)

bufs = [key_bytes, key_mask]
def take(l, desc="?", show=True, full=False, check=None, checkhex=None):
    if check is None and checkhex is not None:
        check = bytes.fromhex(checkhex)
    if check:
        assert int(check.hex(), 16) & int(bufs[1][:l].hex(), 16) == int(bufs[0][:l].hex(), 16)
    if show:
        print(desc)
        for i in range(2):
            if l < 30 or full:
                print(bufs[i][:l].hex())
            else:
                bitcount = bin(int(bufs[i][:l].hex(), 16)).count("1")
                n_set_bits = [f"{bitcount}/{l*8}"] if i == 1 else []
                print(bufs[i][:l][:15].hex(), "...", bufs[i][:l][-15:].hex(), *n_set_bits)
        print()

    ret = []
    for i in range(2):
        ret.append(int(bufs[i][:l].hex(), 16))

    for i in range(2):
        bufs[i] = bufs[i][l:]

    return ret

take(4, "struct header", show=False)
take(3, "int length 1, version 0", show=False)
take(15, "sequence ? keytype", show=False)
take(4, "octet stream header", show=False)
take(4, "sequence header", show=False)
take(3, "int length 1, version 0", show=False)

take(4, "int length 0x101 (for n)")
n, nmask = take(0x101, "value of n")

take(5, "int length 0x03, value of e = 65537", checkhex="0203010001")
e = 0x10001

take(4, "int length 0x100", checkhex="02820100")
d, dmask = take(256, "value of d")

take(3, "int length 0x81", checkhex="028181")
p, pmask = take(129, "value of p")

take(3, "int length 0x81", checkhex="028181")
q, qmask = take(129, "value of q")

int length 0x101 (for n)
02820101
ffffffff

value of n
0096e5d0c15710d408135a223dcf6f ... 6930000000000000001adeaf833a09
ffffffffffffffffffffffffffffff ... fff000000000000000ffffffffffff 1876/2056

int length 0x03, value of e = 65537
0203010001
ffffffffff

int length 0x100
02820100
ffffffff

value of d
07b15cab09d8c57b49ea14230328f4 ... 9cc862002cad5462897b9b52c13647
ffffffffffffffffffffffffffffff ... ffffffffffffffffffffffffffffff 1568/2048

int length 0x81
000000
f00000

value of p
00000000000000000000000000001c ... 52d1982e75f30c9758e4e9e1708b8f
0000000000000000000000000000ff ... ffffffffffffffffffffffffffffff 620/1032

int length 0x81
028181
ffffff

value of q
00be1539c000000000000000000000 ... 0000000000000000baeb6b1e7ff4e7
ffffffffff00000000000000000000 ... 0000000000000000ffffffffffffff 492/1032



In [4]:
bits_for_k = 14 * 4
mod = 2 ** bits_for_k
def M(x):
    return x % mod

lhs = M(d * e - 1)
rhs = M((p-1) * (q-1))
for k in range(1, e):
    if lhs == M(k * rhs):
        print(f"Found {k=}")

In [5]:
bits_for_k = 14 * 4
mod = 2 ** bits_for_k
def M(x):
    return x % mod

lhs = M(d * e - 1)
rhs = M((p-1) * (q-1))
found_gk = None
for g in range(100):
    for k in range(1, e):
        if M(g * lhs) == M(k * rhs):
            found_gk = (g, k)
            print(f"Found {g=} {k=}")

assert found_gk is not None
g, k = found_gk

Found g=10 k=33411


In [6]:
from itertools import product

params = {
    "n": (n, nmask),
    "d": (d, dmask),
    "p": (p, pmask),
    "q": (q, qmask)
}

def get_fiddle_choices(name, cur, fiddle):
    val, mask = params[name]
    if fiddle & mask:
        return [cur + (fiddle & val)]
    else:
        return [cur, cur + fiddle]
    
endpos = 1024
results = []
def search_and_prune(pos, vals):
    if pos == endpos:
        results.append(vals)
        print(vals)
        return
    fiddle = 2**pos
    mod = fiddle * 2
    options = product(
        *[get_fiddle_choices(name, cur, fiddle) for name, cur in zip("pqnd", vals)]
    )

    for next_vals in options:
        next_p, next_q, next_n, next_d = next_vals
        if (next_p * next_q) % mod != next_n:
            continue
        if ((next_d * e - 1) * g) % mod != (k * (next_p - 1) * (next_q - 1)) % mod:
            continue
        search_and_prune(pos + 1, next_vals)

search_and_prune(0, (0, 0, 0, 0))
print(f"{len(results)=}")

(142710244548240793159699572784667942965705597965582813340076903550284587474979963221304279150582562939600583748288321502816326805387067383591435334776777246613012496090769018352386968305950402019439913918994189509860606368864108133875043107015125440921782018517384461932865474416011156850686483240695029074831, 133480760455551509062668077026860103880645887467422448435638716951583414195873116366992771274871618887571707777864389848099440814273801401323900893718018500769253238628985089867000767857097674388416864916468961564235405662871602980335735605932732414354471646167477790927427901295396336834533290550708175631591, 171269452369763223105693853999419492894476014718721905173840697753104840608608269495169129212915423863384814606526302751637942997180493599265963313884370644645194138012347852542059210676732844702723968555892423740128895142432473916119607171743443255213262085046184067804312360871974587154300898744257682291209, 827893409458384169311006890306189448017871107122121288315502730994

In [7]:
from gmpy2 import is_prime

assert results[0][:2] == results[1][:2]

p, q, n, d = results[0]
assert is_prime(p)
assert is_prime(q)

In [8]:
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_OAEP

n = p * q
phi = (p-1) * (q-1)
d = pow(e, -1, phi)

key = RSA.construct((n, e, d))
cipher = PKCS1_OAEP.new(key)
print(cipher.decrypt(open("out.txt", "rb").read()))

b'UMDCTF{impressive_recovery!_i_forgot_to_tell_you_this_but_the_private_key_ends_with_VATE KEY-----}'


In [16]:
from collections import Counter

masks = [nmask, dmask, pmask, qmask]
Counter([
    ([bool(m & (1 << b)) for m in masks]).count(False)
    for b in range(1024)
])

Counter({2: 488, 1: 348, 0: 188})