From f8d99d8a96e14b5f27873e23ccf33802260eea62 Mon Sep 17 00:00:00 2001 From: Rob Johnson Date: Wed, 27 May 2026 16:34:46 -0700 Subject: [PATCH 1/2] automatically register threads Signed-off-by: Rob Johnson --- docs/site/content/docs/usage.md | 15 +- docs/site/content/docs/v0.0.1/usage.md | 15 +- docs/usage.md | 15 +- .../platform_linux/public_platform.h | 10 +- include/splinterdb/splinterdb.h | 5 +- src/platform_linux/platform_threads.c | 129 +++++++++++++++--- src/platform_linux/platform_threads.h | 16 +++ src/splinterdb.c | 53 +++++++ tests/unit/splinterdb_forked_child_test.c | 2 +- tests/unit/splinterdb_quick_test.c | 51 ++++++- 10 files changed, 253 insertions(+), 58 deletions(-) diff --git a/docs/site/content/docs/usage.md b/docs/site/content/docs/usage.md index f6f558ca1..ec4238573 100644 --- a/docs/site/content/docs/usage.md +++ b/docs/site/content/docs/usage.md @@ -54,17 +54,12 @@ looks out-of-date, please open an issue or pull request. - The initial thread should be the one to call `splinterdb_close()` when the `splinterdb` instance is no longer needed. - - Threads (other than the initial thread) that will use the `splinterdb` - must be registered before use and unregistered before exiting: + - Threads are registered automatically the first time they call a SplinterDB + public API, and are deregistered automatically when they exit. Internally, + SplinterDB allocates per-thread scratch space at registration time. - - From a non-initial thread, call `splinterdb_register_thread()`. - Internally, SplinterDB will allocate scratch space for use by that thread. - - - To avoid leaking memory, a non-initial thread should call - `splinterdb_deregister_thread()` before it exits. - - - Known issue: In a pinch, non-initial, registered threads may call - `splinterdb_close()`, but their scratch memory would be leaked. + - The low-level thread registration APIs remain available for tests and + callers that want to reserve or release a thread ID explicitly. - Note these rules apply to system threads, not [runtime-managed threading](https://en.wikipedia.org/wiki/Green_threads) available in higher-level languages. diff --git a/docs/site/content/docs/v0.0.1/usage.md b/docs/site/content/docs/v0.0.1/usage.md index e8ffac7c6..a7192df1a 100644 --- a/docs/site/content/docs/v0.0.1/usage.md +++ b/docs/site/content/docs/v0.0.1/usage.md @@ -54,17 +54,12 @@ looks out-of-date, please open an issue or pull request. - The initial thread should be the one to call `splinterdb_close()` when the `splinterdb` instance is no longer needed. - - Threads (other than the initial thread) that will use the `splinterdb` - must be registered before use and unregistered before exiting: + - Threads are registered automatically the first time they call a SplinterDB + public API, and are deregistered automatically when they exit. Internally, + SplinterDB allocates per-thread scratch space at registration time. - - From a non-initial thread, call `splinterdb_register_thread()`. - Internally, SplinterDB will allocate scratch space for use by that thread. - - - To avoid leaking memory, a non-initial thread should call - `splinterdb_deregister_thread()` before it exits. - - - Known issue: In a pinch, non-initial, registered threads may call - `splinterdb_close()`, but their scratch memory would be leaked. + - The low-level thread registration APIs remain available for tests and + callers that want to reserve or release a thread ID explicitly. - Note these rules apply to system threads, not [runtime-managed threading](https://en.wikipedia.org/wiki/Green_threads) available in higher-level languages. diff --git a/docs/usage.md b/docs/usage.md index f6f558ca1..ec4238573 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -54,17 +54,12 @@ looks out-of-date, please open an issue or pull request. - The initial thread should be the one to call `splinterdb_close()` when the `splinterdb` instance is no longer needed. - - Threads (other than the initial thread) that will use the `splinterdb` - must be registered before use and unregistered before exiting: + - Threads are registered automatically the first time they call a SplinterDB + public API, and are deregistered automatically when they exit. Internally, + SplinterDB allocates per-thread scratch space at registration time. - - From a non-initial thread, call `splinterdb_register_thread()`. - Internally, SplinterDB will allocate scratch space for use by that thread. - - - To avoid leaking memory, a non-initial thread should call - `splinterdb_deregister_thread()` before it exits. - - - Known issue: In a pinch, non-initial, registered threads may call - `splinterdb_close()`, but their scratch memory would be leaked. + - The low-level thread registration APIs remain available for tests and + callers that want to reserve or release a thread ID explicitly. - Note these rules apply to system threads, not [runtime-managed threading](https://en.wikipedia.org/wiki/Green_threads) available in higher-level languages. diff --git a/include/splinterdb/platform_linux/public_platform.h b/include/splinterdb/platform_linux/public_platform.h index 015fd951a..b554b75f5 100644 --- a/include/splinterdb/platform_linux/public_platform.h +++ b/include/splinterdb/platform_linux/public_platform.h @@ -82,10 +82,9 @@ platform_set_log_streams(platform_log_handle *info_stream, // Register the current thread so that it can be used with splinterdb. // -// Any thread that uses a splinterdb must first be registered with it. -// -// The only exception is the initial thread which called create or open, -// as that thread is implicitly registered. Re-registering it is an error. +// SplinterDB public APIs register threads automatically on first use, so most +// callers do not need this function. It remains available for lower-level tests +// and callers that want to reserve a thread ID before using lower-level APIs. // // A thread should not be registered more than once; that is an error. // @@ -97,7 +96,8 @@ platform_register_thread(void); // Deregister the current thread. // -// Call this function before exiting a registered thread. +// Registered threads are deregistered automatically when they exit. This +// function remains available for callers that want to release a thread ID early. void platform_deregister_thread(void); diff --git a/include/splinterdb/splinterdb.h b/include/splinterdb/splinterdb.h index 471d0b4e0..c9830f3dc 100644 --- a/include/splinterdb/splinterdb.h +++ b/include/splinterdb/splinterdb.h @@ -9,9 +9,8 @@ * A data_config must be provided at the time of create/open. * See default_data_config.h for a basic reference implementation. * - * Each thread must call splinterdb_register_thread() before making any calls to - * SplinterDB. Each thread must call splinterdb_deregister_thread() before - * exiting. + * Threads are registered with SplinterDB automatically on first use and are + * deregistered automatically when they exit. */ #pragma once diff --git a/src/platform_linux/platform_threads.c b/src/platform_linux/platform_threads.c index 9f076cd18..e6c4e6fcb 100644 --- a/src/platform_linux/platform_threads.c +++ b/src/platform_linux/platform_threads.c @@ -13,6 +13,10 @@ __thread threadid xxxtid = INVALID_TID; threadid xxxpid; pid_t ospid; +static pthread_key_t thread_registration_key; +static pthread_once_t thread_registration_key_once = PTHREAD_ONCE_INIT; +static char thread_registration_key_value; + /**************************************** * Thread ID allocation and management * @@ -53,6 +57,68 @@ typedef struct id_allocator { // processes. static id_allocator *id_alloc = NULL; +static void +thread_registration_after_fork_child(void) +{ + /* + * The child inherits TLS values, but the parent's thread ID still belongs to + * the parent in the shared allocator. Force the child through registration. + */ + xxxtid = INVALID_TID; + xxxpid = INVALID_TID; + ospid = 0; +} + +static void +thread_registration_destructor(void *arg) +{ + (void)arg; + + if (xxxtid != INVALID_TID) { + platform_deregister_thread(); + } +} + +static void +thread_registration_key_init(void) +{ + int ret = + pthread_key_create(&thread_registration_key, + thread_registration_destructor); + platform_assert(ret == 0); + + ret = pthread_atfork(NULL, NULL, thread_registration_after_fork_child); + platform_assert(ret == 0); +} + +static platform_status +thread_registration_cleanup_set(void *value) +{ + int ret; + + ret = pthread_once(&thread_registration_key_once, + thread_registration_key_init); + if (ret != 0) { + return CONST_STATUS(ret); + } + + ret = pthread_setspecific(thread_registration_key, value); + return CONST_STATUS(ret); +} + +static platform_status +thread_registration_cleanup_arm(void) +{ + return thread_registration_cleanup_set(&thread_registration_key_value); +} + +static void +thread_registration_cleanup_disarm(void) +{ + platform_status rc = thread_registration_cleanup_set(NULL); + platform_assert_status_ok(rc); +} + /* * task_init_tid_bitmask() - Initialize the global bitmask of active threads * in the task system structure to indicate that no threads are currently @@ -304,6 +370,7 @@ platform_deregister_thread() platform_assert(tid != INVALID_TID, "Error! Attempt to deregister unregistered thread.\n"); + thread_registration_cleanup_disarm(); deallocate_threadid(tid); decref_xxxpid(); } @@ -311,6 +378,8 @@ platform_deregister_thread() static void thread_registration_cleanup_function(void *arg) { + (void)arg; + if (xxxtid == INVALID_TID) { platform_error_log("Thread registration cleanup function called for " "unregistered thread %lu", @@ -320,6 +389,38 @@ thread_registration_cleanup_function(void *arg) } } +static platform_status +register_thread_common(void) +{ + platform_status status; + threadid thread_tid; + + id_allocator_init_if_needed(); + + status = ensure_xxxpid_is_setup(); + if (!SUCCESS(status)) { + return status; + } + + thread_tid = allocate_threadid(); + // Unavailable threads is a temporary state that could go away. + if (thread_tid == INVALID_TID) { + decref_xxxpid(); + return STATUS_BUSY; + } + + platform_assert(thread_tid < MAX_THREADS); + xxxtid = thread_tid; + + status = thread_registration_cleanup_arm(); + if (!SUCCESS(status)) { + deallocate_threadid(thread_tid); + decref_xxxpid(); + return status; + } + + return STATUS_OK; +} /* * platform_register_thread(): Register this new thread. @@ -330,13 +431,6 @@ thread_registration_cleanup_function(void *arg) int platform_register_thread(void) { - id_allocator_init_if_needed(); - - platform_status status = ensure_xxxpid_is_setup(); - if (!SUCCESS(status)) { - return -1; - } - threadid thread_tid = xxxtid; // Before registration, all SplinterDB threads' tid will be its default @@ -346,17 +440,17 @@ platform_register_thread(void) "registered as thread %lu\n", thread_tid); - thread_tid = allocate_threadid(); - // Unavailable threads is a temporary state that could go away. - if (thread_tid == INVALID_TID) { - decref_xxxpid(); - return -1; - } + return SUCCESS(register_thread_common()) ? 0 : -1; +} - platform_assert(thread_tid < MAX_THREADS); - xxxtid = thread_tid; +platform_status +platform_register_thread_auto(void) +{ + if (xxxtid != INVALID_TID) { + return STATUS_OK; + } - return 0; + return register_thread_common(); } @@ -371,6 +465,8 @@ thread_worker_function(void *arg) thread_invocation *thread_inv = (thread_invocation *)arg; threadid tid = thread_inv - id_alloc->thread_invocations; xxxtid = tid; + platform_status rc = thread_registration_cleanup_arm(); + platform_assert_status_ok(rc); thread_inv->worker(thread_inv->arg); pthread_cleanup_pop(1); return NULL; @@ -398,6 +494,7 @@ platform_thread_create(platform_thread *thread, // so that we can report an error if the threadid allocation fails. threadid tid = allocate_threadid(); if (tid == INVALID_TID) { + decref_xxxpid(); return STATUS_BUSY; } thread_invocation *thread_inv = &id_alloc->thread_invocations[tid]; diff --git a/src/platform_linux/platform_threads.h b/src/platform_linux/platform_threads.h index 9eb06d7f5..c1b4f3dcb 100644 --- a/src/platform_linux/platform_threads.h +++ b/src/platform_linux/platform_threads.h @@ -6,6 +6,7 @@ #include "splinterdb/platform_linux/public_platform.h" #include "platform_status.h" #include "platform_heap.h" +#include "platform_util.h" #include #include @@ -27,6 +28,21 @@ platform_get_tid() return xxxtid; } +platform_status +platform_register_thread_auto(void); + +static inline platform_status +platform_ensure_thread_registered() +{ + extern __thread threadid xxxtid; + + if (LIKELY(xxxtid != INVALID_TID)) { + return STATUS_OK; + } + + return platform_register_thread_auto(); +} + /* This is not part of the platform API. It is used internally to this platform * implementation. Specifically, it is used in laio.c. */ static inline threadid diff --git a/src/splinterdb.c b/src/splinterdb.c index 6a1c55ffa..9df80d85f 100644 --- a/src/splinterdb.c +++ b/src/splinterdb.c @@ -24,6 +24,7 @@ #include "platform_typed_alloc.h" #include "platform_assert.h" #include "platform_units.h" +#include "platform_threads.h" #include "poison.h" const char *BUILD_VERSION = "splinterdb_build_version " GIT_VERSION; @@ -73,6 +74,19 @@ platform_status_to_int(const platform_status status) // IN return status.r; } +static inline int +splinterdb_ensure_thread_registered(void) +{ + return platform_status_to_int(platform_ensure_thread_registered()); +} + +static inline void +splinterdb_assert_thread_registered(void) +{ + platform_status rc = platform_ensure_thread_registered(); + platform_assert_status_ok(rc); +} + static void splinterdb_config_set_defaults(splinterdb_config *cfg) { @@ -255,6 +269,11 @@ splinterdb_create_or_open(const splinterdb_config *kvs_cfg, // IN bool we_created_heap = FALSE; platform_heap_id use_this_heap_id = kvs_cfg->heap_id; + status = platform_ensure_thread_registered(); + if (!SUCCESS(status)) { + return platform_status_to_int(status); + } + // Allocate a shared segment if so requested. For now, we hard-code // the required size big enough to run most tests. Eventually this // has to be calculated here based on other run-time params. @@ -447,6 +466,7 @@ splinterdb_close(splinterdb **kvs_in) // IN { splinterdb *kvs = *kvs_in; platform_assert(kvs != NULL); + splinterdb_assert_thread_registered(); // Print stats if shared memory is enabled. if (kvs->heap_id) { @@ -550,6 +570,11 @@ splinterdb_lookup(splinterdb *kvs, // IN key target = key_create_from_slice(TRUE, user_key); platform_assert(kvs != NULL); + status = platform_ensure_thread_registered(); + if (!SUCCESS(status)) { + return platform_status_to_int(status); + } + status = core_lookup(&kvs->spl, target, _result); return platform_status_to_int(status); } @@ -575,6 +600,11 @@ splinterdb_insert_message(splinterdb *kvs, // IN lookup_result *old_result // IN/OUT ) { + int rc = splinterdb_ensure_thread_registered(); + if (rc != 0) { + return rc; + } + key tuple_key = key_create_from_slice(FALSE, user_key); platform_assert(kvs != NULL); platform_status status = core_insert(&kvs->spl, tuple_key, msg, old_result); @@ -669,6 +699,11 @@ splinterdb_iterator_init_with_bounds(splinterdb *kvs, // IN return platform_status_to_int(STATUS_BAD_PARAM); } + int auto_register_rc = splinterdb_ensure_thread_registered(); + if (auto_register_rc != 0) { + return auto_register_rc; + } + splinterdb_iterator *it = TYPED_MALLOC(kvs->spl.heap_id, it); if (it == NULL) { platform_error_log("TYPED_MALLOC error\n"); @@ -720,6 +755,8 @@ splinterdb_iterator_init_with_bounds(splinterdb *kvs, // IN void splinterdb_iterator_deinit(splinterdb_iterator *iter) { + splinterdb_assert_thread_registered(); + core_range_iterator *range_itor = &(iter->sri); core_range_iterator_deinit(range_itor); @@ -760,6 +797,11 @@ splinterdb_iterator_can_next(splinterdb_iterator *kvi) void splinterdb_iterator_next(splinterdb_iterator *kvi) { + kvi->last_rc = platform_ensure_thread_registered(); + if (!SUCCESS(kvi->last_rc)) { + return; + } + iterator *itor = &(kvi->sri.super); kvi->last_rc = iterator_next(itor); } @@ -767,6 +809,11 @@ splinterdb_iterator_next(splinterdb_iterator *kvi) void splinterdb_iterator_prev(splinterdb_iterator *kvi) { + kvi->last_rc = platform_ensure_thread_registered(); + if (!SUCCESS(kvi->last_rc)) { + return; + } + iterator *itor = &(kvi->sri.super); kvi->last_rc = iterator_prev(itor); } @@ -783,6 +830,8 @@ splinterdb_iterator_get_current(splinterdb_iterator *iter, // IN slice *value // OUT ) { + splinterdb_assert_thread_registered(); + key result_key; message msg; iterator *itor = &(iter->sri.super); @@ -795,18 +844,21 @@ splinterdb_iterator_get_current(splinterdb_iterator *iter, // IN void splinterdb_stats_print_insertion(const splinterdb *kvs) { + splinterdb_assert_thread_registered(); core_print_insertion_stats(Platform_default_log_handle, &kvs->spl); } void splinterdb_stats_print_lookup(splinterdb *kvs) { + splinterdb_assert_thread_registered(); core_print_lookup_stats(Platform_default_log_handle, &kvs->spl); } void splinterdb_stats_reset(splinterdb *kvs) { + splinterdb_assert_thread_registered(); core_reset_stats(&kvs->spl); } @@ -826,6 +878,7 @@ splinterdb_close_print_stats(splinterdb *kvs) void splinterdb_cache_flush(splinterdb *kvs) { + splinterdb_assert_thread_registered(); cache_flush(kvs->spl.cc); } diff --git a/tests/unit/splinterdb_forked_child_test.c b/tests/unit/splinterdb_forked_child_test.c index 42d181fed..1aa78ab50 100644 --- a/tests/unit/splinterdb_forked_child_test.c +++ b/tests/unit/splinterdb_forked_child_test.c @@ -132,6 +132,7 @@ CTEST2(splinterdb_forked_child, test_data_structures_handles) ASSERT_EQUAL(1, platform_get_tid()); // After deregistering w/Splinter, child process is back to invalid value + platform_deregister_thread(); ASSERT_EQUAL(INVALID_TID, platform_get_tid()); } @@ -146,7 +147,6 @@ CTEST2(splinterdb_forked_child, test_data_structures_handles) splinterdb_close(&spl_handle); } else { // Child should not attempt to run the rest of the tests - platform_deregister_thread(); exit(0); } } diff --git a/tests/unit/splinterdb_quick_test.c b/tests/unit/splinterdb_quick_test.c index 9179d15a2..8a1c63333 100644 --- a/tests/unit/splinterdb_quick_test.c +++ b/tests/unit/splinterdb_quick_test.c @@ -27,6 +27,7 @@ #include // Needed for system calls; e.g. free #include #include +#include #include "platform_threads.h" #include "splinterdb/splinterdb.h" @@ -119,11 +120,21 @@ static void assert_lookup_result_matches_slice(const splinterdb_lookup_result *result, slice expected); +static void * +auto_registered_lookup_thread(void *arg); + typedef struct { data_config super; uint64 num_comparisons; } comparison_counting_data_config; +typedef struct { + splinterdb *kvsb; + int rc; + threadid tid_before; + threadid tid_after; +} auto_registration_thread_args; + /* * Global data declaration macro: * @@ -149,14 +160,12 @@ CTEST_DATA(splinterdb_quick) // Optional setup function for suite, called before every test in suite CTEST_SETUP(splinterdb_quick) { - int rc = platform_register_thread(); - ASSERT_EQUAL(0, rc); default_data_config_init(TEST_MAX_KEY_SIZE, &data->default_data_cfg.super); create_default_cfg(&data->cfg, &data->default_data_cfg.super); data->cfg.use_shmem = config_parse_use_shmem(Ctest_argc, (char **)Ctest_argv); - rc = splinterdb_create(&data->cfg, &data->kvsb); + int rc = splinterdb_create(&data->cfg, &data->kvsb); ASSERT_EQUAL(0, rc); ASSERT_TRUE(TEST_MAX_VALUE_SIZE < MAX_INLINE_MESSAGE_SIZE(IO_DEFAULT_PAGE_SIZE)); @@ -238,6 +247,22 @@ CTEST2(splinterdb_quick, test_basic_flow) splinterdb_lookup_result_deinit(&result); } +CTEST2(splinterdb_quick, test_pthread_auto_registration) +{ + pthread_t thread; + auto_registration_thread_args args = {.kvsb = data->kvsb}; + + int rc = pthread_create(&thread, NULL, auto_registered_lookup_thread, &args); + ASSERT_EQUAL(0, rc); + + rc = pthread_join(thread, NULL); + ASSERT_EQUAL(0, rc); + + ASSERT_EQUAL(0, args.rc); + ASSERT_EQUAL(INVALID_TID, args.tid_before); + ASSERT_TRUE(args.tid_after < MAX_THREADS); +} + /* * Basic test case that exercises and validates the basic flow of the * Splinter APIs for key of max-key-length. @@ -1612,6 +1637,26 @@ assert_lookup_result_matches_slice(const splinterdb_lookup_result *result, ASSERT_EQUAL(0, slice_lex_cmp(value, expected)); } +static void * +auto_registered_lookup_thread(void *arg) +{ + auto_registration_thread_args *args = + (auto_registration_thread_args *)arg; + char key_data[] = "thread-key"; + slice user_key = slice_create(sizeof(key_data), key_data); + + splinterdb_lookup_result result; + splinterdb_lookup_result_init( + args->kvsb, &result, SPLINTERDB_LOOKUP_VALUE, 0, NULL); + + args->tid_before = platform_get_tid(); + args->rc = splinterdb_lookup(args->kvsb, user_key, &result); + args->tid_after = platform_get_tid(); + + splinterdb_lookup_result_deinit(&result); + return NULL; +} + /* * Helper function to insert n-keys (num_inserts), using pre-formatted * key and value strings. From c40d98a281bbb9a43e557d772762a7943573fa9e Mon Sep 17 00:00:00 2001 From: Rob Johnson Date: Wed, 27 May 2026 21:54:03 -0700 Subject: [PATCH 2/2] formatting Signed-off-by: Rob Johnson --- include/splinterdb/platform_linux/public_platform.h | 3 ++- src/platform_linux/platform_threads.c | 9 ++++----- tests/unit/splinterdb_quick_test.c | 7 +++---- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/include/splinterdb/platform_linux/public_platform.h b/include/splinterdb/platform_linux/public_platform.h index b554b75f5..f65ca7d61 100644 --- a/include/splinterdb/platform_linux/public_platform.h +++ b/include/splinterdb/platform_linux/public_platform.h @@ -97,7 +97,8 @@ platform_register_thread(void); // Deregister the current thread. // // Registered threads are deregistered automatically when they exit. This -// function remains available for callers that want to release a thread ID early. +// function remains available for callers that want to release a thread ID +// early. void platform_deregister_thread(void); diff --git a/src/platform_linux/platform_threads.c b/src/platform_linux/platform_threads.c index e6c4e6fcb..451d876b9 100644 --- a/src/platform_linux/platform_threads.c +++ b/src/platform_linux/platform_threads.c @@ -82,9 +82,8 @@ thread_registration_destructor(void *arg) static void thread_registration_key_init(void) { - int ret = - pthread_key_create(&thread_registration_key, - thread_registration_destructor); + int ret = pthread_key_create(&thread_registration_key, + thread_registration_destructor); platform_assert(ret == 0); ret = pthread_atfork(NULL, NULL, thread_registration_after_fork_child); @@ -96,8 +95,8 @@ thread_registration_cleanup_set(void *value) { int ret; - ret = pthread_once(&thread_registration_key_once, - thread_registration_key_init); + ret = + pthread_once(&thread_registration_key_once, thread_registration_key_init); if (ret != 0) { return CONST_STATUS(ret); } diff --git a/tests/unit/splinterdb_quick_test.c b/tests/unit/splinterdb_quick_test.c index 8a1c63333..3711a1499 100644 --- a/tests/unit/splinterdb_quick_test.c +++ b/tests/unit/splinterdb_quick_test.c @@ -1640,10 +1640,9 @@ assert_lookup_result_matches_slice(const splinterdb_lookup_result *result, static void * auto_registered_lookup_thread(void *arg) { - auto_registration_thread_args *args = - (auto_registration_thread_args *)arg; - char key_data[] = "thread-key"; - slice user_key = slice_create(sizeof(key_data), key_data); + auto_registration_thread_args *args = (auto_registration_thread_args *)arg; + char key_data[] = "thread-key"; + slice user_key = slice_create(sizeof(key_data), key_data); splinterdb_lookup_result result; splinterdb_lookup_result_init(