Skip to content

Commit

Permalink
Merge fd273d0 into 483700a
Browse files Browse the repository at this point in the history
  • Loading branch information
methane committed Jun 11, 2021
2 parents 483700a + fd273d0 commit 808d1e6
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 7 deletions.
2 changes: 1 addition & 1 deletion rsa/pkcs1.py
Expand Up @@ -49,7 +49,7 @@
"SHA-512": b"\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40",
}

HASH_METHODS: typing.Dict[str, typing.Callable[[], HashType]] = {
HASH_METHODS: typing.Dict[str, typing.Callable[..., HashType]] = {
"MD5": hashlib.md5,
"SHA-1": hashlib.sha1,
"SHA-224": hashlib.sha224,
Expand Down
207 changes: 201 additions & 6 deletions rsa/pkcs1_v2.py
Expand Up @@ -15,14 +15,24 @@
"""Functions for PKCS#1 version 2 encryption and signing
This module implements certain functionality from PKCS#1 version 2. Main
documentation is RFC 2437: https://tools.ietf.org/html/rfc2437
documentation is RFC 8017: https://tools.ietf.org/html/rfc8017
"""

from rsa import (
common,
pkcs1,
transform,
)
import os
from hmac import compare_digest

from . import common, transform, core, key, pkcs1
from ._compat import xor_bytes


def _constant_time_select(v: int, t: int, f: int) -> int:
"""Return t if v else f.
v must be 0 or 1. (False and True are allowed)
t and f are integer between 0 and 255.
"""
v -= 1
return (~v & t) | (v & f)


def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes:
Expand Down Expand Up @@ -81,8 +91,193 @@ def mgf1(seed: bytes, length: int, hasher: str = "SHA-1") -> bytes:
return output[:length]


def _OAEP_encode(
message: bytes, keylength: int, label, hash_method: str, mgf1_hash_method: str
) -> bytes:
try:
hasher = pkcs1.HASH_METHODS[hash_method](label)
except KeyError:
raise ValueError(
"Invalid `hash_method` specified. Please select one of: {hash_list}".format(
hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys()))
)
)
hash_length = hasher.digest_size
max_message_length = keylength - 2 * hash_length - 2
message_length = len(message)
if message_length > max_message_length:
raise OverflowError(
"message is too long; at most %s bytes, given %s bytes"
% (max_message_length, len(message))
)

lhash = hasher.digest()
ps = bytearray(keylength - message_length - 2 * hash_length - 2)
db = (
hasher.digest()
+ b"\0" * (keylength - message_length - 2 * hash_length - 2)
+ b"\x01"
+ message
)

seed = os.urandom(hash_length)
db_mask = mgf1(seed, keylength - hash_length - 1, mgf1_hash_method)
masked_db = xor_bytes(db, db_mask)

seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method)
masked_seed = xor_bytes(seed, seed_mask)

em = b"\x00" + masked_seed + masked_db
return em


def encrypt_OAEP(
message: bytes,
pub_key: key.PublicKey,
label: bytes = b"",
hash_method: str = "SHA-1",
mgf1_hash_method: str = None,
) -> bytes:
"""Encrypts the given message using PKCS#1 v2 RSA-OEAP.
:param message: the message to encrypt.
:param pub_key: the public key to encrypt with.
:param label: optional RSA-OAEP label.
:param hash_method: hash function to be used. 'SHA-1' (default),
'SHA-256', 'SHA-384', and 'SHA-512' can be used.
:param mgf1_hash_method: hash function to be used by MGF1 function.
If it is None (default), *hash_method* is used.
"""
# NOTE: Some hash method other than listed in the docstring can be used
# for hash_method. But the RFC 8017 recommends only them.
if mgf1_hash_method is None:
mgf1_hash_method = hash_method
keylength = common.byte_size(pub_key.n)

em = _OAEP_encode(message, keylength, label, hash_method, mgf1_hash_method)

m = transform.bytes2int(em)
encrypted = core.encrypt_int(m, pub_key.e, pub_key.n)
c = transform.int2bytes(encrypted, keylength)

return c


