diff --git a/src/scitokens.cpp b/src/scitokens.cpp index 0919bcc..d3d8346 100644 --- a/src/scitokens.cpp +++ b/src/scitokens.cpp @@ -1134,6 +1134,80 @@ int keycache_stop_background_refresh(char **err_msg) { return keycache_set_background_refresh(0, err_msg); } +int keycache_load_jwks(const char *issuer, char **jwks, char **err_msg) { + if (!issuer) { + if (err_msg) { + *err_msg = strdup("Issuer may not be a null pointer"); + } + return -1; + } + if (!jwks) { + if (err_msg) { + *err_msg = strdup("JWKS output pointer may not be null."); + } + return -1; + } + try { + *jwks = strdup(scitokens::Validator::load_jwks(issuer).c_str()); + } catch (std::exception &exc) { + if (err_msg) { + *err_msg = strdup(exc.what()); + } + return -1; + } + return 0; +} + +int keycache_get_jwks_metadata(const char *issuer, char **metadata, + char **err_msg) { + if (!issuer) { + if (err_msg) { + *err_msg = strdup("Issuer may not be a null pointer"); + } + return -1; + } + if (!metadata) { + if (err_msg) { + *err_msg = strdup("Metadata output pointer may not be null."); + } + return -1; + } + try { + *metadata = + strdup(scitokens::Validator::get_jwks_metadata(issuer).c_str()); + } catch (std::exception &exc) { + if (err_msg) { + *err_msg = strdup(exc.what()); + } + return -1; + } + return 0; +} + +int keycache_delete_jwks(const char *issuer, char **err_msg) { + if (!issuer) { + if (err_msg) { + *err_msg = strdup("Issuer may not be a null pointer"); + } + return -1; + } + try { + if (!scitokens::Validator::delete_jwks(issuer)) { + if (err_msg) { + *err_msg = + strdup("Failed to delete JWKS cache entry for issuer."); + } + return -1; + } + } catch (std::exception &exc) { + if (err_msg) { + *err_msg = strdup(exc.what()); + } + return -1; + } + return 0; +} + int config_set_int(const char *key, int value, char **err_msg) { return scitoken_config_set_int(key, value, err_msg); } diff --git a/src/scitokens.h b/src/scitokens.h index 88dcc68..b067425 100644 --- a/src/scitokens.h +++ b/src/scitokens.h @@ -309,6 +309,36 @@ int keycache_set_background_refresh(int enabled, char **err_msg); */ int keycache_stop_background_refresh(char **err_msg); +/** + * Load the JWKS from the keycache for a given issuer, refreshing only if + * needed. + * - Returns 0 if successful, nonzero on failure. + * - If the existing JWKS has not expired, this will return the cached JWKS + * without triggering a download. + * - If the JWKS has expired or does not exist, this will attempt to refresh + * it from the issuer. + * - `jwks` is an output variable set to the contents of the JWKS. + */ +int keycache_load_jwks(const char *issuer, char **jwks, char **err_msg); + +/** + * Get metadata for a cached JWKS entry. + * - Returns 0 if successful, nonzero on failure. + * - `metadata` is an output variable set to a JSON string containing: + * - "expires": expiration time (Unix epoch seconds) + * - "next_update": next update time (Unix epoch seconds) + * - If the issuer does not exist in the cache, returns an error. + */ +int keycache_get_jwks_metadata(const char *issuer, char **metadata, + char **err_msg); + +/** + * Delete a JWKS entry from the keycache. + * - Returns 0 if successful, nonzero on failure. + * - If the issuer does not exist in the cache, this is not considered an error. + */ +int keycache_delete_jwks(const char *issuer, char **err_msg); + /** * APIs for managing scitokens configuration parameters. */ diff --git a/src/scitokens_cache.cpp b/src/scitokens_cache.cpp index a2094b1..3dd0a8c 100644 --- a/src/scitokens_cache.cpp +++ b/src/scitokens_cache.cpp @@ -22,6 +22,9 @@ namespace { // This handles concurrent access from multiple threads/processes constexpr int SQLITE_BUSY_TIMEOUT_MS = 5000; +// Default time before expiry when next_update should occur (4 hours) +constexpr int64_t DEFAULT_NEXT_UPDATE_OFFSET_S = 4 * 3600; + void initialize_cachedb(const std::string &keycache_file) { sqlite3 *db; @@ -257,7 +260,7 @@ bool scitokens::Validator::get_public_keys_from_db(const std::string issuer, sqlite3_close(db); iter = top_obj.find("next_update"); if (iter == top_obj.end() || !iter->second.is()) { - next_update = expiry - 4 * 3600; + next_update = expiry - DEFAULT_NEXT_UPDATE_OFFSET_S; } else { next_update = iter->second.get(); } @@ -406,7 +409,7 @@ scitokens::Validator::get_all_issuers_from_db(int64_t now) { if (next_update_iter == top_obj.end() || !next_update_iter->second.is()) { // If next_update is not set, default to 4 hours before expiry - next_update = expiry - 4 * 3600; + next_update = expiry - DEFAULT_NEXT_UPDATE_OFFSET_S; } else { next_update = next_update_iter->second.get(); } @@ -425,3 +428,138 @@ scitokens::Validator::get_all_issuers_from_db(int64_t now) { sqlite3_close(db); return result; } + +std::string scitokens::Validator::load_jwks(const std::string &issuer) { + auto now = std::time(NULL); + picojson::value jwks; + int64_t next_update; + + try { + // Try to get from cache + if (get_public_keys_from_db(issuer, now, jwks, next_update)) { + // Check if refresh is needed (expired based on next_update) + if (now <= next_update) { + // Still valid, return cached version + return jwks.serialize(); + } + // Past next_update, need to refresh + } + } catch (const NegativeCacheHitException &) { + // Negative cache hit - return empty keys + return std::string("{\"keys\": []}"); + } + + // Either not in cache or past next_update - refresh + if (!refresh_jwks(issuer)) { + throw CurlException("Failed to load JWKS for issuer: " + issuer); + } + + // Get the newly refreshed JWKS + return get_jwks(issuer); +} + +std::string scitokens::Validator::get_jwks_metadata(const std::string &issuer) { + auto now = std::time(NULL); + int64_t next_update = -1; + int64_t expires = -1; + + // Get the metadata from database without expiry check + auto cache_fname = get_cache_file(); + if (cache_fname.size() == 0) { + throw std::runtime_error("Unable to access cache file"); + } + + sqlite3 *db; + int rc = sqlite3_open(cache_fname.c_str(), &db); + if (rc) { + sqlite3_close(db); + throw std::runtime_error("Failed to open cache database"); + } + sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS); + + sqlite3_stmt *stmt; + rc = sqlite3_prepare_v2(db, "SELECT keys from keycache where issuer = ?", + -1, &stmt, NULL); + if (rc != SQLITE_OK) { + sqlite3_close(db); + throw std::runtime_error("Failed to prepare database query"); + } + + if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), + SQLITE_STATIC) != SQLITE_OK) { + sqlite3_finalize(stmt); + sqlite3_close(db); + throw std::runtime_error("Failed to bind issuer to query"); + } + + rc = sqlite3_step(stmt); + if (rc == SQLITE_ROW) { + const unsigned char *data = sqlite3_column_text(stmt, 0); + std::string metadata(reinterpret_cast(data)); + sqlite3_finalize(stmt); + sqlite3_close(db); + + picojson::value json_obj; + auto err = picojson::parse(json_obj, metadata); + if (!err.empty() || !json_obj.is()) { + throw JsonException("Invalid JSON in cache entry"); + } + + auto top_obj = json_obj.get(); + + // Extract expires + auto iter = top_obj.find("expires"); + if (iter != top_obj.end() && iter->second.is()) { + expires = iter->second.get(); + } + + // Extract next_update + iter = top_obj.find("next_update"); + if (iter != top_obj.end() && iter->second.is()) { + next_update = iter->second.get(); + } else if (expires != -1) { + // Default next_update to 4 hours before expiry + next_update = expires - DEFAULT_NEXT_UPDATE_OFFSET_S; + } + + // Build metadata JSON (add future keys at top level if needed) + picojson::object metadata_obj; + if (expires != -1) { + metadata_obj["expires"] = picojson::value(expires); + } + if (next_update != -1) { + metadata_obj["next_update"] = picojson::value(next_update); + } + + return picojson::value(metadata_obj).serialize(); + } else { + sqlite3_finalize(stmt); + sqlite3_close(db); + throw std::runtime_error("Issuer not found in cache"); + } +} + +bool scitokens::Validator::delete_jwks(const std::string &issuer) { + auto cache_fname = get_cache_file(); + if (cache_fname.size() == 0) { + return false; + } + + sqlite3 *db; + int rc = sqlite3_open(cache_fname.c_str(), &db); + if (rc) { + sqlite3_close(db); + return false; + } + sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS); + + // Use the existing remove_issuer_entry function + // Note: remove_issuer_entry closes the database on error + if (remove_issuer_entry(db, issuer, true) != 0) { + // Database already closed by remove_issuer_entry + return false; + } + + sqlite3_close(db); + return true; +} diff --git a/src/scitokens_internal.h b/src/scitokens_internal.h index f682739..b54d46d 100644 --- a/src/scitokens_internal.h +++ b/src/scitokens_internal.h @@ -1309,6 +1309,26 @@ class Validator { static std::vector> get_all_issuers_from_db(int64_t now); + /** + * Load JWKS for a given issuer, refreshing only if needed. + * Returns the JWKS string. If refresh is needed and fails, throws + * exception. + */ + static std::string load_jwks(const std::string &issuer); + + /** + * Get metadata for a cached JWKS entry. + * Returns a JSON string with expires, next_update, and extra fields. + * Throws exception if issuer not found in cache. + */ + static std::string get_jwks_metadata(const std::string &issuer); + + /** + * Delete a JWKS entry from the keycache. + * Returns true on success, false on failure. + */ + static bool delete_jwks(const std::string &issuer); + private: static std::unique_ptr get_public_key_pem(const std::string &issuer, const std::string &kid, diff --git a/test/main.cpp b/test/main.cpp index b301c4a..3c7238a 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -1,11 +1,17 @@ #include "../src/scitokens.h" #include "test_utils.h" +#include #include #include #include #include +#ifndef PICOJSON_USE_INT64 +#define PICOJSON_USE_INT64 +#endif +#include #include +#include #include using scitokens_test::SecureTempDir; @@ -629,6 +635,20 @@ class KeycacheTest : public ::testing::Test { protected: std::string demo_scitokens_url = "https://demo.scitokens.org"; + static int64_t get_next_update_from_metadata(const std::string &metadata) { + picojson::value root; + std::string err = picojson::parse(root, metadata); + if (!err.empty() || !root.is()) { + return -1; + } + auto &root_obj = root.get(); + auto nu_it = root_obj.find("next_update"); + if (nu_it != root_obj.end() && nu_it->second.is()) { + return nu_it->second.get(); + } + return -1; + } + void SetUp() override { char *err_msg = nullptr; auto rv = keycache_set_jwks(demo_scitokens_url.c_str(), @@ -955,6 +975,236 @@ TEST_F(KeycacheTest, NegativeCacheTest) { << "Should have 2 negative cache hits. JSON: " << json_str; } +TEST_F(KeycacheTest, LoadJwksTest) { + // Test load API - should return cached JWKS without triggering refresh + char *err_msg = nullptr; + char *jwks = nullptr; + + // Capture metadata before load to ensure no refresh changes it + char *metadata_before = nullptr; + auto rv = keycache_get_jwks_metadata(demo_scitokens_url.c_str(), + &metadata_before, &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + ASSERT_TRUE(metadata_before != nullptr); + int64_t next_update_before = + get_next_update_from_metadata(std::string(metadata_before)); + free(metadata_before); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Load JWKS - should return cached version from SetUp() + rv = keycache_load_jwks(demo_scitokens_url.c_str(), &jwks, &err_msg); + ASSERT_TRUE(rv == 0) << (err_msg ? err_msg : "unknown error"); + ASSERT_TRUE(jwks != nullptr); + std::string jwks_str(jwks); + free(jwks); + if (err_msg) + free(err_msg); + + EXPECT_EQ(demo_scitokens, jwks_str); + + // Metadata should be unchanged (no refresh triggered) + char *metadata_after = nullptr; + rv = keycache_get_jwks_metadata(demo_scitokens_url.c_str(), &metadata_after, + &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + ASSERT_TRUE(metadata_after != nullptr); + int64_t next_update_after = + get_next_update_from_metadata(std::string(metadata_after)); + free(metadata_after); + if (err_msg) + free(err_msg); + + EXPECT_EQ(next_update_before, next_update_after); +} + +TEST_F(KeycacheTest, LoadJwksMissingTest) { + // Test load API with missing issuer - should attempt refresh + char *err_msg = nullptr; + char *jwks = nullptr; + + // Try to load a non-existent issuer - will fail to refresh + auto rv = keycache_load_jwks("https://demo.scitokens.org/nonexistent", + &jwks, &err_msg); + ASSERT_FALSE(rv == 0); // Should fail since issuer doesn't exist + if (err_msg) + free(err_msg); +} + +TEST_F(KeycacheTest, LoadJwksTriggersRefreshWhenStale) { + // Force next_update in the past so load_jwks triggers a refresh + char *err_msg = nullptr; + + // Reset monitoring to capture only this test's activity + scitoken_reset_monitoring_stats(&err_msg); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Save and override update interval so next_update is "now" + int original_update_interval = + scitoken_config_get_int("keycache.update_interval_s", &err_msg); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + auto rv = + scitoken_config_set_int("keycache.update_interval_s", 0, &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Re-set JWKS so the stored next_update uses the new interval (now) + rv = keycache_set_jwks(demo_scitokens_url.c_str(), demo_scitokens.c_str(), + &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Capture metadata before load to confirm next_update advances + char *metadata_before = nullptr; + rv = keycache_get_jwks_metadata(demo_scitokens_url.c_str(), + &metadata_before, &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + ASSERT_TRUE(metadata_before != nullptr); + int64_t next_update_before = + get_next_update_from_metadata(std::string(metadata_before)); + free(metadata_before); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Ensure current time passes the stored next_update + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // Load JWKS - should detect stale entry and trigger refresh + char *jwks = nullptr; + rv = keycache_load_jwks(demo_scitokens_url.c_str(), &jwks, &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + ASSERT_TRUE(jwks != nullptr); + std::string jwks_str(jwks); + free(jwks); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + EXPECT_EQ(demo_scitokens, jwks_str); + + // Verify next_update moved forward (refresh occurred) + char *metadata_after = nullptr; + rv = keycache_get_jwks_metadata(demo_scitokens_url.c_str(), &metadata_after, + &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + ASSERT_TRUE(metadata_after != nullptr); + int64_t next_update_after = + get_next_update_from_metadata(std::string(metadata_after)); + free(metadata_after); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + EXPECT_GT(next_update_after, next_update_before); + + // Restore original update interval + rv = scitoken_config_set_int("keycache.update_interval_s", + original_update_interval, &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + if (err_msg) + free(err_msg); +} + +TEST_F(KeycacheTest, GetMetadataTest) { + // Test metadata API - should return expires and next_update + char *err_msg = nullptr; + char *metadata = nullptr; + + // Get metadata for cached issuer + auto rv = keycache_get_jwks_metadata(demo_scitokens_url.c_str(), &metadata, + &err_msg); + ASSERT_TRUE(rv == 0) << (err_msg ? err_msg : "unknown error"); + ASSERT_TRUE(metadata != nullptr); + std::string metadata_str(metadata); + free(metadata); + if (err_msg) + free(err_msg); + + // Verify JSON structure - should have expires and next_update fields + EXPECT_NE(metadata_str.find("\"expires\":"), std::string::npos); + EXPECT_NE(metadata_str.find("\"next_update\":"), std::string::npos); +} + +TEST_F(KeycacheTest, GetMetadataMissingTest) { + // Test metadata API with missing issuer + char *err_msg = nullptr; + char *metadata = nullptr; + + // Try to get metadata for non-existent issuer + auto rv = keycache_get_jwks_metadata("https://demo.scitokens.org/unknown", + &metadata, &err_msg); + ASSERT_FALSE(rv == 0); // Should fail + if (err_msg) + free(err_msg); +} + +TEST_F(KeycacheTest, DeleteJwksTest) { + // Test delete API + char *err_msg = nullptr; + + // First verify the issuer is in cache + char *jwks = nullptr; + auto rv = + keycache_get_cached_jwks(demo_scitokens_url.c_str(), &jwks, &err_msg); + ASSERT_TRUE(rv == 0) << (err_msg ? err_msg : "unknown error"); + ASSERT_TRUE(jwks != nullptr); + free(jwks); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Delete the entry + rv = keycache_delete_jwks(demo_scitokens_url.c_str(), &err_msg); + ASSERT_TRUE(rv == 0) << (err_msg ? err_msg : "unknown error"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Verify it's gone - get_cached_jwks should return empty keys + rv = keycache_get_cached_jwks(demo_scitokens_url.c_str(), &jwks, &err_msg); + ASSERT_TRUE(rv == 0) << (err_msg ? err_msg : "unknown error"); + ASSERT_TRUE(jwks != nullptr); + std::string jwks_str(jwks); + free(jwks); + if (err_msg) + free(err_msg); + + EXPECT_EQ(jwks_str, "{\"keys\": []}"); +} + +TEST_F(KeycacheTest, DeleteJwksNonExistentTest) { + // Test delete API with non-existent issuer - should not fail + char *err_msg = nullptr; + + auto rv = keycache_delete_jwks("https://demo.scitokens.org/never-existed", + &err_msg); + ASSERT_TRUE(rv == 0) + << (err_msg ? err_msg : "unknown error"); // Should succeed (idempotent) + if (err_msg) + free(err_msg); +} + class IssuerSecurityTest : public ::testing::Test { protected: void SetUp() override {