Skip to content

Commit

Permalink
Port RSA to rust (#9152)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Aug 20, 2023
1 parent 8b4025a commit 0000b40
Show file tree
Hide file tree
Showing 10 changed files with 768 additions and 832 deletions.
270 changes: 15 additions & 255 deletions src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
from cryptography.hazmat.backends.openssl import aead
from cryptography.hazmat.backends.openssl.ciphers import _CipherContext
from cryptography.hazmat.backends.openssl.cmac import _CMACContext
from cryptography.hazmat.backends.openssl.rsa import (
_RSAPrivateKey,
_RSAPublicKey,
)
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
from cryptography.hazmat.bindings.openssl import binding
from cryptography.hazmat.primitives import hashes, serialization
Expand Down Expand Up @@ -63,7 +59,6 @@
XTS,
Mode,
)
from cryptography.hazmat.primitives.serialization import ssh
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PBES,
PKCS12Certificate,
Expand Down Expand Up @@ -358,24 +353,7 @@ def generate_rsa_private_key(
self, public_exponent: int, key_size: int
) -> rsa.RSAPrivateKey:
rsa._verify_rsa_parameters(public_exponent, key_size)

rsa_cdata = self._lib.RSA_new()
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)

bn = self._int_to_bn(public_exponent)
bn = self._ffi.gc(bn, self._lib.BN_free)

res = self._lib.RSA_generate_key_ex(
rsa_cdata, key_size, bn, self._ffi.NULL
)
self.openssl_assert(res == 1)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)

# We can skip RSA key validation here since we just generated the key
return _RSAPrivateKey(
self, rsa_cdata, evp_pkey, unsafe_skip_rsa_key_validation=True
)
return rust_openssl.rsa.generate_private_key(public_exponent, key_size)

def generate_rsa_parameters_supported(
self, public_exponent: int, key_size: int
Expand All @@ -401,46 +379,15 @@ def load_rsa_private_numbers(
numbers.public_numbers.e,
numbers.public_numbers.n,
)
rsa_cdata = self._lib.RSA_new()
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
p = self._int_to_bn(numbers.p)
q = self._int_to_bn(numbers.q)
d = self._int_to_bn(numbers.d)
dmp1 = self._int_to_bn(numbers.dmp1)
dmq1 = self._int_to_bn(numbers.dmq1)
iqmp = self._int_to_bn(numbers.iqmp)
e = self._int_to_bn(numbers.public_numbers.e)
n = self._int_to_bn(numbers.public_numbers.n)
res = self._lib.RSA_set0_factors(rsa_cdata, p, q)
self.openssl_assert(res == 1)
res = self._lib.RSA_set0_key(rsa_cdata, n, e, d)
self.openssl_assert(res == 1)
res = self._lib.RSA_set0_crt_params(rsa_cdata, dmp1, dmq1, iqmp)
self.openssl_assert(res == 1)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)

return _RSAPrivateKey(
self,
rsa_cdata,
evp_pkey,
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
return rust_openssl.rsa.from_private_numbers(
numbers, unsafe_skip_rsa_key_validation
)

def load_rsa_public_numbers(
self, numbers: rsa.RSAPublicNumbers
) -> rsa.RSAPublicKey:
rsa._check_public_key_components(numbers.e, numbers.n)
rsa_cdata = self._lib.RSA_new()
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
e = self._int_to_bn(numbers.e)
n = self._int_to_bn(numbers.n)
res = self._lib.RSA_set0_key(rsa_cdata, n, e, self._ffi.NULL)
self.openssl_assert(res == 1)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)

return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.from_public_numbers(numbers)

