Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ add_library(SciTokens SHARED src/scitokens.cpp src/scitokens_internal.cpp src/sc
target_compile_features(SciTokens PUBLIC cxx_std_11) # Use at least C++11 for building and when linking to scitokens
target_include_directories(SciTokens PUBLIC ${JWT_CPP_INCLUDES} "${PROJECT_SOURCE_DIR}/src" PRIVATE ${CURL_INCLUDE_DIRS} ${OPENSSL_INCLUDE_DIRS} ${LIBCRYPTO_INCLUDE_DIRS} ${SQLITE_INCLUDE_DIRS} ${UUID_INCLUDE_DIRS})

target_link_libraries(SciTokens PUBLIC ${OPENSSL_LIBRARIES} ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES} ${SQLITE_LIBRARIES} ${UUID_LIBRARIES})
# Find threading library
find_package(Threads REQUIRED)

target_link_libraries(SciTokens PUBLIC ${OPENSSL_LIBRARIES} ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES} ${SQLITE_LIBRARIES} ${UUID_LIBRARIES} Threads::Threads)
if (UNIX)
# pkg_check_modules fails to return an absolute path on RHEL7. Set the
# link directories accordingly.
Expand Down
69 changes: 67 additions & 2 deletions src/scitokens.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ void load_config_from_environment() {
bool is_int;
};

const std::array<ConfigMapping, 6> known_configs = {
const std::array<ConfigMapping, 8> known_configs = {
{{"keycache.update_interval_s", "KEYCACHE_UPDATE_INTERVAL_S", true},
{"keycache.expiration_interval_s", "KEYCACHE_EXPIRATION_INTERVAL_S",
true},
{"keycache.cache_home", "KEYCACHE_CACHE_HOME", false},
{"tls.ca_file", "TLS_CA_FILE", false},
{"monitoring.file", "MONITORING_FILE", false},
{"monitoring.file_interval_s", "MONITORING_FILE_INTERVAL_S", true}}};
{"monitoring.file_interval_s", "MONITORING_FILE_INTERVAL_S", true},
{"keycache.refresh_interval_ms", "KEYCACHE_REFRESH_INTERVAL_MS", true},
{"keycache.refresh_threshold_ms", "KEYCACHE_REFRESH_THRESHOLD_MS",
true}}};

const char *prefix = "SCITOKEN_CONFIG_";

Expand Down Expand Up @@ -128,6 +131,13 @@ int configurer::Configuration::get_monitoring_file_interval() {
return m_monitoring_file_interval;
}

// Background refresh config
std::atomic_bool configurer::Configuration::m_background_refresh_enabled{false};
std::atomic_int configurer::Configuration::m_refresh_interval_ms{
60000}; // 60 seconds
std::atomic_int configurer::Configuration::m_refresh_threshold_ms{
600000}; // 10 minutes

