diff --git a/ext/openssl/lib/openssl/pkey.rb b/ext/openssl/lib/openssl/pkey.rb index 3d1e8885ca22f6..39871e15dde854 100644 --- a/ext/openssl/lib/openssl/pkey.rb +++ b/ext/openssl/lib/openssl/pkey.rb @@ -7,6 +7,9 @@ require_relative 'marshal' module OpenSSL::PKey + # Alias of PKeyError. Before version 4.0.0, this was a subclass of PKeyError. + DHError = PKeyError + class DH include OpenSSL::Marshal @@ -102,7 +105,7 @@ def compute_key(pub_bn) # puts dh0.pub_key == dh.pub_key #=> false def generate_key! if OpenSSL::OPENSSL_VERSION_NUMBER >= 0x30000000 - raise DHError, "OpenSSL::PKey::DH is immutable on OpenSSL 3.0; " \ + raise PKeyError, "OpenSSL::PKey::DH is immutable on OpenSSL 3.0; " \ "use OpenSSL::PKey.generate_key instead" end @@ -147,6 +150,9 @@ def new(*args, &blk) # :nodoc: end end + # Alias of PKeyError. Before version 4.0.0, this was a subclass of PKeyError. + DSAError = PKeyError + class DSA include OpenSSL::Marshal @@ -242,13 +248,9 @@ def new(*args, &blk) # :nodoc: # sig = dsa.sign_raw(nil, digest) # p dsa.verify_raw(nil, sig, digest) #=> true def syssign(string) - q or raise OpenSSL::PKey::DSAError, "incomplete DSA" - private? or raise OpenSSL::PKey::DSAError, "Private DSA key needed!" - begin - sign_raw(nil, string) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::DSAError, $!.message - end + q or raise PKeyError, "incomplete DSA" + private? or raise PKeyError, "Private DSA key needed!" + sign_raw(nil, string) end # :call-seq: @@ -266,12 +268,13 @@ def syssign(string) # A \DSA signature value. def sysverify(digest, sig) verify_raw(nil, sig, digest) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::DSAError, $!.message end end if defined?(EC) + # Alias of PKeyError. Before version 4.0.0, this was a subclass of PKeyError. + ECError = PKeyError + class EC include OpenSSL::Marshal @@ -282,8 +285,6 @@ class EC # Consider using PKey::PKey#sign_raw and PKey::PKey#verify_raw instead. def dsa_sign_asn1(data) sign_raw(nil, data) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::ECError, $!.message end # :call-seq: @@ -293,8 +294,6 @@ def dsa_sign_asn1(data) # Consider using PKey::PKey#sign_raw and PKey::PKey#verify_raw instead. def dsa_verify_asn1(data, sig) verify_raw(nil, sig, data) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::ECError, $!.message end # :call-seq: @@ -334,6 +333,9 @@ def to_bn(conversion_form = group.point_conversion_form) end end + # Alias of PKeyError. Before version 4.0.0, this was a subclass of PKeyError. + RSAError = PKeyError + class RSA include OpenSSL::Marshal @@ -407,15 +409,11 @@ def new(*args, &blk) # :nodoc: # Consider using PKey::PKey#sign_raw and PKey::PKey#verify_raw, and # PKey::PKey#verify_recover instead. def private_encrypt(string, padding = PKCS1_PADDING) - n or raise OpenSSL::PKey::RSAError, "incomplete RSA" - private? or raise OpenSSL::PKey::RSAError, "private key needed." - begin - sign_raw(nil, string, { - "rsa_padding_mode" => translate_padding_mode(padding), - }) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::RSAError, $!.message - end + n or raise PKeyError, "incomplete RSA" + private? or raise PKeyError, "private key needed." + sign_raw(nil, string, { + "rsa_padding_mode" => translate_padding_mode(padding), + }) end # :call-seq: @@ -430,14 +428,10 @@ def private_encrypt(string, padding = PKCS1_PADDING) # Consider using PKey::PKey#sign_raw and PKey::PKey#verify_raw, and # PKey::PKey#verify_recover instead. def public_decrypt(string, padding = PKCS1_PADDING) - n or raise OpenSSL::PKey::RSAError, "incomplete RSA" - begin - verify_recover(nil, string, { - "rsa_padding_mode" => translate_padding_mode(padding), - }) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::RSAError, $!.message - end + n or raise PKeyError, "incomplete RSA" + verify_recover(nil, string, { + "rsa_padding_mode" => translate_padding_mode(padding), + }) end # :call-seq: @@ -452,14 +446,10 @@ def public_decrypt(string, padding = PKCS1_PADDING) # Deprecated in version 3.0. # Consider using PKey::PKey#encrypt and PKey::PKey#decrypt instead. def public_encrypt(data, padding = PKCS1_PADDING) - n or raise OpenSSL::PKey::RSAError, "incomplete RSA" - begin - encrypt(data, { - "rsa_padding_mode" => translate_padding_mode(padding), - }) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::RSAError, $!.message - end + n or raise PKeyError, "incomplete RSA" + encrypt(data, { + "rsa_padding_mode" => translate_padding_mode(padding), + }) end # :call-seq: @@ -473,15 +463,11 @@ def public_encrypt(data, padding = PKCS1_PADDING) # Deprecated in version 3.0. # Consider using PKey::PKey#encrypt and PKey::PKey#decrypt instead. def private_decrypt(data, padding = PKCS1_PADDING) - n or raise OpenSSL::PKey::RSAError, "incomplete RSA" - private? or raise OpenSSL::PKey::RSAError, "private key needed." - begin - decrypt(data, { - "rsa_padding_mode" => translate_padding_mode(padding), - }) - rescue OpenSSL::PKey::PKeyError - raise OpenSSL::PKey::RSAError, $!.message - end + n or raise PKeyError, "incomplete RSA" + private? or raise PKeyError, "private key needed." + decrypt(data, { + "rsa_padding_mode" => translate_padding_mode(padding), + }) end PKCS1_PADDING = 1 @@ -500,7 +486,7 @@ def private_decrypt(data, padding = PKCS1_PADDING) when PKCS1_OAEP_PADDING "oaep" else - raise OpenSSL::PKey::PKeyError, "unsupported padding mode" + raise PKeyError, "unsupported padding mode" end end end diff --git a/ext/openssl/ossl_cipher.c b/ext/openssl/ossl_cipher.c index f449c63b695b8b..a52518291124b3 100644 --- a/ext/openssl/ossl_cipher.c +++ b/ext/openssl/ossl_cipher.c @@ -33,7 +33,7 @@ static VALUE cCipher; static VALUE eCipherError; static VALUE eAuthTagError; -static ID id_auth_tag_len, id_key_set; +static ID id_auth_tag_len, id_key_set, id_cipher_holder; static VALUE ossl_cipher_alloc(VALUE klass); static void ossl_cipher_free(void *ptr); @@ -46,30 +46,58 @@ static const rb_data_type_t ossl_cipher_type = { 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, }; +#ifdef OSSL_USE_PROVIDER +static void +ossl_evp_cipher_free(void *ptr) +{ + // This is safe to call against const EVP_CIPHER * returned by + // EVP_get_cipherbyname() + EVP_CIPHER_free(ptr); +} + +static const rb_data_type_t ossl_evp_cipher_holder_type = { + "OpenSSL/EVP_CIPHER", + { + .dfree = ossl_evp_cipher_free, + }, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, +}; +#endif + /* * PUBLIC */ const EVP_CIPHER * -ossl_evp_get_cipherbyname(VALUE obj) +ossl_evp_cipher_fetch(VALUE obj, volatile VALUE *holder) { + *holder = Qnil; if (rb_obj_is_kind_of(obj, cCipher)) { - EVP_CIPHER_CTX *ctx; - - GetCipher(obj, ctx); - - return EVP_CIPHER_CTX_cipher(ctx); + EVP_CIPHER_CTX *ctx; + GetCipher(obj, ctx); + EVP_CIPHER *cipher = (EVP_CIPHER *)EVP_CIPHER_CTX_cipher(ctx); +#ifdef OSSL_USE_PROVIDER + *holder = TypedData_Wrap_Struct(0, &ossl_evp_cipher_holder_type, NULL); + if (!EVP_CIPHER_up_ref(cipher)) + ossl_raise(eCipherError, "EVP_CIPHER_up_ref"); + RTYPEDDATA_DATA(*holder) = cipher; +#endif + return cipher; } - else { - const EVP_CIPHER *cipher; - - StringValueCStr(obj); - cipher = EVP_get_cipherbyname(RSTRING_PTR(obj)); - if (!cipher) - ossl_raise(rb_eArgError, - "unsupported cipher algorithm: %"PRIsVALUE, obj); - return cipher; + const char *name = StringValueCStr(obj); + EVP_CIPHER *cipher = (EVP_CIPHER *)EVP_get_cipherbyname(name); +#ifdef OSSL_USE_PROVIDER + if (!cipher) { + ossl_clear_error(); + *holder = TypedData_Wrap_Struct(0, &ossl_evp_cipher_holder_type, NULL); + cipher = EVP_CIPHER_fetch(NULL, name, NULL); + RTYPEDDATA_DATA(*holder) = cipher; } +#endif + if (!cipher) + ossl_raise(eCipherError, "unsupported cipher algorithm: %"PRIsVALUE, + obj); + return cipher; } VALUE @@ -78,6 +106,9 @@ ossl_cipher_new(const EVP_CIPHER *cipher) VALUE ret; EVP_CIPHER_CTX *ctx; + // NOTE: This does not set id_cipher_holder because this function should + // only be called from ossl_engine.c, which will not use any + // reference-counted ciphers. ret = ossl_cipher_alloc(cCipher); AllocCipher(ret, ctx); if (EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, -1) != 1) @@ -114,19 +145,17 @@ ossl_cipher_initialize(VALUE self, VALUE str) { EVP_CIPHER_CTX *ctx; const EVP_CIPHER *cipher; - char *name; + VALUE cipher_holder; - name = StringValueCStr(str); GetCipherInit(self, ctx); if (ctx) { ossl_raise(rb_eRuntimeError, "Cipher already initialized!"); } + cipher = ossl_evp_cipher_fetch(str, &cipher_holder); AllocCipher(self, ctx); - if (!(cipher = EVP_get_cipherbyname(name))) { - ossl_raise(rb_eRuntimeError, "unsupported cipher algorithm (%"PRIsVALUE")", str); - } if (EVP_CipherInit_ex(ctx, cipher, NULL, NULL, NULL, -1) != 1) - ossl_raise(eCipherError, NULL); + ossl_raise(eCipherError, "EVP_CipherInit_ex"); + rb_ivar_set(self, id_cipher_holder, cipher_holder); return self; } @@ -268,7 +297,7 @@ ossl_cipher_pkcs5_keyivgen(int argc, VALUE *argv, VALUE self) { EVP_CIPHER_CTX *ctx; const EVP_MD *digest; - VALUE vpass, vsalt, viter, vdigest; + VALUE vpass, vsalt, viter, vdigest, md_holder; unsigned char key[EVP_MAX_KEY_LENGTH], iv[EVP_MAX_IV_LENGTH], *salt = NULL; int iter; @@ -283,7 +312,7 @@ ossl_cipher_pkcs5_keyivgen(int argc, VALUE *argv, VALUE self) iter = NIL_P(viter) ? 2048 : NUM2INT(viter); if (iter <= 0) rb_raise(rb_eArgError, "iterations must be a positive integer"); - digest = NIL_P(vdigest) ? EVP_md5() : ossl_evp_get_digestbyname(vdigest); + digest = NIL_P(vdigest) ? EVP_md5() : ossl_evp_md_fetch(vdigest, &md_holder); GetCipher(self, ctx); EVP_BytesToKey(EVP_CIPHER_CTX_cipher(ctx), digest, salt, (unsigned char *)RSTRING_PTR(vpass), RSTRING_LENINT(vpass), iter, key, iv); @@ -1110,4 +1139,5 @@ Init_ossl_cipher(void) id_auth_tag_len = rb_intern_const("auth_tag_len"); id_key_set = rb_intern_const("key_set"); + id_cipher_holder = rb_intern_const("EVP_CIPHER_holder"); } diff --git a/ext/openssl/ossl_cipher.h b/ext/openssl/ossl_cipher.h index 12da68ca3e9cd5..fba63a140f0b6d 100644 --- a/ext/openssl/ossl_cipher.h +++ b/ext/openssl/ossl_cipher.h @@ -10,7 +10,16 @@ #if !defined(_OSSL_CIPHER_H_) #define _OSSL_CIPHER_H_ -const EVP_CIPHER *ossl_evp_get_cipherbyname(VALUE); +/* + * Gets EVP_CIPHER from a String or an OpenSSL::Digest instance (discouraged, + * but still supported for compatibility). A holder object is created if the + * EVP_CIPHER is a "fetched" algorithm. + */ +const EVP_CIPHER *ossl_evp_cipher_fetch(VALUE obj, volatile VALUE *holder); +/* + * This is meant for OpenSSL::Engine#cipher. EVP_CIPHER must not be a fetched + * one. + */ VALUE ossl_cipher_new(const EVP_CIPHER *); void Init_ossl_cipher(void); diff --git a/ext/openssl/ossl_digest.c b/ext/openssl/ossl_digest.c index 329de6c1bab078..e2f1af7e61aa8f 100644 --- a/ext/openssl/ossl_digest.c +++ b/ext/openssl/ossl_digest.c @@ -21,6 +21,7 @@ */ static VALUE cDigest; static VALUE eDigestError; +static ID id_md_holder; static VALUE ossl_digest_alloc(VALUE klass); @@ -38,34 +39,62 @@ static const rb_data_type_t ossl_digest_type = { 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, }; +#ifdef OSSL_USE_PROVIDER +static void +ossl_evp_md_free(void *ptr) +{ + // This is safe to call against const EVP_MD * returned by + // EVP_get_digestbyname() + EVP_MD_free(ptr); +} + +static const rb_data_type_t ossl_evp_md_holder_type = { + "OpenSSL/EVP_MD", + { + .dfree = ossl_evp_md_free, + }, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, +}; +#endif + /* * Public */ const EVP_MD * -ossl_evp_get_digestbyname(VALUE obj) +ossl_evp_md_fetch(VALUE obj, volatile VALUE *holder) { - const EVP_MD *md; - ASN1_OBJECT *oid = NULL; - - if (RB_TYPE_P(obj, T_STRING)) { - const char *name = StringValueCStr(obj); - - md = EVP_get_digestbyname(name); - if (!md) { - oid = OBJ_txt2obj(name, 0); - md = EVP_get_digestbyobj(oid); - ASN1_OBJECT_free(oid); - } - if(!md) - ossl_raise(rb_eRuntimeError, "Unsupported digest algorithm (%"PRIsVALUE").", obj); - } else { + *holder = Qnil; + if (rb_obj_is_kind_of(obj, cDigest)) { EVP_MD_CTX *ctx; - GetDigest(obj, ctx); - - md = EVP_MD_CTX_get0_md(ctx); + EVP_MD *md = (EVP_MD *)EVP_MD_CTX_get0_md(ctx); +#ifdef OSSL_USE_PROVIDER + *holder = TypedData_Wrap_Struct(0, &ossl_evp_md_holder_type, NULL); + if (!EVP_MD_up_ref(md)) + ossl_raise(eDigestError, "EVP_MD_up_ref"); + RTYPEDDATA_DATA(*holder) = md; +#endif + return md; } + const char *name = StringValueCStr(obj); + EVP_MD *md = (EVP_MD *)EVP_get_digestbyname(name); + if (!md) { + ASN1_OBJECT *oid = OBJ_txt2obj(name, 0); + md = (EVP_MD *)EVP_get_digestbyobj(oid); + ASN1_OBJECT_free(oid); + } +#ifdef OSSL_USE_PROVIDER + if (!md) { + ossl_clear_error(); + *holder = TypedData_Wrap_Struct(0, &ossl_evp_md_holder_type, NULL); + md = EVP_MD_fetch(NULL, name, NULL); + RTYPEDDATA_DATA(*holder) = md; + } +#endif + if (!md) + ossl_raise(eDigestError, "unsupported digest algorithm: %"PRIsVALUE, + obj); return md; } @@ -75,6 +104,9 @@ ossl_digest_new(const EVP_MD *md) VALUE ret; EVP_MD_CTX *ctx; + // NOTE: This does not set id_md_holder because this function should + // only be called from ossl_engine.c, which will not use any + // reference-counted digests. ret = ossl_digest_alloc(cDigest); ctx = EVP_MD_CTX_new(); if (!ctx) @@ -121,10 +153,10 @@ ossl_digest_initialize(int argc, VALUE *argv, VALUE self) { EVP_MD_CTX *ctx; const EVP_MD *md; - VALUE type, data; + VALUE type, data, md_holder; rb_scan_args(argc, argv, "11", &type, &data); - md = ossl_evp_get_digestbyname(type); + md = ossl_evp_md_fetch(type, &md_holder); if (!NIL_P(data)) StringValue(data); TypedData_Get_Struct(self, EVP_MD_CTX, &ossl_digest_type, ctx); @@ -136,6 +168,7 @@ ossl_digest_initialize(int argc, VALUE *argv, VALUE self) if (!EVP_DigestInit_ex(ctx, md, NULL)) ossl_raise(eDigestError, "Digest initialization failed"); + rb_ivar_set(self, id_md_holder, md_holder); if (!NIL_P(data)) return ossl_digest_update(self, data); return self; @@ -442,4 +475,6 @@ Init_ossl_digest(void) rb_define_method(cDigest, "block_length", ossl_digest_block_length, 0); rb_define_method(cDigest, "name", ossl_digest_name, 0); + + id_md_holder = rb_intern_const("EVP_MD_holder"); } diff --git a/ext/openssl/ossl_digest.h b/ext/openssl/ossl_digest.h index 588a0c6f578a63..9c3bb2b149ede9 100644 --- a/ext/openssl/ossl_digest.h +++ b/ext/openssl/ossl_digest.h @@ -10,7 +10,15 @@ #if !defined(_OSSL_DIGEST_H_) #define _OSSL_DIGEST_H_ -const EVP_MD *ossl_evp_get_digestbyname(VALUE); +/* + * Gets EVP_MD from a String or an OpenSSL::Digest instance (discouraged, but + * still supported for compatibility). A holder object is created if the EVP_MD + * is a "fetched" algorithm. + */ +const EVP_MD *ossl_evp_md_fetch(VALUE obj, volatile VALUE *holder); +/* + * This is meant for OpenSSL::Engine#digest. EVP_MD must not be a fetched one. + */ VALUE ossl_digest_new(const EVP_MD *); void Init_ossl_digest(void); diff --git a/ext/openssl/ossl_hmac.c b/ext/openssl/ossl_hmac.c index b30482757997b6..250b427bdcc8f6 100644 --- a/ext/openssl/ossl_hmac.c +++ b/ext/openssl/ossl_hmac.c @@ -23,6 +23,7 @@ */ static VALUE cHMAC; static VALUE eHMACError; +static ID id_md_holder; /* * Public @@ -94,19 +95,22 @@ ossl_hmac_initialize(VALUE self, VALUE key, VALUE digest) { EVP_MD_CTX *ctx; EVP_PKEY *pkey; + const EVP_MD *md; + VALUE md_holder; GetHMAC(self, ctx); StringValue(key); + md = ossl_evp_md_fetch(digest, &md_holder); pkey = EVP_PKEY_new_raw_private_key(EVP_PKEY_HMAC, NULL, (unsigned char *)RSTRING_PTR(key), RSTRING_LENINT(key)); if (!pkey) ossl_raise(eHMACError, "EVP_PKEY_new_raw_private_key"); - if (EVP_DigestSignInit(ctx, NULL, ossl_evp_get_digestbyname(digest), - NULL, pkey) != 1) { + if (EVP_DigestSignInit(ctx, NULL, md, NULL, pkey) != 1) { EVP_PKEY_free(pkey); ossl_raise(eHMACError, "EVP_DigestSignInit"); } + rb_ivar_set(self, id_md_holder, md_holder); /* Decrement reference counter; EVP_MD_CTX still keeps it */ EVP_PKEY_free(pkey); @@ -300,4 +304,6 @@ Init_ossl_hmac(void) rb_define_method(cHMAC, "hexdigest", ossl_hmac_hexdigest, 0); rb_define_alias(cHMAC, "inspect", "hexdigest"); rb_define_alias(cHMAC, "to_s", "hexdigest"); + + id_md_holder = rb_intern_const("EVP_MD_holder"); } diff --git a/ext/openssl/ossl_kdf.c b/ext/openssl/ossl_kdf.c index f349939a80475e..e7429a76881c7d 100644 --- a/ext/openssl/ossl_kdf.c +++ b/ext/openssl/ossl_kdf.c @@ -35,7 +35,7 @@ static VALUE mKDF, eKDF; static VALUE kdf_pbkdf2_hmac(int argc, VALUE *argv, VALUE self) { - VALUE pass, salt, opts, kwargs[4], str; + VALUE pass, salt, opts, kwargs[4], str, md_holder; static ID kwargs_ids[4]; int iters, len; const EVP_MD *md; @@ -53,7 +53,7 @@ kdf_pbkdf2_hmac(int argc, VALUE *argv, VALUE self) salt = StringValue(kwargs[0]); iters = NUM2INT(kwargs[1]); len = NUM2INT(kwargs[2]); - md = ossl_evp_get_digestbyname(kwargs[3]); + md = ossl_evp_md_fetch(kwargs[3], &md_holder); str = rb_str_new(0, len); if (!PKCS5_PBKDF2_HMAC(RSTRING_PTR(pass), RSTRING_LENINT(pass), @@ -172,7 +172,7 @@ kdf_scrypt(int argc, VALUE *argv, VALUE self) static VALUE kdf_hkdf(int argc, VALUE *argv, VALUE self) { - VALUE ikm, salt, info, opts, kwargs[4], str; + VALUE ikm, salt, info, opts, kwargs[4], str, md_holder; static ID kwargs_ids[4]; int saltlen, ikmlen, infolen; size_t len; @@ -197,7 +197,7 @@ kdf_hkdf(int argc, VALUE *argv, VALUE self) len = (size_t)NUM2LONG(kwargs[2]); if (len > LONG_MAX) rb_raise(rb_eArgError, "length must be non-negative"); - md = ossl_evp_get_digestbyname(kwargs[3]); + md = ossl_evp_md_fetch(kwargs[3], &md_holder); str = rb_str_new(NULL, (long)len); pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL); diff --git a/ext/openssl/ossl_ns_spki.c b/ext/openssl/ossl_ns_spki.c index ffed3a64a602f6..51ec8532c3ee8c 100644 --- a/ext/openssl/ossl_ns_spki.c +++ b/ext/openssl/ossl_ns_spki.c @@ -283,13 +283,13 @@ ossl_spki_sign(VALUE self, VALUE key, VALUE digest) NETSCAPE_SPKI *spki; EVP_PKEY *pkey; const EVP_MD *md; + VALUE md_holder; pkey = GetPrivPKeyPtr(key); /* NO NEED TO DUP */ - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); GetSPKI(self, spki); - if (!NETSCAPE_SPKI_sign(spki, pkey, md)) { - ossl_raise(eSPKIError, NULL); - } + if (!NETSCAPE_SPKI_sign(spki, pkey, md)) + ossl_raise(eSPKIError, "NETSCAPE_SPKI_sign"); return self; } diff --git a/ext/openssl/ossl_ocsp.c b/ext/openssl/ossl_ocsp.c index 5a3a71cae0d83e..84d38760e5c0de 100644 --- a/ext/openssl/ossl_ocsp.c +++ b/ext/openssl/ossl_ocsp.c @@ -369,7 +369,7 @@ ossl_ocspreq_get_certid(VALUE self) static VALUE ossl_ocspreq_sign(int argc, VALUE *argv, VALUE self) { - VALUE signer_cert, signer_key, certs, flags, digest; + VALUE signer_cert, signer_key, certs, flags, digest, md_holder; OCSP_REQUEST *req; X509 *signer; EVP_PKEY *key; @@ -384,10 +384,7 @@ ossl_ocspreq_sign(int argc, VALUE *argv, VALUE self) key = GetPrivPKeyPtr(signer_key); if (!NIL_P(flags)) flg = NUM2INT(flags); - if (NIL_P(digest)) - md = NULL; - else - md = ossl_evp_get_digestbyname(digest); + md = NIL_P(digest) ? NULL : ossl_evp_md_fetch(digest, &md_holder); if (NIL_P(certs)) flg |= OCSP_NOCERTS; else @@ -395,7 +392,8 @@ ossl_ocspreq_sign(int argc, VALUE *argv, VALUE self) ret = OCSP_request_sign(req, signer, key, md, x509s, flg); sk_X509_pop_free(x509s, X509_free); - if (!ret) ossl_raise(eOCSPError, NULL); + if (!ret) + ossl_raise(eOCSPError, "OCSP_request_sign"); return self; } @@ -1000,7 +998,7 @@ ossl_ocspbres_find_response(VALUE self, VALUE target) static VALUE ossl_ocspbres_sign(int argc, VALUE *argv, VALUE self) { - VALUE signer_cert, signer_key, certs, flags, digest; + VALUE signer_cert, signer_key, certs, flags, digest, md_holder; OCSP_BASICRESP *bs; X509 *signer; EVP_PKEY *key; @@ -1015,10 +1013,7 @@ ossl_ocspbres_sign(int argc, VALUE *argv, VALUE self) key = GetPrivPKeyPtr(signer_key); if (!NIL_P(flags)) flg = NUM2INT(flags); - if (NIL_P(digest)) - md = NULL; - else - md = ossl_evp_get_digestbyname(digest); + md = NIL_P(digest) ? NULL : ossl_evp_md_fetch(digest, &md_holder); if (NIL_P(certs)) flg |= OCSP_NOCERTS; else @@ -1026,7 +1021,8 @@ ossl_ocspbres_sign(int argc, VALUE *argv, VALUE self) ret = OCSP_basic_sign(bs, signer, key, md, x509s, flg); sk_X509_pop_free(x509s, X509_free); - if (!ret) ossl_raise(eOCSPError, NULL); + if (!ret) + ossl_raise(eOCSPError, "OCSP_basic_sign"); return self; } @@ -1460,10 +1456,11 @@ ossl_ocspcid_initialize(int argc, VALUE *argv, VALUE self) else { X509 *x509s, *x509i; const EVP_MD *md; + VALUE md_holder; x509s = GetX509CertPtr(subject); /* NO NEED TO DUP */ x509i = GetX509CertPtr(issuer); /* NO NEED TO DUP */ - md = !NIL_P(digest) ? ossl_evp_get_digestbyname(digest) : NULL; + md = NIL_P(digest) ? NULL : ossl_evp_md_fetch(digest, &md_holder); newid = OCSP_cert_to_id(md, x509s, x509i); if (!newid) diff --git a/ext/openssl/ossl_pkcs12.c b/ext/openssl/ossl_pkcs12.c index 0b7469e673f77c..f76e1625f596a2 100644 --- a/ext/openssl/ossl_pkcs12.c +++ b/ext/openssl/ossl_pkcs12.c @@ -271,7 +271,7 @@ static VALUE pkcs12_set_mac(int argc, VALUE *argv, VALUE self) { PKCS12 *p12; - VALUE pass, salt, iter, md_name; + VALUE pass, salt, iter, md_name, md_holder = Qnil; int iter_i = 0; const EVP_MD *md_type = NULL; @@ -285,7 +285,7 @@ pkcs12_set_mac(int argc, VALUE *argv, VALUE self) if (!NIL_P(iter)) iter_i = NUM2INT(iter); if (!NIL_P(md_name)) - md_type = ossl_evp_get_digestbyname(md_name); + md_type = ossl_evp_md_fetch(md_name, &md_holder); if (!PKCS12_set_mac(p12, RSTRING_PTR(pass), RSTRING_LENINT(pass), !NIL_P(salt) ? (unsigned char *)RSTRING_PTR(salt) : NULL, diff --git a/ext/openssl/ossl_pkcs7.c b/ext/openssl/ossl_pkcs7.c index 910ef9665c7919..0fcae1971cfa4b 100644 --- a/ext/openssl/ossl_pkcs7.c +++ b/ext/openssl/ossl_pkcs7.c @@ -68,6 +68,7 @@ static VALUE cPKCS7; static VALUE cPKCS7Signer; static VALUE cPKCS7Recipient; static VALUE ePKCS7Error; +static ID id_md_holder, id_cipher_holder; static void ossl_pkcs7_free(void *ptr) @@ -312,7 +313,7 @@ ossl_pkcs7_s_sign(int argc, VALUE *argv, VALUE klass) static VALUE ossl_pkcs7_s_encrypt(int argc, VALUE *argv, VALUE klass) { - VALUE certs, data, cipher, flags; + VALUE certs, data, cipher, flags, cipher_holder; STACK_OF(X509) *x509s; BIO *in; const EVP_CIPHER *ciph; @@ -326,7 +327,7 @@ ossl_pkcs7_s_encrypt(int argc, VALUE *argv, VALUE klass) "cipher must be specified. Before version 3.3, " \ "the default cipher was RC2-40-CBC."); } - ciph = ossl_evp_get_cipherbyname(cipher); + ciph = ossl_evp_cipher_fetch(cipher, &cipher_holder); flg = NIL_P(flags) ? 0 : NUM2INT(flags); ret = NewPKCS7(cPKCS7); in = ossl_obj2bio(&data); @@ -343,6 +344,7 @@ ossl_pkcs7_s_encrypt(int argc, VALUE *argv, VALUE klass) BIO_free(in); SetPKCS7(ret, p7); ossl_pkcs7_set_data(ret, data); + rb_ivar_set(ret, id_cipher_holder, cipher_holder); sk_X509_pop_free(x509s, X509_free); return ret; @@ -535,11 +537,14 @@ static VALUE ossl_pkcs7_set_cipher(VALUE self, VALUE cipher) { PKCS7 *pkcs7; + const EVP_CIPHER *ciph; + VALUE cipher_holder; GetPKCS7(self, pkcs7); - if (!PKCS7_set_cipher(pkcs7, ossl_evp_get_cipherbyname(cipher))) { - ossl_raise(ePKCS7Error, NULL); - } + ciph = ossl_evp_cipher_fetch(cipher, &cipher_holder); + if (!PKCS7_set_cipher(pkcs7, ciph)) + ossl_raise(ePKCS7Error, "PKCS7_set_cipher"); + rb_ivar_set(self, id_cipher_holder, cipher_holder); return cipher; } @@ -968,14 +973,15 @@ ossl_pkcs7si_initialize(VALUE self, VALUE cert, VALUE key, VALUE digest) EVP_PKEY *pkey; X509 *x509; const EVP_MD *md; + VALUE md_holder; pkey = GetPrivPKeyPtr(key); /* NO NEED TO DUP */ x509 = GetX509CertPtr(cert); /* NO NEED TO DUP */ - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); GetPKCS7si(self, p7si); - if (!(PKCS7_SIGNER_INFO_set(p7si, x509, pkey, md))) { - ossl_raise(ePKCS7Error, NULL); - } + if (!(PKCS7_SIGNER_INFO_set(p7si, x509, pkey, md))) + ossl_raise(ePKCS7Error, "PKCS7_SIGNER_INFO_set"); + rb_ivar_set(self, id_md_holder, md_holder); return self; } @@ -1161,4 +1167,7 @@ Init_ossl_pkcs7(void) DefPKCS7Const(BINARY); DefPKCS7Const(NOATTR); DefPKCS7Const(NOSMIMECAP); + + id_md_holder = rb_intern_const("EVP_MD_holder"); + id_cipher_holder = rb_intern_const("EVP_CIPHER_holder"); } diff --git a/ext/openssl/ossl_pkey.c b/ext/openssl/ossl_pkey.c index 37c132ef2ea680..2d66c6ce625d34 100644 --- a/ext/openssl/ossl_pkey.c +++ b/ext/openssl/ossl_pkey.c @@ -814,14 +814,14 @@ VALUE ossl_pkey_export_traditional(int argc, VALUE *argv, VALUE self, int to_der) { EVP_PKEY *pkey; - VALUE cipher, pass; + VALUE cipher, pass, cipher_holder; const EVP_CIPHER *enc = NULL; BIO *bio; GetPKey(self, pkey); rb_scan_args(argc, argv, "02", &cipher, &pass); if (!NIL_P(cipher)) { - enc = ossl_evp_get_cipherbyname(cipher); + enc = ossl_evp_cipher_fetch(cipher, &cipher_holder); pass = ossl_pem_passwd_value(pass); } @@ -849,7 +849,7 @@ static VALUE do_pkcs8_export(int argc, VALUE *argv, VALUE self, int to_der) { EVP_PKEY *pkey; - VALUE cipher, pass; + VALUE cipher, pass, cipher_holder; const EVP_CIPHER *enc = NULL; BIO *bio; @@ -860,7 +860,7 @@ do_pkcs8_export(int argc, VALUE *argv, VALUE self, int to_der) * TODO: EncryptedPrivateKeyInfo actually has more options. * Should they be exposed? */ - enc = ossl_evp_get_cipherbyname(cipher); + enc = ossl_evp_cipher_fetch(cipher, &cipher_holder); pass = ossl_pem_passwd_value(pass); } @@ -1111,7 +1111,7 @@ static VALUE ossl_pkey_sign(int argc, VALUE *argv, VALUE self) { EVP_PKEY *pkey; - VALUE digest, data, options, sig; + VALUE digest, data, options, sig, md_holder; const EVP_MD *md = NULL; EVP_MD_CTX *ctx; EVP_PKEY_CTX *pctx; @@ -1121,7 +1121,7 @@ ossl_pkey_sign(int argc, VALUE *argv, VALUE self) pkey = GetPrivPKeyPtr(self); rb_scan_args(argc, argv, "21", &digest, &data, &options); if (!NIL_P(digest)) - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); StringValue(data); ctx = EVP_MD_CTX_new(); @@ -1190,7 +1190,7 @@ static VALUE ossl_pkey_verify(int argc, VALUE *argv, VALUE self) { EVP_PKEY *pkey; - VALUE digest, sig, data, options; + VALUE digest, sig, data, options, md_holder; const EVP_MD *md = NULL; EVP_MD_CTX *ctx; EVP_PKEY_CTX *pctx; @@ -1200,7 +1200,7 @@ ossl_pkey_verify(int argc, VALUE *argv, VALUE self) rb_scan_args(argc, argv, "31", &digest, &sig, &data, &options); ossl_pkey_check_public_key(pkey); if (!NIL_P(digest)) - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); StringValue(sig); StringValue(data); @@ -1269,7 +1269,7 @@ static VALUE ossl_pkey_sign_raw(int argc, VALUE *argv, VALUE self) { EVP_PKEY *pkey; - VALUE digest, data, options, sig; + VALUE digest, data, options, sig, md_holder; const EVP_MD *md = NULL; EVP_PKEY_CTX *ctx; size_t outlen; @@ -1278,7 +1278,7 @@ ossl_pkey_sign_raw(int argc, VALUE *argv, VALUE self) GetPKey(self, pkey); rb_scan_args(argc, argv, "21", &digest, &data, &options); if (!NIL_P(digest)) - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); StringValue(data); ctx = EVP_PKEY_CTX_new(pkey, /* engine */NULL); @@ -1345,7 +1345,7 @@ static VALUE ossl_pkey_verify_raw(int argc, VALUE *argv, VALUE self) { EVP_PKEY *pkey; - VALUE digest, sig, data, options; + VALUE digest, sig, data, options, md_holder; const EVP_MD *md = NULL; EVP_PKEY_CTX *ctx; int state, ret; @@ -1354,7 +1354,7 @@ ossl_pkey_verify_raw(int argc, VALUE *argv, VALUE self) rb_scan_args(argc, argv, "31", &digest, &sig, &data, &options); ossl_pkey_check_public_key(pkey); if (!NIL_P(digest)) - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); StringValue(sig); StringValue(data); @@ -1408,7 +1408,7 @@ static VALUE ossl_pkey_verify_recover(int argc, VALUE *argv, VALUE self) { EVP_PKEY *pkey; - VALUE digest, sig, options, out; + VALUE digest, sig, options, out, md_holder; const EVP_MD *md = NULL; EVP_PKEY_CTX *ctx; int state; @@ -1418,7 +1418,7 @@ ossl_pkey_verify_recover(int argc, VALUE *argv, VALUE self) rb_scan_args(argc, argv, "21", &digest, &sig, &options); ossl_pkey_check_public_key(pkey); if (!NIL_P(digest)) - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); StringValue(sig); ctx = EVP_PKEY_CTX_new(pkey, /* engine */NULL); @@ -1718,7 +1718,16 @@ Init_ossl_pkey(void) /* Document-class: OpenSSL::PKey::PKeyError * - *Raised when errors occur during PKey#sign or PKey#verify. + * Raised when errors occur during PKey#sign or PKey#verify. + * + * Before version 4.0.0, OpenSSL::PKey::PKeyError had the following + * subclasses. These subclasses have been removed and the constants are + * now defined as aliases of OpenSSL::PKey::PKeyError. + * + * * OpenSSL::PKey::DHError + * * OpenSSL::PKey::DSAError + * * OpenSSL::PKey::ECError + * * OpenSSL::PKey::RSAError */ ePKeyError = rb_define_class_under(mPKey, "PKeyError", eOSSLError); diff --git a/ext/openssl/ossl_pkey_dh.c b/ext/openssl/ossl_pkey_dh.c index 561007fec8b058..79509bef3d2f82 100644 --- a/ext/openssl/ossl_pkey_dh.c +++ b/ext/openssl/ossl_pkey_dh.c @@ -22,14 +22,13 @@ GetPKeyDH((obj), _pkey); \ (dh) = EVP_PKEY_get0_DH(_pkey); \ if ((dh) == NULL) \ - ossl_raise(eDHError, "failed to get DH from EVP_PKEY"); \ + ossl_raise(ePKeyError, "failed to get DH from EVP_PKEY"); \ } while (0) /* * Classes */ VALUE cDH; -static VALUE eDHError; /* * Private @@ -94,7 +93,7 @@ ossl_dh_initialize(int argc, VALUE *argv, VALUE self) #else dh = DH_new(); if (!dh) - ossl_raise(eDHError, "DH_new"); + ossl_raise(ePKeyError, "DH_new"); goto legacy; #endif } @@ -114,12 +113,12 @@ ossl_dh_initialize(int argc, VALUE *argv, VALUE self) pkey = ossl_pkey_read_generic(in, Qnil); BIO_free(in); if (!pkey) - ossl_raise(eDHError, "could not parse pkey"); + ossl_raise(ePKeyError, "could not parse pkey"); type = EVP_PKEY_base_id(pkey); if (type != EVP_PKEY_DH) { EVP_PKEY_free(pkey); - rb_raise(eDHError, "incorrect pkey type: %s", OBJ_nid2sn(type)); + rb_raise(ePKeyError, "incorrect pkey type: %s", OBJ_nid2sn(type)); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -130,7 +129,7 @@ ossl_dh_initialize(int argc, VALUE *argv, VALUE self) if (!pkey || EVP_PKEY_assign_DH(pkey, dh) != 1) { EVP_PKEY_free(pkey); DH_free(dh); - ossl_raise(eDHError, "EVP_PKEY_assign_DH"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_DH"); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -152,7 +151,7 @@ ossl_dh_initialize_copy(VALUE self, VALUE other) dh = DHparams_dup(dh_other); if (!dh) - ossl_raise(eDHError, "DHparams_dup"); + ossl_raise(ePKeyError, "DHparams_dup"); DH_get0_key(dh_other, &pub, &priv); if (pub) { @@ -162,7 +161,7 @@ ossl_dh_initialize_copy(VALUE self, VALUE other) if (!pub2 || (priv && !priv2)) { BN_clear_free(pub2); BN_clear_free(priv2); - ossl_raise(eDHError, "BN_dup"); + ossl_raise(ePKeyError, "BN_dup"); } DH_set0_key(dh, pub2, priv2); } @@ -171,7 +170,7 @@ ossl_dh_initialize_copy(VALUE self, VALUE other) if (!pkey || EVP_PKEY_assign_DH(pkey, dh) != 1) { EVP_PKEY_free(pkey); DH_free(dh); - ossl_raise(eDHError, "EVP_PKEY_assign_DH"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_DH"); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -250,11 +249,11 @@ ossl_dh_export(VALUE self) GetDH(self, dh); if (!(out = BIO_new(BIO_s_mem()))) { - ossl_raise(eDHError, NULL); + ossl_raise(ePKeyError, NULL); } if (!PEM_write_bio_DHparams(out, dh)) { BIO_free(out); - ossl_raise(eDHError, NULL); + ossl_raise(ePKeyError, NULL); } str = ossl_membio2str(out); @@ -284,11 +283,11 @@ ossl_dh_to_der(VALUE self) GetDH(self, dh); if((len = i2d_DHparams(dh, NULL)) <= 0) - ossl_raise(eDHError, NULL); + ossl_raise(ePKeyError, NULL); str = rb_str_new(0, len); p = (unsigned char *)RSTRING_PTR(str); if(i2d_DHparams(dh, &p) < 0) - ossl_raise(eDHError, NULL); + ossl_raise(ePKeyError, NULL); ossl_str_adjust(str, p); return str; @@ -315,7 +314,7 @@ ossl_dh_check_params(VALUE self) GetPKey(self, pkey); pctx = EVP_PKEY_CTX_new(pkey, /* engine */NULL); if (!pctx) - ossl_raise(eDHError, "EVP_PKEY_CTX_new"); + ossl_raise(ePKeyError, "EVP_PKEY_CTX_new"); ret = EVP_PKEY_param_check(pctx); EVP_PKEY_CTX_free(pctx); #else @@ -364,13 +363,6 @@ Init_ossl_dh(void) ePKeyError = rb_define_class_under(mPKey, "PKeyError", eOSSLError); #endif - /* Document-class: OpenSSL::PKey::DHError - * - * Generic exception that is raised if an operation on a DH PKey - * fails unexpectedly or in case an instantiation of an instance of DH - * fails due to non-conformant input data. - */ - eDHError = rb_define_class_under(mPKey, "DHError", ePKeyError); /* Document-class: OpenSSL::PKey::DH * * An implementation of the Diffie-Hellman key exchange protocol based on diff --git a/ext/openssl/ossl_pkey_dsa.c b/ext/openssl/ossl_pkey_dsa.c index cb38786b560c0a..34e1c7052165be 100644 --- a/ext/openssl/ossl_pkey_dsa.c +++ b/ext/openssl/ossl_pkey_dsa.c @@ -22,7 +22,7 @@ GetPKeyDSA((obj), _pkey); \ (dsa) = EVP_PKEY_get0_DSA(_pkey); \ if ((dsa) == NULL) \ - ossl_raise(eDSAError, "failed to get DSA from EVP_PKEY"); \ + ossl_raise(ePKeyError, "failed to get DSA from EVP_PKEY"); \ } while (0) static inline int @@ -43,7 +43,6 @@ DSA_PRIVATE(VALUE obj, OSSL_3_const DSA *dsa) * Classes */ VALUE cDSA; -static VALUE eDSAError; /* * Private @@ -105,7 +104,7 @@ ossl_dsa_initialize(int argc, VALUE *argv, VALUE self) #else dsa = DSA_new(); if (!dsa) - ossl_raise(eDSAError, "DSA_new"); + ossl_raise(ePKeyError, "DSA_new"); goto legacy; #endif } @@ -125,12 +124,12 @@ ossl_dsa_initialize(int argc, VALUE *argv, VALUE self) pkey = ossl_pkey_read_generic(in, pass); BIO_free(in); if (!pkey) - ossl_raise(eDSAError, "Neither PUB key nor PRIV key"); + ossl_raise(ePKeyError, "Neither PUB key nor PRIV key"); type = EVP_PKEY_base_id(pkey); if (type != EVP_PKEY_DSA) { EVP_PKEY_free(pkey); - rb_raise(eDSAError, "incorrect pkey type: %s", OBJ_nid2sn(type)); + rb_raise(ePKeyError, "incorrect pkey type: %s", OBJ_nid2sn(type)); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -141,7 +140,7 @@ ossl_dsa_initialize(int argc, VALUE *argv, VALUE self) if (!pkey || EVP_PKEY_assign_DSA(pkey, dsa) != 1) { EVP_PKEY_free(pkey); DSA_free(dsa); - ossl_raise(eDSAError, "EVP_PKEY_assign_DSA"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_DSA"); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -164,13 +163,13 @@ ossl_dsa_initialize_copy(VALUE self, VALUE other) (d2i_of_void *)d2i_DSAPrivateKey, (char *)dsa); if (!dsa_new) - ossl_raise(eDSAError, "ASN1_dup"); + ossl_raise(ePKeyError, "ASN1_dup"); pkey = EVP_PKEY_new(); if (!pkey || EVP_PKEY_assign_DSA(pkey, dsa_new) != 1) { EVP_PKEY_free(pkey); DSA_free(dsa_new); - ossl_raise(eDSAError, "EVP_PKEY_assign_DSA"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_DSA"); } RTYPEDDATA_DATA(self) = pkey; @@ -341,14 +340,6 @@ Init_ossl_dsa(void) ePKeyError = rb_define_class_under(mPKey, "PKeyError", eOSSLError); #endif - /* Document-class: OpenSSL::PKey::DSAError - * - * Generic exception that is raised if an operation on a DSA PKey - * fails unexpectedly or in case an instantiation of an instance of DSA - * fails due to non-conformant input data. - */ - eDSAError = rb_define_class_under(mPKey, "DSAError", ePKeyError); - /* Document-class: OpenSSL::PKey::DSA * * DSA, the Digital Signature Algorithm, is specified in NIST's diff --git a/ext/openssl/ossl_pkey_ec.c b/ext/openssl/ossl_pkey_ec.c index 8c97297a56193e..c063450a4f2ef3 100644 --- a/ext/openssl/ossl_pkey_ec.c +++ b/ext/openssl/ossl_pkey_ec.c @@ -23,7 +23,7 @@ static const rb_data_type_t ossl_ec_point_type; GetPKeyEC(obj, _pkey); \ (key) = EVP_PKEY_get0_EC_KEY(_pkey); \ if ((key) == NULL) \ - ossl_raise(eECError, "failed to get EC_KEY from EVP_PKEY"); \ + ossl_raise(ePKeyError, "failed to get EC_KEY from EVP_PKEY"); \ } while (0) #define GetECGroup(obj, group) do { \ @@ -43,7 +43,6 @@ static const rb_data_type_t ossl_ec_point_type; } while (0) VALUE cEC; -static VALUE eECError; static VALUE cEC_GROUP; static VALUE eEC_GROUP; static VALUE cEC_POINT; @@ -71,20 +70,20 @@ ec_key_new_from_group(VALUE arg) GetECGroup(arg, group); if (!(ec = EC_KEY_new())) - ossl_raise(eECError, NULL); + ossl_raise(ePKeyError, NULL); if (!EC_KEY_set_group(ec, group)) { EC_KEY_free(ec); - ossl_raise(eECError, NULL); + ossl_raise(ePKeyError, NULL); } } else { int nid = OBJ_sn2nid(StringValueCStr(arg)); if (nid == NID_undef) - ossl_raise(eECError, "invalid curve name"); + ossl_raise(ePKeyError, "invalid curve name"); if (!(ec = EC_KEY_new_by_curve_name(nid))) - ossl_raise(eECError, NULL); + ossl_raise(ePKeyError, NULL); EC_KEY_set_asn1_flag(ec, OPENSSL_EC_NAMED_CURVE); EC_KEY_set_conv_form(ec, POINT_CONVERSION_UNCOMPRESSED); @@ -114,12 +113,12 @@ ossl_ec_key_s_generate(VALUE klass, VALUE arg) if (!pkey || EVP_PKEY_assign_EC_KEY(pkey, ec) != 1) { EVP_PKEY_free(pkey); EC_KEY_free(ec); - ossl_raise(eECError, "EVP_PKEY_assign_EC_KEY"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_EC_KEY"); } RTYPEDDATA_DATA(obj) = pkey; if (!EC_KEY_generate_key(ec)) - ossl_raise(eECError, "EC_KEY_generate_key"); + ossl_raise(ePKeyError, "EC_KEY_generate_key"); return obj; } @@ -154,7 +153,7 @@ static VALUE ossl_ec_key_initialize(int argc, VALUE *argv, VALUE self) "without arguments; pkeys are immutable with OpenSSL 3.0"); #else if (!(ec = EC_KEY_new())) - ossl_raise(eECError, "EC_KEY_new"); + ossl_raise(ePKeyError, "EC_KEY_new"); goto legacy; #endif } @@ -178,7 +177,7 @@ static VALUE ossl_ec_key_initialize(int argc, VALUE *argv, VALUE self) type = EVP_PKEY_base_id(pkey); if (type != EVP_PKEY_EC) { EVP_PKEY_free(pkey); - rb_raise(eECError, "incorrect pkey type: %s", OBJ_nid2sn(type)); + rb_raise(ePKeyError, "incorrect pkey type: %s", OBJ_nid2sn(type)); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -188,7 +187,7 @@ static VALUE ossl_ec_key_initialize(int argc, VALUE *argv, VALUE self) if (!pkey || EVP_PKEY_assign_EC_KEY(pkey, ec) != 1) { EVP_PKEY_free(pkey); EC_KEY_free(ec); - ossl_raise(eECError, "EVP_PKEY_assign_EC_KEY"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_EC_KEY"); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -209,12 +208,12 @@ ossl_ec_key_initialize_copy(VALUE self, VALUE other) ec_new = EC_KEY_dup(ec); if (!ec_new) - ossl_raise(eECError, "EC_KEY_dup"); + ossl_raise(ePKeyError, "EC_KEY_dup"); pkey = EVP_PKEY_new(); if (!pkey || EVP_PKEY_assign_EC_KEY(pkey, ec_new) != 1) { EC_KEY_free(ec_new); - ossl_raise(eECError, "EVP_PKEY_assign_EC_KEY"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_EC_KEY"); } RTYPEDDATA_DATA(self) = pkey; @@ -263,7 +262,7 @@ ossl_ec_key_set_group(VALUE self, VALUE group_v) GetECGroup(group_v, group); if (EC_KEY_set_group(ec, group) != 1) - ossl_raise(eECError, "EC_KEY_set_group"); + ossl_raise(ePKeyError, "EC_KEY_set_group"); return group_v; #endif @@ -313,7 +312,7 @@ static VALUE ossl_ec_key_set_private_key(VALUE self, VALUE private_key) break; /* fallthrough */ default: - ossl_raise(eECError, "EC_KEY_set_private_key"); + ossl_raise(ePKeyError, "EC_KEY_set_private_key"); } return private_key; @@ -364,7 +363,7 @@ static VALUE ossl_ec_key_set_public_key(VALUE self, VALUE public_key) break; /* fallthrough */ default: - ossl_raise(eECError, "EC_KEY_set_public_key"); + ossl_raise(ePKeyError, "EC_KEY_set_public_key"); } return public_key; @@ -468,7 +467,7 @@ ossl_ec_key_export(int argc, VALUE *argv, VALUE self) GetEC(self, ec); if (EC_KEY_get0_public_key(ec) == NULL) - ossl_raise(eECError, "can't export - no public key set"); + ossl_raise(ePKeyError, "can't export - no public key set"); if (EC_KEY_get0_private_key(ec)) return ossl_pkey_export_traditional(argc, argv, self, 0); else @@ -496,7 +495,7 @@ ossl_ec_key_to_der(VALUE self) GetEC(self, ec); if (EC_KEY_get0_public_key(ec) == NULL) - ossl_raise(eECError, "can't export - no public key set"); + ossl_raise(ePKeyError, "can't export - no public key set"); if (EC_KEY_get0_private_key(ec)) return ossl_pkey_export_traditional(0, NULL, self, 1); else @@ -525,7 +524,7 @@ static VALUE ossl_ec_key_generate_key(VALUE self) GetEC(self, ec); if (EC_KEY_generate_key(ec) != 1) - ossl_raise(eECError, "EC_KEY_generate_key"); + ossl_raise(ePKeyError, "EC_KEY_generate_key"); return self; #endif @@ -550,18 +549,18 @@ static VALUE ossl_ec_key_check_key(VALUE self) GetEC(self, ec); pctx = EVP_PKEY_CTX_new(pkey, /* engine */NULL); if (!pctx) - ossl_raise(eECError, "EVP_PKEY_CTX_new"); + ossl_raise(ePKeyError, "EVP_PKEY_CTX_new"); if (EC_KEY_get0_private_key(ec) != NULL) { if (EVP_PKEY_check(pctx) != 1) { EVP_PKEY_CTX_free(pctx); - ossl_raise(eECError, "EVP_PKEY_check"); + ossl_raise(ePKeyError, "EVP_PKEY_check"); } } else { if (EVP_PKEY_public_check(pctx) != 1) { EVP_PKEY_CTX_free(pctx); - ossl_raise(eECError, "EVP_PKEY_public_check"); + ossl_raise(ePKeyError, "EVP_PKEY_public_check"); } } @@ -571,7 +570,7 @@ static VALUE ossl_ec_key_check_key(VALUE self) GetEC(self, ec); if (EC_KEY_check_key(ec) != 1) - ossl_raise(eECError, "EC_KEY_check_key"); + ossl_raise(ePKeyError, "EC_KEY_check_key"); #endif return Qtrue; @@ -1108,7 +1107,7 @@ static VALUE ossl_ec_group_to_string(VALUE self, int format) if (i != 1) { BIO_free(out); - ossl_raise(eECError, NULL); + ossl_raise(ePKeyError, NULL); } str = ossl_membio2str(out); @@ -1536,8 +1535,6 @@ void Init_ossl_ec(void) ePKeyError = rb_define_class_under(mPKey, "PKeyError", eOSSLError); #endif - eECError = rb_define_class_under(mPKey, "ECError", ePKeyError); - /* * Document-class: OpenSSL::PKey::EC * diff --git a/ext/openssl/ossl_pkey_rsa.c b/ext/openssl/ossl_pkey_rsa.c index b2983d3b53cfc2..ed65121acc8849 100644 --- a/ext/openssl/ossl_pkey_rsa.c +++ b/ext/openssl/ossl_pkey_rsa.c @@ -22,7 +22,7 @@ GetPKeyRSA((obj), _pkey); \ (rsa) = EVP_PKEY_get0_RSA(_pkey); \ if ((rsa) == NULL) \ - ossl_raise(eRSAError, "failed to get RSA from EVP_PKEY"); \ + ossl_raise(ePKeyError, "failed to get RSA from EVP_PKEY"); \ } while (0) static inline int @@ -44,7 +44,6 @@ RSA_PRIVATE(VALUE obj, OSSL_3_const RSA *rsa) * Classes */ VALUE cRSA; -static VALUE eRSAError; /* * Private @@ -98,7 +97,7 @@ ossl_rsa_initialize(int argc, VALUE *argv, VALUE self) #else rsa = RSA_new(); if (!rsa) - ossl_raise(eRSAError, "RSA_new"); + ossl_raise(ePKeyError, "RSA_new"); goto legacy; #endif } @@ -121,12 +120,12 @@ ossl_rsa_initialize(int argc, VALUE *argv, VALUE self) pkey = ossl_pkey_read_generic(in, pass); BIO_free(in); if (!pkey) - ossl_raise(eRSAError, "Neither PUB key nor PRIV key"); + ossl_raise(ePKeyError, "Neither PUB key nor PRIV key"); type = EVP_PKEY_base_id(pkey); if (type != EVP_PKEY_RSA) { EVP_PKEY_free(pkey); - rb_raise(eRSAError, "incorrect pkey type: %s", OBJ_nid2sn(type)); + rb_raise(ePKeyError, "incorrect pkey type: %s", OBJ_nid2sn(type)); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -137,7 +136,7 @@ ossl_rsa_initialize(int argc, VALUE *argv, VALUE self) if (!pkey || EVP_PKEY_assign_RSA(pkey, rsa) != 1) { EVP_PKEY_free(pkey); RSA_free(rsa); - ossl_raise(eRSAError, "EVP_PKEY_assign_RSA"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_RSA"); } RTYPEDDATA_DATA(self) = pkey; return self; @@ -160,12 +159,12 @@ ossl_rsa_initialize_copy(VALUE self, VALUE other) (d2i_of_void *)d2i_RSAPrivateKey, (char *)rsa); if (!rsa_new) - ossl_raise(eRSAError, "ASN1_dup"); + ossl_raise(ePKeyError, "ASN1_dup"); pkey = EVP_PKEY_new(); if (!pkey || EVP_PKEY_assign_RSA(pkey, rsa_new) != 1) { RSA_free(rsa_new); - ossl_raise(eRSAError, "EVP_PKEY_assign_RSA"); + ossl_raise(ePKeyError, "EVP_PKEY_assign_RSA"); } RTYPEDDATA_DATA(self) = pkey; @@ -320,7 +319,7 @@ ossl_rsa_to_der(VALUE self) * Signs _data_ using the Probabilistic Signature Scheme (RSA-PSS) and returns * the calculated signature. * - * RSAError will be raised if an error occurs. + * PKeyError will be raised if an error occurs. * * See #verify_pss for the verification operation. * @@ -349,7 +348,7 @@ ossl_rsa_to_der(VALUE self) static VALUE ossl_rsa_sign_pss(int argc, VALUE *argv, VALUE self) { - VALUE digest, data, options, kwargs[2], signature; + VALUE digest, data, options, kwargs[2], signature, mgf1md_holder, md_holder; static ID kwargs_ids[2]; EVP_PKEY *pkey; EVP_PKEY_CTX *pkey_ctx; @@ -370,11 +369,11 @@ ossl_rsa_sign_pss(int argc, VALUE *argv, VALUE self) salt_len = -1; /* RSA_PSS_SALTLEN_DIGEST */ else salt_len = NUM2INT(kwargs[0]); - mgf1md = ossl_evp_get_digestbyname(kwargs[1]); + mgf1md = ossl_evp_md_fetch(kwargs[1], &mgf1md_holder); pkey = GetPrivPKeyPtr(self); buf_len = EVP_PKEY_size(pkey); - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); StringValue(data); signature = rb_str_new(NULL, (long)buf_len); @@ -407,7 +406,7 @@ ossl_rsa_sign_pss(int argc, VALUE *argv, VALUE self) err: EVP_MD_CTX_free(md_ctx); - ossl_raise(eRSAError, NULL); + ossl_raise(ePKeyError, NULL); } /* @@ -417,7 +416,7 @@ ossl_rsa_sign_pss(int argc, VALUE *argv, VALUE self) * Verifies _data_ using the Probabilistic Signature Scheme (RSA-PSS). * * The return value is +true+ if the signature is valid, +false+ otherwise. - * RSAError will be raised if an error occurs. + * PKeyError will be raised if an error occurs. * * See #sign_pss for the signing operation and an example code. * @@ -436,7 +435,7 @@ ossl_rsa_sign_pss(int argc, VALUE *argv, VALUE self) static VALUE ossl_rsa_verify_pss(int argc, VALUE *argv, VALUE self) { - VALUE digest, signature, data, options, kwargs[2]; + VALUE digest, signature, data, options, kwargs[2], mgf1md_holder, md_holder; static ID kwargs_ids[2]; EVP_PKEY *pkey; EVP_PKEY_CTX *pkey_ctx; @@ -456,10 +455,10 @@ ossl_rsa_verify_pss(int argc, VALUE *argv, VALUE self) salt_len = -1; /* RSA_PSS_SALTLEN_DIGEST */ else salt_len = NUM2INT(kwargs[0]); - mgf1md = ossl_evp_get_digestbyname(kwargs[1]); + mgf1md = ossl_evp_md_fetch(kwargs[1], &mgf1md_holder); GetPKey(self, pkey); - md = ossl_evp_get_digestbyname(digest); + md = ossl_evp_md_fetch(digest, &md_holder); StringValue(signature); StringValue(data); @@ -485,22 +484,21 @@ ossl_rsa_verify_pss(int argc, VALUE *argv, VALUE self) result = EVP_DigestVerifyFinal(md_ctx, (unsigned char *)RSTRING_PTR(signature), RSTRING_LEN(signature)); + EVP_MD_CTX_free(md_ctx); switch (result) { case 0: ossl_clear_error(); - EVP_MD_CTX_free(md_ctx); return Qfalse; case 1: - EVP_MD_CTX_free(md_ctx); return Qtrue; default: - goto err; + ossl_raise(ePKeyError, "EVP_DigestVerifyFinal"); } err: EVP_MD_CTX_free(md_ctx); - ossl_raise(eRSAError, NULL); + ossl_raise(ePKeyError, NULL); } /* @@ -544,14 +542,6 @@ Init_ossl_rsa(void) ePKeyError = rb_define_class_under(mPKey, "PKeyError", eOSSLError); #endif - /* Document-class: OpenSSL::PKey::RSAError - * - * Generic exception that is raised if an operation on an RSA PKey - * fails unexpectedly or in case an instantiation of an instance of RSA - * fails due to non-conformant input data. - */ - eRSAError = rb_define_class_under(mPKey, "RSAError", ePKeyError); - /* Document-class: OpenSSL::PKey::RSA * * RSA is an asymmetric public key algorithm that has been formalized in diff --git a/ext/openssl/ossl_ts.c b/ext/openssl/ossl_ts.c index c7d2bd271b94c4..3c505b64a9f6a5 100644 --- a/ext/openssl/ossl_ts.c +++ b/ext/openssl/ossl_ts.c @@ -1155,9 +1155,14 @@ ossl_tsfac_time_cb(struct TS_resp_ctx *ctx, void *data, time_t *sec, long *usec) } static VALUE -ossl_evp_get_digestbyname_i(VALUE arg) +ossl_evp_md_fetch_i(VALUE args_) { - return (VALUE)ossl_evp_get_digestbyname(arg); + VALUE *args = (VALUE *)args_, md_holder; + const EVP_MD *md; + + md = ossl_evp_md_fetch(args[1], &md_holder); + rb_ary_push(args[0], md_holder); + return (VALUE)md; } static VALUE @@ -1193,7 +1198,8 @@ ossl_obj2bio_i(VALUE arg) static VALUE ossl_tsfac_create_ts(VALUE self, VALUE key, VALUE certificate, VALUE request) { - VALUE serial_number, def_policy_id, gen_time, additional_certs, allowed_digests; + VALUE serial_number, def_policy_id, gen_time, additional_certs, + allowed_digests, allowed_digests_tmp = Qnil; VALUE str; STACK_OF(X509) *inter_certs; VALUE tsresp, ret = Qnil; @@ -1270,16 +1276,18 @@ ossl_tsfac_create_ts(VALUE self, VALUE key, VALUE certificate, VALUE request) allowed_digests = ossl_tsfac_get_allowed_digests(self); if (rb_obj_is_kind_of(allowed_digests, rb_cArray)) { - int i; - VALUE rbmd; - const EVP_MD *md; - - for (i = 0; i < RARRAY_LEN(allowed_digests); i++) { - rbmd = rb_ary_entry(allowed_digests, i); - md = (const EVP_MD *)rb_protect(ossl_evp_get_digestbyname_i, rbmd, &status); + allowed_digests_tmp = rb_ary_new_capa(RARRAY_LEN(allowed_digests)); + for (long i = 0; i < RARRAY_LEN(allowed_digests); i++) { + VALUE args[] = { + allowed_digests_tmp, + rb_ary_entry(allowed_digests, i), + }; + const EVP_MD *md = (const EVP_MD *)rb_protect(ossl_evp_md_fetch_i, + (VALUE)args, &status); if (status) goto end; - TS_RESP_CTX_add_md(ctx, md); + if (!TS_RESP_CTX_add_md(ctx, md)) + goto end; } } @@ -1293,6 +1301,7 @@ ossl_tsfac_create_ts(VALUE self, VALUE key, VALUE certificate, VALUE request) response = TS_RESP_create_response(ctx, req_bio); BIO_free(req_bio); + RB_GC_GUARD(allowed_digests_tmp); if (!response) { err_msg = "Error during response generation"; diff --git a/ext/openssl/ossl_x509cert.c b/ext/openssl/ossl_x509cert.c index 30e3c617531bde..c7653031b4bc1f 100644 --- a/ext/openssl/ossl_x509cert.c +++ b/ext/openssl/ossl_x509cert.c @@ -535,17 +535,14 @@ ossl_x509_sign(VALUE self, VALUE key, VALUE digest) X509 *x509; EVP_PKEY *pkey; const EVP_MD *md; + VALUE md_holder; pkey = GetPrivPKeyPtr(key); /* NO NEED TO DUP */ - if (NIL_P(digest)) { - md = NULL; /* needed for some key types, e.g. Ed25519 */ - } else { - md = ossl_evp_get_digestbyname(digest); - } + /* NULL needed for some key types, e.g. Ed25519 */ + md = NIL_P(digest) ? NULL : ossl_evp_md_fetch(digest, &md_holder); GetX509(self, x509); - if (!X509_sign(x509, pkey, md)) { - ossl_raise(eX509CertError, NULL); - } + if (!X509_sign(x509, pkey, md)) + ossl_raise(eX509CertError, "X509_sign"); return self; } diff --git a/ext/openssl/ossl_x509crl.c b/ext/openssl/ossl_x509crl.c index 52174d1711487b..b9ee5f05692b52 100644 --- a/ext/openssl/ossl_x509crl.c +++ b/ext/openssl/ossl_x509crl.c @@ -349,17 +349,14 @@ ossl_x509crl_sign(VALUE self, VALUE key, VALUE digest) X509_CRL *crl; EVP_PKEY *pkey; const EVP_MD *md; + VALUE md_holder; GetX509CRL(self, crl); pkey = GetPrivPKeyPtr(key); /* NO NEED TO DUP */ - if (NIL_P(digest)) { - md = NULL; /* needed for some key types, e.g. Ed25519 */ - } else { - md = ossl_evp_get_digestbyname(digest); - } - if (!X509_CRL_sign(crl, pkey, md)) { - ossl_raise(eX509CRLError, NULL); - } + /* NULL needed for some key types, e.g. Ed25519 */ + md = NIL_P(digest) ? NULL : ossl_evp_md_fetch(digest, &md_holder); + if (!X509_CRL_sign(crl, pkey, md)) + ossl_raise(eX509CRLError, "X509_CRL_sign"); return self; } diff --git a/ext/openssl/ossl_x509req.c b/ext/openssl/ossl_x509req.c index b4c29f877e8536..eae57969241954 100644 --- a/ext/openssl/ossl_x509req.c +++ b/ext/openssl/ossl_x509req.c @@ -312,17 +312,14 @@ ossl_x509req_sign(VALUE self, VALUE key, VALUE digest) X509_REQ *req; EVP_PKEY *pkey; const EVP_MD *md; + VALUE md_holder; GetX509Req(self, req); pkey = GetPrivPKeyPtr(key); /* NO NEED TO DUP */ - if (NIL_P(digest)) { - md = NULL; /* needed for some key types, e.g. Ed25519 */ - } else { - md = ossl_evp_get_digestbyname(digest); - } - if (!X509_REQ_sign(req, pkey, md)) { - ossl_raise(eX509ReqError, NULL); - } + /* NULL needed for some key types, e.g. Ed25519 */ + md = NIL_P(digest) ? NULL : ossl_evp_md_fetch(digest, &md_holder); + if (!X509_REQ_sign(req, pkey, md)) + ossl_raise(eX509ReqError, "X509_REQ_sign"); return self; } diff --git a/prism/prism.c b/prism/prism.c index 6a77dd0febd5d3..03b12e9db82d34 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -14621,18 +14621,6 @@ update_parameter_state(pm_parser_t *parser, pm_token_t *token, pm_parameters_ord return true; } -/** - * Ensures that after parsing a parameter, the next token is not `=`. - * Some parameters like `def(* = 1)` cannot become optional. When no parens - * are present like in `def * = 1`, this creates ambiguity with endless method definitions. - */ -static inline void -refute_optional_parameter(pm_parser_t *parser) { - if (match1(parser, PM_TOKEN_EQUAL)) { - pm_parser_err_previous(parser, PM_ERR_DEF_ENDLESS_PARAMETERS); - } -} - /** * Parse a list of parameters on a method definition. */ @@ -14685,10 +14673,6 @@ parse_parameters( parser->current_scope->parameters |= PM_SCOPE_PARAMETERS_FORWARDING_BLOCK; } - if (!uses_parentheses) { - refute_optional_parameter(parser); - } - pm_block_parameter_node_t *param = pm_block_parameter_node_create(parser, &name, &operator); if (repeated) { pm_node_flag_set_repeated_parameter((pm_node_t *)param); @@ -14710,10 +14694,6 @@ parse_parameters( bool succeeded = update_parameter_state(parser, &parser->current, &order); parser_lex(parser); - if (!uses_parentheses) { - refute_optional_parameter(parser); - } - parser->current_scope->parameters |= PM_SCOPE_PARAMETERS_FORWARDING_ALL; pm_forwarding_parameter_node_t *param = pm_forwarding_parameter_node_create(parser, &parser->previous); @@ -14895,10 +14875,6 @@ parse_parameters( context_pop(parser); pm_parameters_node_keywords_append(params, param); - if (!uses_parentheses) { - refute_optional_parameter(parser); - } - // If parsing the value of the parameter resulted in error recovery, // then we can put a missing node in its place and stop parsing the // parameters entirely now. @@ -14930,10 +14906,6 @@ parse_parameters( parser->current_scope->parameters |= PM_SCOPE_PARAMETERS_FORWARDING_POSITIONALS; } - if (!uses_parentheses) { - refute_optional_parameter(parser); - } - pm_node_t *param = (pm_node_t *) pm_rest_parameter_node_create(parser, &operator, &name); if (repeated) { pm_node_flag_set_repeated_parameter(param); @@ -14982,10 +14954,6 @@ parse_parameters( } } - if (!uses_parentheses) { - refute_optional_parameter(parser); - } - if (params->keyword_rest == NULL) { pm_parameters_node_keyword_rest_set(params, param); } else { @@ -19586,6 +19554,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b pm_token_t rparen; pm_parameters_node_t *params; + bool accept_endless_def = true; switch (parser->current.type) { case PM_TOKEN_PARENTHESIS_LEFT: { parser_lex(parser); @@ -19621,6 +19590,10 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b rparen = not_provided(parser); params = parse_parameters(parser, PM_BINDING_POWER_DEFINED, false, false, true, true, false, (uint16_t) (depth + 1)); + // Reject `def * = 1` and similar. We have to specifically check + // for them because they create ambiguity with optional arguments. + accept_endless_def = false; + context_pop(parser); break; } @@ -19642,6 +19615,9 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b if (token_is_setter_name(&name)) { pm_parser_err_token(parser, &name, PM_ERR_DEF_ENDLESS_SETTER); } + if (!accept_endless_def) { + pm_parser_err_previous(parser, PM_ERR_DEF_ENDLESS_PARAMETERS); + } equal = parser->previous; context_push(parser, PM_CONTEXT_DEF); diff --git a/spec/ruby/library/openssl/digest/initialize_spec.rb b/spec/ruby/library/openssl/digest/initialize_spec.rb index 1cd0409c4d838c..b5911716ca8c14 100644 --- a/spec/ruby/library/openssl/digest/initialize_spec.rb +++ b/spec/ruby/library/openssl/digest/initialize_spec.rb @@ -23,18 +23,14 @@ OpenSSL::Digest.new("sha512").name.should == "SHA512" end - it "throws an error when called with an unknown digest" do - -> { OpenSSL::Digest.new("wd40") }.should raise_error(RuntimeError, /Unsupported digest algorithm \(wd40\)/) + version_is OpenSSL::VERSION, "4.0.0" do + it "throws an error when called with an unknown digest" do + -> { OpenSSL::Digest.new("wd40") }.should raise_error(OpenSSL::Digest::DigestError, /wd40/) + end end it "cannot be called with a symbol" do - -> { OpenSSL::Digest.new(:SHA1) }.should raise_error(TypeError, /wrong argument type Symbol/) - end - - it "does not call #to_str on the argument" do - name = mock("digest name") - name.should_not_receive(:to_str) - -> { OpenSSL::Digest.new(name) }.should raise_error(TypeError, /wrong argument type/) + -> { OpenSSL::Digest.new(:SHA1) }.should raise_error(TypeError) end end @@ -62,7 +58,7 @@ end it "cannot be called with a digest class" do - -> { OpenSSL::Digest.new(OpenSSL::Digest::SHA1) }.should raise_error(TypeError, /wrong argument type Class/) + -> { OpenSSL::Digest.new(OpenSSL::Digest::SHA1) }.should raise_error(TypeError) end context "when called without an initial String argument" do diff --git a/spec/ruby/library/openssl/kdf/pbkdf2_hmac_spec.rb b/spec/ruby/library/openssl/kdf/pbkdf2_hmac_spec.rb index 40f85972759a5a..1112972060e18e 100644 --- a/spec/ruby/library/openssl/kdf/pbkdf2_hmac_spec.rb +++ b/spec/ruby/library/openssl/kdf/pbkdf2_hmac_spec.rb @@ -107,21 +107,15 @@ it "raises a TypeError when hash is neither a String nor an OpenSSL::Digest" do -> { OpenSSL::KDF.pbkdf2_hmac("secret", **@defaults, hash: Object.new) - }.should raise_error(TypeError, "wrong argument type Object (expected OpenSSL/Digest)") + }.should raise_error(TypeError) end - it "raises a TypeError when hash is neither a String nor an OpenSSL::Digest, it does not try to call #to_str" do - hash = mock("hash") - hash.should_not_receive(:to_str) - -> { - OpenSSL::KDF.pbkdf2_hmac("secret", **@defaults, hash: hash) - }.should raise_error(TypeError, "wrong argument type MockObject (expected OpenSSL/Digest)") - end - - it "raises a RuntimeError for unknown digest algorithms" do - -> { - OpenSSL::KDF.pbkdf2_hmac("secret", **@defaults, hash: "wd40") - }.should raise_error(RuntimeError, /Unsupported digest algorithm \(wd40\)/) + version_is OpenSSL::VERSION, "4.0.0" do + it "raises a OpenSSL::Digest::DigestError for unknown digest algorithms" do + -> { + OpenSSL::KDF.pbkdf2_hmac("secret", **@defaults, hash: "wd40") + }.should raise_error(OpenSSL::Digest::DigestError, /wd40/) + end end it "treats salt as a required keyword" do diff --git a/test/openssl/test_cipher.rb b/test/openssl/test_cipher.rb index 10858b353559e1..93766cfc88ce17 100644 --- a/test/openssl/test_cipher.rb +++ b/test/openssl/test_cipher.rb @@ -112,6 +112,9 @@ def test_initialize cipher = OpenSSL::Cipher.new("DES-EDE3-CBC") assert_raise(RuntimeError) { cipher.__send__(:initialize, "DES-EDE3-CBC") } assert_raise(RuntimeError) { OpenSSL::Cipher.allocate.final } + assert_raise(OpenSSL::Cipher::CipherError) { + OpenSSL::Cipher.new("no such algorithm") + } end def test_ctr_if_exists @@ -342,6 +345,24 @@ def test_aes_ocb_tag_len end if has_cipher?("aes-128-ocb") + def test_aes_gcm_siv + # RFC 8452 Appendix C.1., 8th example + key = ["01000000000000000000000000000000"].pack("H*") + iv = ["030000000000000000000000"].pack("H*") + aad = ["01"].pack("H*") + pt = ["0200000000000000"].pack("H*") + ct = ["1e6daba35669f4273b0a1a2560969cdf790d99759abd1508"].pack("H*") + tag = ["3b0a1a2560969cdf790d99759abd1508"].pack("H*") + ct_without_tag = ct.byteslice(0, ct.bytesize - tag.bytesize) + + cipher = new_encryptor("aes-128-gcm-siv", key: key, iv: iv, auth_data: aad) + assert_equal ct_without_tag, cipher.update(pt) << cipher.final + assert_equal tag, cipher.auth_tag + cipher = new_decryptor("aes-128-gcm-siv", key: key, iv: iv, auth_tag: tag, + auth_data: aad) + assert_equal pt, cipher.update(ct_without_tag) << cipher.final + end if openssl?(3, 2, 0) + def test_aes_gcm_key_iv_order_issue pt = "[ruby/openssl#49]" cipher = OpenSSL::Cipher.new("aes-128-gcm").encrypt @@ -368,7 +389,7 @@ def test_aes_keywrap_pad begin cipher = OpenSSL::Cipher.new("id-aes192-wrap-pad").encrypt - rescue OpenSSL::Cipher::CipherError, RuntimeError + rescue OpenSSL::Cipher::CipherError omit "id-aes192-wrap-pad is not supported: #$!" end cipher.key = kek diff --git a/test/openssl/test_digest.rb b/test/openssl/test_digest.rb index 5b4eb3c74c440a..2ef84cfa4c7d2a 100644 --- a/test/openssl/test_digest.rb +++ b/test/openssl/test_digest.rb @@ -10,6 +10,12 @@ def setup @d2 = OpenSSL::Digest::MD5.new end + def test_initialize + assert_raise(OpenSSL::Digest::DigestError) { + OpenSSL::Digest.new("no such algorithm") + } + end + def test_digest null_hex = "d41d8cd98f00b204e9800998ecf8427e" null_bin = [null_hex].pack("H*") @@ -62,8 +68,17 @@ def test_digest_constants end def test_digest_by_oid_and_name - check_digest(OpenSSL::ASN1::ObjectId.new("MD5")) - check_digest(OpenSSL::ASN1::ObjectId.new("SHA1")) + # SHA256 + o1 = OpenSSL::Digest.digest("SHA256", "") + o2 = OpenSSL::Digest.digest("sha256", "") + assert_equal(o1, o2) + o3 = OpenSSL::Digest.digest("2.16.840.1.101.3.4.2.1", "") + assert_equal(o1, o3) + + # An alias for SHA256 recognized by EVP_get_digestbyname(), but not by + # EVP_MD_fetch() + o4 = OpenSSL::Digest.digest("RSA-SHA256", "") + assert_equal(o1, o4) end def encode16(str) @@ -109,12 +124,12 @@ def test_sha3 assert_equal(s512, OpenSSL::Digest.hexdigest('SHA3-512', "")) end - def test_digest_by_oid_and_name_sha2 - check_digest(OpenSSL::ASN1::ObjectId.new("SHA224")) - check_digest(OpenSSL::ASN1::ObjectId.new("SHA256")) - check_digest(OpenSSL::ASN1::ObjectId.new("SHA384")) - check_digest(OpenSSL::ASN1::ObjectId.new("SHA512")) - end + def test_fetched_evp_md + # Pre-NIST Keccak is an example of a digest algorithm that doesn't have an + # NID and requires dynamic allocation of EVP_MD + hex = "c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470" + assert_equal(hex, OpenSSL::Digest.hexdigest("KECCAK-256", "")) + end if openssl?(3, 2, 0) def test_openssl_digest assert_equal OpenSSL::Digest::MD5, OpenSSL::Digest("MD5") @@ -132,17 +147,6 @@ def test_digests assert_include digests, "sha256" assert_include digests, "sha512" end - - private - - def check_digest(oid) - d = OpenSSL::Digest.new(oid.sn) - assert_not_nil(d) - d = OpenSSL::Digest.new(oid.ln) - assert_not_nil(d) - d = OpenSSL::Digest.new(oid.oid) - assert_not_nil(d) - end end end diff --git a/test/openssl/test_pkey.rb b/test/openssl/test_pkey.rb index 0943a7737db707..88299888f046bf 100644 --- a/test/openssl/test_pkey.rb +++ b/test/openssl/test_pkey.rb @@ -314,4 +314,11 @@ def test_to_text rsa = Fixtures.pkey("rsa-1") assert_include rsa.to_text, "publicExponent" end + + def test_legacy_error_classes + assert_same(OpenSSL::PKey::PKeyError, OpenSSL::PKey::DSAError) + assert_same(OpenSSL::PKey::PKeyError, OpenSSL::PKey::DHError) + assert_same(OpenSSL::PKey::PKeyError, OpenSSL::PKey::ECError) + assert_same(OpenSSL::PKey::PKeyError, OpenSSL::PKey::RSAError) + end end diff --git a/test/openssl/test_pkey_dh.rb b/test/openssl/test_pkey_dh.rb index 6ca5b1f5f8cd40..cd13283a2a7dce 100644 --- a/test/openssl/test_pkey_dh.rb +++ b/test/openssl/test_pkey_dh.rb @@ -140,7 +140,7 @@ def test_params_ok? # AWS-LC automatically does parameter checks on the parsed params. if aws_lc? - assert_raise(OpenSSL::PKey::DHError) { + assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey::DH.new(OpenSSL::ASN1::Sequence([ OpenSSL::ASN1::Integer(dh0.p + 1), OpenSSL::ASN1::Integer(dh0.g) diff --git a/test/openssl/test_pkey_dsa.rb b/test/openssl/test_pkey_dsa.rb index ef0fdf9182fb1b..1ec0bf0b4d4bf8 100644 --- a/test/openssl/test_pkey_dsa.rb +++ b/test/openssl/test_pkey_dsa.rb @@ -97,7 +97,7 @@ def test_sign_verify_raw sig = key.syssign(digest) assert_equal true, key.sysverify(digest, sig) assert_equal false, key.sysverify(digest, invalid_sig) - assert_sign_verify_false_or_error{ key.sysverify(digest, malformed_sig) } + assert_sign_verify_false_or_error { key.sysverify(digest, malformed_sig) } assert_equal true, key.verify_raw(nil, sig, digest) assert_equal false, key.verify_raw(nil, invalid_sig, digest) assert_sign_verify_false_or_error { key.verify_raw(nil, malformed_sig, digest) } @@ -148,7 +148,7 @@ def test_DSAPrivateKey_encrypted cipher = OpenSSL::Cipher.new("aes-128-cbc") exported = orig.to_pem(cipher, "abcdef\0\1") assert_same_dsa orig, OpenSSL::PKey::DSA.new(exported, "abcdef\0\1") - assert_raise(OpenSSL::PKey::DSAError) { + assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey::DSA.new(exported, "abcdef") } end diff --git a/test/openssl/test_pkey_ec.rb b/test/openssl/test_pkey_ec.rb index 58857cb038b10b..df91a1be255f07 100644 --- a/test/openssl/test_pkey_ec.rb +++ b/test/openssl/test_pkey_ec.rb @@ -54,7 +54,9 @@ def test_builtin_curves end def test_generate - assert_raise(OpenSSL::PKey::ECError) { OpenSSL::PKey::EC.generate("non-existent") } + assert_raise(OpenSSL::PKey::PKeyError) { + OpenSSL::PKey::EC.generate("non-existent") + } g = OpenSSL::PKey::EC::Group.new("prime256v1") ec = OpenSSL::PKey::EC.generate(g) assert_equal(true, ec.private?) @@ -65,7 +67,7 @@ def test_generate def test_generate_key ec = OpenSSL::PKey::EC.new("prime256v1") assert_equal false, ec.private? - assert_raise(OpenSSL::PKey::ECError) { ec.to_der } + assert_raise(OpenSSL::PKey::PKeyError) { ec.to_der } ec.generate_key! assert_equal true, ec.private? assert_nothing_raised { ec.to_der } @@ -109,13 +111,13 @@ def test_check_key assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey.read(ec_key_data) } else key4 = OpenSSL::PKey.read(ec_key_data) - assert_raise(OpenSSL::PKey::ECError) { key4.check_key } + assert_raise(OpenSSL::PKey::PKeyError) { key4.check_key } end # EC#private_key= is deprecated in 3.0 and won't work on OpenSSL 3.0 if !openssl?(3, 0, 0) key2.private_key += 1 - assert_raise(OpenSSL::PKey::ECError) { key2.check_key } + assert_raise(OpenSSL::PKey::PKeyError) { key2.check_key } end end @@ -269,7 +271,7 @@ def test_ECPrivateKey_encrypted cipher = OpenSSL::Cipher.new("aes-128-cbc") exported = p256.to_pem(cipher, "abcdef\0\1") assert_same_ec p256, OpenSSL::PKey::EC.new(exported, "abcdef\0\1") - assert_raise(OpenSSL::PKey::ECError) { + assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey::EC.new(exported, "abcdef") } end diff --git a/test/openssl/test_pkey_rsa.rb b/test/openssl/test_pkey_rsa.rb index 90dd0481edfc67..86f51cf4385a01 100644 --- a/test/openssl/test_pkey_rsa.rb +++ b/test/openssl/test_pkey_rsa.rb @@ -9,8 +9,8 @@ def test_no_private_exp rsa = Fixtures.pkey("rsa-1") key.set_key(rsa.n, rsa.e, nil) key.set_factors(rsa.p, rsa.q) - assert_raise(OpenSSL::PKey::RSAError){ key.private_encrypt("foo") } - assert_raise(OpenSSL::PKey::RSAError){ key.private_decrypt("foo") } + assert_raise(OpenSSL::PKey::PKeyError){ key.private_encrypt("foo") } + assert_raise(OpenSSL::PKey::PKeyError){ key.private_decrypt("foo") } end if !openssl?(3, 0, 0) # Impossible state in OpenSSL 3.0 def test_private @@ -180,7 +180,7 @@ def test_sign_verify_raw_legacy # Failure cases assert_raise(ArgumentError){ key.private_encrypt() } assert_raise(ArgumentError){ key.private_encrypt("hi", 1, nil) } - assert_raise(OpenSSL::PKey::RSAError){ key.private_encrypt(plain0, 666) } + assert_raise(OpenSSL::PKey::PKeyError){ key.private_encrypt(plain0, 666) } end @@ -231,7 +231,7 @@ def test_sign_verify_pss key.verify_pss("SHA256", signature, data, salt_length: :auto, mgf1_hash: "SHA256") end - assert_raise(OpenSSL::PKey::RSAError) { + assert_raise(OpenSSL::PKey::PKeyError) { key.sign_pss("SHA256", data, salt_length: 223, mgf1_hash: "SHA256") } end @@ -373,7 +373,7 @@ def test_RSAPrivateKey_encrypted cipher = OpenSSL::Cipher.new("aes-128-cbc") exported = rsa.to_pem(cipher, "abcdef\0\1") assert_same_rsa rsa, OpenSSL::PKey::RSA.new(exported, "abcdef\0\1") - assert_raise(OpenSSL::PKey::RSAError) { + assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey::RSA.new(exported, "abcdef") } end diff --git a/test/prism/errors/endless_method_command_call_parameters.txt b/test/prism/errors/endless_method_command_call_parameters.txt index 94c4f88fc835df..5dc92ce7f9f0d9 100644 --- a/test/prism/errors/endless_method_command_call_parameters.txt +++ b/test/prism/errors/endless_method_command_call_parameters.txt @@ -1,24 +1,27 @@ def f x: = 1 - ^~ could not parse the endless method parameters + ^ could not parse the endless method parameters def f ... = 1 - ^~~ could not parse the endless method parameters + ^ could not parse the endless method parameters def f * = 1 - ^ could not parse the endless method parameters + ^ could not parse the endless method parameters def f ** = 1 - ^~ could not parse the endless method parameters + ^ could not parse the endless method parameters def f & = 1 - ^ could not parse the endless method parameters + ^ could not parse the endless method parameters def f *a = 1 - ^ could not parse the endless method parameters + ^ could not parse the endless method parameters def f **a = 1 - ^ could not parse the endless method parameters + ^ could not parse the endless method parameters def f &a = 1 - ^ could not parse the endless method parameters + ^ could not parse the endless method parameters + +def f a, (b) = 1 + ^ could not parse the endless method parameters diff --git a/zjit/src/asm/mod.rs b/zjit/src/asm/mod.rs index dca2b7b0cf018a..86176c0ec9bae5 100644 --- a/zjit/src/asm/mod.rs +++ b/zjit/src/asm/mod.rs @@ -208,6 +208,14 @@ impl CodeBlock { self.dropped_bytes } + /// Set dropped_bytes to false if the current zjit_alloc_bytes() + code_region_size + /// + page_size is below --zjit-mem-size. + pub fn update_dropped_bytes(&mut self) { + if self.mem_block.borrow().can_allocate() { + self.dropped_bytes = false; + } + } + /// Allocate a new label with a given name pub fn new_label(&mut self, name: String) -> Label { assert!(!name.contains(' '), "use underscores in label names, not spaces"); diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index e2f75e01c8fcba..69b030608be776 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -1750,7 +1750,15 @@ impl Assembler #[cfg(feature = "disasm")] let start_addr = cb.get_write_ptr(); let alloc_regs = Self::get_alloc_regs(); - let ret = self.compile_with_regs(cb, alloc_regs); + let had_dropped_bytes = cb.has_dropped_bytes(); + let ret = self.compile_with_regs(cb, alloc_regs).inspect_err(|err| { + // If we use too much memory to compile the Assembler, it would set cb.dropped_bytes = true. + // To avoid failing future compilation by cb.has_dropped_bytes(), attempt to reset dropped_bytes with + // the current zjit_alloc_bytes() which may be decreased after self is dropped in compile_with_regs(). + if *err == CompileError::OutOfMemory && !had_dropped_bytes { + cb.update_dropped_bytes(); + } + }); #[cfg(feature = "disasm")] if get_option!(dump_disasm) { diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 364b9225fe07e6..6a7707dd5a5a17 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -220,6 +220,11 @@ fn gen_iseq(cb: &mut CodeBlock, iseq: IseqPtr, function: Option<&Function>) -> R /// Compile an ISEQ into machine code fn gen_iseq_body(cb: &mut CodeBlock, iseq: IseqPtr, function: Option<&Function>, payload: &mut IseqPayload) -> Result { + // If we ran out of code region, we shouldn't attempt to generate new code. + if cb.has_dropped_bytes() { + return Err(CompileError::OutOfMemory); + } + // Convert ISEQ into optimized High-level IR if not given let function = match function { Some(function) => function, @@ -426,6 +431,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::GetClassVar { id, ic, state } => gen_getclassvar(jit, asm, *id, *ic, &function.frame_state(*state)), Insn::SetClassVar { id, val, ic, state } => no_output!(gen_setclassvar(jit, asm, *id, opnd!(val), *ic, &function.frame_state(*state))), Insn::SetIvar { self_val, id, val, state } => no_output!(gen_setivar(jit, asm, opnd!(self_val), *id, opnd!(val), &function.frame_state(*state))), + Insn::SetInstanceVariable { self_val, id, ic, val, state } => no_output!(gen_set_instance_variable(jit, asm, opnd!(self_val), *id, *ic, opnd!(val), &function.frame_state(*state))), Insn::SideExit { state, reason } => no_output!(gen_side_exit(jit, asm, reason, &function.frame_state(*state))), Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type), Insn::AnyToString { val, str, state } => gen_anytostring(asm, opnd!(val), opnd!(str), &function.frame_state(*state)), @@ -840,6 +846,15 @@ fn gen_setivar(jit: &mut JITState, asm: &mut Assembler, recv: Opnd, id: ID, val: asm_ccall!(asm, rb_ivar_set, recv, id.0.into(), val); } +/// Emit an uncached instance variable store using the interpreter inline cache +fn gen_set_instance_variable(jit: &mut JITState, asm: &mut Assembler, recv: Opnd, id: ID, ic: *const iseq_inline_constant_cache, val: Opnd, state: &FrameState) { + gen_incr_counter(asm, Counter::dynamic_setivar_count); + // Setting an ivar can raise FrozenError, so we need proper frame state for exception handling. + gen_prepare_non_leaf_call(jit, asm, state); + let iseq = Opnd::Value(jit.iseq.into()); + asm_ccall!(asm, rb_vm_setinstancevariable, iseq, recv, id.0.into(), val, Opnd::const_ptr(ic)); +} + fn gen_getclassvar(jit: &mut JITState, asm: &mut Assembler, id: ID, ic: *const iseq_inline_cvar_cache_entry, state: &FrameState) -> Opnd { gen_prepare_non_leaf_call(jit, asm, state); asm_ccall!(asm, rb_vm_getclassvariable, VALUE::from(jit.iseq).into(), CFP, id.0.into(), Opnd::const_ptr(ic)) diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 449047d0dfd72c..e0a2e0fdff1802 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -675,6 +675,8 @@ pub enum Insn { GetIvar { self_val: InsnId, id: ID, state: InsnId }, /// Set `self_val`'s instance variable `id` to `val` SetIvar { self_val: InsnId, id: ID, val: InsnId, state: InsnId }, + /// Set `self_val`'s instance variable `id` to `val` using the interpreter inline cache + SetInstanceVariable { self_val: InsnId, id: ID, ic: *const iseq_inline_constant_cache, val: InsnId, state: InsnId }, /// Check whether an instance variable exists on `self_val` DefinedIvar { self_val: InsnId, id: ID, pushval: VALUE, state: InsnId }, @@ -866,7 +868,7 @@ impl Insn { | Insn::PatchPoint { .. } | Insn::SetIvar { .. } | Insn::SetClassVar { .. } | Insn::ArrayExtend { .. } | Insn::ArrayPush { .. } | Insn::SideExit { .. } | Insn::SetGlobal { .. } | Insn::SetLocal { .. } | Insn::Throw { .. } | Insn::IncrCounter(_) | Insn::IncrCounterPtr { .. } - | Insn::CheckInterrupts { .. } | Insn::GuardBlockParamProxy { .. } => false, + | Insn::CheckInterrupts { .. } | Insn::GuardBlockParamProxy { .. } | Insn::SetInstanceVariable { .. } => false, _ => true, } } @@ -1193,6 +1195,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::LoadSelf => write!(f, "LoadSelf"), &Insn::LoadField { recv, id, offset, return_type: _ } => write!(f, "LoadField {recv}, :{}@{:p}", id.contents_lossy(), self.ptr_map.map_offset(offset)), Insn::SetIvar { self_val, id, val, .. } => write!(f, "SetIvar {self_val}, :{}, {val}", id.contents_lossy()), + Insn::SetInstanceVariable { self_val, id, val, .. } => write!(f, "SetInstanceVariable {self_val}, :{}, {val}", id.contents_lossy()), Insn::GetGlobal { id, .. } => write!(f, "GetGlobal :{}", id.contents_lossy()), Insn::SetGlobal { id, val, .. } => write!(f, "SetGlobal :{}, {val}", id.contents_lossy()), &Insn::GetLocal { level, ep_offset, use_sp: true, rest_param } => write!(f, "GetLocal l{level}, SP@{}{}", ep_offset + 1, if rest_param { ", *" } else { "" }), @@ -1817,6 +1820,7 @@ impl Function { &GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state }, &LoadField { recv, id, offset, return_type } => LoadField { recv: find!(recv), id, offset, return_type }, &SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val: find!(val), state }, + &SetInstanceVariable { self_val, id, ic, val, state } => SetInstanceVariable { self_val: find!(self_val), id, ic, val: find!(val), state }, &GetClassVar { id, ic, state } => GetClassVar { id, ic, state }, &SetClassVar { id, val, ic, state } => SetClassVar { id, val: find!(val), ic, state }, &SetLocal { val, ep_offset, level } => SetLocal { val: find!(val), ep_offset, level }, @@ -1870,7 +1874,8 @@ impl Function { | Insn::IfTrue { .. } | Insn::IfFalse { .. } | Insn::Return { .. } | Insn::Throw { .. } | Insn::PatchPoint { .. } | Insn::SetIvar { .. } | Insn::SetClassVar { .. } | Insn::ArrayExtend { .. } | Insn::ArrayPush { .. } | Insn::SideExit { .. } | Insn::SetLocal { .. } | Insn::IncrCounter(_) - | Insn::CheckInterrupts { .. } | Insn::GuardBlockParamProxy { .. } | Insn::IncrCounterPtr { .. } => + | Insn::CheckInterrupts { .. } | Insn::GuardBlockParamProxy { .. } | Insn::IncrCounterPtr { .. } + | Insn::SetInstanceVariable { .. } => panic!("Cannot infer type of instruction with no output: {}", self.insns[insn.0]), Insn::Const { val: Const::Value(val) } => Type::from_value(*val), Insn::Const { val: Const::CBool(val) } => Type::from_cbool(*val), @@ -3379,7 +3384,8 @@ impl Function { worklist.push_back(self_val); worklist.push_back(state); } - &Insn::SetIvar { self_val, val, state, .. } => { + &Insn::SetIvar { self_val, val, state, .. } + | &Insn::SetInstanceVariable { self_val, val, state, .. } => { worklist.push_back(self_val); worklist.push_back(val); worklist.push_back(state); @@ -5089,13 +5095,13 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { } YARVINSN_setinstancevariable => { let id = ID(get_arg(pc, 0).as_u64()); - // ic is in arg 1 + let ic = get_arg(pc, 1).as_ptr(); let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); // Assume single-Ractor mode to omit gen_prepare_non_leaf_call on gen_setivar // TODO: We only really need this if self_val is a class/module fun.push_insn(block, Insn::PatchPoint { invariant: Invariant::SingleRactorMode, state: exit_id }); let val = state.stack_pop()?; - fun.push_insn(block, Insn::SetIvar { self_val: self_param, id, val, state: exit_id }); + fun.push_insn(block, Insn::SetInstanceVariable { self_val: self_param, id, ic, val, state: exit_id }); } YARVINSN_getclassvariable => { let id = ID(get_arg(pc, 0).as_u64()); diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index f824351eca7551..9b757433e1b032 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -3198,7 +3198,7 @@ mod hir_opt_tests { bb2(v6:BasicObject): v10:Fixnum[1] = Const Value(1) PatchPoint SingleRactorMode - SetIvar v6, :@foo, v10 + SetInstanceVariable v6, :@foo, v10 CheckInterrupts Return v10 "); diff --git a/zjit/src/hir/tests.rs b/zjit/src/hir/tests.rs index a8738a07157b16..af3fd3de9153d0 100644 --- a/zjit/src/hir/tests.rs +++ b/zjit/src/hir/tests.rs @@ -2185,7 +2185,7 @@ pub mod hir_build_tests { bb2(v6:BasicObject): v10:Fixnum[1] = Const Value(1) PatchPoint SingleRactorMode - SetIvar v6, :@foo, v10 + SetInstanceVariable v6, :@foo, v10 CheckInterrupts Return v10 "); diff --git a/zjit/src/virtualmem.rs b/zjit/src/virtualmem.rs index 770fbfba47d1c8..9741a7b13867d5 100644 --- a/zjit/src/virtualmem.rs +++ b/zjit/src/virtualmem.rs @@ -258,6 +258,13 @@ impl VirtualMemory { Ok(()) } + /// Return true if write_byte() can allocate a new page + pub fn can_allocate(&self) -> bool { + let memory_usage_bytes = self.mapped_region_bytes + zjit_alloc_bytes(); + let memory_limit_bytes = self.memory_limit_bytes.unwrap_or(self.region_size_bytes); + memory_usage_bytes + self.page_size_bytes < memory_limit_bytes + } + /// Make all the code in the region executable. Call this at the end of a write session. /// See [Self] for usual usage flow. pub fn mark_all_executable(&mut self) {