def _create_evp_pkey_gc(self):
evp_pkey = self._lib.EVP_PKEY_new()
Expand Down Expand Up @@ -500,13 +447,8 @@ def _evp_pkey_to_private_key(
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if key_type == self._lib.EVP_PKEY_RSA:
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
return _RSAPrivateKey(
self,
rsa_cdata,
evp_pkey,
return rust_openssl.rsa.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey)),
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
)
elif (
Expand Down Expand Up @@ -573,10 +515,9 @@ def _evp_pkey_to_public_key(self, evp_pkey) -> PublicKeyTypes:
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if key_type == self._lib.EVP_PKEY_RSA:
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif (
key_type == self._lib.EVP_PKEY_RSA_PSS
and not self._lib.CRYPTOGRAPHY_IS_LIBRESSL
Expand Down Expand Up @@ -733,7 +674,9 @@ def load_pem_public_key(self, data: bytes) -> PublicKeyTypes:
if rsa_cdata != self._ffi.NULL:
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
else:
self._handle_key_loading_error()

Expand Down Expand Up @@ -796,7 +739,9 @@ def load_der_public_key(self, data: bytes) -> PublicKeyTypes:
if rsa_cdata != self._ffi.NULL:
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
return _RSAPublicKey(self, rsa_cdata, evp_pkey)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
else:
self._handle_key_loading_error()

Expand Down Expand Up @@ -984,191 +929,6 @@ def elliptic_curve_exchange_algorithm_supported(
algorithm, ec.ECDH
)

def _private_key_bytes(
self,
encoding: serialization.Encoding,
format: serialization.PrivateFormat,
encryption_algorithm: serialization.KeySerializationEncryption,
key,
evp_pkey,
cdata,
) -> bytes:
# validate argument types
if not isinstance(encoding, serialization.Encoding):
raise TypeError("encoding must be an item from the Encoding enum")
if not isinstance(format, serialization.PrivateFormat):
raise TypeError(
"format must be an item from the PrivateFormat enum"
)
if not isinstance(
encryption_algorithm, serialization.KeySerializationEncryption
):
raise TypeError(
"Encryption algorithm must be a KeySerializationEncryption "
"instance"
)

# validate password
if isinstance(encryption_algorithm, serialization.NoEncryption):
password = b""
elif isinstance(
encryption_algorithm, serialization.BestAvailableEncryption
):
password = encryption_algorithm.password
if len(password) > 1023:
raise ValueError(
"Passwords longer than 1023 bytes are not supported by "
"this backend"
)
elif (
isinstance(
encryption_algorithm, serialization._KeySerializationEncryption
)
and encryption_algorithm._format
is format
is serialization.PrivateFormat.OpenSSH
):
password = encryption_algorithm.password
else:
raise ValueError("Unsupported encryption type")

# PKCS8 + PEM/DER
if format is serialization.PrivateFormat.PKCS8:
if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_PKCS8PrivateKey
elif encoding is serialization.Encoding.DER:
write_bio = self._lib.i2d_PKCS8PrivateKey_bio
else:
raise ValueError("Unsupported encoding for PKCS8")
return self._private_key_bytes_via_bio(
write_bio, evp_pkey, password
)

# TraditionalOpenSSL + PEM/DER
if format is serialization.PrivateFormat.TraditionalOpenSSL:
if self._fips_enabled and not isinstance(
encryption_algorithm, serialization.NoEncryption
):
raise ValueError(
"Encrypted traditional OpenSSL format is not "
"supported in FIPS mode."
)
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if encoding is serialization.Encoding.PEM:
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.PEM_write_bio_RSAPrivateKey
return self._private_key_bytes_via_bio(
write_bio, cdata, password
)

if encoding is serialization.Encoding.DER:
if password:
raise ValueError(
"Encryption is not supported for DER encoded "
"traditional OpenSSL keys"
)
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.i2d_RSAPrivateKey_bio
return self._bio_func_output(write_bio, cdata)

raise ValueError("Unsupported encoding for TraditionalOpenSSL")

# OpenSSH + PEM
if format is serialization.PrivateFormat.OpenSSH:
if encoding is serialization.Encoding.PEM:
return ssh._serialize_ssh_private_key(
key, password, encryption_algorithm
)

raise ValueError(
"OpenSSH private key format can only be used"
" with PEM encoding"
)

# Anything that key-specific code was supposed to handle earlier,
# like Raw.
raise ValueError("format is invalid with this key")

def _private_key_bytes_via_bio(
self, write_bio, evp_pkey, password
) -> bytes:
if not password:
evp_cipher = self._ffi.NULL
else:
# This is a curated value that we will update over time.
evp_cipher = self._lib.EVP_get_cipherbyname(b"aes-256-cbc")

return self._bio_func_output(
write_bio,
evp_pkey,
evp_cipher,
password,
len(password),
self._ffi.NULL,
self._ffi.NULL,
)

def _bio_func_output(self, write_bio, *args) -> bytes:
bio = self._create_mem_bio_gc()
res = write_bio(bio, *args)
self.openssl_assert(res == 1)
return self._read_mem_bio(bio)

def _public_key_bytes(
self,
encoding: serialization.Encoding,
format: serialization.PublicFormat,
key,
evp_pkey,
cdata,
) -> bytes:
if not isinstance(encoding, serialization.Encoding):
raise TypeError("encoding must be an item from the Encoding enum")
if not isinstance(format, serialization.PublicFormat):
raise TypeError(
"format must be an item from the PublicFormat enum"
)

# SubjectPublicKeyInfo + PEM/DER
if format is serialization.PublicFormat.SubjectPublicKeyInfo:
if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_PUBKEY
elif encoding is serialization.Encoding.DER:
write_bio = self._lib.i2d_PUBKEY_bio
else:
raise ValueError(
"SubjectPublicKeyInfo works only with PEM or DER encoding"
)
return self._bio_func_output(write_bio, evp_pkey)

# PKCS1 + PEM/DER
if format is serialization.PublicFormat.PKCS1:
# Only RSA is supported here.
key_type = self._lib.EVP_PKEY_id(evp_pkey)
self.openssl_assert(key_type == self._lib.EVP_PKEY_RSA)

if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_RSAPublicKey
elif encoding is serialization.Encoding.DER:
write_bio = self._lib.i2d_RSAPublicKey_bio
else:
raise ValueError("PKCS1 works only with PEM or DER encoding")
return self._bio_func_output(write_bio, cdata)

# OpenSSH + OpenSSH
if format is serialization.PublicFormat.OpenSSH:
if encoding is serialization.Encoding.OpenSSH:
return ssh.serialize_ssh_public_key(key)

raise ValueError(
"OpenSSH format must be used with OpenSSH encoding"
)

# Anything that key-specific code was supposed to handle earlier,
# like Raw, CompressedPoint, UncompressedPoint
raise ValueError("format is invalid with this key")

def dh_supported(self) -> bool:
return not self._lib.CRYPTOGRAPHY_IS_BORINGSSL

Expand Down

0 comments on commit 0000b40

Please sign in to comment.