Skip to content
10 changes: 10 additions & 0 deletions src/scitokens_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

namespace {

// Timeout in milliseconds to wait when database is locked
// This handles concurrent access from multiple threads/processes
constexpr int SQLITE_BUSY_TIMEOUT_MS = 5000;

void initialize_cachedb(const std::string &keycache_file) {

sqlite3 *db;
Expand All @@ -27,6 +31,8 @@ void initialize_cachedb(const std::string &keycache_file) {
sqlite3_close(db);
return;
}
// Set busy timeout to handle concurrent access
sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS);
char *err_msg = nullptr;
rc = sqlite3_exec(db,
"CREATE TABLE IF NOT EXISTS keycache ("
Expand Down Expand Up @@ -161,6 +167,8 @@ bool scitokens::Validator::get_public_keys_from_db(const std::string issuer,
sqlite3_close(db);
return false;
}
// Set busy timeout to handle concurrent access
sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS);

sqlite3_stmt *stmt;
rc = sqlite3_prepare_v2(db, "SELECT keys from keycache where issuer = ?",
Expand Down Expand Up @@ -260,6 +268,8 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer,
sqlite3_close(db);
return false;
}
// Set busy timeout to handle concurrent access
sqlite3_busy_timeout(db, SQLITE_BUSY_TIMEOUT_MS);

if ((rc = sqlite3_exec(db, "BEGIN", 0, 0, 0)) != SQLITE_OK) {
sqlite3_close(db);
Expand Down
125 changes: 111 additions & 14 deletions src/scitokens_internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <memory>
#include <sstream>
#include <sys/stat.h>
#include <unordered_map>

#include <jwt-cpp/base.h>
#include <jwt-cpp/jwt.h>
Expand Down Expand Up @@ -34,6 +35,47 @@ CurlRaii myCurl;

std::mutex key_refresh_mutex;

// Per-issuer mutex map for preventing thundering herd on new issuers
std::mutex issuer_mutex_map_lock;
std::unordered_map<std::string, std::shared_ptr<std::mutex>> issuer_mutexes;
constexpr size_t MAX_ISSUER_MUTEXES = 1000;

// Get or create a mutex for a specific issuer
std::shared_ptr<std::mutex> get_issuer_mutex(const std::string &issuer) {
std::lock_guard<std::mutex> guard(issuer_mutex_map_lock);

auto it = issuer_mutexes.find(issuer);
if (it != issuer_mutexes.end()) {
return it->second;
}

// Prevent resource exhaustion: limit the number of cached mutexes
if (issuer_mutexes.size() >= MAX_ISSUER_MUTEXES) {
// Remove mutexes that are no longer in use
// Since we hold issuer_mutex_map_lock, no other thread can acquire
// a reference to these mutexes, making this check safe
for (auto iter = issuer_mutexes.begin();
iter != issuer_mutexes.end();) {
if (iter->second.use_count() == 1) {
// Only we hold a reference, safe to remove
iter = issuer_mutexes.erase(iter);
} else {
++iter;
}
}

// If still at capacity after cleanup, fail rather than unbounded growth
if (issuer_mutexes.size() >= MAX_ISSUER_MUTEXES) {
throw std::runtime_error(
"Too many concurrent issuers - resource exhaustion prevented");
}
}

auto mutex_ptr = std::make_shared<std::mutex>();
issuer_mutexes[issuer] = mutex_ptr;
return mutex_ptr;
}

} // namespace

namespace scitokens {
Expand Down Expand Up @@ -948,16 +990,36 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid,
result->m_done = true;
}
} else {
// No keys in the DB, or they are expired
// No keys in the DB, or they are expired, so get them from the web.
// Record that we had expired keys if the issuer was previously known
// (This is tracked by having an entry in issuer stats)
auto &issuer_stats =
internal::MonitoringStats::instance().get_issuer_stats(issuer);
issuer_stats.inc_expired_key();

// Get keys from the web.
result = get_public_keys_from_web(
issuer, internal::SimpleCurlGet::default_timeout);
// Use per-issuer lock to prevent thundering herd for new issuers
auto issuer_mutex = get_issuer_mutex(issuer);
std::unique_lock<std::mutex> issuer_lock(*issuer_mutex);

// Check again if keys are now in DB (another thread may have fetched
// them while we were waiting for the lock)
if (get_public_keys_from_db(issuer, now, result->m_keys,
result->m_next_update)) {
// Keys are now available, use them
result->m_continue_fetch = false;
result->m_do_store = false;
result->m_done = true;
// Lock released here - no need to hold it
} else {
// Still no keys, fetch them from the web
result = get_public_keys_from_web(
issuer, internal::SimpleCurlGet::default_timeout);

// Transfer ownership of the lock to the async status
// The lock will be held until keys are stored in
// get_public_key_pem_continue
result->m_issuer_mutex = issuer_mutex;
result->m_issuer_lock = std::move(issuer_lock);
}
}
result->m_issuer = issuer;
result->m_kid = kid;
Expand All @@ -973,21 +1035,56 @@ Validator::get_public_key_pem_continue(std::unique_ptr<AsyncStatus> status,
std::string &algorithm) {

if (status->m_continue_fetch) {
status = get_public_keys_from_web_continue(std::move(status));
if (status->m_continue_fetch) {
return std::move(status);
// Save issuer and lock info before potentially moving status
std::string issuer = status->m_issuer;
auto issuer_mutex = status->m_issuer_mutex;
std::unique_lock<std::mutex> issuer_lock(
std::move(status->m_issuer_lock));

try {
status = get_public_keys_from_web_continue(std::move(status));
if (status->m_continue_fetch) {
// Restore the lock to status before returning
status->m_issuer_mutex = issuer_mutex;
status->m_issuer_lock = std::move(issuer_lock);
return std::move(status);
}
// Success - restore the lock to status for later release
status->m_issuer_mutex = issuer_mutex;
status->m_issuer_lock = std::move(issuer_lock);
} catch (...) {
// Web fetch failed - store empty keys as negative cache entry
// This prevents thundering herd on repeated failed lookups
if (issuer_lock.owns_lock()) {
// Store empty keys with short TTL for negative caching
auto now = std::time(NULL);
int negative_cache_ttl =
configurer::Configuration::get_next_update_delta();
picojson::value empty_keys;
picojson::object keys_obj;
keys_obj["keys"] = picojson::value(picojson::array());
empty_keys = picojson::value(keys_obj);
store_public_keys(issuer, empty_keys, now + negative_cache_ttl,
now + negative_cache_ttl);
issuer_lock.unlock();
}
throw; // Re-throw the original exception
}
}
if (status->m_do_store) {
// Async web fetch completed successfully - record monitoring
if (status->m_is_refresh) {
auto &issuer_stats =
internal::MonitoringStats::instance().get_issuer_stats(
status->m_issuer);
issuer_stats.inc_successful_key_lookup();
}
// This counts both initial fetches and refreshes
auto &issuer_stats =
internal::MonitoringStats::instance().get_issuer_stats(
status->m_issuer);
issuer_stats.inc_successful_key_lookup();
store_public_keys(status->m_issuer, status->m_keys,
status->m_next_update, status->m_expires);
// Release the per-issuer lock now that keys are stored
// Other threads waiting on this issuer can now proceed
if (status->m_issuer_lock.owns_lock()) {
status->m_issuer_lock.unlock();
}
}
status->m_done = true;

Expand Down
23 changes: 18 additions & 5 deletions src/scitokens_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,10 @@ class AsyncStatus {
bool m_is_refresh{false}; // True if this is a refresh of an existing key
AsyncState m_state{DOWNLOAD_METADATA};
std::unique_lock<std::mutex> m_refresh_lock;
// Per-issuer lock to prevent thundering herd on new issuers
// We store both the shared_ptr (to keep mutex alive) and the lock
std::shared_ptr<std::mutex> m_issuer_mutex;
std::unique_lock<std::mutex> m_issuer_lock;

int64_t m_next_update{-1};
int64_t m_expires{-1};
Expand Down Expand Up @@ -776,6 +780,8 @@ class Validator {

try {
auto result = verify_async(scitoken);
// Note: m_is_sync flag no longer needed since counting is only done
// in verify_async_continue

// Extract issuer from the result's JWT string after decoding starts
const jwt::decoded_jwt<jwt::traits::kazuho_picojson> *jwt_decoded =
Expand Down Expand Up @@ -834,7 +840,8 @@ class Validator {
std::chrono::duration_cast<std::chrono::nanoseconds>(
end_time - last_duration_update);
issuer_stats->add_sync_time(delta);
issuer_stats->inc_successful_validation();
// Note: inc_successful_validation() is called in
// verify_async_continue
}
} catch (const std::exception &e) {
// Record failure (final duration update)
Expand Down Expand Up @@ -882,6 +889,8 @@ class Validator {
}

auto result = verify_async(jwt);
// Note: m_is_sync flag no longer needed since counting is only done
// in verify_async_continue
while (!result->m_done) {
result = verify_async_continue(std::move(result));
}
Expand All @@ -893,7 +902,8 @@ class Validator {
std::chrono::duration_cast<std::chrono::nanoseconds>(
end_time - start_time);
issuer_stats->add_sync_time(duration);
issuer_stats->inc_successful_validation();
// Note: inc_successful_validation() is called in
// verify_async_continue
}
} catch (const std::exception &e) {
// Record failure if we have an issuer
Expand Down Expand Up @@ -1004,6 +1014,7 @@ class Validator {
// Start monitoring timing and record async validation started
status->m_start_time = std::chrono::steady_clock::now();
status->m_monitoring_started = true;
status->m_issuer = jwt.get_issuer();
auto &stats = internal::MonitoringStats::instance().get_issuer_stats(
jwt.get_issuer());
stats.inc_async_validation_started();
Expand Down Expand Up @@ -1181,9 +1192,8 @@ class Validator {
}
}

// Record successful validation (only for async API, sync handles its
// own)
if (status->m_monitoring_started && !status->m_is_sync) {
// Record successful validation
if (status->m_monitoring_started) {
auto end_time = std::chrono::steady_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::nanoseconds>(
Expand All @@ -1195,8 +1205,11 @@ class Validator {
stats.add_async_time(duration);
}

// Create new result, preserving monitoring flags
std::unique_ptr<AsyncStatus> result(new AsyncStatus());
result->m_done = true;
result->m_is_sync = status->m_is_sync;
result->m_monitoring_started = status->m_monitoring_started;
return result;
}

Expand Down
Loading
Loading