Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Couple of refactorings to make adding dynamic linker support easier #111

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/program.rs
Expand Up @@ -90,6 +90,13 @@ pub(super) unsafe extern "C" fn entry(mem: *mut usize) -> ! {
// Compute `argc`, `argv`, and `envp`.
let (argc, argv, envp) = compute_args(mem);

// Before doing anything else, perform dynamic relocations.
#[cfg(all(feature = "experimental-relocate", feature = "origin-start"))]
#[cfg(relocation_model = "pic")]
{
crate::relocate::relocate(envp);
}

// Initialize program state before running any user code.
init_runtime(mem, envp);

Expand Down Expand Up @@ -187,8 +194,7 @@ unsafe fn compute_args(mem: *mut usize) -> (i32, *mut *mut u8, *mut *mut u8) {
(argc, argv, envp)
}

/// Perform dynamic relocation (if enabled), and initialize `origin` and
/// `rustix` runtime state.
/// Initialize `origin` and `rustix` runtime state.
///
/// # Safety
///
Expand All @@ -197,11 +203,6 @@ unsafe fn compute_args(mem: *mut usize) -> (i32, *mut *mut u8, *mut *mut u8) {
#[cfg(feature = "origin-program")]
#[allow(unused_variables)]
unsafe fn init_runtime(mem: *mut usize, envp: *mut *mut u8) {
// Before doing anything else, perform dynamic relocations.
#[cfg(all(feature = "experimental-relocate", feature = "origin-start"))]
#[cfg(relocation_model = "pic")]
crate::relocate::relocate(envp);

// Explicitly initialize `rustix`. This is needed for things like
// `page_size()` to work.
#[cfg(feature = "param")]
Expand Down
180 changes: 65 additions & 115 deletions src/thread/linux_raw.rs
Expand Up @@ -84,7 +84,9 @@ struct ThreadData {
stack_addr: *mut c_void,
stack_size: usize,
guard_size: usize,
map_size: usize,
stack_map_size: usize,
tls_addr: *mut u8,
tls_map_size: usize,
return_value: AtomicPtr<c_void>,

// Support a few dtors before using dynamic allocation.
Expand All @@ -100,21 +102,22 @@ const ABANDONED: u8 = 2;
impl ThreadData {
#[inline]
fn new(
tid: Option<ThreadId>,
stack_addr: *mut c_void,
stack_size: usize,
guard_size: usize,
map_size: usize,
stack_map_size: usize,
) -> Self {
Self {
thread_id: AtomicI32::new(ThreadId::as_raw(tid)),
thread_id: AtomicI32::new(0),
#[cfg(feature = "unstable-errno")]
errno_val: Cell::new(0),
detached: AtomicU8::new(INITIAL),
stack_addr,
stack_size,
guard_size,
map_size,
stack_map_size,
tls_addr: null_mut(),
tls_map_size: 0,
return_value: AtomicPtr::new(null_mut()),
#[cfg(feature = "alloc")]
dtors: smallvec::SmallVec::new(),
Expand Down Expand Up @@ -336,8 +339,26 @@ pub(super) unsafe fn initialize_main(mem: *mut c_void) {
let stack_least = stack_base.cast::<u8>().sub(stack_map_size);
let stack_size = stack_least.offset_from(mem.cast::<u8>()) as usize;
let guard_size = page_size();
let map_size = 0;

// Initialize the canary value from the OS-provided random bytes.
let random_ptr = rustix::runtime::random().cast::<usize>();
let canary = random_ptr.read_unaligned();
__stack_chk_guard = canary;

let (newtls, metadata) = allocate_tls(
canary,
ThreadData::new(stack_least.cast(), stack_size, guard_size, 0),
);

let thread_id_ptr = (*metadata).thread.thread_id.as_ptr();
let tid = rustix::runtime::set_tid_address(thread_id_ptr.cast());
*thread_id_ptr = tid.as_raw_nonzero().get();

// Point the platform thread-pointer register at the new thread metadata.
set_thread_pointer(newtls);
}

unsafe fn allocate_tls(canary: usize, thread: ThreadData) -> (*mut c_void, *mut Metadata) {
// Compute relevant alignments.
let tls_data_align = STARTUP_TLS_INFO.align;
let header_align = align_of::<Metadata>();
Expand Down Expand Up @@ -390,14 +411,6 @@ pub(super) unsafe fn initialize_main(mem: *mut c_void) {
let metadata: *mut Metadata = new.add(header).cast();
let newtls: *mut c_void = (*metadata).abi.thread_pointee.as_mut_ptr().cast();

let thread_id_ptr = (*metadata).thread.thread_id.as_ptr();
let tid = rustix::runtime::set_tid_address(thread_id_ptr.cast());

// Initialize the canary value from the OS-provided random bytes.
let random_ptr = rustix::runtime::random().cast::<usize>();
let canary = random_ptr.read_unaligned();
__stack_chk_guard = canary;

// Initialize the thread metadata.
metadata.write(Metadata {
abi: Abi {
Expand All @@ -408,15 +421,12 @@ pub(super) unsafe fn initialize_main(mem: *mut c_void) {
_pad: Default::default(),
thread_pointee: [],
},
thread: ThreadData::new(
Some(tid),
stack_least.cast(),
stack_size,
guard_size,
map_size,
),
thread,
});

(*metadata).thread.tls_addr = new;
(*metadata).thread.tls_map_size = alloc_size;

// Initialize the TLS data with explicit initializer data.
slice::from_raw_parts_mut(tls_data, STARTUP_TLS_INFO.file_size).copy_from_slice(
slice::from_raw_parts(
Expand All @@ -425,15 +435,12 @@ pub(super) unsafe fn initialize_main(mem: *mut c_void) {
),
);

// Initialize the TLS data beyond `file_size` which is zero-filled.
slice::from_raw_parts_mut(
tls_data.add(STARTUP_TLS_INFO.file_size),
STARTUP_TLS_INFO.mem_size - STARTUP_TLS_INFO.file_size,
)
.fill(0);
// The TLS region includes additional data beyond `file_size` which is
// expected to be zero-initialized, but we don't need to do anything
// here since we allocated the memory with `mmap_anonymous` so it's
// already zeroed.

// Point the platform thread-pointer register at the new thread metadata.
set_thread_pointer(newtls);
(newtls, metadata)
}

/// Creates a new thread.
Expand All @@ -452,19 +459,10 @@ pub unsafe fn create(
stack_size: usize,
guard_size: usize,
) -> io::Result<Thread> {
// SAFETY: `STARTUP_TLS_INFO` is initialized at program startup before
// we come here creating new threads.
let (startup_tls_align, startup_tls_mem_size) =
unsafe { (STARTUP_TLS_INFO.align, STARTUP_TLS_INFO.mem_size) };

// Compute relevant alignments.
let tls_data_align = startup_tls_align;
let page_align = page_size();
let stack_align = 16;
let header_align = align_of::<Metadata>();
let metadata_align = max(tls_data_align, header_align);
let stack_metadata_align = max(stack_align, metadata_align);
debug_assert!(stack_metadata_align <= page_align);
debug_assert!(stack_align <= page_align);

// Compute the `mmap` size.
let mut map_size = 0;
Expand All @@ -473,37 +471,10 @@ pub unsafe fn create(

let stack_bottom = map_size;

map_size += round_up(stack_size, stack_metadata_align);
map_size += round_up(stack_size, stack_align);

let stack_top = map_size;

// Variant II: TLS data goes below the TCB.
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
let tls_data_bottom = map_size;

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
map_size += round_up(startup_tls_mem_size, tls_data_align);
}

let header = map_size;

map_size += size_of::<Metadata>();

// Variant I: TLS data goes above the TCB.
#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
{
map_size = round_up(map_size, tls_data_align);
}

#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
let tls_data_bottom = map_size;

#[cfg(any(target_arch = "aarch64", target_arch = "arm", target_arch = "riscv64"))]
{
map_size += round_up(startup_tls_mem_size, tls_data_align);
}

// Now we'll `mmap` the memory, initialize it, and create the OS thread.
unsafe {
// Allocate address space for the thread, including guard pages.
Expand All @@ -515,8 +486,8 @@ pub unsafe fn create(
)?
.cast::<u8>();

// Make the thread metadata and stack readable and writeable, leaving
// the guard region inaccessible.
// Make the stack readable and writeable, leaving the guard region
// inaccessible.
mprotect(
map.add(stack_bottom).cast(),
map_size - stack_bottom,
Expand All @@ -527,38 +498,12 @@ pub unsafe fn create(
let stack = map.add(stack_top);
let stack_least = map.add(stack_bottom);

let tls_data = map.add(tls_data_bottom);
let metadata: *mut Metadata = map.add(header).cast();
let newtls: *mut c_void = (*metadata).abi.thread_pointee.as_mut_ptr().cast();

// Copy the current thread's canary to the new thread.
let canary = (*current_metadata()).abi.canary;

// Initialize the thread metadata.
metadata.write(Metadata {
abi: Abi {
canary,
dtv: null(),
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
this: newtls,
_pad: Default::default(),
thread_pointee: [],
},
thread: ThreadData::new(
None, // the real tid will be written by `clone`.
stack_least.cast(),
stack_size,
guard_size,
map_size,
),
});

// Initialize the TLS data with explicit initializer data.
slice::from_raw_parts_mut(tls_data, STARTUP_TLS_INFO.file_size).copy_from_slice(
slice::from_raw_parts(
STARTUP_TLS_INFO.addr.cast::<u8>(),
STARTUP_TLS_INFO.file_size,
),
let (newtls, metadata) = allocate_tls(
canary,
ThreadData::new(stack_least.cast(), stack_size, guard_size, map_size),
);

// Allocate space for the thread arguments on the child's stack.
Expand All @@ -570,11 +515,6 @@ pub unsafe fn create(
// Store the thread arguments on the child's stack.
copy_nonoverlapping(args.as_ptr(), stack, args.len());

// The TLS region includes additional data beyond `file_size` which is
// expected to be zero-initialized, but we don't need to do anything
// here since we allocated the memory with `mmap_anonymous` so it's
// already zeroed.

// Create the OS thread. In Linux, this is a process that shares much
// of its state with the current process. We also pass additional
// flags:
Expand Down Expand Up @@ -725,9 +665,11 @@ unsafe fn exit(return_value: Option<NonNull<c_void>>) -> ! {
// all the fields that we'll need before freeing it.
#[cfg(feature = "log")]
let current_thread_id = current.0.as_ref().thread_id.load(SeqCst);
let current_map_size = current.0.as_ref().map_size;
let current_stack_map_size = current.0.as_ref().stack_map_size;
let current_stack_addr = current.0.as_ref().stack_addr;
let current_guard_size = current.0.as_ref().guard_size;
let current_tls_addr = current.0.as_ref().tls_addr;
let current_tls_map_size = current.0.as_ref().tls_map_size;

#[cfg(feature = "log")]
log::trace!("Thread[{:?}] exiting as detached", current_thread_id);
Expand All @@ -736,13 +678,16 @@ unsafe fn exit(return_value: Option<NonNull<c_void>>) -> ! {
// Deallocate the `ThreadData`.
drop_in_place(current.0.as_ptr());

// Free the thread's `mmap` region, if we allocated it.
let map_size = current_map_size;
if map_size != 0 {
// Null out the tid address so that the kernel doesn't write to
// memory that we've freed trying to clear our tid when we exit.
let _ = set_tid_address(null_mut());
// Null out the tid address so that the kernel doesn't write to
// memory that we've freed trying to clear our tid when we exit.
let _ = set_tid_address(null_mut());

// Free the thread's TLS data.
rustix::mm::munmap(current_tls_addr.cast(), current_tls_map_size).unwrap();

// Free the thread's stack, if we allocated it.
let map_size = current_stack_map_size;
if current_stack_map_size != 0 {
// `munmap` the memory, which also frees the stack we're currently
// on, and do an `exit` carefully without touching the stack.
let map = current_stack_addr.cast::<u8>().sub(current_guard_size);
Expand Down Expand Up @@ -923,17 +868,22 @@ unsafe fn free_memory(thread: Thread) {

// The thread was detached. Prepare to free the memory. First read out
// all the fields that we'll need before freeing it.
let map_size = thread.0.as_ref().map_size;
let stack_map_size = thread.0.as_ref().stack_map_size;
let stack_addr = thread.0.as_ref().stack_addr;
let guard_size = thread.0.as_ref().guard_size;
let tls_addr = thread.0.as_ref().tls_addr;
let tls_map_size = thread.0.as_ref().tls_map_size;

// Deallocate the `ThreadData`.
drop_in_place(thread.0.as_ptr());

// Free the thread's `mmap` region, if we allocated it.
if map_size != 0 {
let map = stack_addr.cast::<u8>().sub(guard_size);
munmap(map.cast(), map_size).unwrap();
// Free the thread's TLS data.
munmap(tls_addr.cast(), tls_map_size).unwrap();

// Free the thread's stack, if we allocated it.
if stack_map_size != 0 {
let stack_map = stack_addr.cast::<u8>().sub(guard_size);
munmap(stack_map.cast(), stack_map_size).unwrap();
}
}

Expand Down