diff --git a/src/scitokens_cache.cpp b/src/scitokens_cache.cpp index 36173e2..a418923 100644 --- a/src/scitokens_cache.cpp +++ b/src/scitokens_cache.cpp @@ -18,6 +18,10 @@ namespace { +// Timeout in milliseconds to wait when database is locked +// This handles concurrent access from multiple threads/processes +constexpr int SQLITE_BUSY_TIMEOUT_MS = 5000; + void initialize_cachedb(const std::string &keycache_file) { sqlite3 *db; @@ -27,6 +31,8 @@ void initialize_cachedb(const std::string &keycache_file) { sqlite3_close(db); return; } + // Set busy timeout to handle concurrent access + sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS); char *err_msg = nullptr; rc = sqlite3_exec(db, "CREATE TABLE IF NOT EXISTS keycache (" @@ -161,6 +167,8 @@ bool scitokens::Validator::get_public_keys_from_db(const std::string issuer, sqlite3_close(db); return false; } + // Set busy timeout to handle concurrent access + sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS); sqlite3_stmt *stmt; rc = sqlite3_prepare_v2(db, "SELECT keys from keycache where issuer = ?", @@ -260,6 +268,8 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer, sqlite3_close(db); return false; } + // Set busy timeout to handle concurrent access + sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS); if ((rc = sqlite3_exec(db, "BEGIN", 0, 0, 0)) != SQLITE_OK) { sqlite3_close(db); diff --git a/src/scitokens_internal.cpp b/src/scitokens_internal.cpp index 025717b..ffd7be5 100644 --- a/src/scitokens_internal.cpp +++ b/src/scitokens_internal.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -34,6 +35,47 @@ CurlRaii myCurl; std::mutex key_refresh_mutex; +// Per-issuer mutex map for preventing thundering herd on new issuers +std::mutex issuer_mutex_map_lock; +std::unordered_map> issuer_mutexes; +constexpr size_t MAX_ISSUER_MUTEXES = 1000; + +// Get or create a mutex for a specific issuer +std::shared_ptr get_issuer_mutex(const std::string &issuer) { + std::lock_guard guard(issuer_mutex_map_lock); + + auto it = issuer_mutexes.find(issuer); + if (it != issuer_mutexes.end()) { + return it->second; + } + + // Prevent resource exhaustion: limit the number of cached mutexes + if (issuer_mutexes.size() >= MAX_ISSUER_MUTEXES) { + // Remove mutexes that are no longer in use + // Since we hold issuer_mutex_map_lock, no other thread can acquire + // a reference to these mutexes, making this check safe + for (auto iter = issuer_mutexes.begin(); + iter != issuer_mutexes.end();) { + if (iter->second.use_count() == 1) { + // Only we hold a reference, safe to remove + iter = issuer_mutexes.erase(iter); + } else { + ++iter; + } + } + + // If still at capacity after cleanup, fail rather than unbounded growth + if (issuer_mutexes.size() >= MAX_ISSUER_MUTEXES) { + throw std::runtime_error( + "Too many concurrent issuers - resource exhaustion prevented"); + } + } + + auto mutex_ptr = std::make_shared(); + issuer_mutexes[issuer] = mutex_ptr; + return mutex_ptr; +} + } // namespace namespace scitokens { @@ -948,16 +990,36 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid, result->m_done = true; } } else { - // No keys in the DB, or they are expired + // No keys in the DB, or they are expired, so get them from the web. // Record that we had expired keys if the issuer was previously known - // (This is tracked by having an entry in issuer stats) auto &issuer_stats = internal::MonitoringStats::instance().get_issuer_stats(issuer); issuer_stats.inc_expired_key(); - // Get keys from the web. - result = get_public_keys_from_web( - issuer, internal::SimpleCurlGet::default_timeout); + // Use per-issuer lock to prevent thundering herd for new issuers + auto issuer_mutex = get_issuer_mutex(issuer); + std::unique_lock issuer_lock(*issuer_mutex); + + // Check again if keys are now in DB (another thread may have fetched + // them while we were waiting for the lock) + if (get_public_keys_from_db(issuer, now, result->m_keys, + result->m_next_update)) { + // Keys are now available, use them + result->m_continue_fetch = false; + result->m_do_store = false; + result->m_done = true; + // Lock released here - no need to hold it + } else { + // Still no keys, fetch them from the web + result = get_public_keys_from_web( + issuer, internal::SimpleCurlGet::default_timeout); + + // Transfer ownership of the lock to the async status + // The lock will be held until keys are stored in + // get_public_key_pem_continue + result->m_issuer_mutex = issuer_mutex; + result->m_issuer_lock = std::move(issuer_lock); + } } result->m_issuer = issuer; result->m_kid = kid; @@ -973,21 +1035,56 @@ Validator::get_public_key_pem_continue(std::unique_ptr status, std::string &algorithm) { if (status->m_continue_fetch) { - status = get_public_keys_from_web_continue(std::move(status)); - if (status->m_continue_fetch) { - return std::move(status); + // Save issuer and lock info before potentially moving status + std::string issuer = status->m_issuer; + auto issuer_mutex = status->m_issuer_mutex; + std::unique_lock issuer_lock( + std::move(status->m_issuer_lock)); + + try { + status = get_public_keys_from_web_continue(std::move(status)); + if (status->m_continue_fetch) { + // Restore the lock to status before returning + status->m_issuer_mutex = issuer_mutex; + status->m_issuer_lock = std::move(issuer_lock); + return std::move(status); + } + // Success - restore the lock to status for later release + status->m_issuer_mutex = issuer_mutex; + status->m_issuer_lock = std::move(issuer_lock); + } catch (...) { + // Web fetch failed - store empty keys as negative cache entry + // This prevents thundering herd on repeated failed lookups + if (issuer_lock.owns_lock()) { + // Store empty keys with short TTL for negative caching + auto now = std::time(NULL); + int negative_cache_ttl = + configurer::Configuration::get_next_update_delta(); + picojson::value empty_keys; + picojson::object keys_obj; + keys_obj["keys"] = picojson::value(picojson::array()); + empty_keys = picojson::value(keys_obj); + store_public_keys(issuer, empty_keys, now + negative_cache_ttl, + now + negative_cache_ttl); + issuer_lock.unlock(); + } + throw; // Re-throw the original exception } } if (status->m_do_store) { // Async web fetch completed successfully - record monitoring - if (status->m_is_refresh) { - auto &issuer_stats = - internal::MonitoringStats::instance().get_issuer_stats( - status->m_issuer); - issuer_stats.inc_successful_key_lookup(); - } + // This counts both initial fetches and refreshes + auto &issuer_stats = + internal::MonitoringStats::instance().get_issuer_stats( + status->m_issuer); + issuer_stats.inc_successful_key_lookup(); store_public_keys(status->m_issuer, status->m_keys, status->m_next_update, status->m_expires); + // Release the per-issuer lock now that keys are stored + // Other threads waiting on this issuer can now proceed + if (status->m_issuer_lock.owns_lock()) { + status->m_issuer_lock.unlock(); + } } status->m_done = true; diff --git a/src/scitokens_internal.h b/src/scitokens_internal.h index 68f2050..9c22018 100644 --- a/src/scitokens_internal.h +++ b/src/scitokens_internal.h @@ -552,6 +552,10 @@ class AsyncStatus { bool m_is_refresh{false}; // True if this is a refresh of an existing key AsyncState m_state{DOWNLOAD_METADATA}; std::unique_lock m_refresh_lock; + // Per-issuer lock to prevent thundering herd on new issuers + // We store both the shared_ptr (to keep mutex alive) and the lock + std::shared_ptr m_issuer_mutex; + std::unique_lock m_issuer_lock; int64_t m_next_update{-1}; int64_t m_expires{-1}; @@ -776,6 +780,8 @@ class Validator { try { auto result = verify_async(scitoken); + // Note: m_is_sync flag no longer needed since counting is only done + // in verify_async_continue // Extract issuer from the result's JWT string after decoding starts const jwt::decoded_jwt *jwt_decoded = @@ -834,7 +840,8 @@ class Validator { std::chrono::duration_cast( end_time - last_duration_update); issuer_stats->add_sync_time(delta); - issuer_stats->inc_successful_validation(); + // Note: inc_successful_validation() is called in + // verify_async_continue } } catch (const std::exception &e) { // Record failure (final duration update) @@ -882,6 +889,8 @@ class Validator { } auto result = verify_async(jwt); + // Note: m_is_sync flag no longer needed since counting is only done + // in verify_async_continue while (!result->m_done) { result = verify_async_continue(std::move(result)); } @@ -893,7 +902,8 @@ class Validator { std::chrono::duration_cast( end_time - start_time); issuer_stats->add_sync_time(duration); - issuer_stats->inc_successful_validation(); + // Note: inc_successful_validation() is called in + // verify_async_continue } } catch (const std::exception &e) { // Record failure if we have an issuer @@ -1004,6 +1014,7 @@ class Validator { // Start monitoring timing and record async validation started status->m_start_time = std::chrono::steady_clock::now(); status->m_monitoring_started = true; + status->m_issuer = jwt.get_issuer(); auto &stats = internal::MonitoringStats::instance().get_issuer_stats( jwt.get_issuer()); stats.inc_async_validation_started(); @@ -1181,9 +1192,8 @@ class Validator { } } - // Record successful validation (only for async API, sync handles its - // own) - if (status->m_monitoring_started && !status->m_is_sync) { + // Record successful validation + if (status->m_monitoring_started) { auto end_time = std::chrono::steady_clock::now(); auto duration = std::chrono::duration_cast( @@ -1195,8 +1205,11 @@ class Validator { stats.add_async_time(duration); } + // Create new result, preserving monitoring flags std::unique_ptr result(new AsyncStatus()); result->m_done = true; + result->m_is_sync = status->m_is_sync; + result->m_monitoring_started = status->m_monitoring_started; return result; } diff --git a/test/integration_test.cpp b/test/integration_test.cpp index e07f444..21cb081 100644 --- a/test/integration_test.cpp +++ b/test/integration_test.cpp @@ -1,5 +1,6 @@ #include "../src/scitokens.h" +#include #include #include #include @@ -8,7 +9,9 @@ #include #include #include +#include #include +#include #ifndef PICOJSON_USE_INT64 #define PICOJSON_USE_INT64 @@ -1336,6 +1339,545 @@ TEST_F(IntegrationTest, BackgroundRefreshTest) { std::cout << "Test completed successfully" << std::endl; } +// Test that concurrent threads validating tokens from the same new issuer +// all succeed even when there's no pre-existing cache entry. +// Note: The per-issuer lock prevents the worst thundering herd scenarios +// by serializing DB checks after initial discovery, but the current +// implementation may still make multiple web requests if the fetch is async. +TEST_F(IntegrationTest, ConcurrentNewIssuerLookup) { + char *err_msg = nullptr; + + // Use a unique cache directory to ensure no cached keys exist + // This forces the code path where keys must be fetched from the server + std::string unique_cache_dir = "/tmp/scitokens_concurrent_test_" + + std::to_string(time(nullptr)) + "_" + + std::to_string(getpid()); + int rv = scitoken_config_set_str("keycache.cache_home", + unique_cache_dir.c_str(), &err_msg); + ASSERT_EQ(rv, 0) << "Failed to set cache_home: " + << (err_msg ? err_msg : "unknown"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Reset monitoring stats before the test + rv = scitoken_reset_monitoring_stats(&err_msg); + ASSERT_EQ(rv, 0) << "Failed to reset monitoring stats"; + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Create a token with the test issuer + std::unique_ptr key( + scitoken_key_create("test-key-1", "ES256", public_key_.c_str(), + private_key_.c_str(), &err_msg), + scitoken_key_destroy); + ASSERT_TRUE(key.get() != nullptr); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::unique_ptr token( + scitoken_create(key.get()), scitoken_destroy); + ASSERT_TRUE(token.get() != nullptr); + + rv = scitoken_set_claim_string(token.get(), "iss", issuer_url_.c_str(), + &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + scitoken_set_lifetime(token.get(), 300); + + char *token_value = nullptr; + rv = scitoken_serialize(token.get(), &token_value, &err_msg); + ASSERT_EQ(rv, 0); + std::string token_str(token_value); + free(token_value); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Get initial counts before the concurrent test + auto stats_before = getCurrentMonitoringStats(); + auto initial_successful_validations = + stats_before.getIssuerStats(issuer_url_).successful_validations; + auto initial_expired_keys = + stats_before.getIssuerStats(issuer_url_).expired_keys; + auto initial_key_lookups = + stats_before.getIssuerStats(issuer_url_).successful_key_lookups; + + std::cout << "Using unique cache directory: " << unique_cache_dir + << std::endl; + std::cout << "Initial successful_validations: " + << initial_successful_validations << std::endl; + std::cout << "Initial expired_keys: " << initial_expired_keys << std::endl; + std::cout << "Initial successful_key_lookups: " << initial_key_lookups + << std::endl; + + // Launch multiple threads to concurrently validate the same token + const int NUM_THREADS = 10; + std::vector threads; + std::atomic success_count{0}; + std::atomic failure_count{0}; + + // Use a barrier to synchronize thread start + std::atomic start_flag{false}; + + for (int i = 0; i < NUM_THREADS; i++) { + threads.emplace_back([&]() { + // Wait for all threads to be ready + while (!start_flag.load()) { + std::this_thread::yield(); + } + + char *thread_err = nullptr; + std::unique_ptr verify_token( + scitoken_create(nullptr), scitoken_destroy); + + int result = scitoken_deserialize_v2( + token_str.c_str(), verify_token.get(), nullptr, &thread_err); + if (result == 0) { + success_count++; + } else { + failure_count++; + if (thread_err) { + std::cerr << "Thread validation error: " << thread_err + << std::endl; + } + } + if (thread_err) + free(thread_err); + }); + } + + // Signal all threads to start simultaneously + start_flag.store(true); + + // Wait for all threads to complete + for (auto &t : threads) { + t.join(); + } + + std::cout << "Threads completed - success: " << success_count.load() + << ", failure: " << failure_count.load() << std::endl; + + // All threads should have successfully validated + // This proves the per-issuer locking and caching mechanisms work correctly + // even under concurrent load with an empty cache + EXPECT_EQ(success_count.load(), NUM_THREADS) + << "All threads should validate successfully"; + + // Check monitoring stats to verify the code paths were exercised + auto stats_after = getCurrentMonitoringStats(); + auto issuer_stats = stats_after.getIssuerStats(issuer_url_); + auto new_expired_keys = issuer_stats.expired_keys - initial_expired_keys; + auto new_key_lookups = + issuer_stats.successful_key_lookups - initial_key_lookups; + + std::cout << "Final stats for issuer:" << std::endl; + std::cout << " successful_validations: " + << issuer_stats.successful_validations << std::endl; + std::cout << " expired_keys: " << issuer_stats.expired_keys + << " (new: " << new_expired_keys << ")" << std::endl; + std::cout << " successful_key_lookups: " + << issuer_stats.successful_key_lookups + << " (new: " << new_key_lookups << ")" << std::endl; + + // The per-issuer lock should ensure only ONE thread fetches keys from web. + // All other threads should wait for the lock, then find keys in the cache. + // This is the key assertion that proves the thundering herd prevention + // works. + EXPECT_EQ(new_key_lookups, 1u) + << "Per-issuer lock should ensure only ONE web fetch for " + << NUM_THREADS << " concurrent requests"; + + // The expired_keys counter tracks entries into the "no cached keys" path. + // With a fresh cache, all threads should hit this path because they all + // check the DB before acquiring the per-issuer lock. + EXPECT_EQ(new_expired_keys, static_cast(NUM_THREADS)) + << "All threads should enter the expired_keys code path"; + + // Cleanup: remove the temporary cache directory + std::string rm_cmd = "rm -rf " + unique_cache_dir; + (void)system(rm_cmd.c_str()); +} + +// Stress test: repeatedly deserialize a valid token across multiple threads +// for a fixed duration and verify monitoring counters match actual counts +TEST_F(IntegrationTest, StressTestValidToken) { + char *err_msg = nullptr; + + // Use a unique cache directory to ensure no cached keys exist from prior + // tests This forces fresh key lookup and prevents background refresh + // interference + std::string unique_cache_dir = "/tmp/scitokens_stress_valid_" + + std::to_string(time(nullptr)) + "_" + + std::to_string(getpid()); + int rv = scitoken_config_set_str("keycache.cache_home", + unique_cache_dir.c_str(), &err_msg); + ASSERT_EQ(rv, 0) << "Failed to set cache_home: " + << (err_msg ? err_msg : "unknown"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Ensure background refresh is disabled so it doesn't interfere + rv = keycache_stop_background_refresh(&err_msg); + ASSERT_EQ(rv, 0) << "Failed to stop background refresh"; + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Reset update interval to default (600 seconds) - BackgroundRefreshTest + // may have set it to 1 second + rv = scitoken_config_set_int("keycache.update_interval_s", 600, &err_msg); + ASSERT_EQ(rv, 0) << "Failed to set update_interval_s"; + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Reset monitoring stats before the test + rv = scitoken_reset_monitoring_stats(&err_msg); + ASSERT_EQ(rv, 0) << "Failed to reset monitoring stats"; + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Create a valid token + std::unique_ptr key( + scitoken_key_create("test-key-1", "ES256", public_key_.c_str(), + private_key_.c_str(), &err_msg), + scitoken_key_destroy); + ASSERT_TRUE(key.get() != nullptr); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::unique_ptr token( + scitoken_create(key.get()), scitoken_destroy); + ASSERT_TRUE(token.get() != nullptr); + + rv = scitoken_set_claim_string(token.get(), "iss", issuer_url_.c_str(), + &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + rv = scitoken_set_claim_string(token.get(), "sub", "stress-test", &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + scitoken_set_lifetime(token.get(), 3600); + + char *token_value = nullptr; + rv = scitoken_serialize(token.get(), &token_value, &err_msg); + ASSERT_EQ(rv, 0); + std::string token_str(token_value); + free(token_value); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Get initial stats + auto stats_before = getCurrentMonitoringStats(); + auto initial_successful = + stats_before.getIssuerStats(issuer_url_).successful_validations; + auto initial_unsuccessful = + stats_before.getIssuerStats(issuer_url_).unsuccessful_validations; + auto initial_key_lookups = + stats_before.getIssuerStats(issuer_url_).successful_key_lookups; + + // Stress test parameters + const int NUM_THREADS = 10; + const int TEST_DURATION_MS = 5000; // 5 seconds + + std::atomic total_attempts{0}; + std::atomic total_successes{0}; + std::atomic total_failures{0}; + std::atomic stop_flag{false}; + + std::vector threads; + for (int i = 0; i < NUM_THREADS; i++) { + threads.emplace_back([&]() { + while (!stop_flag.load()) { + total_attempts++; + + char *thread_err = nullptr; + std::unique_ptr verify_token( + scitoken_create(nullptr), scitoken_destroy); + + int result = scitoken_deserialize_v2(token_str.c_str(), + verify_token.get(), + nullptr, &thread_err); + + if (result == 0) { + total_successes++; + } else { + total_failures++; + if (thread_err) { + std::cerr << "Unexpected error: " << thread_err + << std::endl; + } + } + if (thread_err) + free(thread_err); + } + }); + } + + // Run for the test duration + std::this_thread::sleep_for(std::chrono::milliseconds(TEST_DURATION_MS)); + stop_flag.store(true); + + // Wait for all threads to complete + for (auto &t : threads) { + t.join(); + } + + // Get final stats + auto stats_after = getCurrentMonitoringStats(); + auto issuer_stats = stats_after.getIssuerStats(issuer_url_); + auto new_successful = + issuer_stats.successful_validations - initial_successful; + auto new_unsuccessful = + issuer_stats.unsuccessful_validations - initial_unsuccessful; + auto new_key_lookups = + issuer_stats.successful_key_lookups - initial_key_lookups; + + std::cout << "Stress test (valid token) results:" << std::endl; + std::cout << " Test duration: " << TEST_DURATION_MS << " ms" << std::endl; + std::cout << " Threads: " << NUM_THREADS << std::endl; + std::cout << " Total attempts: " << total_attempts.load() << std::endl; + std::cout << " Total successes: " << total_successes.load() << std::endl; + std::cout << " Total failures: " << total_failures.load() << std::endl; + std::cout << " Monitoring successful_validations: " << new_successful + << std::endl; + std::cout << " Monitoring unsuccessful_validations: " << new_unsuccessful + << std::endl; + std::cout << " Monitoring successful_key_lookups: " << new_key_lookups + << std::endl; + + // Verify all attempts succeeded + EXPECT_EQ(total_failures.load(), 0u) + << "All deserializations of valid token should succeed"; + + // Verify monitoring counters match actual counts + EXPECT_EQ(new_successful, total_successes.load()) + << "Monitoring successful_validations should match actual success " + "count"; + + EXPECT_EQ(new_unsuccessful, 0u) + << "There should be no unsuccessful validations for valid token"; + + // Verify at most one key lookup (keys should be cached after first fetch) + // Using a fresh cache directory ensures no interference from prior tests + EXPECT_LE(new_key_lookups, 1u) + << "Should have at most one key lookup (cached after first)"; + + // Sanity check: we should have done a meaningful number of validations + EXPECT_GT(total_attempts.load(), 100u) + << "Should have completed at least 100 validations in " + << TEST_DURATION_MS << "ms"; + + // Cleanup: remove the temporary cache directory + std::string rm_cmd = "rm -rf " + unique_cache_dir; + (void)system(rm_cmd.c_str()); +} + +// Stress test: repeatedly deserialize a token with an invalid issuer (404) +// across multiple threads and verify monitoring counters match actual failure +// counts +TEST_F(IntegrationTest, StressTestInvalidIssuer) { + char *err_msg = nullptr; + + // Use a unique cache directory to ensure no cached keys exist from prior + // tests + std::string unique_cache_dir = "/tmp/scitokens_stress_invalid_" + + std::to_string(time(nullptr)) + "_" + + std::to_string(getpid()); + int rv = scitoken_config_set_str("keycache.cache_home", + unique_cache_dir.c_str(), &err_msg); + ASSERT_EQ(rv, 0) << "Failed to set cache_home: " + << (err_msg ? err_msg : "unknown"); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Reset monitoring stats before the test + rv = scitoken_reset_monitoring_stats(&err_msg); + ASSERT_EQ(rv, 0) << "Failed to reset monitoring stats"; + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Create a token with an issuer path that returns 404 + // The server returns 404 for paths like /nonexistent-path + std::string invalid_issuer = issuer_url_ + "/nonexistent-path"; + + std::unique_ptr key( + scitoken_key_create("test-key-1", "ES256", public_key_.c_str(), + private_key_.c_str(), &err_msg), + scitoken_key_destroy); + ASSERT_TRUE(key.get() != nullptr); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + std::unique_ptr token( + scitoken_create(key.get()), scitoken_destroy); + ASSERT_TRUE(token.get() != nullptr); + + rv = scitoken_set_claim_string(token.get(), "iss", invalid_issuer.c_str(), + &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + rv = scitoken_set_claim_string(token.get(), "sub", "stress-test-invalid", + &err_msg); + ASSERT_EQ(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + scitoken_set_lifetime(token.get(), 3600); + + char *token_value = nullptr; + rv = scitoken_serialize(token.get(), &token_value, &err_msg); + ASSERT_EQ(rv, 0); + std::string token_str(token_value); + free(token_value); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Get initial stats for the invalid issuer + auto stats_before = getCurrentMonitoringStats(); + auto initial_successful = + stats_before.getIssuerStats(invalid_issuer).successful_validations; + auto initial_unsuccessful = + stats_before.getIssuerStats(invalid_issuer).unsuccessful_validations; + auto initial_key_lookups = + stats_before.getIssuerStats(invalid_issuer).successful_key_lookups; + + // Stress test parameters + const int NUM_THREADS = 10; + const int TEST_DURATION_MS = 5000; // 5 seconds + + std::atomic total_attempts{0}; + std::atomic total_successes{0}; + std::atomic total_failures{0}; + std::atomic stop_flag{false}; + + std::vector threads; + for (int i = 0; i < NUM_THREADS; i++) { + threads.emplace_back([&]() { + while (!stop_flag.load()) { + total_attempts++; + + char *thread_err = nullptr; + std::unique_ptr verify_token( + scitoken_create(nullptr), scitoken_destroy); + + int result = scitoken_deserialize_v2(token_str.c_str(), + verify_token.get(), + nullptr, &thread_err); + + if (result == 0) { + total_successes++; + } else { + total_failures++; + } + if (thread_err) + free(thread_err); + } + }); + } + + // Run for the test duration + std::this_thread::sleep_for(std::chrono::milliseconds(TEST_DURATION_MS)); + stop_flag.store(true); + + // Wait for all threads to complete + for (auto &t : threads) { + t.join(); + } + + // Get final stats for the invalid issuer + auto stats_after = getCurrentMonitoringStats(); + auto issuer_stats = stats_after.getIssuerStats(invalid_issuer); + auto new_successful = + issuer_stats.successful_validations - initial_successful; + auto new_unsuccessful = + issuer_stats.unsuccessful_validations - initial_unsuccessful; + auto new_key_lookups = + issuer_stats.successful_key_lookups - initial_key_lookups; + + std::cout << "Stress test (invalid issuer - 404) results:" << std::endl; + std::cout << " Test duration: " << TEST_DURATION_MS << " ms" << std::endl; + std::cout << " Threads: " << NUM_THREADS << std::endl; + std::cout << " Invalid issuer: " << invalid_issuer << std::endl; + std::cout << " Total attempts: " << total_attempts.load() << std::endl; + std::cout << " Total successes: " << total_successes.load() << std::endl; + std::cout << " Total failures: " << total_failures.load() << std::endl; + std::cout << " Monitoring successful_validations: " << new_successful + << std::endl; + std::cout << " Monitoring unsuccessful_validations: " << new_unsuccessful + << std::endl; + std::cout << " Monitoring successful_key_lookups: " << new_key_lookups + << std::endl; + + // Verify all attempts failed (issuer returns 404) + EXPECT_EQ(total_successes.load(), 0u) + << "All deserializations with invalid issuer should fail"; + + // Verify monitoring counters match actual counts + EXPECT_EQ(new_successful, 0u) + << "There should be no successful validations for invalid issuer"; + + EXPECT_EQ(new_unsuccessful, total_failures.load()) + << "Monitoring unsuccessful_validations should match actual failure " + "count"; + + // No successful key lookups expected (issuer returns 404) + EXPECT_EQ(new_key_lookups, 0u) + << "Should have no successful key lookups (issuer returns 404)"; + + // Sanity check: we should have done a meaningful number of validations + EXPECT_GT(total_attempts.load(), 100u) + << "Should have completed at least 100 validations in " + << TEST_DURATION_MS << "ms"; + + // Cleanup: remove the temporary cache directory + std::string rm_cmd = "rm -rf " + unique_cache_dir; + (void)system(rm_cmd.c_str()); +} + } // namespace int main(int argc, char **argv) { diff --git a/test/main.cpp b/test/main.cpp index 1e1a54f..9564d8b 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -1028,6 +1028,43 @@ TEST_F(EnvConfigTest, StringConfigFromEnv) { free(err_msg); } +// Test for thundering herd prevention with per-issuer locks +TEST_F(IssuerSecurityTest, ThunderingHerdPrevention) { + char *err_msg = nullptr; + + // Create tokens for a new issuer and pre-populate the cache + std::string test_issuer = "https://thundering-herd-test.example.org/gtest"; + + auto rv = scitoken_set_claim_string(m_token.get(), "iss", + test_issuer.c_str(), &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + + // Store public key for this issuer in the cache + rv = scitoken_store_public_ec_key(test_issuer.c_str(), "1", ec_public, + &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + + char *token_value = nullptr; + rv = scitoken_serialize(m_token.get(), &token_value, &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + std::unique_ptr token_value_ptr(token_value, free); + + // Successfully deserialize - the per-issuer lock should prevent thundering + // herd Since we pre-populated the cache, this should succeed without + // network access + rv = scitoken_deserialize_v2(token_value, m_read_token.get(), nullptr, + &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + + // Verify the issuer claim + char *value; + rv = scitoken_get_claim_string(m_read_token.get(), "iss", &value, &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + ASSERT_TRUE(value != nullptr); + std::unique_ptr value_ptr(value, free); + EXPECT_STREQ(value, test_issuer.c_str()); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS();