In [1]:
p = random_prime(2**1024, lbound=2**1023); p

176560261280700021104198920229722802800997439469497246700485693000187332389437612943959631164376460005349413312164276388723911940245042099113792756723215125177496311819480794475570086185739437576104377780561032193974393186245545758818618329643582794906138342518777172283464714526741898087094151912579220791469

In [2]:
q = random_prime(2**1024, lbound=2**1023); q

118000551743360086929938573812557675958933367432905902981986724378730875384763601190908009895785030261651879126191862824976423007629650214208070642866078787685716693767756778196269761805920104221459999465719153881659225150606565485633012688385840969012449035359831533721476919395154216160188098943187879244723

In [4]:
n = p*q

In [5]:
floor(log(n, 2)) + 1

2048

In [6]:
phi_n = (p-1)*(q-1)

In [7]:
e = ZZ.random_element(phi_n)
while gcd(e, phi_n) != 1:
    e = ZZ.random_element(phi_n)
e

12198233055532555311609366736366824423840730293625300492506809084034548013222228297415509651880232078587452483755232498121417493981630254317837526042917352943278746783893188055752067338101691460145519110708264824516123898649718231640881606162371143310582189162712743283280773878270025319902887184014371271581812172281573191048655930994099977117482487248452697418623021388621775572187974005971034107576316134542813751422773195820222267173581492979609841124234338337528045894741329157262676715680102242434685687145495336124660585276101760424884972410644210612316785283124170409501623634127587346990900078432296725051533

In [11]:
# in practice
e = 2**16 + 1

In [9]:
_, d, _ = xgcd(e, phi_n); d

1037307498210229798034591993241309870735940595896570299336720652594871192780207887431181839643847605149877632809295556441608963854176249268575457680274592935994263313342620441436873342268579664630744417120811838264100841812422786233821868923593179009147272221447054788566641352390355290603877423231123846118038382313825232081570816312148488783986554169355663702180180034151221603650505972236068129315865892481385224601886178522687631441703040460231203022675079139477576290370081000740031431837305079943074582349279898987296209770938926394034577103996239717653499655541374599608320062903140683742624349339212111631977

In [12]:
m = ZZ.random_element(n)
c = power_mod(m, e, n) # m^e mod n
c

10738207750639980214939964500763504664173900912148402577047388637366579367834080172557463518555702010061366435308493360920654321560238576304273255009173672083268780773715520796528228534940206168317363346238680063770074305572167614264148656517118188128313466886121602299211700165617871908427023392820010444174491595637094814064555742916766313232464623215007625263996367379666975774839259762101247634703821536703181483379564984366519419620678919809689115721397882465923865621212329231498568005134389893723669167715930558844405668445816749402839863532592754315832265715114089427966319074008630169248516357889389745083105

In [13]:
power_mod(c, d, n) == m

True

In [18]:
import hashlib

def mgf(seed: bytes, length: int, hash_func = hashlib.sha1):
    T = b''
    counter = 0
    
    while len(T) < length:
        T += hash_func(seed + int.to_bytes(int(counter), 4, 'big')).digest()
        counter += 1
    
    return T[:length]

def xor_bytes(lhs, rhs):
    return bytes(map(lambda x: x[0] ^^ x[1], zip(lhs, rhs)))

In [23]:
import secrets

def rsa_oaep_encode(rsa_length: int, label: bytes, message: bytes, hash_func=hashlib.sha1):
    hash_length = hash_func().digest_size # 20 for sha1
    
    seed = secrets.token_bytes(hash_length)
    
    hashed_label = hash_func(label).digest()
    
    padding_string_length = rsa_length - 2*hash_length - 2 - len(message)
    padding_string = bytes(padding_string_length)
    
    data_block = hashed_label + padding_string + b'\x01' + message
    data_block_length = len(data_block)
    
    data_block_mask = mgf(seed, data_block_length, hash_func)
    masked_data_block = xor_bytes(data_block, data_block_mask)
    
    seed_mask = mgf(masked_data_block, hash_length)
    masked_seed = xor_bytes(seed, seed_mask)
    
    return b'\x00' + masked_seed + masked_data_block

def rsa_oaep_decode(rsa_length: int, label: bytes, encoded_message: bytes, hash_func=hashlib.sha1):
    hash_length = hash_func().digest_size
    
    masked_seed = encoded_message[1:hash_length + 1]
    masked_data_block = encoded_message[hash_length + 1:]
    
    seed_mask = mgf(masked_data_block, hash_length, hash_func)
    seed = xor_bytes(masked_seed, seed_mask)
    
    data_block_mask = mgf(seed, len(masked_data_block), hash_func)
    data_block = xor_bytes(masked_data_block, data_block_mask)
    
    hashed_label = data_block[:hash_length]
    # we could check if hashed_labled == hash_func(label).digest()
    
    offset = hash_length
    while data_block[offset] != 1:
        # here we could check if data_block[offset] == 0
        offset += 1
    
    return data_block[offset + 1:]

In [24]:
encoded_message = rsa_oaep_encode(256, b'', b'attack at dawn')
rsa_oaep_decode(256, b'', encoded_message)

b'attack at dawn'