diff --git a/src/scitokens_cache.cpp b/src/scitokens_cache.cpp index a418923..a2094b1 100644 --- a/src/scitokens_cache.cpp +++ b/src/scitokens_cache.cpp @@ -209,6 +209,35 @@ bool scitokens::Validator::get_public_keys_from_db(const std::string issuer, return false; } auto keys_local = iter->second; + + // Check if this is a negative cache entry (empty keys array) + if (keys_local.is()) { + auto jwks_obj = keys_local.get(); + auto keys_iter = jwks_obj.find("keys"); + if (keys_iter != jwks_obj.end() && + keys_iter->second.is()) { + auto keys_array = keys_iter->second.get(); + if (keys_array.empty()) { + // Check if negative cache has expired + iter = top_obj.find("expires"); + if (iter != top_obj.end() && iter->second.is()) { + auto expiry = iter->second.get(); + if (now > expiry) { + // Negative cache expired, remove and return false + if (remove_issuer_entry(db, issuer, true) != 0) { + return false; + } + sqlite3_close(db); + return false; + } + } + // Negative cache still valid - throw exception + sqlite3_close(db); + throw NegativeCacheHitException(issuer); + } + } + } + iter = top_obj.find("expires"); if (iter == top_obj.end() || !iter->second.is()) { if (remove_issuer_entry(db, issuer, true) != 0) { diff --git a/src/scitokens_internal.cpp b/src/scitokens_internal.cpp index ffd7be5..034d4e4 100644 --- a/src/scitokens_internal.cpp +++ b/src/scitokens_internal.cpp @@ -921,8 +921,14 @@ std::string Validator::get_jwks(const std::string &issuer) { auto now = std::time(NULL); picojson::value jwks; int64_t next_update; - if (get_public_keys_from_db(issuer, now, jwks, next_update)) { - return jwks.serialize(); + try { + if (get_public_keys_from_db(issuer, now, jwks, next_update)) { + return jwks.serialize(); + } + } catch (const NegativeCacheHitException &) { + // Negative cache hit - return empty keys without incrementing counter + // (counter is incremented elsewhere for validation failures) + return std::string("{\"keys\": []}"); } return std::string("{\"keys\": []}"); } diff --git a/src/scitokens_internal.h b/src/scitokens_internal.h index 9c22018..f682739 100644 --- a/src/scitokens_internal.h +++ b/src/scitokens_internal.h @@ -264,6 +264,9 @@ struct IssuerStats { std::atomic background_successful_refreshes{0}; std::atomic background_failed_refreshes{0}; + // Negative cache statistics + std::atomic negative_cache_hits{0}; + // Increment methods for atomic counters (use relaxed ordering for stats) void inc_successful_validation() { successful_validations.fetch_add(1, std::memory_order_relaxed); @@ -301,6 +304,9 @@ struct IssuerStats { void inc_background_failed_refresh() { background_failed_refreshes.fetch_add(1, std::memory_order_relaxed); } + void inc_negative_cache_hit() { + negative_cache_hits.fetch_add(1, std::memory_order_relaxed); + } // Time setters that accept std::chrono::duration (use relaxed ordering) template @@ -476,6 +482,14 @@ class InvalidIssuerException : public std::runtime_error { InvalidIssuerException(const std::string &msg) : std::runtime_error(msg) {} }; +class NegativeCacheHitException : public InvalidIssuerException { + public: + explicit NegativeCacheHitException(const std::string &issuer) + : InvalidIssuerException("Issuer is in negative cache (recently failed " + "to retrieve keys): " + + issuer) {} +}; + class JsonException : public std::runtime_error { public: JsonException(const std::string &msg) : std::runtime_error(msg) {} @@ -1350,6 +1364,8 @@ class Validator { const std::exception &e) { if (dynamic_cast(&e)) { stats.inc_expired_token(); + } else if (dynamic_cast(&e)) { + stats.inc_negative_cache_hit(); } stats.inc_unsuccessful_validation(); diff --git a/src/scitokens_monitoring.cpp b/src/scitokens_monitoring.cpp index ad7216f..39889a9 100644 --- a/src/scitokens_monitoring.cpp +++ b/src/scitokens_monitoring.cpp @@ -134,6 +134,11 @@ std::string MonitoringStats::get_json() const { static_cast(stats.background_failed_refreshes.load( std::memory_order_relaxed))); + // Negative cache statistics + issuer_obj["negative_cache_hits"] = + picojson::value(static_cast( + stats.negative_cache_hits.load(std::memory_order_relaxed))); + std::string sanitized_issuer = sanitize_issuer_for_json(issuer); issuers_obj[sanitized_issuer] = picojson::value(issuer_obj); } diff --git a/test/integration_test.cpp b/test/integration_test.cpp index 21cb081..6ec5a95 100644 --- a/test/integration_test.cpp +++ b/test/integration_test.cpp @@ -1,7 +1,9 @@ #include "../src/scitokens.h" +#include "test_utils.h" #include #include +#include #include #include #include @@ -9,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -18,6 +21,8 @@ #endif #include +using scitokens_test::SecureTempDir; + namespace { // Helper class to parse monitoring JSON @@ -1037,15 +1042,15 @@ TEST_F(IntegrationTest, MonitoringDurationTracking) { TEST_F(IntegrationTest, MonitoringFileOutput) { char *err_msg = nullptr; + // Create a secure temp directory for the monitoring file + SecureTempDir temp_dir("monitoring_test_"); + ASSERT_TRUE(temp_dir.valid()) << "Failed to create temp directory"; + // Set up a test file path and zero interval for immediate write - std::string test_file = "/tmp/scitokens_monitoring_integration_" + - std::to_string(time(nullptr)) + ".json"; + std::string test_file = temp_dir.path() + "/monitoring.json"; scitoken_config_set_str("monitoring.file", test_file.c_str(), &err_msg); scitoken_config_set_int("monitoring.file_interval_s", 0, &err_msg); - // Clean up any existing file - std::remove(test_file.c_str()); - // Reset monitoring stats scitoken_reset_monitoring_stats(&err_msg); @@ -1119,10 +1124,10 @@ TEST_F(IntegrationTest, MonitoringFileOutput) { std::cout << content << std::endl; } - // Clean up + // Clean up - disable monitoring file scitoken_config_set_str("monitoring.file", "", &err_msg); scitoken_config_set_int("monitoring.file_interval_s", 60, &err_msg); - std::remove(test_file.c_str()); + // temp_dir destructor will clean up the directory and file } // ============================================================================= @@ -1347,13 +1352,14 @@ TEST_F(IntegrationTest, BackgroundRefreshTest) { TEST_F(IntegrationTest, ConcurrentNewIssuerLookup) { char *err_msg = nullptr; - // Use a unique cache directory to ensure no cached keys exist + // Use a unique secure cache directory to ensure no cached keys exist // This forces the code path where keys must be fetched from the server - std::string unique_cache_dir = "/tmp/scitokens_concurrent_test_" + - std::to_string(time(nullptr)) + "_" + - std::to_string(getpid()); + SecureTempDir unique_cache("concurrent_test_"); + ASSERT_TRUE(unique_cache.valid()) + << "Failed to create temp cache directory"; + int rv = scitoken_config_set_str("keycache.cache_home", - unique_cache_dir.c_str(), &err_msg); + unique_cache.path().c_str(), &err_msg); ASSERT_EQ(rv, 0) << "Failed to set cache_home: " << (err_msg ? err_msg : "unknown"); if (err_msg) { @@ -1413,7 +1419,7 @@ TEST_F(IntegrationTest, ConcurrentNewIssuerLookup) { auto initial_key_lookups = stats_before.getIssuerStats(issuer_url_).successful_key_lookups; - std::cout << "Using unique cache directory: " << unique_cache_dir + std::cout << "Using unique cache directory: " << unique_cache.path() << std::endl; std::cout << "Initial successful_validations: " << initial_successful_validations << std::endl; @@ -1504,9 +1510,7 @@ TEST_F(IntegrationTest, ConcurrentNewIssuerLookup) { EXPECT_EQ(new_expired_keys, static_cast(NUM_THREADS)) << "All threads should enter the expired_keys code path"; - // Cleanup: remove the temporary cache directory - std::string rm_cmd = "rm -rf " + unique_cache_dir; - (void)system(rm_cmd.c_str()); + // unique_cache destructor will clean up the temporary cache directory } // Stress test: repeatedly deserialize a valid token across multiple threads @@ -1514,14 +1518,15 @@ TEST_F(IntegrationTest, ConcurrentNewIssuerLookup) { TEST_F(IntegrationTest, StressTestValidToken) { char *err_msg = nullptr; - // Use a unique cache directory to ensure no cached keys exist from prior - // tests This forces fresh key lookup and prevents background refresh - // interference - std::string unique_cache_dir = "/tmp/scitokens_stress_valid_" + - std::to_string(time(nullptr)) + "_" + - std::to_string(getpid()); + // Use a unique secure cache directory to ensure no cached keys exist from + // prior tests. This forces fresh key lookup and prevents background refresh + // interference. + SecureTempDir unique_cache("stress_valid_"); + ASSERT_TRUE(unique_cache.valid()) + << "Failed to create temp cache directory"; + int rv = scitoken_config_set_str("keycache.cache_home", - unique_cache_dir.c_str(), &err_msg); + unique_cache.path().c_str(), &err_msg); ASSERT_EQ(rv, 0) << "Failed to set cache_home: " << (err_msg ? err_msg : "unknown"); if (err_msg) { @@ -1564,7 +1569,6 @@ TEST_F(IntegrationTest, StressTestValidToken) { free(err_msg); err_msg = nullptr; } - std::unique_ptr token( scitoken_create(key.get()), scitoken_destroy); ASSERT_TRUE(token.get() != nullptr); @@ -1697,9 +1701,7 @@ TEST_F(IntegrationTest, StressTestValidToken) { << "Should have completed at least 100 validations in " << TEST_DURATION_MS << "ms"; - // Cleanup: remove the temporary cache directory - std::string rm_cmd = "rm -rf " + unique_cache_dir; - (void)system(rm_cmd.c_str()); + // unique_cache destructor will clean up the temporary cache directory } // Stress test: repeatedly deserialize a token with an invalid issuer (404) @@ -1708,13 +1710,14 @@ TEST_F(IntegrationTest, StressTestValidToken) { TEST_F(IntegrationTest, StressTestInvalidIssuer) { char *err_msg = nullptr; - // Use a unique cache directory to ensure no cached keys exist from prior - // tests - std::string unique_cache_dir = "/tmp/scitokens_stress_invalid_" + - std::to_string(time(nullptr)) + "_" + - std::to_string(getpid()); + // Use a unique secure cache directory to ensure no cached keys exist from + // prior tests + SecureTempDir unique_cache("stress_invalid_"); + ASSERT_TRUE(unique_cache.valid()) + << "Failed to create temp cache directory"; + int rv = scitoken_config_set_str("keycache.cache_home", - unique_cache_dir.c_str(), &err_msg); + unique_cache.path().c_str(), &err_msg); ASSERT_EQ(rv, 0) << "Failed to set cache_home: " << (err_msg ? err_msg : "unknown"); if (err_msg) { @@ -1873,9 +1876,7 @@ TEST_F(IntegrationTest, StressTestInvalidIssuer) { << "Should have completed at least 100 validations in " << TEST_DURATION_MS << "ms"; - // Cleanup: remove the temporary cache directory - std::string rm_cmd = "rm -rf " + unique_cache_dir; - (void)system(rm_cmd.c_str()); + // unique_cache destructor will clean up the temporary cache directory } } // namespace diff --git a/test/main.cpp b/test/main.cpp index 9564d8b..b301c4a 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -1,9 +1,15 @@ #include "../src/scitokens.h" +#include "test_utils.h" +#include +#include #include #include +#include #include +using scitokens_test::SecureTempDir; + namespace { const char ec_private[] = @@ -724,14 +730,15 @@ TEST_F(KeycacheTest, SetGetTest) { } TEST_F(KeycacheTest, SetGetConfiguredCacheHome) { - // Set cache home - char cache_path[FILENAME_MAX]; - ASSERT_TRUE(getcwd(cache_path, sizeof(cache_path)) != - nullptr); // Side effect gets cwd + // Create a secure temporary directory for the cache + SecureTempDir temp_cache("cache_home_test_"); + ASSERT_TRUE(temp_cache.valid()) << "Failed to create temp directory"; + char *err_msg = nullptr; std::string key = "keycache.cache_home"; - auto rv = scitoken_config_set_str(key.c_str(), cache_path, &err_msg); + auto rv = scitoken_config_set_str(key.c_str(), temp_cache.path().c_str(), + &err_msg); ASSERT_TRUE(rv == 0) << err_msg; // Set the jwks at the new cache home @@ -753,12 +760,14 @@ TEST_F(KeycacheTest, SetGetConfiguredCacheHome) { char *output; rv = scitoken_config_get_str(key.c_str(), &output, &err_msg); ASSERT_TRUE(rv == 0) << err_msg; - EXPECT_EQ(*output, *cache_path); + EXPECT_STREQ(output, temp_cache.path().c_str()); free(output); // Reset cache home to whatever it was before by setting empty config rv = scitoken_config_set_str(key.c_str(), "", &err_msg); ASSERT_TRUE(rv == 0) << err_msg; + + // temp_cache destructor will clean up the directory } TEST_F(KeycacheTest, InvalidConfigKeyTest) { @@ -843,6 +852,109 @@ TEST_F(KeycacheTest, RefreshExpiredTest) { EXPECT_EQ(jwks_str, "{\"keys\": []}"); } +TEST_F(KeycacheTest, NegativeCacheTest) { + // This test verifies that failed issuer lookups are cached as negative + // entries and that subsequent attempts fail quickly with the right counter + char *err_msg = nullptr; + + // Reset monitoring stats for clean baseline + scitoken_reset_monitoring_stats(&err_msg); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Create a token with an issuer that will fail to lookup + std::unique_ptr mykey( + scitoken_key_create("1", "ES256", ec_public, ec_private, &err_msg), + scitoken_key_destroy); + ASSERT_TRUE(mykey.get() != nullptr) << err_msg; + + std::unique_ptr mytoken( + scitoken_create(mykey.get()), scitoken_destroy); + ASSERT_TRUE(mytoken.get() != nullptr); + + // Use a unique issuer that doesn't exist (will fail to fetch keys) + // Include timestamp to avoid interference from previous test runs + std::string invalid_issuer = "https://invalid-issuer-negative-cache-" + + std::to_string(std::time(nullptr)) + + ".example.com"; + auto rv = scitoken_set_claim_string(mytoken.get(), "iss", + invalid_issuer.c_str(), &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + + char *token_value = nullptr; + rv = scitoken_serialize(mytoken.get(), &token_value, &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + std::unique_ptr token_value_ptr(token_value, free); + + // First attempt should fail to fetch keys (DNS failure or connection + // refused). This is a cache MISS (creates negative cache entry). + std::unique_ptr read_token( + scitoken_create(nullptr), scitoken_destroy); + ASSERT_TRUE(read_token.get() != nullptr); + + rv = scitoken_deserialize_v2(token_value, read_token.get(), nullptr, + &err_msg); + ASSERT_FALSE(rv == 0); // Should fail + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Check that a negative cache entry was created (returns empty keys) + char *jwks; + rv = keycache_get_cached_jwks(invalid_issuer.c_str(), &jwks, &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + ASSERT_TRUE(jwks != nullptr); + std::string jwks_str(jwks); + free(jwks); + + // Should return empty keys array (negative cache) + EXPECT_EQ(jwks_str, "{\"keys\": []}"); + + // Second attempt should fail quickly using negative cache + rv = scitoken_deserialize_v2(token_value, read_token.get(), nullptr, + &err_msg); + ASSERT_FALSE(rv == 0); // Should still fail + ASSERT_TRUE(err_msg != nullptr); + std::string error_msg(err_msg); + free(err_msg); + err_msg = nullptr; + + // Error message should indicate it's from negative cache + EXPECT_NE(error_msg.find("negative cache"), std::string::npos) + << "Error message should mention negative cache: " << error_msg; + + // Third attempt to verify counter increments correctly + rv = scitoken_deserialize_v2(token_value, read_token.get(), nullptr, + &err_msg); + ASSERT_FALSE(rv == 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Get monitoring stats and verify negative_cache_hits counter + char *json_out = nullptr; + rv = scitoken_get_monitoring_json(&json_out, &err_msg); + ASSERT_TRUE(rv == 0) << err_msg; + ASSERT_TRUE(json_out != nullptr); + std::string json_str(json_out); + free(json_out); + + // Parse JSON and check negative_cache_hits + // Only the second and third attempts should hit the negative cache: + // - First attempt: creates negative cache (cache miss, not hit) + // - Second and third attempts: hit existing negative cache + EXPECT_NE(json_str.find("\"negative_cache_hits\""), std::string::npos) + << "JSON should contain negative_cache_hits field"; + + // Verify 2 negative cache hits (attempts 2 and 3 only) + EXPECT_NE(json_str.find("\"negative_cache_hits\":2"), std::string::npos) + << "Should have 2 negative cache hits. JSON: " << json_str; +} + class IssuerSecurityTest : public ::testing::Test { protected: void SetUp() override { @@ -1011,21 +1123,26 @@ TEST_F(EnvConfigTest, IntConfigFromEnv) { TEST_F(EnvConfigTest, StringConfigFromEnv) { // Test that we can manually set and get string config values + // Use a secure temp directory instead of hardcoded /tmp path + SecureTempDir temp_cache("env_config_test_"); + ASSERT_TRUE(temp_cache.valid()) << "Failed to create temp directory"; + char *err_msg = nullptr; - const char *test_path = "/tmp/test_cache"; - auto rv = - scitoken_config_set_str("keycache.cache_home", test_path, &err_msg); + auto rv = scitoken_config_set_str("keycache.cache_home", + temp_cache.path().c_str(), &err_msg); ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); char *output = nullptr; rv = scitoken_config_get_str("keycache.cache_home", &output, &err_msg); ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); ASSERT_TRUE(output != nullptr); - EXPECT_STREQ(output, test_path); + EXPECT_STREQ(output, temp_cache.path().c_str()); free(output); if (err_msg) free(err_msg); + + // temp_cache destructor will clean up the directory } // Test for thundering herd prevention with per-issuer locks diff --git a/test/test_utils.h b/test/test_utils.h new file mode 100644 index 0000000..8b13f9f --- /dev/null +++ b/test/test_utils.h @@ -0,0 +1,129 @@ +#ifndef SCITOKENS_TEST_UTILS_H +#define SCITOKENS_TEST_UTILS_H + +#include +#include +#include +#include +#include +#include +#include + +namespace scitokens_test { + +/** + * Helper class to create and manage secure temporary directories. + * Uses mkdtemp for security and cleans up on destruction. + * + * Example usage: + * SecureTempDir temp_dir("my_test_"); + * ASSERT_TRUE(temp_dir.valid()); + * std::string cache_path = temp_dir.path() + "/cache"; + * // ... use the directory ... + * // Directory is automatically cleaned up when temp_dir goes out of scope + */ +class SecureTempDir { + public: + /** + * Create a temp directory under the specified base path. + * @param prefix Prefix for the directory name (default: "scitokens_test_") + * @param base_path Base path for the temp directory. If empty, uses + * BINARY_DIR/tests (from CMake) or falls back to cwd/tests + */ + explicit SecureTempDir(const std::string &prefix = "scitokens_test_", + const std::string &base_path = "") { + std::string base = base_path; + if (base.empty()) { + // Try to use build/tests directory (set by CMake) + const char *binary_dir = std::getenv("BINARY_DIR"); + if (binary_dir) { + base = std::string(binary_dir) + "/tests"; + } else { + // Fallback: use current working directory + tests + char cwd[PATH_MAX]; + if (getcwd(cwd, sizeof(cwd))) { + base = std::string(cwd) + "/tests"; + } else { + base = "/tmp"; // Last resort fallback + } + } + } + + // Ensure base directory exists + mkdir(base.c_str(), 0700); + + // Create template for mkdtemp + std::string tmpl = base + "/" + prefix + "XXXXXX"; + std::vector tmpl_buf(tmpl.begin(), tmpl.end()); + tmpl_buf.push_back('\0'); + + char *result = mkdtemp(tmpl_buf.data()); + if (result) { + path_ = result; + } + } + + ~SecureTempDir() { cleanup(); } + + // Delete copy constructor and assignment + SecureTempDir(const SecureTempDir &) = delete; + SecureTempDir &operator=(const SecureTempDir &) = delete; + + // Allow move + SecureTempDir(SecureTempDir &&other) noexcept + : path_(std::move(other.path_)) { + other.path_.clear(); + } + + SecureTempDir &operator=(SecureTempDir &&other) noexcept { + if (this != &other) { + cleanup(); + path_ = std::move(other.path_); + other.path_.clear(); + } + return *this; + } + + /** Get the path to the temporary directory */ + const std::string &path() const { return path_; } + + /** Check if the directory was created successfully */ + bool valid() const { return !path_.empty(); } + + /** Manually trigger cleanup (also called by destructor) */ + void cleanup() { + if (!path_.empty()) { + remove_directory_recursive(path_); + path_.clear(); + } + } + + private: + std::string path_; + + /** + * Safely remove a directory recursively using fork/execv. + * This prevents shell injection attacks that could occur with system(). + */ + static void remove_directory_recursive(const std::string &path) { + pid_t pid = fork(); + if (pid == 0) { + // Child process: exec rm -rf with path as direct argument + // Using execv prevents any shell interpretation of the path + char *const args[] = {const_cast("rm"), + const_cast("-rf"), + const_cast(path.c_str()), nullptr}; + execv("/bin/rm", args); + _exit(1); // execv failed + } else if (pid > 0) { + // Parent: wait for child to complete + int status; + waitpid(pid, &status, 0); + } + // If fork failed, silently ignore (cleanup is best-effort) + } +}; + +} // namespace scitokens_test + +#endif // SCITOKENS_TEST_UTILS_H