SciTokenKey scitoken_key_create(const char *key_id, const char *alg,
const char *public_contents,
const char *private_contents, char **err_msg) {
Expand Down Expand Up @@ -1099,6 +1109,31 @@ int keycache_set_jwks(const char *issuer, const char *jwks, char **err_msg) {
return 0;
}

int keycache_set_background_refresh(int enabled, char **err_msg) {
try {
bool enable = (enabled != 0);
configurer::Configuration::set_background_refresh_enabled(enable);

if (enable) {
scitokens::internal::BackgroundRefreshManager::get_instance()
.start();
} else {
scitokens::internal::BackgroundRefreshManager::get_instance()
.stop();
}
} catch (std::exception &exc) {
if (err_msg) {
*err_msg = strdup(exc.what());
}
return -1;
}
return 0;
}

int keycache_stop_background_refresh(char **err_msg) {
return keycache_set_background_refresh(0, err_msg);
}

int config_set_int(const char *key, int value, char **err_msg) {
return scitoken_config_set_int(key, value, err_msg);
}
Expand Down Expand Up @@ -1145,6 +1180,28 @@ int scitoken_config_set_int(const char *key, int value, char **err_msg) {
return 0;
}

else if (_key == "keycache.refresh_interval_ms") {
if (value < 0) {
if (err_msg) {
*err_msg = strdup("Refresh interval must be positive.");
}
return -1;
}
configurer::Configuration::set_refresh_interval(value);
return 0;
}

else if (_key == "keycache.refresh_threshold_ms") {
if (value < 0) {
if (err_msg) {
*err_msg = strdup("Refresh threshold must be positive.");
}
return -1;
}
configurer::Configuration::set_refresh_threshold(value);
return 0;
}

else {
if (err_msg) {
*err_msg = strdup("Key not recognized.");
Expand Down Expand Up @@ -1178,6 +1235,14 @@ int scitoken_config_get_int(const char *key, char **err_msg) {
return configurer::Configuration::get_monitoring_file_interval();
}

else if (_key == "keycache.refresh_interval_ms") {
return configurer::Configuration::get_refresh_interval();
}

else if (_key == "keycache.refresh_threshold_ms") {
return configurer::Configuration::get_refresh_threshold();
}

else {
if (err_msg) {
*err_msg = strdup("Key not recognized.");
Expand Down
27 changes: 27 additions & 0 deletions src/scitokens.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,25 @@ int keycache_get_cached_jwks(const char *issuer, char **jwks, char **err_msg);
*/
int keycache_set_jwks(const char *issuer, const char *jwks, char **err_msg);

/**
* Enable or disable the background refresh thread for JWKS.
* - When enabled, a background thread will periodically check if any known
* issuers need their JWKS refreshed based on the configured refresh interval
* and threshold.
* - If enabled=1 and the thread is not running, it will be started.
* - If enabled=0 and the thread is running, it will be stopped gracefully.
* - Returns 0 on success, nonzero on failure.
*/
int keycache_set_background_refresh(int enabled, char **err_msg);

/**
* Stop the background refresh thread if it is running.
* - This is a convenience function equivalent to
* keycache_set_background_refresh(0, err_msg).
* - Returns 0 on success, nonzero on failure.
*/
int keycache_stop_background_refresh(char **err_msg);

/**
* APIs for managing scitokens configuration parameters.
*/
Expand All @@ -308,6 +327,10 @@ int config_set_int(const char *key, int value, char **err_msg);
* - "keycache.expiration_interval_s": Key cache expiration time (seconds)
* - "monitoring.file_interval_s": Interval between monitoring file writes
* (seconds, default 60)
* - "keycache.refresh_interval_ms": Background refresh thread check interval
* (milliseconds, default 60000)
* - "keycache.refresh_threshold_ms": Time before next_update when background
* refresh triggers (milliseconds, default 600000)
*/
int scitoken_config_set_int(const char *key, int value, char **err_msg);

Expand All @@ -325,6 +348,10 @@ int config_get_int(const char *key, char **err_msg);
* - "keycache.expiration_interval_s": Key cache expiration time (seconds)
* - "monitoring.file_interval_s": Interval between monitoring file writes
* (seconds, default 60)
* - "keycache.refresh_interval_ms": Background refresh thread check interval
* (milliseconds, default 60000)
* - "keycache.refresh_threshold_ms": Time before next_update when background
* refresh triggers (milliseconds, default 600000)
*/
int scitoken_config_get_int(const char *key, char **err_msg);

Expand Down
78 changes: 78 additions & 0 deletions src/scitokens_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,81 @@ bool scitokens::Validator::store_public_keys(const std::string &issuer,
sqlite3_close(db);
return true;
}

std::vector<std::pair<std::string, int64_t>>
scitokens::Validator::get_all_issuers_from_db(int64_t now) {
std::vector<std::pair<std::string, int64_t>> result;

auto cache_fname = get_cache_file();
if (cache_fname.size() == 0) {
return result;
}

sqlite3 *db;
int rc = sqlite3_open(cache_fname.c_str(), &db);
if (rc) {
sqlite3_close(db);
return result;
}

sqlite3_stmt *stmt;
rc = sqlite3_prepare_v2(db, "SELECT issuer, keys FROM keycache", -1, &stmt,
NULL);
if (rc != SQLITE_OK) {
sqlite3_close(db);
return result;
}

while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
const unsigned char *issuer_data = sqlite3_column_text(stmt, 0);
const unsigned char *keys_data = sqlite3_column_text(stmt, 1);

if (!issuer_data || !keys_data) {
continue;
}

std::string issuer(reinterpret_cast<const char *>(issuer_data));
std::string metadata(reinterpret_cast<const char *>(keys_data));

// Parse the metadata to get next_update and check expiry
picojson::value json_obj;
auto err = picojson::parse(json_obj, metadata);
if (!err.empty() || !json_obj.is<picojson::object>()) {
continue;
}

auto top_obj = json_obj.get<picojson::object>();

// Get expiry time
auto expires_iter = top_obj.find("expires");
if (expires_iter == top_obj.end() ||
!expires_iter->second.is<int64_t>()) {
continue;
}
auto expiry = expires_iter->second.get<int64_t>();

// Get next_update time
auto next_update_iter = top_obj.find("next_update");
int64_t next_update;
if (next_update_iter == top_obj.end() ||
!next_update_iter->second.is<int64_t>()) {
// If next_update is not set, default to 4 hours before expiry
next_update = expiry - 4 * 3600;
} else {
next_update = next_update_iter->second.get<int64_t>();
}

// Include expired entries - they should be refreshed after a long
// downtime If expired, set next_update to now so they get refreshed
// immediately
if (now > expiry) {
next_update = now;
}

result.push_back({issuer, next_update});
}

sqlite3_finalize(stmt);
sqlite3_close(db);
return result;
}
94 changes: 94 additions & 0 deletions src/scitokens_internal.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

#include <chrono>
#include <functional>
#include <memory>
#include <sstream>
Expand Down Expand Up @@ -37,8 +38,101 @@ std::mutex key_refresh_mutex;

namespace scitokens {

// Define the static once_flag for Validator
std::once_flag Validator::m_background_refresh_once;

namespace internal {

// BackgroundRefreshManager implementation
void BackgroundRefreshManager::start() {
std::lock_guard<std::mutex> lock(m_mutex);
if (m_running.load(std::memory_order_acquire)) {
return; // Already running
}
m_shutdown.store(false, std::memory_order_release);
m_running.store(true, std::memory_order_release);
m_thread = std::make_unique<std::thread>(
&BackgroundRefreshManager::refresh_loop, this);
}

void BackgroundRefreshManager::stop() {
std::unique_ptr<std::thread> thread_to_join;

{
std::lock_guard<std::mutex> lock(m_mutex);
if (!m_running.load(std::memory_order_acquire)) {
return; // Not running
}

m_shutdown.store(true, std::memory_order_release);
m_running.store(false, std::memory_order_release);
thread_to_join = std::move(m_thread);
}

m_cv.notify_all();

if (thread_to_join && thread_to_join->joinable()) {
thread_to_join->join();
}
}

void BackgroundRefreshManager::refresh_loop() {
while (!m_shutdown.load(std::memory_order_acquire)) {
auto interval = configurer::Configuration::get_refresh_interval();
auto threshold = configurer::Configuration::get_refresh_threshold();

// Wait for the interval or until shutdown
{
std::unique_lock<std::mutex> lock(m_mutex);
m_cv.wait_for(lock, std::chrono::milliseconds(interval), [this]() {
return m_shutdown.load(std::memory_order_acquire);
});
}

if (m_shutdown.load(std::memory_order_acquire)) {
break;
}

// Get list of issuers from the database
auto now = std::time(NULL);
auto issuers = scitokens::Validator::get_all_issuers_from_db(now);

for (const auto &issuer_pair : issuers) {
if (m_shutdown.load(std::memory_order_acquire)) {
break;
}

const auto &issuer = issuer_pair.first;
const auto &next_update = issuer_pair.second;

// Calculate time until next_update in milliseconds
int64_t time_until_update = (next_update - now) * 1000;

// If next update is within threshold, try to refresh
if (time_until_update <= threshold) {
auto &stats =
MonitoringStats::instance().get_issuer_stats(issuer);
try {
// Perform refresh (this will use the refresh_jwks method)
scitokens::Validator::refresh_jwks(issuer);
stats.inc_background_successful_refresh();
} catch (std::exception &) {
// Track failed refresh attempts
stats.inc_background_failed_refresh();
// Silently ignore errors in background refresh to avoid
// disrupting the application. Background refresh is a
// best-effort optimization. If it fails, the next token
// verification will trigger a foreground refresh as usual.
}
}
}

// Write monitoring file from background thread if configured
// This avoids writing from verify() when background thread is running
MonitoringStats::instance().maybe_write_monitoring_file();
}
}

SimpleCurlGet::GetStatus SimpleCurlGet::perform_start(const std::string &url) {
m_len = 0;

Expand Down
Loading
Loading