diff --git a/CMakeLists.txt b/CMakeLists.txt index 93973fb..ec6d2b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,7 +48,10 @@ add_library(SciTokens SHARED src/scitokens.cpp src/scitokens_internal.cpp src/sc target_compile_features(SciTokens PUBLIC cxx_std_11) # Use at least C++11 for building and when linking to scitokens target_include_directories(SciTokens PUBLIC ${JWT_CPP_INCLUDES} "${PROJECT_SOURCE_DIR}/src" PRIVATE ${CURL_INCLUDE_DIRS} ${OPENSSL_INCLUDE_DIRS} ${LIBCRYPTO_INCLUDE_DIRS} ${SQLITE_INCLUDE_DIRS} ${UUID_INCLUDE_DIRS}) -target_link_libraries(SciTokens PUBLIC ${OPENSSL_LIBRARIES} ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES} ${SQLITE_LIBRARIES} ${UUID_LIBRARIES}) +# Find threading library +find_package(Threads REQUIRED) + +target_link_libraries(SciTokens PUBLIC ${OPENSSL_LIBRARIES} ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES} ${SQLITE_LIBRARIES} ${UUID_LIBRARIES} Threads::Threads) if (UNIX) # pkg_check_modules fails to return an absolute path on RHEL7. Set the # link directories accordingly. diff --git a/src/scitokens.cpp b/src/scitokens.cpp index 2693140..0919bcc 100644 --- a/src/scitokens.cpp +++ b/src/scitokens.cpp @@ -43,14 +43,17 @@ void load_config_from_environment() { bool is_int; }; - const std::array known_configs = { + const std::array known_configs = { {{"keycache.update_interval_s", "KEYCACHE_UPDATE_INTERVAL_S", true}, {"keycache.expiration_interval_s", "KEYCACHE_EXPIRATION_INTERVAL_S", true}, {"keycache.cache_home", "KEYCACHE_CACHE_HOME", false}, {"tls.ca_file", "TLS_CA_FILE", false}, {"monitoring.file", "MONITORING_FILE", false}, - {"monitoring.file_interval_s", "MONITORING_FILE_INTERVAL_S", true}}}; + {"monitoring.file_interval_s", "MONITORING_FILE_INTERVAL_S", true}, + {"keycache.refresh_interval_ms", "KEYCACHE_REFRESH_INTERVAL_MS", true}, + {"keycache.refresh_threshold_ms", "KEYCACHE_REFRESH_THRESHOLD_MS", + true}}}; const char *prefix = "SCITOKEN_CONFIG_"; @@ -128,6 +131,13 @@ int configurer::Configuration::get_monitoring_file_interval() { return m_monitoring_file_interval; } +// Background refresh config +std::atomic_bool configurer::Configuration::m_background_refresh_enabled{false}; +std::atomic_int configurer::Configuration::m_refresh_interval_ms{ + 60000}; // 60 seconds +std::atomic_int configurer::Configuration::m_refresh_threshold_ms{ + 600000}; // 10 minutes + SciTokenKey scitoken_key_create(const char *key_id, const char *alg, const char *public_contents, const char *private_contents, char **err_msg) { @@ -1099,6 +1109,31 @@ int keycache_set_jwks(const char *issuer, const char *jwks, char **err_msg) { return 0; } +int keycache_set_background_refresh(int enabled, char **err_msg) { + try { + bool enable = (enabled != 0); + configurer::Configuration::set_background_refresh_enabled(enable); + + if (enable) { + scitokens::internal::BackgroundRefreshManager::get_instance() + .start(); + } else { + scitokens::internal::BackgroundRefreshManager::get_instance() + .stop(); + } + } catch (std::exception &exc) { + if (err_msg) { + *err_msg = strdup(exc.what()); + } + return -1; + } + return 0; +} + +int keycache_stop_background_refresh(char **err_msg) { + return keycache_set_background_refresh(0, err_msg); +} + int config_set_int(const char *key, int value, char **err_msg) { return scitoken_config_set_int(key, value, err_msg); } @@ -1145,6 +1180,28 @@ int scitoken_config_set_int(const char *key, int value, char **err_msg) { return 0; } + else if (_key == "keycache.refresh_interval_ms") { + if (value < 0) { + if (err_msg) { + *err_msg = strdup("Refresh interval must be positive."); + } + return -1; + } + configurer::Configuration::set_refresh_interval(value); + return 0; + } + + else if (_key == "keycache.refresh_threshold_ms") { + if (value < 0) { + if (err_msg) { + *err_msg = strdup("Refresh threshold must be positive."); + } + return -1; + } + configurer::Configuration::set_refresh_threshold(value); + return 0; + } + else { if (err_msg) { *err_msg = strdup("Key not recognized."); @@ -1178,6 +1235,14 @@ int scitoken_config_get_int(const char *key, char **err_msg) { return configurer::Configuration::get_monitoring_file_interval(); } + else if (_key == "keycache.refresh_interval_ms") { + return configurer::Configuration::get_refresh_interval(); + } + + else if (_key == "keycache.refresh_threshold_ms") { + return configurer::Configuration::get_refresh_threshold(); + } + else { if (err_msg) { *err_msg = strdup("Key not recognized."); diff --git a/src/scitokens.h b/src/scitokens.h index cdf3953..88dcc68 100644 --- a/src/scitokens.h +++ b/src/scitokens.h @@ -290,6 +290,25 @@ int keycache_get_cached_jwks(const char *issuer, char **jwks, char **err_msg); */ int keycache_set_jwks(const char *issuer, const char *jwks, char **err_msg); +/** + * Enable or disable the background refresh thread for JWKS. + * - When enabled, a background thread will periodically check if any known + * issuers need their JWKS refreshed based on the configured refresh interval + * and threshold. + * - If enabled=1 and the thread is not running, it will be started. + * - If enabled=0 and the thread is running, it will be stopped gracefully. + * - Returns 0 on success, nonzero on failure. + */ +int keycache_set_background_refresh(int enabled, char **err_msg); + +/** + * Stop the background refresh thread if it is running. + * - This is a convenience function equivalent to + * keycache_set_background_refresh(0, err_msg). + * - Returns 0 on success, nonzero on failure. + */ +int keycache_stop_background_refresh(char **err_msg); + /** * APIs for managing scitokens configuration parameters. */ @@ -308,6 +327,10 @@ int config_set_int(const char *key, int value, char **err_msg); * - "keycache.expiration_interval_s": Key cache expiration time (seconds) * - "monitoring.file_interval_s": Interval between monitoring file writes * (seconds, default 60) + * - "keycache.refresh_interval_ms": Background refresh thread check interval + * (milliseconds, default 60000) + * - "keycache.refresh_threshold_ms": Time before next_update when background + * refresh triggers (milliseconds, default 600000) */ int scitoken_config_set_int(const char *key, int value, char **err_msg); @@ -325,6 +348,10 @@ int config_get_int(const char *key, char **err_msg); * - "keycache.expiration_interval_s": Key cache expiration time (seconds) * - "monitoring.file_interval_s": Interval between monitoring file writes * (seconds, default 60) + * - "keycache.refresh_interval_ms": Background refresh thread check interval + * (milliseconds, default 60000) + * - "keycache.refresh_threshold_ms": Time before next_update when background + * refresh triggers (milliseconds, default 600000) */ int scitoken_config_get_int(const char *key, char **err_msg); diff --git a/src/scitokens_cache.cpp b/src/scitokens_cache.cpp index 12536c1..36173e2 100644 --- a/src/scitokens_cache.cpp +++ b/src/scitokens_cache.cpp @@ -308,3 +308,81 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer, sqlite3_close(db); return true; } + +std::vector> +scitokens::Validator::get_all_issuers_from_db(int64_t now) { + std::vector> result; + + auto cache_fname = get_cache_file(); + if (cache_fname.size() == 0) { + return result; + } + + sqlite3 *db; + int rc = sqlite3_open(cache_fname.c_str(), &db); + if (rc) { + sqlite3_close(db); + return result; + } + + sqlite3_stmt *stmt; + rc = sqlite3_prepare_v2(db, "SELECT issuer, keys FROM keycache", -1, &stmt, + NULL); + if (rc != SQLITE_OK) { + sqlite3_close(db); + return result; + } + + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { + const unsigned char *issuer_data = sqlite3_column_text(stmt, 0); + const unsigned char *keys_data = sqlite3_column_text(stmt, 1); + + if (!issuer_data || !keys_data) { + continue; + } + + std::string issuer(reinterpret_cast(issuer_data)); + std::string metadata(reinterpret_cast(keys_data)); + + // Parse the metadata to get next_update and check expiry + picojson::value json_obj; + auto err = picojson::parse(json_obj, metadata); + if (!err.empty() || !json_obj.is()) { + continue; + } + + auto top_obj = json_obj.get(); + + // Get expiry time + auto expires_iter = top_obj.find("expires"); + if (expires_iter == top_obj.end() || + !expires_iter->second.is()) { + continue; + } + auto expiry = expires_iter->second.get(); + + // Get next_update time + auto next_update_iter = top_obj.find("next_update"); + int64_t next_update; + if (next_update_iter == top_obj.end() || + !next_update_iter->second.is()) { + // If next_update is not set, default to 4 hours before expiry + next_update = expiry - 4 * 3600; + } else { + next_update = next_update_iter->second.get(); + } + + // Include expired entries - they should be refreshed after a long + // downtime If expired, set next_update to now so they get refreshed + // immediately + if (now > expiry) { + next_update = now; + } + + result.push_back({issuer, next_update}); + } + + sqlite3_finalize(stmt); + sqlite3_close(db); + return result; +} diff --git a/src/scitokens_internal.cpp b/src/scitokens_internal.cpp index 4938f1d..025717b 100644 --- a/src/scitokens_internal.cpp +++ b/src/scitokens_internal.cpp @@ -1,4 +1,5 @@ +#include #include #include #include @@ -37,8 +38,101 @@ std::mutex key_refresh_mutex; namespace scitokens { +// Define the static once_flag for Validator +std::once_flag Validator::m_background_refresh_once; + namespace internal { +// BackgroundRefreshManager implementation +void BackgroundRefreshManager::start() { + std::lock_guard lock(m_mutex); + if (m_running.load(std::memory_order_acquire)) { + return; // Already running + } + m_shutdown.store(false, std::memory_order_release); + m_running.store(true, std::memory_order_release); + m_thread = std::make_unique( + &BackgroundRefreshManager::refresh_loop, this); +} + +void BackgroundRefreshManager::stop() { + std::unique_ptr thread_to_join; + + { + std::lock_guard lock(m_mutex); + if (!m_running.load(std::memory_order_acquire)) { + return; // Not running + } + + m_shutdown.store(true, std::memory_order_release); + m_running.store(false, std::memory_order_release); + thread_to_join = std::move(m_thread); + } + + m_cv.notify_all(); + + if (thread_to_join && thread_to_join->joinable()) { + thread_to_join->join(); + } +} + +void BackgroundRefreshManager::refresh_loop() { + while (!m_shutdown.load(std::memory_order_acquire)) { + auto interval = configurer::Configuration::get_refresh_interval(); + auto threshold = configurer::Configuration::get_refresh_threshold(); + + // Wait for the interval or until shutdown + { + std::unique_lock lock(m_mutex); + m_cv.wait_for(lock, std::chrono::milliseconds(interval), [this]() { + return m_shutdown.load(std::memory_order_acquire); + }); + } + + if (m_shutdown.load(std::memory_order_acquire)) { + break; + } + + // Get list of issuers from the database + auto now = std::time(NULL); + auto issuers = scitokens::Validator::get_all_issuers_from_db(now); + + for (const auto &issuer_pair : issuers) { + if (m_shutdown.load(std::memory_order_acquire)) { + break; + } + + const auto &issuer = issuer_pair.first; + const auto &next_update = issuer_pair.second; + + // Calculate time until next_update in milliseconds + int64_t time_until_update = (next_update - now) * 1000; + + // If next update is within threshold, try to refresh + if (time_until_update <= threshold) { + auto &stats = + MonitoringStats::instance().get_issuer_stats(issuer); + try { + // Perform refresh (this will use the refresh_jwks method) + scitokens::Validator::refresh_jwks(issuer); + stats.inc_background_successful_refresh(); + } catch (std::exception &) { + // Track failed refresh attempts + stats.inc_background_failed_refresh(); + // Silently ignore errors in background refresh to avoid + // disrupting the application. Background refresh is a + // best-effort optimization. If it fails, the next token + // verification will trigger a foreground refresh as usual. + } + } + } + + // Write monitoring file from background thread if configured + // This avoids writing from verify() when background thread is running + MonitoringStats::instance().maybe_write_monitoring_file(); + } +} + SimpleCurlGet::GetStatus SimpleCurlGet::perform_start(const std::string &url) { m_len = 0; diff --git a/src/scitokens_internal.h b/src/scitokens_internal.h index d8b8970..68f2050 100644 --- a/src/scitokens_internal.h +++ b/src/scitokens_internal.h @@ -6,8 +6,10 @@ #include #include +#include #include #include +#include #include #if defined(__GNUC__) @@ -60,6 +62,22 @@ class Configuration { return m_monitoring_file_configured.load(std::memory_order_relaxed); } + // Background refresh configuration + static void set_background_refresh_enabled(bool enabled) { + m_background_refresh_enabled = enabled; + } + static bool get_background_refresh_enabled() { + return m_background_refresh_enabled; + } + static void set_refresh_interval(int interval_ms) { + m_refresh_interval_ms = interval_ms; + } + static int get_refresh_interval() { return m_refresh_interval_ms; } + static void set_refresh_threshold(int threshold_ms) { + m_refresh_threshold_ms = threshold_ms; + } + static int get_refresh_threshold() { return m_refresh_threshold_ms; } + private: // Accessor functions for construct-on-first-use idiom static std::atomic_int &get_next_update_delta_ref() { @@ -108,6 +126,9 @@ class Configuration { static std::mutex m_monitoring_file_mutex; static std::atomic m_monitoring_file_configured; // Fast-path flag static std::atomic_int m_monitoring_file_interval; // In seconds, default 60 + static std::atomic_bool m_background_refresh_enabled; + static std::atomic_int m_refresh_interval_ms; // N milliseconds + static std::atomic_int m_refresh_threshold_ms; // M milliseconds // static bool check_dir(const std::string dir_path); static std::pair mkdir_and_parents_if_needed(const std::string dir_path); @@ -122,6 +143,45 @@ namespace internal { // Forward declaration class MonitoringStats; +/** + * Manages the background thread for refreshing JWKS. + * This is a singleton that starts/stops a background thread which periodically + * checks if any known issuers need their JWKS refreshed. + */ +class BackgroundRefreshManager { + public: + static BackgroundRefreshManager &get_instance() { + static BackgroundRefreshManager instance; + return instance; + } + + // Start the background refresh thread (can be called multiple times) + void start(); + + // Stop the background refresh thread (can be called multiple times) + void stop(); + + // Check if the background refresh thread is running + bool is_running() const { + return m_running.load(std::memory_order_acquire); + } + + private: + BackgroundRefreshManager() = default; + ~BackgroundRefreshManager() { stop(); } + BackgroundRefreshManager(const BackgroundRefreshManager &) = delete; + BackgroundRefreshManager & + operator=(const BackgroundRefreshManager &) = delete; + + void refresh_loop(); + + std::mutex m_mutex; + std::condition_variable m_cv; + std::unique_ptr m_thread; + std::atomic_bool m_shutdown{false}; + std::atomic_bool m_running{false}; +}; + class SimpleCurlGet { int m_maxbytes{1048576}; @@ -200,31 +260,63 @@ struct IssuerStats { std::atomic failed_refreshes{0}; std::atomic stale_key_uses{0}; - // Increment methods for atomic counters - void inc_successful_validation() { successful_validations++; } - void inc_unsuccessful_validation() { unsuccessful_validations++; } - void inc_expired_token() { expired_tokens++; } - void inc_sync_validation_started() { sync_validations_started++; } - void inc_async_validation_started() { async_validations_started++; } - void inc_stale_key_use() { stale_key_uses++; } - void inc_failed_refresh() { failed_refreshes++; } - void inc_expired_key() { expired_keys++; } - void inc_successful_key_lookup() { successful_key_lookups++; } - void inc_failed_key_lookup() { failed_key_lookups++; } - - // Time setters that accept std::chrono::duration + // Background refresh statistics (tracked by background thread) + std::atomic background_successful_refreshes{0}; + std::atomic background_failed_refreshes{0}; + + // Increment methods for atomic counters (use relaxed ordering for stats) + void inc_successful_validation() { + successful_validations.fetch_add(1, std::memory_order_relaxed); + } + void inc_unsuccessful_validation() { + unsuccessful_validations.fetch_add(1, std::memory_order_relaxed); + } + void inc_expired_token() { + expired_tokens.fetch_add(1, std::memory_order_relaxed); + } + void inc_sync_validation_started() { + sync_validations_started.fetch_add(1, std::memory_order_relaxed); + } + void inc_async_validation_started() { + async_validations_started.fetch_add(1, std::memory_order_relaxed); + } + void inc_stale_key_use() { + stale_key_uses.fetch_add(1, std::memory_order_relaxed); + } + void inc_failed_refresh() { + failed_refreshes.fetch_add(1, std::memory_order_relaxed); + } + void inc_expired_key() { + expired_keys.fetch_add(1, std::memory_order_relaxed); + } + void inc_successful_key_lookup() { + successful_key_lookups.fetch_add(1, std::memory_order_relaxed); + } + void inc_failed_key_lookup() { + failed_key_lookups.fetch_add(1, std::memory_order_relaxed); + } + void inc_background_successful_refresh() { + background_successful_refreshes.fetch_add(1, std::memory_order_relaxed); + } + void inc_background_failed_refresh() { + background_failed_refreshes.fetch_add(1, std::memory_order_relaxed); + } + + // Time setters that accept std::chrono::duration (use relaxed ordering) template void add_sync_time(std::chrono::duration duration) { auto ns = std::chrono::duration_cast(duration); - sync_total_time_ns += static_cast(ns.count()); + sync_total_time_ns.fetch_add(static_cast(ns.count()), + std::memory_order_relaxed); } template void add_async_time(std::chrono::duration duration) { auto ns = std::chrono::duration_cast(duration); - async_total_time_ns += static_cast(ns.count()); + async_total_time_ns.fetch_add(static_cast(ns.count()), + std::memory_order_relaxed); } template @@ -232,21 +324,27 @@ struct IssuerStats { add_failed_key_lookup_time(std::chrono::duration duration) { auto ns = std::chrono::duration_cast(duration); - failed_key_lookup_time_ns += static_cast(ns.count()); + failed_key_lookup_time_ns.fetch_add(static_cast(ns.count()), + std::memory_order_relaxed); } void inc_failed_key_lookup(std::chrono::nanoseconds duration) { - failed_key_lookups++; - failed_key_lookup_time_ns += static_cast(duration.count()); + failed_key_lookups.fetch_add(1, std::memory_order_relaxed); + failed_key_lookup_time_ns.fetch_add( + static_cast(duration.count()), std::memory_order_relaxed); } - // Time getters that return seconds as double + // Time getters that return seconds as double (use relaxed ordering) double get_sync_time_s() const { - return static_cast(sync_total_time_ns.load()) / 1e9; + return static_cast( + sync_total_time_ns.load(std::memory_order_relaxed)) / + 1e9; } double get_async_time_s() const { - return static_cast(async_total_time_ns.load()) / 1e9; + return static_cast( + async_total_time_ns.load(std::memory_order_relaxed)) / + 1e9; } double get_total_time_s() const { @@ -254,7 +352,9 @@ struct IssuerStats { } double get_failed_key_lookup_time_s() const { - return static_cast(failed_key_lookup_time_ns.load()) / 1e9; + return static_cast( + failed_key_lookup_time_ns.load(std::memory_order_relaxed)) / + 1e9; } }; @@ -305,6 +405,13 @@ class MonitoringStats { */ void maybe_write_monitoring_file() noexcept; + /** + * Same as maybe_write_monitoring_file(), but skips if background refresh + * thread is running. This should be called from verify() routines to + * avoid redundant writes when the background thread is handling them. + */ + void maybe_write_monitoring_file_from_verify() noexcept; + private: MonitoringStats() = default; ~MonitoringStats() = default; @@ -627,6 +734,8 @@ class SciToken { class Validator { + friend class internal::BackgroundRefreshManager; + typedef int (*StringValidatorFunction)(const char *value, char **err_msg); typedef bool (*ClaimValidatorFunction)(const jwt::claim &claim_value, void *data); @@ -656,8 +765,9 @@ class Validator { void verify(const SciToken &scitoken, time_t expiry_time) { // Check if monitoring file should be written (fast-path, relaxed - // atomic) - internal::MonitoringStats::instance().maybe_write_monitoring_file(); + // atomic). Skip if background thread is running. + internal::MonitoringStats::instance() + .maybe_write_monitoring_file_from_verify(); std::string issuer = ""; auto start_time = std::chrono::steady_clock::now(); @@ -752,8 +862,9 @@ class Validator { void verify(const jwt::decoded_jwt &jwt) { // Check if monitoring file should be written (fast-path, relaxed - // atomic) - internal::MonitoringStats::instance().maybe_write_monitoring_file(); + // atomic). Skip if background thread is running. + internal::MonitoringStats::instance() + .maybe_write_monitoring_file_from_verify(); std::string issuer = ""; auto start_time = std::chrono::steady_clock::now(); @@ -810,6 +921,13 @@ class Validator { std::unique_ptr verify_async(const jwt::decoded_jwt &jwt) { + // Start background refresh thread if configured on first verification + std::call_once(m_background_refresh_once, []() { + if (configurer::Configuration::get_background_refresh_enabled()) { + internal::BackgroundRefreshManager::get_instance().start(); + } + }); + // If token has a typ header claim (RFC8725 Section 3.11), trust that in // COMPAT mode. if (jwt.has_type()) { @@ -1156,6 +1274,14 @@ class Validator { */ static std::string get_jwks(const std::string &issuer); + /** + * Get all issuers from the database along with their next_update times. + * Returns a vector of pairs (issuer, next_update). + * Only returns non-expired entries. + */ + static std::vector> + get_all_issuers_from_db(int64_t now); + private: static std::unique_ptr get_public_key_pem(const std::string &issuer, const std::string &kid, @@ -1226,6 +1352,9 @@ class Validator { std::vector m_critical_claims; std::vector m_allowed_issuers; + + // Once flag for starting background refresh on first verification + static std::once_flag m_background_refresh_once; }; class Enforcer { diff --git a/src/scitokens_monitoring.cpp b/src/scitokens_monitoring.cpp index b264287..ad7216f 100644 --- a/src/scitokens_monitoring.cpp +++ b/src/scitokens_monitoring.cpp @@ -84,18 +84,22 @@ std::string MonitoringStats::get_json() const { const IssuerStats &stats = entry.second; picojson::object issuer_obj; - issuer_obj["successful_validations"] = picojson::value( - static_cast(stats.successful_validations.load())); + issuer_obj["successful_validations"] = + picojson::value(static_cast( + stats.successful_validations.load(std::memory_order_relaxed))); issuer_obj["unsuccessful_validations"] = picojson::value( - static_cast(stats.unsuccessful_validations.load())); - issuer_obj["expired_tokens"] = - picojson::value(static_cast(stats.expired_tokens.load())); + static_cast(stats.unsuccessful_validations.load( + std::memory_order_relaxed))); + issuer_obj["expired_tokens"] = picojson::value(static_cast( + stats.expired_tokens.load(std::memory_order_relaxed))); // Validation started counters issuer_obj["sync_validations_started"] = picojson::value( - static_cast(stats.sync_validations_started.load())); + static_cast(stats.sync_validations_started.load( + std::memory_order_relaxed))); issuer_obj["async_validations_started"] = picojson::value( - static_cast(stats.async_validations_started.load())); + static_cast(stats.async_validations_started.load( + std::memory_order_relaxed))); // Duration tracking issuer_obj["sync_total_time_s"] = @@ -106,20 +110,29 @@ std::string MonitoringStats::get_json() const { picojson::value(stats.get_total_time_s()); // Web lookup statistics - issuer_obj["successful_key_lookups"] = picojson::value( - static_cast(stats.successful_key_lookups.load())); - issuer_obj["failed_key_lookups"] = picojson::value( - static_cast(stats.failed_key_lookups.load())); + issuer_obj["successful_key_lookups"] = + picojson::value(static_cast( + stats.successful_key_lookups.load(std::memory_order_relaxed))); + issuer_obj["failed_key_lookups"] = picojson::value(static_cast( + stats.failed_key_lookups.load(std::memory_order_relaxed))); issuer_obj["failed_key_lookup_time_s"] = picojson::value(stats.get_failed_key_lookup_time_s()); // Key refresh statistics - issuer_obj["expired_keys"] = - picojson::value(static_cast(stats.expired_keys.load())); - issuer_obj["failed_refreshes"] = - picojson::value(static_cast(stats.failed_refreshes.load())); - issuer_obj["stale_key_uses"] = - picojson::value(static_cast(stats.stale_key_uses.load())); + issuer_obj["expired_keys"] = picojson::value(static_cast( + stats.expired_keys.load(std::memory_order_relaxed))); + issuer_obj["failed_refreshes"] = picojson::value(static_cast( + stats.failed_refreshes.load(std::memory_order_relaxed))); + issuer_obj["stale_key_uses"] = picojson::value(static_cast( + stats.stale_key_uses.load(std::memory_order_relaxed))); + + // Background refresh statistics + issuer_obj["background_successful_refreshes"] = picojson::value( + static_cast(stats.background_successful_refreshes.load( + std::memory_order_relaxed))); + issuer_obj["background_failed_refreshes"] = picojson::value( + static_cast(stats.background_failed_refreshes.load( + std::memory_order_relaxed))); std::string sanitized_issuer = sanitize_issuer_for_json(issuer); issuers_obj[sanitized_issuer] = picojson::value(issuer_obj); @@ -135,7 +148,7 @@ std::string MonitoringStats::get_json() const { sanitize_issuer_for_json(entry.first); picojson::object lookup_stats; lookup_stats["count"] = - picojson::value(static_cast(entry.second.count)); + picojson::value(static_cast(entry.second.count)); lookup_stats["total_time_s"] = picojson::value(entry.second.total_time_s); failed_obj[sanitized_issuer] = picojson::value(lookup_stats); @@ -190,6 +203,15 @@ void MonitoringStats::maybe_write_monitoring_file() noexcept { } } +void MonitoringStats::maybe_write_monitoring_file_from_verify() noexcept { + // If background refresh thread is running, it will handle the writes + // This avoids redundant writes and potential contention + if (BackgroundRefreshManager::get_instance().is_running()) { + return; + } + maybe_write_monitoring_file(); +} + void MonitoringStats::write_monitoring_file_impl() noexcept { try { std::string monitoring_file = diff --git a/test/integration_test.cpp b/test/integration_test.cpp index 15922cc..e07f444 100644 --- a/test/integration_test.cpp +++ b/test/integration_test.cpp @@ -39,6 +39,9 @@ class MonitoringStats { uint64_t expired_keys{0}; uint64_t failed_refreshes{0}; uint64_t stale_key_uses{0}; + // Background refresh statistics + uint64_t background_successful_refreshes{0}; + uint64_t background_failed_refreshes{0}; }; struct FailedIssuerLookup { @@ -157,6 +160,19 @@ class MonitoringStats { static_cast(it->second.get()); } + // Background refresh statistics + it = stats_obj.find("background_successful_refreshes"); + if (it != stats_obj.end() && it->second.is()) { + stats.background_successful_refreshes = + static_cast(it->second.get()); + } + + it = stats_obj.find("background_failed_refreshes"); + if (it != stats_obj.end() && it->second.is()) { + stats.background_failed_refreshes = + static_cast(it->second.get()); + } + issuers_[issuer_entry.first] = stats; } } @@ -1106,6 +1122,220 @@ TEST_F(IntegrationTest, MonitoringFileOutput) { std::remove(test_file.c_str()); } +// ============================================================================= +// Background JWKS Refresh Test +// ============================================================================= + +TEST_F(IntegrationTest, BackgroundRefreshTest) { + char *err_msg = nullptr; + + // Reset monitoring stats to get a clean baseline + scitoken_reset_monitoring_stats(&err_msg); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Set smaller intervals for testing (1 second refresh interval, 2 seconds + // threshold) + int rv = + scitoken_config_set_int("keycache.refresh_interval_ms", 1000, &err_msg); + ASSERT_EQ(rv, 0) << "Failed to set refresh interval: " + << (err_msg ? err_msg : "unknown error"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + rv = scitoken_config_set_int("keycache.refresh_threshold_ms", 2000, + &err_msg); + ASSERT_EQ(rv, 0) << "Failed to set refresh threshold: " + << (err_msg ? err_msg : "unknown error"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Set update interval to 1 second BEFORE first verification so the + // cache entry will have next_update just 1 second in the future. + // This ensures the background thread can refresh within the test window. + rv = scitoken_config_set_int("keycache.update_interval_s", 1, &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Create a key and token + std::unique_ptr key( + scitoken_key_create("test-key-1", "ES256", public_key_.c_str(), + private_key_.c_str(), &err_msg), + scitoken_key_destroy); + ASSERT_TRUE(key.get() != nullptr); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::unique_ptr token( + scitoken_create(key.get()), scitoken_destroy); + ASSERT_TRUE(token.get() != nullptr); + + rv = scitoken_set_claim_string(token.get(), "iss", issuer_url_.c_str(), + &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + rv = + scitoken_set_claim_string(token.get(), "sub", "test-subject", &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + rv = + scitoken_set_claim_string(token.get(), "scope", "read:/test", &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + scitoken_set_lifetime(token.get(), 3600); + + char *token_value = nullptr; + rv = scitoken_serialize(token.get(), &token_value, &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + std::unique_ptr token_value_ptr(token_value, free); + + // First verification - this will fetch JWKS and track the issuer + std::unique_ptr verify_token( + scitoken_create(nullptr), scitoken_destroy); + ASSERT_TRUE(verify_token.get() != nullptr); + + rv = scitoken_deserialize_v2(token_value, verify_token.get(), nullptr, + &err_msg); + ASSERT_EQ(rv, 0) << "Failed to verify token: " + << (err_msg ? err_msg : "unknown error"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Get the current JWKS to verify it exists + char *jwks_before = nullptr; + rv = keycache_get_cached_jwks(issuer_url_.c_str(), &jwks_before, &err_msg); + ASSERT_EQ(rv, 0) << "Failed to get cached JWKS: " + << (err_msg ? err_msg : "unknown error"); + ASSERT_TRUE(jwks_before != nullptr); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::cout << "Initial JWKS fetched successfully" << std::endl; + + // Re-set the JWKS to force a fresh cache entry with the current + // update_interval (1 second). This ensures next_update is just 1 second + // in the future so the background thread will refresh it. + rv = keycache_set_jwks(issuer_url_.c_str(), jwks_before, &err_msg); + ASSERT_EQ(rv, 0) << "Failed to set JWKS: " + << (err_msg ? err_msg : "unknown error"); + free(jwks_before); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::cout << "JWKS re-set with 1-second update interval" << std::endl; + + // Get monitoring stats before background refresh + auto before_stats = getCurrentMonitoringStats(); + auto before_issuer_stats = before_stats.getIssuerStats(issuer_url_); + std::cout << "Before background refresh:" << std::endl; + std::cout << " background_successful_refreshes: " + << before_issuer_stats.background_successful_refreshes + << std::endl; + std::cout << " background_failed_refreshes: " + << before_issuer_stats.background_failed_refreshes << std::endl; + + // Enable background refresh + rv = keycache_set_background_refresh(1, &err_msg); + ASSERT_EQ(rv, 0) << "Failed to enable background refresh: " + << (err_msg ? err_msg : "unknown error"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::cout << "Background refresh enabled" << std::endl; + + // Wait for background refresh to trigger (threshold is 2 seconds, interval + // is 1 second) We need to wait at least 3 seconds: 1s for next_update to be + // within threshold + 2s for detection Note: Using sleep() is acceptable for + // integration tests as we're verifying real-time behavior of the background + // thread against an actual HTTPS server + std::cout << "Waiting 4 seconds for background refresh..." << std::endl; + sleep(4); + + // Stop background refresh + rv = keycache_stop_background_refresh(&err_msg); + ASSERT_EQ(rv, 0) << "Failed to stop background refresh: " + << (err_msg ? err_msg : "unknown error"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::cout << "Background refresh stopped successfully" << std::endl; + + // Verify we can still access the JWKS + char *jwks_after = nullptr; + rv = keycache_get_cached_jwks(issuer_url_.c_str(), &jwks_after, &err_msg); + ASSERT_EQ(rv, 0) << "Failed to get cached JWKS after background refresh: " + << (err_msg ? err_msg : "unknown error"); + ASSERT_TRUE(jwks_after != nullptr); + std::unique_ptr jwks_after_ptr(jwks_after, free); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Verify that background refresh statistics increased for our issuer + auto after_stats = getCurrentMonitoringStats(); + auto after_issuer_stats = after_stats.getIssuerStats(issuer_url_); + + std::cout << "After background refresh:" << std::endl; + std::cout << " background_successful_refreshes: " + << after_issuer_stats.background_successful_refreshes + << std::endl; + std::cout << " background_failed_refreshes: " + << after_issuer_stats.background_failed_refreshes << std::endl; + + // The background thread should have performed at least one refresh + // for our issuer (either successful or failed) + uint64_t total_background_refreshes = + after_issuer_stats.background_successful_refreshes + + after_issuer_stats.background_failed_refreshes; + uint64_t before_total = + before_issuer_stats.background_successful_refreshes + + before_issuer_stats.background_failed_refreshes; + + EXPECT_GT(total_background_refreshes, before_total) + << "Background refresh thread should have performed at least one " + "refresh attempt for our issuer"; + + std::cout << "Test completed successfully" << std::endl; +} + } // namespace int main(int argc, char **argv) { diff --git a/test/monitoring_test.cpp b/test/monitoring_test.cpp index c27c3b0..327d35b 100644 --- a/test/monitoring_test.cpp +++ b/test/monitoring_test.cpp @@ -77,34 +77,34 @@ class MonitoringStats { issuer_entry.second.get(); auto it = stats_obj.find("successful_validations"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.successful_validations = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = stats_obj.find("unsuccessful_validations"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.unsuccessful_validations = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = stats_obj.find("expired_tokens"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.expired_tokens = - static_cast(it->second.get()); + static_cast(it->second.get()); } // Validation started counters it = stats_obj.find("sync_validations_started"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.sync_validations_started = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = stats_obj.find("async_validations_started"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.async_validations_started = - static_cast(it->second.get()); + static_cast(it->second.get()); } // Duration tracking @@ -126,15 +126,15 @@ class MonitoringStats { // Key lookup statistics it = stats_obj.find("successful_key_lookups"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.successful_key_lookups = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = stats_obj.find("failed_key_lookups"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.failed_key_lookups = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = stats_obj.find("failed_key_lookup_time_s"); @@ -145,21 +145,21 @@ class MonitoringStats { // Key refresh statistics it = stats_obj.find("expired_keys"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.expired_keys = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = stats_obj.find("failed_refreshes"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.failed_refreshes = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = stats_obj.find("stale_key_uses"); - if (it != stats_obj.end() && it->second.is()) { + if (it != stats_obj.end() && it->second.is()) { stats.stale_key_uses = - static_cast(it->second.get()); + static_cast(it->second.get()); } issuers_[issuer_entry.first] = stats; @@ -179,9 +179,9 @@ class MonitoringStats { auto &lookup_obj = entry.second.get(); auto it = lookup_obj.find("count"); - if (it != lookup_obj.end() && it->second.is()) { + if (it != lookup_obj.end() && it->second.is()) { lookup.count = - static_cast(it->second.get()); + static_cast(it->second.get()); } it = lookup_obj.find("total_time_s");