def decrypt_OAEP(
crypto: bytes,
priv_key: key.PrivateKey,
label: bytes = b"",
hash_method: str = "SHA-1",
mgf1_hash_method: str = None,
) -> bytes:
"""Decrypts the givem crypto using PKCS#1 v2 RSA-OAEP.
:param crypto: the crypto text as returned by :py:func:`rsa.encrypt`
:param priv_key: the private key to decrypt with.
:param label: optional RSA-OAEP label.
:param hash_method: hash function to be used. 'SHA-1' (default),
'SHA-256', 'SHA-384', and 'SHA-512' can be used.
:param mgf1_hash_method: hash function to be used by MGF1 function.
If it is None (default), *hash_method* is used.
:raise rsa.pkcs1.DecryptionError: when the decryption fails. No details are given as
to why the code thinks the decryption fails, as this would leak
information about the private key.
>>> import rsa
>>> (pub_key, priv_key) = rsa.newkeys(512)
It works with binary data:
>>> crypto = encrypt_OAEP(b'hello', pub_key)
>>> decrypt_OAEP(crypto, priv_key)
b'hello'
You can pass optional label data too:
>>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world')
>>> decrypt_OAEP(crypto, priv_key, label=b'world')
b'hello'
Altering the encrypted information will cause a
:py:class:`rsa.pkcs1.DecryptionError`.
>>> crypto = encrypt_OAEP(b'hello', pub_key)
>>> crypto = crypto[0:5] + bytes([(ord(crypto[5:6])+1)%256]) + crypto[6:] # change a byte
>>> decrypt_OAEP(crypto, priv_key)
Traceback (most recent call last):
...
rsa.pkcs1.DecryptionError: Decryption failed
Changing label will also cause the error.
>>> crypto = encrypt_OAEP(b'hello', pub_key, label=b'world')
>>> decrypt_OAEP(crypto, priv_key, label=b'universe')
Traceback (most recent call last):
...
rsa.pkcs1.DecryptionError: Decryption failed
"""
if mgf1_hash_method is None:
mgf1_hash_method = hash_method

# todo: Step 1: length checking
k = common.byte_size(priv_key.n)
if k != len(crypto):
raise pkcs1.DecryptionError("Decryption failed")

# Step 2: RSA Decryption
c = transform.bytes2int(crypto)
m = priv_key.blinded_decrypt(c)
em = transform.int2bytes(m, k)

# Step 3: EME-OAEP decoding
try:
hasher = pkcs1.HASH_METHODS[hash_method](label)
except KeyError:
raise ValueError(
"Invalid `hash_method` specified. Please select one of: {hash_list}".format(
hash_list=", ".join(sorted(pkcs1.HASH_METHODS.keys()))
)
)
hash_length = hasher.digest_size
lhash = hasher.digest()
Y = em[0:1]
masked_seed = em[1 : 1 + hash_length]
masked_db = em[1 + hash_length :]

seed_mask = mgf1(masked_db, hash_length, mgf1_hash_method)
seed = xor_bytes(masked_seed, seed_mask)

db_mask = mgf1(seed, k - hash_length - 1, mgf1_hash_method)
db = xor_bytes(masked_db, db_mask)

lhash_ = db[:hash_length]
rest = db[hash_length:]

# NOTE: Take care about timing attack. See note in the RFC.
hash_is_good = compare_digest(lhash, lhash_)

index = invalid = 0
looking_one = 1

for i, c in enumerate(rest):
iszero = c == 0
isone = c == 1

index = _constant_time_select(looking_one & isone, i, index)
looking_one = _constant_time_select(isone, 0, looking_one)
invalid = _constant_time_select(looking_one & ~iszero, 1, invalid)

if invalid | looking_one | (not hash_is_good):
raise pkcs1.DecryptionError("Decryption failed")

return rest[index + 1 :]


__all__ = [
"mgf1",
"encrypt_OAEP",
"decrypt_OAEP",
]

if __name__ == "__main__":
Expand Down
43 changes: 43 additions & 0 deletions tests/test_pkcs1_v2.py
Expand Up @@ -18,9 +18,13 @@
http://www.itomorrowmag.com/emc-plus/rsa-labs/standards-initiatives/pkcs-rsa-cryptography-standard.htm
"""

import struct
import unittest

import rsa
from rsa import pkcs1_v2
from rsa._compat import byte
from rsa.pkcs1 import DecryptionError


class MGFTest(unittest.TestCase):
Expand Down Expand Up @@ -77,3 +81,42 @@ def test_invalid_hasher(self):
def test_invalid_length(self):
with self.assertRaises(OverflowError):
pkcs1_v2.mgf1(b"\x06\xe1\xde\xb2", length=2 ** 50)


class BinaryTest(unittest.TestCase):
def setUp(self):
(self.pub, self.priv) = rsa.newkeys(512)

def test_enc_dec(self):
message = struct.pack(">IIII", 0, 0, 0, 1)
print("\tMessage: %r" % message)

encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub)
print("\tEncrypted: %r" % encrypted)

decrypted = pkcs1_v2.decrypt_OAEP(encrypted, self.priv)
print("\tDecrypted: %r" % decrypted)

self.assertEqual(message, decrypted)

def test_decoding_failure(self):
message = struct.pack(">IIII", 0, 0, 0, 1)
encrypted = pkcs1_v2.encrypt_OAEP(message, self.pub)

# Alter the encrypted stream
a = encrypted[5]
altered_a = (a + 1) % 256
encrypted = encrypted[:5] + byte(altered_a) + encrypted[6:]

self.assertRaises(DecryptionError, pkcs1_v2.decrypt_OAEP, encrypted, self.priv)

def test_randomness(self):
"""Encrypting the same message twice should result in different
cryptos.
"""

message = struct.pack(">IIII", 0, 0, 0, 1)
encrypted1 = pkcs1_v2.encrypt_OAEP(message, self.pub)
encrypted2 = pkcs1_v2.encrypt_OAEP(message, self.pub)

self.assertNotEqual(encrypted1, encrypted2)

0 comments on commit 808d1e6

Please sign in to comment.