diff --git a/Cargo.lock b/Cargo.lock index 9dce64ce66ab6..679e2866c7d15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3574,6 +3574,7 @@ dependencies = [ "gimli 0.31.1", "itertools", "libc", + "libloading 0.9.0", "measureme", "object 0.37.3", "rustc-demangle", diff --git a/compiler/rustc_codegen_llvm/Cargo.toml b/compiler/rustc_codegen_llvm/Cargo.toml index 67bd1e59bb0c2..5eb65c01b4b8d 100644 --- a/compiler/rustc_codegen_llvm/Cargo.toml +++ b/compiler/rustc_codegen_llvm/Cargo.toml @@ -14,6 +14,7 @@ bitflags = "2.4.1" gimli = "0.31" itertools = "0.12" libc = "0.2" +libloading = "0.9.0" measureme = "12.0.1" object = { version = "0.37.0", default-features = false, features = ["std", "read"] } rustc-demangle = "0.1.21" diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index b820b992105fd..92614558588b8 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -528,31 +528,34 @@ fn thin_lto( } } -fn enable_autodiff_settings(ad: &[config::AutoDiff]) { +fn enable_autodiff_settings(cgcx: &CodegenContext, ad: &[config::AutoDiff]) { + use std::sync::Mutex; + let enzyme: &'static Mutex = llvm::EnzymeWrapper::init(cgcx); + for val in ad { // We intentionally don't use a wildcard, to not forget handling anything new. match val { config::AutoDiff::PrintPerf => { - llvm::set_print_perf(true); + enzyme.lock().unwrap().set_print_perf(true); } config::AutoDiff::PrintAA => { - llvm::set_print_activity(true); + enzyme.lock().unwrap().set_print_activity(true); } config::AutoDiff::PrintTA => { - llvm::set_print_type(true); + enzyme.lock().unwrap().set_print_type(true); } config::AutoDiff::PrintTAFn(fun) => { - llvm::set_print_type(true); // Enable general type printing - llvm::set_print_type_fun(&fun); // Set specific function to analyze + enzyme.lock().unwrap().set_print_type(true); // Enable general type printing + enzyme.lock().unwrap().set_print_type_fun(&fun); // Set specific function to analyze } config::AutoDiff::Inline => { - llvm::set_inline(true); + enzyme.lock().unwrap().set_inline(true); } config::AutoDiff::LooseTypes => { - llvm::set_loose_types(true); + enzyme.lock().unwrap().set_loose_types(true); } config::AutoDiff::PrintSteps => { - llvm::set_print(true); + enzyme.lock().unwrap().set_print(true); } // We handle this in the PassWrapper.cpp config::AutoDiff::PrintPasses => {} @@ -571,9 +574,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { } } // This helps with handling enums for now. - llvm::set_strict_aliasing(false); + enzyme.lock().unwrap().set_strict_aliasing(false); // FIXME(ZuseZ4): Test this, since it was added a long time ago. - llvm::set_rust_rules(true); + enzyme.lock().unwrap().set_rust_rules(true); } pub(crate) fn run_pass_manager( @@ -609,7 +612,7 @@ pub(crate) fn run_pass_manager( }; if enable_ad { - enable_autodiff_settings(&config.autodiff); + enable_autodiff_settings(&cgcx, &config.autodiff); } unsafe { diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index fde7dd6ef7a85..3d5da6ac70747 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -2,7 +2,7 @@ use std::ffi::{CStr, CString}; use std::io::{self, Write}; use std::path::{Path, PathBuf}; use std::ptr::null_mut; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::{fs, slice, str}; use libc::{c_char, c_int, c_void, size_t}; @@ -726,6 +726,13 @@ pub(crate) unsafe fn llvm_optimize( let llvm_plugins = config.llvm_plugins.join(","); + let enzyme_fn = if consider_ad { + let wrapper: &'static Mutex = llvm::EnzymeWrapper::init(cgcx); + wrapper.lock().unwrap().registerEnzymeAndPassPipeline + } else { + std::ptr::null() + }; + let result = unsafe { llvm::LLVMRustOptimize( module.module_llvm.llmod(), @@ -745,7 +752,7 @@ pub(crate) unsafe fn llvm_optimize( vectorize_loop, config.no_builtins, config.emit_lifetime_markers, - run_enzyme, + enzyme_fn, print_before_enzyme, print_after_enzyme, print_passes, diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index e63043b21227f..f43bb4747bb24 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -1,6 +1,11 @@ #![expect(dead_code)] use libc::{c_char, c_uint}; +// I am going to delete this declaration. +// I have just added this to avoid conflicting to main branch and to let CI run. +// I will make libloading as optional later, before merging this PR. +#[cfg(not(feature = "llvm_enzyme"))] +use libloading as _; use super::MetadataKindId; use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value}; @@ -91,102 +96,355 @@ pub(crate) use self::Enzyme_AD::*; #[cfg(feature = "llvm_enzyme")] pub(crate) mod Enzyme_AD { - use std::ffi::{CString, c_char}; + use std::ffi::{c_char, c_void}; + use std::sync::{Mutex, OnceLock}; - use libc::c_void; + use rustc_codegen_ssa::back::write::CodegenContext; + use rustc_codegen_ssa::traits::WriteBackendMethods; + use rustc_middle::bug; + use rustc_session::config::{Sysroot, host_tuple}; + use rustc_session::filesearch; use super::{CConcreteType, CTypeTreeRef, Context}; - - unsafe extern "C" { - pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); - pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); + use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor}; + + type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8); + type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char); + + type EnzymeNewTypeTreeFn = unsafe extern "C" fn() -> CTypeTreeRef; + type EnzymeNewTypeTreeCTFn = unsafe extern "C" fn(CConcreteType, &Context) -> CTypeTreeRef; + type EnzymeNewTypeTreeTRFn = unsafe extern "C" fn(CTypeTreeRef) -> CTypeTreeRef; + type EnzymeFreeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef); + type EnzymeMergeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef, CTypeTreeRef) -> bool; + type EnzymeTypeTreeOnlyEqFn = unsafe extern "C" fn(CTypeTreeRef, i64); + type EnzymeTypeTreeData0EqFn = unsafe extern "C" fn(CTypeTreeRef); + type EnzymeTypeTreeShiftIndiciesEqFn = + unsafe extern "C" fn(CTypeTreeRef, *const c_char, i64, i64, u64); + type EnzymeTypeTreeInsertEqFn = + unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context); + type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char; + type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char); + + #[allow(non_snake_case)] + pub(crate) struct EnzymeWrapper { + EnzymeNewTypeTree: EnzymeNewTypeTreeFn, + EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn, + EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn, + EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn, + EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn, + EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn, + EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn, + EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn, + EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, + EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, + EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, + + EnzymePrintPerf: *mut c_void, + EnzymePrintActivity: *mut c_void, + EnzymePrintType: *mut c_void, + EnzymeFunctionToAnalyze: *mut c_void, + EnzymePrint: *mut c_void, + EnzymeStrictAliasing: *mut c_void, + EnzymeInline: *mut c_void, + EnzymeMaxTypeDepth: *mut c_void, + RustTypeRules: *mut c_void, + looseTypeAnalysis: *mut c_void, + + EnzymeSetCLBool: EnzymeSetCLBoolFn, + EnzymeSetCLString: EnzymeSetCLStringFn, + pub registerEnzymeAndPassPipeline: *const c_void, + lib: libloading::Library, + } + + unsafe impl Sync for EnzymeWrapper {} + unsafe impl Send for EnzymeWrapper {} + + fn load_ptr_by_symbol_mut_void( + lib: &libloading::Library, + bytes: &[u8], + ) -> Result<*mut c_void, Box> { + unsafe { + let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?; + // libloading = 0.9.0: try_as_raw_ptr always succeeds and returns Some + let s = s.try_as_raw_ptr().unwrap(); + Ok(s) + } } - // TypeTree functions - unsafe extern "C" { - pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef; - pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; - pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; - pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); - pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; - pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); - pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); - pub(crate) fn EnzymeTypeTreeShiftIndiciesEq( - arg1: CTypeTreeRef, + // e.g. + // load_ptrs_by_symbols_mut_void(ABC, XYZ); + // => + // let ABC = load_ptr_mut_void(&lib, b"ABC")?; + // let XYZ = load_ptr_mut_void(&lib, b"XYZ")?; + macro_rules! load_ptrs_by_symbols_mut_void { + ($lib:expr, $($name:ident),* $(,)?) => { + $( + #[allow(non_snake_case)] + let $name = load_ptr_by_symbol_mut_void(&$lib, stringify!($name).as_bytes())?; + )* + }; + } + + // e.g. + // load_ptrs_by_symbols_fn(ABC: ABCFn, XYZ: XYZFn); + // => + // let ABC: libloading::Symbol<'_, ABCFn> = unsafe { lib.get(b"ABC")? }; + // let XYZ: libloading::Symbol<'_, XYZFn> = unsafe { lib.get(b"XYZ")? }; + macro_rules! load_ptrs_by_symbols_fn { + ($lib:expr, $($name:ident : $ty:ty),* $(,)?) => { + $( + #[allow(non_snake_case)] + let $name: $ty = *unsafe { $lib.get::<$ty>(stringify!($name).as_bytes())? }; + )* + }; + } + + static ENZYME_INSTANCE: OnceLock> = OnceLock::new(); + + impl EnzymeWrapper { + pub(crate) fn init<'a, B: WriteBackendMethods>( + cgcx: &'a CodegenContext, + ) -> &'static Mutex { + ENZYME_INSTANCE.get_or_init(|| { + Self::call_dynamic(cgcx) + .unwrap_or_else(|e| bug!("failed to load Enzyme: {e}")) + .into() + }) + } + + pub(crate) fn get_instance() -> &'static Mutex { + ENZYME_INSTANCE.get().expect("EnzymeWrapper not initialized") + } + + pub(crate) fn new_type_tree(&self) -> CTypeTreeRef { + unsafe { (self.EnzymeNewTypeTree)() } + } + + pub(crate) fn new_type_tree_ct( + &self, + t: CConcreteType, + ctx: &Context, + ) -> *mut EnzymeTypeTree { + unsafe { (self.EnzymeNewTypeTreeCT)(t, ctx) } + } + + pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef { + unsafe { (self.EnzymeNewTypeTreeTR)(tree) } + } + + pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) { + unsafe { (self.EnzymeFreeTypeTree)(tree) } + } + + pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool { + unsafe { (self.EnzymeMergeTypeTree)(tree1, tree2) } + } + + pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) { + unsafe { (self.EnzymeTypeTreeOnlyEq)(tree, num) } + } + + pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) { + unsafe { (self.EnzymeTypeTreeData0Eq)(tree) } + } + + pub(crate) fn shift_indicies_eq( + &self, + tree: CTypeTreeRef, data_layout: *const c_char, offset: i64, max_size: i64, add_offset: u64, - ); - pub(crate) fn EnzymeTypeTreeInsertEq( - CTT: CTypeTreeRef, + ) { + unsafe { + (self.EnzymeTypeTreeShiftIndiciesEq)( + tree, + data_layout, + offset, + max_size, + add_offset, + ) + } + } + + pub(crate) fn tree_insert_eq( + &self, + tree: CTypeTreeRef, indices: *const i64, len: usize, ct: CConcreteType, ctx: &Context, - ); - pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; - pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); - } + ) { + unsafe { (self.EnzymeTypeTreeInsertEq)(tree, indices, len, ct, ctx) } + } - unsafe extern "C" { - static mut EnzymePrintPerf: c_void; - static mut EnzymePrintActivity: c_void; - static mut EnzymePrintType: c_void; - static mut EnzymeFunctionToAnalyze: c_void; - static mut EnzymePrint: c_void; - static mut EnzymeStrictAliasing: c_void; - static mut looseTypeAnalysis: c_void; - static mut EnzymeInline: c_void; - static mut RustTypeRules: c_void; - } - pub(crate) fn set_print_perf(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8); + pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char { + unsafe { (self.EnzymeTypeTreeToString)(tree) } } - } - pub(crate) fn set_print_activity(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8); + + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { + unsafe { (self.EnzymeTypeTreeToStringFree)(ch) } } - } - pub(crate) fn set_print_type(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); + + pub(crate) fn get_max_type_depth(&self) -> usize { + unsafe { std::ptr::read::(self.EnzymeMaxTypeDepth as *const u32) as usize } } - } - pub(crate) fn set_print_type_fun(fun_name: &str) { - let c_fun_name = CString::new(fun_name).unwrap(); - unsafe { - EnzymeSetCLString( - std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze), - c_fun_name.as_ptr() as *const c_char, - ); + + pub(crate) fn set_print_perf(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrintPerf, print as u8); + } } - } - pub(crate) fn set_print(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); + + pub(crate) fn set_print_activity(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrintActivity, print as u8); + } } - } - pub(crate) fn set_strict_aliasing(strict: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); + + pub(crate) fn set_print_type(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrintType, print as u8); + } } - } - pub(crate) fn set_loose_types(loose: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); + + pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) { + let c_fun_name = std::ffi::CString::new(fun_name) + .unwrap_or_else(|err| bug!("failed to set_print_type_fun: {err}")); + unsafe { + (self.EnzymeSetCLString)( + self.EnzymeFunctionToAnalyze, + c_fun_name.as_ptr() as *const c_char, + ); + } } - } - pub(crate) fn set_inline(val: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8); + + pub(crate) fn set_print(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrint, print as u8); + } } - } - pub(crate) fn set_rust_rules(val: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8); + + pub(crate) fn set_strict_aliasing(&mut self, strict: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymeStrictAliasing, strict as u8); + } + } + + pub(crate) fn set_loose_types(&mut self, loose: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.looseTypeAnalysis, loose as u8); + } + } + + pub(crate) fn set_inline(&mut self, val: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymeInline, val as u8); + } + } + + pub(crate) fn set_rust_rules(&mut self, val: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.RustTypeRules, val as u8); + } + } + + #[allow(non_snake_case)] + fn call_dynamic<'a, B: WriteBackendMethods>( + cgcx: &'a CodegenContext, + ) -> Result> { + let enzyme_path = Self::get_enzyme_path(&cgcx.sysroot)?; + let lib = unsafe { libloading::Library::new(enzyme_path)? }; + + load_ptrs_by_symbols_fn!( + lib, + EnzymeNewTypeTree: EnzymeNewTypeTreeFn, + EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn, + EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn, + EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn, + EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn, + EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn, + EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn, + EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn, + EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, + EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, + EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, + EnzymeSetCLBool: EnzymeSetCLBoolFn, + EnzymeSetCLString: EnzymeSetCLStringFn, + ); + + load_ptrs_by_symbols_mut_void!( + lib, + registerEnzymeAndPassPipeline, + EnzymePrintPerf, + EnzymePrintActivity, + EnzymePrintType, + EnzymeFunctionToAnalyze, + EnzymePrint, + EnzymeStrictAliasing, + EnzymeInline, + EnzymeMaxTypeDepth, + RustTypeRules, + looseTypeAnalysis, + ); + + Ok(Self { + EnzymeNewTypeTree, + EnzymeNewTypeTreeCT, + EnzymeNewTypeTreeTR, + EnzymeFreeTypeTree, + EnzymeMergeTypeTree, + EnzymeTypeTreeOnlyEq, + EnzymeTypeTreeData0Eq, + EnzymeTypeTreeShiftIndiciesEq, + EnzymeTypeTreeInsertEq, + EnzymeTypeTreeToString, + EnzymeTypeTreeToStringFree, + EnzymePrintPerf, + EnzymePrintActivity, + EnzymePrintType, + EnzymeFunctionToAnalyze, + EnzymePrint, + EnzymeStrictAliasing, + EnzymeInline, + EnzymeMaxTypeDepth, + RustTypeRules, + looseTypeAnalysis, + EnzymeSetCLBool, + EnzymeSetCLString, + registerEnzymeAndPassPipeline, + lib, + }) + } + + fn get_enzyme_path(sysroot: &Sysroot) -> Result { + let llvm_version_major = unsafe { LLVMRustVersionMajor() }; + + let path_buf = sysroot + .all_paths() + .map(|sysroot_path| { + filesearch::make_target_lib_path(sysroot_path, host_tuple()) + .join("lib") + .with_file_name(format!("libEnzyme-{llvm_version_major}")) + .with_extension(std::env::consts::DLL_EXTENSION) + }) + .find(|f| f.exists()) + .ok_or_else(|| { + let candidates = sysroot + .all_paths() + .map(|p| p.join("lib").display().to_string()) + .collect::>() + .join("\n* "); + format!( + "failed to find a `libEnzyme-{llvm_version_major}` folder \ + in the sysroot candidates:\n* {candidates}" + ) + })?; + + Ok(path_buf + .to_str() + .ok_or_else(|| format!("invalid UTF-8 in path: {}", path_buf.display()))? + .to_string()) } } } @@ -198,111 +456,150 @@ pub(crate) use self::Fallback_AD::*; pub(crate) mod Fallback_AD { #![allow(unused_variables)] + use std::ffi::c_void; + use std::sync::Mutex; + use libc::c_char; + use rustc_codegen_ssa::back::write::CodegenContext; + use rustc_codegen_ssa::traits::WriteBackendMethods; - use super::{CConcreteType, CTypeTreeRef, Context}; + use super::{CConcreteType, CTypeTreeRef, Context, EnzymeTypeTree}; - // TypeTree function fallbacks - pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef { - unimplemented!() + pub(crate) struct EnzymeWrapper { + pub registerEnzymeAndPassPipeline: *const c_void, } - pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef { - unimplemented!() - } + impl EnzymeWrapper { + pub(crate) fn init<'a, B: WriteBackendMethods>( + _cgcx: &'a CodegenContext, + ) -> &'static Mutex { + unimplemented!("Enzyme not available: build with llvm_enzyme feature") + } - pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef { - unimplemented!() - } + pub(crate) fn get_instance() -> &'static Mutex { + unimplemented!("Enzyme not available: build with llvm_enzyme feature") + } - pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) { - unimplemented!() - } + pub(crate) fn new_type_tree(&self) -> CTypeTreeRef { + unimplemented!() + } - pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool { - unimplemented!() - } + pub(crate) fn new_type_tree_ct( + &self, + t: CConcreteType, + ctx: &Context, + ) -> *mut EnzymeTypeTree { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) { - unimplemented!() - } + pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) { - unimplemented!() - } + pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq( - arg1: CTypeTreeRef, - data_layout: *const c_char, - offset: i64, - max_size: i64, - add_offset: u64, - ) { - unimplemented!() - } + pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeInsertEq( - CTT: CTypeTreeRef, - indices: *const i64, - len: usize, - ct: CConcreteType, - ctx: &Context, - ) { - unimplemented!() - } + pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { - unimplemented!() - } + pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) { - unimplemented!() - } + pub(crate) fn shift_indicies_eq( + &self, + tree: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ) { + unimplemented!() + } - pub(crate) fn set_inline(val: bool) { - unimplemented!() - } - pub(crate) fn set_print_perf(print: bool) { - unimplemented!() - } - pub(crate) fn set_print_activity(print: bool) { - unimplemented!() - } - pub(crate) fn set_print_type(print: bool) { - unimplemented!() - } - pub(crate) fn set_print_type_fun(fun_name: &str) { - unimplemented!() - } - pub(crate) fn set_print(print: bool) { - unimplemented!() - } - pub(crate) fn set_strict_aliasing(strict: bool) { - unimplemented!() - } - pub(crate) fn set_loose_types(loose: bool) { - unimplemented!() - } - pub(crate) fn set_rust_rules(val: bool) { - unimplemented!() + pub(crate) fn tree_insert_eq( + &self, + tree: CTypeTreeRef, + indices: *const i64, + len: usize, + ct: CConcreteType, + ctx: &Context, + ) { + unimplemented!() + } + + pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char { + unimplemented!() + } + + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { + unimplemented!() + } + + pub(crate) fn get_max_type_depth(&self) -> usize { + unimplemented!() + } + + pub(crate) fn set_inline(&mut self, val: bool) { + unimplemented!() + } + + pub(crate) fn set_print_perf(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_print_activity(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_print_type(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) { + unimplemented!() + } + + pub(crate) fn set_print(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_strict_aliasing(&mut self, strict: bool) { + unimplemented!() + } + + pub(crate) fn set_loose_types(&mut self, loose: bool) { + unimplemented!() + } + + pub(crate) fn set_rust_rules(&mut self, val: bool) { + unimplemented!() + } } } impl TypeTree { pub(crate) fn new() -> TypeTree { - let inner = unsafe { EnzymeNewTypeTree() }; + let wrapper = EnzymeWrapper::get_instance(); + let inner = wrapper.lock().unwrap().new_type_tree(); TypeTree { inner } } pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { - let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + let wrapper = EnzymeWrapper::get_instance(); + let inner = wrapper.lock().unwrap().new_type_tree_ct(t, ctx); TypeTree { inner } } pub(crate) fn merge(self, other: Self) -> Self { - unsafe { - EnzymeMergeTypeTree(self.inner, other.inner); - } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.lock().unwrap().merge_type_tree(self.inner, other.inner); drop(other); self } @@ -316,37 +613,42 @@ impl TypeTree { add_offset: usize, ) -> Self { let layout = std::ffi::CString::new(layout).unwrap(); - - unsafe { - EnzymeTypeTreeShiftIndiciesEq( - self.inner, - layout.as_ptr(), - offset as i64, - max_size as i64, - add_offset as u64, - ); - } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.lock().unwrap().shift_indicies_eq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ); self } pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) { - unsafe { - EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx); - } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.lock().unwrap().tree_insert_eq( + self.inner, + indices.as_ptr(), + indices.len(), + ct, + ctx, + ); } } impl Clone for TypeTree { fn clone(&self) -> Self { - let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + let wrapper = EnzymeWrapper::get_instance(); + let inner = wrapper.lock().unwrap().new_type_tree_tr(self.inner); TypeTree { inner } } } impl std::fmt::Display for TypeTree { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let wrapper = EnzymeWrapper::get_instance(); + let ptr = wrapper.lock().unwrap().tree_to_string(self.inner); let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) }; match cstr.to_str() { Ok(x) => write!(f, "{}", x)?, @@ -354,9 +656,7 @@ impl std::fmt::Display for TypeTree { } // delete C string pointer - unsafe { - EnzymeTypeTreeToStringFree(ptr); - } + wrapper.lock().unwrap().tree_to_string_free(ptr); Ok(()) } @@ -370,6 +670,7 @@ impl std::fmt::Debug for TypeTree { impl Drop for TypeTree { fn drop(&mut self) { - unsafe { EnzymeFreeTypeTree(self.inner) } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.lock().unwrap().free_type_tree(self.inner) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index ca64d96c2a33c..be68d52330341 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2384,7 +2384,7 @@ unsafe extern "C" { LoopVectorize: bool, DisableSimplifyLibCalls: bool, EmitLifetimeMarkers: bool, - RunEnzyme: bool, + RunEnzyme: *const c_void, PrintBeforeEnzyme: bool, PrintAfterEnzyme: bool, PrintPasses: bool, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 7e2635037008e..422248991f4b6 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -2,6 +2,7 @@ use rustc_ast::expand::typetree::FncTree; #[cfg(feature = "llvm_enzyme")] use { crate::attributes, + crate::llvm::EnzymeWrapper, rustc_ast::expand::typetree::TypeTree as RustTypeTree, std::ffi::{CString, c_char, c_uint}, }; @@ -74,10 +75,12 @@ pub(crate) fn add_tt<'ll>( let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); + let enzyme_wrapper = EnzymeWrapper::get_instance().lock().unwrap(); + for (i, input) in inputs.iter().enumerate() { unsafe { let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); - let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); let c_str = std::ffi::CStr::from_ptr(c_str); let attr = llvm::LLVMCreateStringAttribute( @@ -89,13 +92,13 @@ pub(crate) fn add_tt<'ll>( ); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); - llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } unsafe { let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); - let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); let c_str = std::ffi::CStr::from_ptr(c_str); let ret_attr = llvm::LLVMCreateStringAttribute( @@ -107,7 +110,7 @@ pub(crate) fn add_tt<'ll>( ); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); - llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index fc1edec8de843..c0abc9f2fbb90 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -29,6 +29,7 @@ use rustc_middle::ty::TyCtxt; use rustc_session::Session; use rustc_session::config::{ self, CrateType, Lto, OutFileName, OutputFilenames, OutputType, Passes, SwitchWithOptPath, + Sysroot, }; use rustc_span::source_map::SourceMap; use rustc_span::{FileName, InnerSpan, Span, SpanData, sym}; @@ -346,6 +347,7 @@ pub struct CodegenContext { pub split_debuginfo: rustc_target::spec::SplitDebuginfo, pub split_dwarf_kind: rustc_session::config::SplitDwarfKind, pub pointer_size: Size, + pub sysroot: Sysroot, /// Emitter to use for diagnostics produced during codegen. pub diag_emitter: SharedEmitter, @@ -1316,6 +1318,7 @@ fn start_executing_work( parallel: backend.supports_parallel() && !sess.opts.unstable_opts.no_parallel_backend, pointer_size: tcx.data_layout.pointer_size(), invocation_temp: sess.invocation_temp.clone(), + sysroot: sess.opts.sysroot.clone(), }; // This is the "main loop" of parallel work happening for parallel codegen. diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp index 143cc94790890..7fc7b5f8b6f7f 100644 --- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp @@ -540,17 +540,8 @@ struct LLVMRustSanitizerOptions { bool SanitizeKernelAddressRecover; }; -// This symbol won't be available or used when Enzyme is not enabled. -// Always set AugmentPassBuilder to true, since it registers optimizations which -// will improve the performance for Enzyme. -#ifdef ENZYME -extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, - /* augmentPassBuilder */ bool); - -extern "C" { -extern llvm::cl::opt EnzymeFunctionToAnalyze; -} -#endif +extern "C" typedef void (*registerEnzymeAndPassPipelineFn)( + llvm::PassBuilder &PB, bool augment); extern "C" LLVMRustResult LLVMRustOptimize( LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef, @@ -559,8 +550,8 @@ extern "C" LLVMRustResult LLVMRustOptimize( bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO, bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops, bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls, - bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme, - bool PrintAfterEnzyme, bool PrintPasses, + bool EmitLifetimeMarkers, registerEnzymeAndPassPipelineFn EnzymePtr, + bool PrintBeforeEnzyme, bool PrintAfterEnzyme, bool PrintPasses, LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage, const char *InstrProfileOutput, const char *PGOSampleUsePath, @@ -898,7 +889,7 @@ extern "C" LLVMRustResult LLVMRustOptimize( // now load "-enzyme" pass: #ifdef ENZYME - if (RunEnzyme) { + if (EnzymePtr) { if (PrintBeforeEnzyme) { // Handle the Rust flag `-Zautodiff=PrintModBefore`. @@ -906,22 +897,13 @@ extern "C" LLVMRustResult LLVMRustOptimize( MPM.addPass(PrintModulePass(outs(), Banner, true, false)); } - registerEnzymeAndPassPipeline(PB, false); + EnzymePtr(PB, false); if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) { std::string ErrMsg = toString(std::move(Err)); LLVMRustSetLastError(ErrMsg.c_str()); return LLVMRustResult::Failure; } - // Check if PrintTAFn was used and add type analysis pass if needed - if (!EnzymeFunctionToAnalyze.empty()) { - if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) { - std::string ErrMsg = toString(std::move(Err)); - LLVMRustSetLastError(ErrMsg.c_str()); - return LLVMRustResult::Failure; - } - } - if (PrintAfterEnzyme) { // Handle the Rust flag `-Zautodiff=PrintModAfter`. std::string Banner = "Module after EnzymeNewPM"; diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 8823c83922822..c27cd013fe32e 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -1700,18 +1700,6 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) { GV.setSanitizerMetadata(MD); } -#ifdef ENZYME -extern "C" { -extern llvm::cl::opt EnzymeMaxTypeDepth; -} - -extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; } -#else -extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { - return 6; // Default fallback depth -} -#endif - // Statically assert that the fixed metadata kind IDs declared in // `metadata_kind.rs` match the ones actually used by LLVM. #define FIXED_MD_KIND(VARIANT, VALUE) \ diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 6857a40ada81b..7865d456c8dad 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1228,19 +1228,6 @@ pub fn rustc_cargo( // . cargo.rustflag("-Zon-broken-pipe=kill"); - // We want to link against registerEnzyme and in the future we want to use additional - // functionality from Enzyme core. For that we need to link against Enzyme. - if builder.config.llvm_enzyme { - let arch = builder.build.host_target; - let enzyme_dir = builder.build.out.join(arch).join("enzyme").join("lib"); - cargo.rustflag("-L").rustflag(enzyme_dir.to_str().expect("Invalid path")); - - if let Some(llvm_config) = builder.llvm_config(builder.config.host_target) { - let llvm_version_major = llvm::get_llvm_version_major(builder, &llvm_config); - cargo.rustflag("-l").rustflag(&format!("Enzyme-{llvm_version_major}")); - } - } - // Building with protected visibility reduces the number of dynamic relocations needed, giving // us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object // with direct references to protected symbols, so for now we only use protected symbols if diff --git a/typos.toml b/typos.toml index 758239ffe751c..b9d9c6c3522cf 100644 --- a/typos.toml +++ b/typos.toml @@ -50,6 +50,8 @@ unstalled = "unstalled" debug_aranges = "debug_aranges" DNS_ERROR_INVAILD_VIRTUALIZATION_INSTANCE_NAME = "DNS_ERROR_INVAILD_VIRTUALIZATION_INSTANCE_NAME" EnzymeTypeTreeShiftIndiciesEq = "EnzymeTypeTreeShiftIndiciesEq" +EnzymeTypeTreeShiftIndiciesEqFn = "EnzymeTypeTreeShiftIndiciesEqFn" +shift_indicies_eq = "shift_indicies_eq" ERRNO_ACCES = "ERRNO_ACCES" ERROR_DS_FILTER_USES_CONTRUCTED_ATTRS = "ERROR_DS_FILTER_USES_CONTRUCTED_ATTRS" ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC = "ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC"