diff --git a/src/scitokens_internal.cpp b/src/scitokens_internal.cpp index b90a301..345165b 100644 --- a/src/scitokens_internal.cpp +++ b/src/scitokens_internal.cpp @@ -189,10 +189,28 @@ struct local_base64url : public jwt::alphabet::base64url { }; +// Assuming a padding, decode +std::string b64url_decode_nopadding(const std::string &input) +{ + std::string result = input; + switch (result.size() % 4) { + case 1: + result += "="; // fallthrough + case 2: + result += "="; // fallthrough + case 3: + result += "="; // fallthrough + default: + break; + } + return jwt::base::decode(result); +} + + std::string es256_from_coords(const std::string &x_str, const std::string &y_str) { - auto x_decode = jwt::base::decode(x_str); - auto y_decode = jwt::base::decode(y_str); + auto x_decode = b64url_decode_nopadding(x_str); + auto y_decode = b64url_decode_nopadding(y_str); std::unique_ptr ec(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); if (!ec.get()) { @@ -232,8 +250,8 @@ es256_from_coords(const std::string &x_str, const std::string &y_str) { std::string rs256_from_coords(const std::string &e_str, const std::string &n_str) { - auto e_decode = jwt::base::decode(e_str); - auto n_decode = jwt::base::decode(n_str); + auto e_decode = b64url_decode_nopadding(e_str); + auto n_decode = b64url_decode_nopadding(n_str); std::unique_ptr e_bignum(BN_bin2bn(reinterpret_cast(e_decode.c_str()), e_decode.size(), nullptr), BN_free); std::unique_ptr n_bignum(BN_bin2bn(reinterpret_cast(n_decode.c_str()), n_decode.size(), nullptr), BN_free); @@ -399,10 +417,33 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid, auto key_obj = find_key_id(keys, kid); auto iter = key_obj.find("alg"); + std::string alg; if (iter == key_obj.end() || (!iter->second.is())) { - throw JsonException("Key is missing algorithm name"); - } - auto alg = iter->second.get(); + auto iter2 = key_obj.find("kty"); + if (iter2 == key_obj.end() || !iter2->second.is()) { + throw JsonException("Key is missing key type"); + } else { + auto kty = iter2->second.get(); + if (kty == "RSA") { + alg = "RS256"; + } else if (kty == "EC") { + auto iter3 = key_obj.find("crv"); + if (iter3 == key_obj.end() || !iter3->second.is()) { + throw JsonException("EC key is missing curve name"); + } + auto crv = iter2->second.get(); + if (crv == "P-256") { + alg = "EC256"; + } else { + throw JsonException("Unsupported EC curve in public key"); + } + } else { + throw JsonException("Unknown public key type"); + } + } + } else { + alg = iter->second.get(); + } if (alg != "RS256" and alg != "ES256") { throw UnsupportedKeyException("Issuer is using an unsupported algorithm"); }