From f12ac2ba5c8e49e3c5be556a764ae7d66eecca31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 19 Oct 2025 12:36:23 +0200 Subject: [PATCH 1/2] first definition of `offload` intrinsic (dirty code) --- .../src/builder/gpu_offload.rs | 138 ++++++++++++------ compiler/rustc_codegen_llvm/src/intrinsic.rs | 71 +++++++++ compiler/rustc_codegen_llvm/src/lib.rs | 2 + .../rustc_hir_analysis/src/check/intrinsic.rs | 2 + compiler/rustc_span/src/symbol.rs | 1 + library/core/src/intrinsics/mod.rs | 4 + .../gpu_offload/offload_intrinsic.rs | 37 +++++ 7 files changed, 210 insertions(+), 45 deletions(-) create mode 100644 tests/codegen-llvm/gpu_offload/offload_intrinsic.rs diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 3d55064ea1304..1fef04a2c9c14 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -4,17 +4,18 @@ use llvm::Linkage::*; use rustc_abi::Align; use rustc_codegen_ssa::back::write::CodegenContext; use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; use crate::builder::SBuilder; -use crate::common::AsCCharPtr; use crate::llvm::AttributePlace::Function; -use crate::llvm::{self, Linkage, Type, Value}; +use crate::llvm::{self, BasicBlock, Linkage, Type, Value}; use crate::{LlvmCodegenBackend, SimpleCx, attributes}; pub(crate) fn handle_gpu_code<'ll>( _cgcx: &CodegenContext, - cx: &'ll SimpleCx<'_>, + _cx: &'ll SimpleCx<'_>, ) { + /* // The offload memory transfer type for each kernel let mut memtransfer_types = vec![]; let mut region_ids = vec![]; @@ -29,6 +30,7 @@ pub(crate) fn handle_gpu_code<'ll>( } gen_call_handling(&cx, &memtransfer_types, ®ion_ids); + */ } // ; Function Attrs: nounwind @@ -76,7 +78,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { at_one } -struct TgtOffloadEntry { +pub(crate) struct TgtOffloadEntry { // uint64_t Reserved; // uint16_t Version; // uint16_t Kind; @@ -253,11 +255,14 @@ pub(crate) fn add_global<'ll>( // This function returns a memtransfer value which encodes how arguments to this kernel shall be // mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be // concatenated into the list of region_ids. -fn gen_define_handling<'ll>( - cx: &'ll SimpleCx<'_>, +pub(crate) fn gen_define_handling<'ll, 'tcx>( + cx: &SimpleCx<'ll>, + tcx: TyCtxt<'tcx>, kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, - num: i64, + // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need + tt: Vec>, + symbol: &str, ) -> (&'ll llvm::Value, &'ll llvm::Value) { let types = cx.func_params_types(cx.get_type_of_global(kernel)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or @@ -267,11 +272,21 @@ fn gen_define_handling<'ll>( .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)) .count(); + // TODO(Sa4dUs): Add typetrees here + let ptr_sizes = types + .iter() + .zip(tt) + .filter_map(|(&x, ty)| match cx.type_kind(x) { + rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)), + _ => None, + }) + .collect::>(); + // We do not know their size anymore at this level, so hardcode a placeholder. // A follow-up pr will track these from the frontend, where we still have Rust types. // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes. // I decided that 1024 bytes is a great placeholder value for now. - add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]); + add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes); // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2), // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten @@ -279,25 +294,28 @@ fn gen_define_handling<'ll>( // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). let memtransfer_types = add_priv_unnamed_arr( &cx, - &format!(".offload_maptypes.{num}"), + &format!(".offload_maptypes.{symbol}"), &vec![1 + 2 + 32; num_ptr_types], ); + // Next: For each function, generate these three entries. A weak constant, // the llvm.rodata entry name, and the llvm_offload_entries value - let name = format!(".kernel_{num}.region_id"); + let name = format!(".{symbol}.region_id"); let initializer = cx.get_const_i8(0); let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage); - let c_entry_name = CString::new(format!("kernel_{num}")).unwrap(); + let c_entry_name = CString::new(symbol).unwrap(); let c_val = c_entry_name.as_bytes_with_nul(); - let offload_entry_name = format!(".offloading.entry_name.{num}"); + let offload_entry_name = format!(".offloading.entry_name.{symbol}"); let initializer = crate::common::bytes_in_context(cx.llcx, c_val); let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); llvm::set_alignment(llglobal, Align::ONE); llvm::set_section(llglobal, c".llvm.rodata.offloading"); - let name = format!(".offloading.entry.kernel_{num}"); + + // Not actively used yet, for calling real kernels + let name = format!(".offloading.entry.{symbol}"); // See the __tgt_offload_entry documentation above. let elems = TgtOffloadEntry::new(&cx, region_id, llglobal); @@ -314,7 +332,57 @@ fn gen_define_handling<'ll>( (memtransfer_types, region_id) } -pub(crate) fn declare_offload_fn<'ll>( +// TODO(Sa4dUs): move this to a proper place +fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { + match ty.kind() { + /* + rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Char => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Str => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(), + */ + ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), + /* + rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Never => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(), + */ + _ => { + tcx + // TODO(Sa4dUs): Maybe `.as_query_input()`? + .layout_of(PseudoCanonicalInput { + typing_env: TypingEnv::fully_monomorphized(), + value: ty, + }) + .unwrap() + .size + .bytes() + } + } +} + +fn declare_offload_fn<'ll>( cx: &'ll SimpleCx<'_>, name: &str, ty: &'ll llvm::Type, @@ -349,10 +417,13 @@ pub(crate) fn declare_offload_fn<'ll>( // 4. set insert point after kernel call. // 5. generate all the GEPS and stores, to be used in 6) // 6. generate __tgt_target_data_end calls to move data from the GPU -fn gen_call_handling<'ll>( - cx: &'ll SimpleCx<'_>, +pub(crate) fn gen_call_handling<'ll>( + cx: &SimpleCx<'ll>, + bb: &BasicBlock, + kernels: &[&'ll llvm::Value], memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], + llfn: &'ll Value, ) { let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } @@ -365,27 +436,14 @@ fn gen_call_handling<'ll>( let tgt_kernel_decl = KernelArgsTy::new_decl(&cx); let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); - let main_fn = cx.get_function("main"); - let Some(main_fn) = main_fn else { return }; - let kernel_name = "kernel_1"; - let call = unsafe { - llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len()) - }; - let Some(kernel_call) = call else { - return; - }; - let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) }; - let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() }; - let mut builder = SBuilder::build(cx, kernel_call_bb); - - let types = cx.func_params_types(cx.get_type_of_global(called)); + let mut builder = SBuilder::build(cx, bb); + + let types = cx.func_params_types(cx.get_type_of_global(kernels[0])); let num_args = types.len() as u64; // Step 0) // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } // %6 = alloca %struct.__tgt_bin_desc, align 8 - unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) }; - let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc"); let ty = cx.type_array(cx.type_ptr(), num_args); @@ -401,15 +459,14 @@ fn gen_call_handling<'ll>( let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args"); // Step 1) - unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); // Now we allocate once per function param, a copy to be passed to one of our maps. let mut vals = vec![]; let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); - for index in 0..types.len() { - let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() }; + for index in 0..num_args { + let v = unsafe { llvm::LLVMGetParam(llfn, index as u32) }; let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); vals.push(v); geps.push(gep); @@ -501,13 +558,8 @@ fn gen_call_handling<'ll>( region_ids[0], a5, ]; - let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None); + builder.call(tgt_target_kernel_ty, tgt_decl, &args, None); // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) - unsafe { - let next = llvm::LLVMGetNextInstruction(offload_success).unwrap(); - llvm::LLVMRustPositionAfter(builder.llbuilder, next); - llvm::LLVMInstructionEraseFromParent(next); - } // Step 4) let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); @@ -516,8 +568,4 @@ fn gen_call_handling<'ll>( builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); drop(builder); - // FIXME(offload) The issue is that we right now add a call to the gpu version of the function, - // and then delete the call to the CPU version. In the future, we should use an intrinsic which - // directly resolves to a call to the GPU version. - unsafe { llvm::LLVMDeleteFunction(called) }; } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index c12383f19312d..01a8cee048dc3 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -23,6 +23,7 @@ use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; +use crate::builder::gpu_offload::TgtOffloadEntry; use crate::context::CodegenCx; use crate::errors::AutoDiffWithoutEnable; use crate::llvm::{self, Metadata, Type, Value}; @@ -195,6 +196,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { codegen_autodiff(self, tcx, instance, args, result); return Ok(()); } + sym::offload => { + codegen_offload(self, tcx, instance, args, result); + return Ok(()); + } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { self.call_intrinsic( @@ -1227,6 +1232,72 @@ fn codegen_autodiff<'ll, 'tcx>( ); } +fn codegen_offload<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + _args: &[OperandRef<'tcx, &'ll Value>], + _result: PlaceRef<'tcx, &'ll Value>, +) { + let cx = bx.cx; + let fn_args = instance.args; + + let (target_id, target_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, params) => (def_id, params), + _ => bug!("invalid offload intrinsic arg"), + }; + + let fn_target = match Instance::try_resolve(tcx, cx.typing_env(), *target_id, target_args) { + Ok(Some(instance)) => instance, + Ok(None) => bug!( + "could not resolve ({:?}, {:?}) to a specific offload instance", + target_id, + target_args + ), + Err(_) => { + // An error has already been emitted + return; + } + }; + + // TODO(Sa4dUs): Will need typetrees + let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE); + let Some(kernel) = cx.get_function(&target_symbol) else { + bug!("could not find target function") + }; + + let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); + + // Build TypeTree (or something similar) + let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder(); + let inputs = sig.inputs(); + + // TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory + let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling( + cx, + tcx, + kernel, + offload_entry_ty, + inputs.to_vec(), + &target_symbol, + ); + + let kernels = &[kernel]; + + let llfn = bx.llfn(); + + // TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix + let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) }; + crate::builder::gpu_offload::gen_call_handling( + cx, + bb, + kernels, + &[memtransfer_type], + &[region_id], + llfn, + ); +} + fn get_args_from_tuple<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, tuple_op: OperandRef<'tcx, &'ll Value>, diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 982d5cd3ac418..84f7440c29593 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -4,6 +4,8 @@ //! //! This API is completely unstable and subject to change. +// TODO(Sa4dUs): remove this once we have a great version, just to ignore unused LLVM wrappers +#![allow(unused)] // tidy-alphabetical-start #![allow(internal_features)] #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")] diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index a6659912e3fb9..57714771e9a9e 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -163,6 +163,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::minnumf128 | sym::mul_with_overflow | sym::needs_drop + | sym::offload | sym::powf16 | sym::powf32 | sym::powf64 @@ -310,6 +311,7 @@ pub(crate) fn check_intrinsic_type( let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity(); (0, 0, vec![type_id, type_id], tcx.types.bool) } + sym::offload => (2, 0, vec![param(0)], param(1)), sym::offset => (2, 0, vec![param(0), param(1)], param(0)), sym::arith_offset => ( 1, diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index ef72c478951b8..53e4850d0f4f9 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1558,6 +1558,7 @@ symbols! { object_safe_for_dispatch, of, off, + offload, offset, offset_of, offset_of_enum, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index c397e762d5589..585c4abbc673c 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3262,6 +3262,10 @@ pub const fn copysignf128(x: f128, y: f128) -> f128; #[rustc_intrinsic] pub const fn autodiff(f: F, df: G, args: T) -> R; +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn offload(f: F) -> R; + /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] #[rustc_allow_const_fn_unstable(const_eval_select)] diff --git a/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs b/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs new file mode 100644 index 0000000000000..739186abc4f45 --- /dev/null +++ b/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs @@ -0,0 +1,37 @@ +//@ compile-flags: -Zoffload=Enable -Zunstable-options -C opt-level=0 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test is verifying that we generate __tgt_target_data_*_mapper before and after a call to the +// kernel_1. Better documentation to what each global or variable means is available in the gpu +// offlaod code, or the LLVM offload documentation. This code does not launch any GPU kernels yet, +// and will be rewritten once a proper offload frontend has landed. +// +// We currently only handle memory transfer for specific calls to functions named `kernel_{num}`, +// when inside of a function called main. This, too, is a temporary workaround for not having a +// frontend. + +// CHECK: ; +#![feature(core_intrinsics)] +#![no_main] + +#[unsafe(no_mangle)] +fn main() { + let mut x = [3.0; 256]; + kernel(&mut x); + core::hint::black_box(&x); +} + +#[unsafe(no_mangle)] +#[inline(never)] +pub fn kernel(x: &mut [f32; 256]) { + core::intrinsics::offload(_kernel) +} + +#[unsafe(no_mangle)] +#[inline(never)] +pub fn _kernel(x: &mut [f32; 256]) { + for i in 0..256 { + x[i] = 21.0; + } +} From 23722aa408bd7fd7aef12de8214a390432b0b93a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 21 Oct 2025 09:49:12 +0200 Subject: [PATCH 2/2] Add basic offload metadata --- .../src/builder/gpu_offload.rs | 65 ++--------------- compiler/rustc_codegen_llvm/src/intrinsic.rs | 12 ++-- compiler/rustc_middle/src/ty/mod.rs | 1 + compiler/rustc_middle/src/ty/offload_meta.rs | 70 +++++++++++++++++++ 4 files changed, 84 insertions(+), 64 deletions(-) create mode 100644 compiler/rustc_middle/src/ty/offload_meta.rs diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 1fef04a2c9c14..a02d1af11e233 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -4,6 +4,7 @@ use llvm::Linkage::*; use rustc_abi::Align; use rustc_codegen_ssa::back::write::CodegenContext; use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_middle::ty::offload_meta::OffloadMetadata; use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; use crate::builder::SBuilder; @@ -260,8 +261,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( tcx: TyCtxt<'tcx>, kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, - // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need - tt: Vec>, + metadata: Vec, symbol: &str, ) -> (&'ll llvm::Value, &'ll llvm::Value) { let types = cx.func_params_types(cx.get_type_of_global(kernel)); @@ -272,12 +272,11 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)) .count(); - // TODO(Sa4dUs): Add typetrees here let ptr_sizes = types .iter() - .zip(tt) - .filter_map(|(&x, ty)| match cx.type_kind(x) { - rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)), + .zip(metadata) + .filter_map(|(&x, meta)| match cx.type_kind(x) { + rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta.payload_size), _ => None, }) .collect::>(); @@ -332,56 +331,6 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( (memtransfer_types, region_id) } -// TODO(Sa4dUs): move this to a proper place -fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { - match ty.kind() { - /* - rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Char => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Str => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(), - */ - ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), - /* - rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Never => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(), - */ - _ => { - tcx - // TODO(Sa4dUs): Maybe `.as_query_input()`? - .layout_of(PseudoCanonicalInput { - typing_env: TypingEnv::fully_monomorphized(), - value: ty, - }) - .unwrap() - .size - .bytes() - } - } -} - fn declare_offload_fn<'ll>( cx: &'ll SimpleCx<'_>, name: &str, @@ -420,7 +369,7 @@ fn declare_offload_fn<'ll>( pub(crate) fn gen_call_handling<'ll>( cx: &SimpleCx<'ll>, bb: &BasicBlock, - kernels: &[&'ll llvm::Value], + kernel: &'ll llvm::Value, memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], llfn: &'ll Value, @@ -438,7 +387,7 @@ pub(crate) fn gen_call_handling<'ll>( let mut builder = SBuilder::build(cx, bb); - let types = cx.func_params_types(cx.get_type_of_global(kernels[0])); + let types = cx.func_params_types(cx.get_type_of_global(kernel)); let num_args = types.len() as u64; // Step 0) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 01a8cee048dc3..0b9ba5fd1822f 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -13,6 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE; use rustc_hir::{self as hir}; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; +use rustc_middle::ty::offload_meta::OffloadMetadata; use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; @@ -1260,7 +1261,6 @@ fn codegen_offload<'ll, 'tcx>( } }; - // TODO(Sa4dUs): Will need typetrees let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE); let Some(kernel) = cx.get_function(&target_symbol) else { bug!("could not find target function") @@ -1272,26 +1272,26 @@ fn codegen_offload<'ll, 'tcx>( let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder(); let inputs = sig.inputs(); + let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::>(); + // TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling( cx, tcx, kernel, offload_entry_ty, - inputs.to_vec(), + metadata, &target_symbol, ); - let kernels = &[kernel]; - let llfn = bx.llfn(); - // TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix + // TODO(Sa4dUs): this is just to a void lifetime's issues let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) }; crate::builder::gpu_offload::gen_call_handling( cx, bb, - kernels, + kernel, &[memtransfer_type], &[region_id], llfn, diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index c3e1defef809d..683d089af1c3f 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -131,6 +131,7 @@ pub mod fast_reject; pub mod inhabitedness; pub mod layout; pub mod normalize_erasing_regions; +pub mod offload_meta; pub mod pattern; pub mod print; pub mod relate; diff --git a/compiler/rustc_middle/src/ty/offload_meta.rs b/compiler/rustc_middle/src/ty/offload_meta.rs new file mode 100644 index 0000000000000..e7159888a643d --- /dev/null +++ b/compiler/rustc_middle/src/ty/offload_meta.rs @@ -0,0 +1,70 @@ +use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; + +// TODO(Sa4dUs): it doesn't feel correct for me to place this on `rustc_ast::expand`, will look for a proper location +pub struct OffloadMetadata { + pub payload_size: u64, + pub mode: TransferKind, +} + +pub enum TransferKind { + FromGpu = 1, + ToGpu = 2, + Both = 3, +} + +impl OffloadMetadata { + pub fn new(payload_size: u64, mode: TransferKind) -> Self { + OffloadMetadata { payload_size, mode } + } + + pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self { + OffloadMetadata { payload_size: get_payload_size(tcx, ty), mode: TransferKind::Both } + } +} + +// TODO(Sa4dUs): WIP, rn we just have a naive logic for references +fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { + match ty.kind() { + /* + rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Char => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Str => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(), + */ + ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), + /* + rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Never => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(), + */ + _ => tcx + .layout_of(PseudoCanonicalInput { + typing_env: TypingEnv::fully_monomorphized(), + value: ty, + }) + .unwrap() + .size + .bytes(), + } +}