diff --git a/src/driver/jit.rs b/src/driver/jit.rs index 3d8f837a66c8a..39a39e764cb6d 100644 --- a/src/driver/jit.rs +++ b/src/driver/jit.rs @@ -37,6 +37,7 @@ enum UnsafeMessage { /// this message is sent. JitFn { instance_ptr: *const Instance<'static>, + trampoline_ptr: *const u8, tx: mpsc::Sender<*const u8>, }, } @@ -192,8 +193,8 @@ pub(crate) fn run_jit(tcx: TyCtxt<'_>, backend_config: BackendConfig) -> ! { loop { match rx.recv().unwrap() { // lazy JIT compilation request - compile requested instance and return pointer to result - UnsafeMessage::JitFn { instance_ptr, tx } => { - tx.send(jit_fn(instance_ptr)) + UnsafeMessage::JitFn { instance_ptr, trampoline_ptr, tx } => { + tx.send(jit_fn(instance_ptr, trampoline_ptr)) .expect("jitted runtime hung up before response to lazy JIT request was sent"); } } @@ -201,10 +202,10 @@ pub(crate) fn run_jit(tcx: TyCtxt<'_>, backend_config: BackendConfig) -> ! { } #[no_mangle] -extern "C" fn __clif_jit_fn(instance_ptr: *const Instance<'static>) -> *const u8 { +extern "C" fn __clif_jit_fn(instance_ptr: *const Instance<'static>, trampoline_ptr: *const u8) -> *const u8 { // send the JIT request to the rustc thread, with a channel for the response let (tx, rx) = mpsc::channel(); - UnsafeMessage::JitFn { instance_ptr, tx } + UnsafeMessage::JitFn { instance_ptr, trampoline_ptr, tx } .send() .expect("rustc thread hung up before lazy JIT request was sent"); @@ -213,7 +214,7 @@ extern "C" fn __clif_jit_fn(instance_ptr: *const Instance<'static>) -> *const u8 .expect("rustc thread hung up before responding to sent lazy JIT request") } -fn jit_fn(instance_ptr: *const Instance<'static>) -> *const u8 { +fn jit_fn(instance_ptr: *const Instance<'static>, trampoline_ptr: *const u8) -> *const u8 { rustc_middle::ty::tls::with(|tcx| { // lift is used to ensure the correct lifetime for instance. let instance = tcx.lift(unsafe { *instance_ptr }).unwrap(); @@ -227,6 +228,17 @@ fn jit_fn(instance_ptr: *const Instance<'static>) -> *const u8 { let name = tcx.symbol_name(instance).name; let sig = crate::abi::get_function_sig(tcx, jit_module.isa().triple(), instance); let func_id = jit_module.declare_function(name, Linkage::Export, &sig).unwrap(); + + let current_ptr = jit_module.read_got_entry(func_id); + + // If the function's GOT entry has already been updated to point at something other + // than the shim trampoline, don't re-jit but just return the new pointer instead. + // This does not need synchronization as this code is executed only by a sole rustc + // thread. + if current_ptr != trampoline_ptr { + return current_ptr; + } + jit_module.prepare_for_function_redefine(func_id).unwrap(); let mut cx = crate::CodegenCx::new(tcx, backend_config, jit_module.isa(), false); @@ -321,7 +333,7 @@ fn codegen_shim<'tcx>(cx: &mut CodegenCx<'tcx>, module: &mut JITModule, inst: In Linkage::Import, &Signature { call_conv: module.target_config().default_call_conv, - params: vec![AbiParam::new(pointer_type)], + params: vec![AbiParam::new(pointer_type), AbiParam::new(pointer_type)], returns: vec![AbiParam::new(pointer_type)], }, ) @@ -334,6 +346,7 @@ fn codegen_shim<'tcx>(cx: &mut CodegenCx<'tcx>, module: &mut JITModule, inst: In let mut builder_ctx = FunctionBuilderContext::new(); let mut trampoline_builder = FunctionBuilder::new(trampoline, &mut builder_ctx); + let trampoline_fn = module.declare_func_in_func(func_id, trampoline_builder.func); let jit_fn = module.declare_func_in_func(jit_fn, trampoline_builder.func); let sig_ref = trampoline_builder.func.import_signature(sig); @@ -343,7 +356,8 @@ fn codegen_shim<'tcx>(cx: &mut CodegenCx<'tcx>, module: &mut JITModule, inst: In trampoline_builder.switch_to_block(entry_block); let instance_ptr = trampoline_builder.ins().iconst(pointer_type, instance_ptr as u64 as i64); - let jitted_fn = trampoline_builder.ins().call(jit_fn, &[instance_ptr]); + let trampoline_ptr = trampoline_builder.ins().func_addr(pointer_type, trampoline_fn); + let jitted_fn = trampoline_builder.ins().call(jit_fn, &[instance_ptr, trampoline_ptr]); let jitted_fn = trampoline_builder.func.dfg.inst_results(jitted_fn)[0]; let call_inst = trampoline_builder.ins().call_indirect(sig_ref, jitted_fn, &fn_args); let ret_vals = trampoline_builder.func.dfg.inst_results(call_inst).to_vec();