Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 42 additions & 45 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::back::write::CodegenContext;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
use rustc_middle::ty::offload_meta::OffloadMetadata;
use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};

use crate::builder::SBuilder;
use crate::common::AsCCharPtr;
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, Linkage, Type, Value};
use crate::llvm::{self, BasicBlock, Linkage, Type, Value};
use crate::{LlvmCodegenBackend, SimpleCx, attributes};

pub(crate) fn handle_gpu_code<'ll>(
_cgcx: &CodegenContext<LlvmCodegenBackend>,
cx: &'ll SimpleCx<'_>,
_cx: &'ll SimpleCx<'_>,
) {
/*
// The offload memory transfer type for each kernel
let mut memtransfer_types = vec![];
let mut region_ids = vec![];
Expand All @@ -29,6 +31,7 @@ pub(crate) fn handle_gpu_code<'ll>(
}

gen_call_handling(&cx, &memtransfer_types, &region_ids);
*/
}

// ; Function Attrs: nounwind
Expand Down Expand Up @@ -76,7 +79,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
at_one
}

struct TgtOffloadEntry {
pub(crate) struct TgtOffloadEntry {
// uint64_t Reserved;
// uint16_t Version;
// uint16_t Kind;
Expand Down Expand Up @@ -253,11 +256,13 @@ pub(crate) fn add_global<'ll>(
// This function returns a memtransfer value which encodes how arguments to this kernel shall be
// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
// concatenated into the list of region_ids.
fn gen_define_handling<'ll>(
cx: &'ll SimpleCx<'_>,
pub(crate) fn gen_define_handling<'ll, 'tcx>(
cx: &SimpleCx<'ll>,
tcx: TyCtxt<'tcx>,
kernel: &'ll llvm::Value,
offload_entry_ty: &'ll llvm::Type,
num: i64,
metadata: Vec<OffloadMetadata>,
symbol: &str,
) -> (&'ll llvm::Value, &'ll llvm::Value) {
let types = cx.func_params_types(cx.get_type_of_global(kernel));
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
Expand All @@ -267,37 +272,49 @@ fn gen_define_handling<'ll>(
.filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
.count();

let ptr_sizes = types
.iter()
.zip(metadata)
.filter_map(|(&x, meta)| match cx.type_kind(x) {
rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta.payload_size),
_ => None,
})
.collect::<Vec<u64>>();

// We do not know their size anymore at this level, so hardcode a placeholder.
// A follow-up pr will track these from the frontend, where we still have Rust types.
// Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
// I decided that 1024 bytes is a great placeholder value for now.
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
// will be 2. For now, everything is 3, until we have our frontend set up.
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
let memtransfer_types = add_priv_unnamed_arr(
&cx,
&format!(".offload_maptypes.{num}"),
&format!(".offload_maptypes.{symbol}"),
&vec![1 + 2 + 32; num_ptr_types],
);

// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the llvm_offload_entries value

let name = format!(".kernel_{num}.region_id");
let name = format!(".{symbol}.region_id");
let initializer = cx.get_const_i8(0);
let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);

let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
let c_entry_name = CString::new(symbol).unwrap();
let c_val = c_entry_name.as_bytes_with_nul();
let offload_entry_name = format!(".offloading.entry_name.{num}");
let offload_entry_name = format!(".offloading.entry_name.{symbol}");

let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
llvm::set_alignment(llglobal, Align::ONE);
llvm::set_section(llglobal, c".llvm.rodata.offloading");
let name = format!(".offloading.entry.kernel_{num}");

// Not actively used yet, for calling real kernels
let name = format!(".offloading.entry.{symbol}");

// See the __tgt_offload_entry documentation above.
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
Expand All @@ -314,7 +331,7 @@ fn gen_define_handling<'ll>(
(memtransfer_types, region_id)
}

pub(crate) fn declare_offload_fn<'ll>(
fn declare_offload_fn<'ll>(
cx: &'ll SimpleCx<'_>,
name: &str,
ty: &'ll llvm::Type,
Expand Down Expand Up @@ -349,10 +366,13 @@ pub(crate) fn declare_offload_fn<'ll>(
// 4. set insert point after kernel call.
// 5. generate all the GEPS and stores, to be used in 6)
// 6. generate __tgt_target_data_end calls to move data from the GPU
fn gen_call_handling<'ll>(
cx: &'ll SimpleCx<'_>,
pub(crate) fn gen_call_handling<'ll>(
cx: &SimpleCx<'ll>,
bb: &BasicBlock,
kernel: &'ll llvm::Value,
memtransfer_types: &[&'ll llvm::Value],
region_ids: &[&'ll llvm::Value],
llfn: &'ll Value,
) {
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
Expand All @@ -365,27 +385,14 @@ fn gen_call_handling<'ll>(
let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);

let main_fn = cx.get_function("main");
let Some(main_fn) = main_fn else { return };
let kernel_name = "kernel_1";
let call = unsafe {
llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
};
let Some(kernel_call) = call else {
return;
};
let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
let mut builder = SBuilder::build(cx, kernel_call_bb);

let types = cx.func_params_types(cx.get_type_of_global(called));
let mut builder = SBuilder::build(cx, bb);

let types = cx.func_params_types(cx.get_type_of_global(kernel));
let num_args = types.len() as u64;

// Step 0)
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
// %6 = alloca %struct.__tgt_bin_desc, align 8
unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };

let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");

let ty = cx.type_array(cx.type_ptr(), num_args);
Expand All @@ -401,15 +408,14 @@ fn gen_call_handling<'ll>(
let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");

// Step 1)
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);

// Now we allocate once per function param, a copy to be passed to one of our maps.
let mut vals = vec![];
let mut geps = vec![];
let i32_0 = cx.get_const_i32(0);
for index in 0..types.len() {
let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
for index in 0..num_args {
let v = unsafe { llvm::LLVMGetParam(llfn, index as u32) };
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
vals.push(v);
geps.push(gep);
Expand Down Expand Up @@ -501,13 +507,8 @@ fn gen_call_handling<'ll>(
region_ids[0],
a5,
];
let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
unsafe {
let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
llvm::LLVMRustPositionAfter(builder.llbuilder, next);
llvm::LLVMInstructionEraseFromParent(next);
}

// Step 4)
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
Expand All @@ -516,8 +517,4 @@ fn gen_call_handling<'ll>(
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);

drop(builder);
// FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
// and then delete the call to the CPU version. In the future, we should use an intrinsic which
// directly resolves to a call to the GPU version.
unsafe { llvm::LLVMDeleteFunction(called) };
}
71 changes: 71 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
use rustc_hir::{self as hir};
use rustc_middle::mir::BinOp;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
use rustc_middle::ty::offload_meta::OffloadMetadata;
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, span_bug};
use rustc_span::{Span, Symbol, sym};
Expand All @@ -23,6 +24,7 @@ use tracing::debug;
use crate::abi::FnAbiLlvmExt;
use crate::builder::Builder;
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
use crate::builder::gpu_offload::TgtOffloadEntry;
use crate::context::CodegenCx;
use crate::errors::AutoDiffWithoutEnable;
use crate::llvm::{self, Metadata, Type, Value};
Expand Down Expand Up @@ -195,6 +197,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
codegen_autodiff(self, tcx, instance, args, result);
return Ok(());
}
sym::offload => {
codegen_offload(self, tcx, instance, args, result);
return Ok(());
}
sym::is_val_statically_known => {
if let OperandValue::Immediate(imm) = args[0].val {
self.call_intrinsic(
Expand Down Expand Up @@ -1227,6 +1233,71 @@ fn codegen_autodiff<'ll, 'tcx>(
);
}

fn codegen_offload<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
tcx: TyCtxt<'tcx>,
instance: ty::Instance<'tcx>,
_args: &[OperandRef<'tcx, &'ll Value>],
_result: PlaceRef<'tcx, &'ll Value>,
) {
let cx = bx.cx;
let fn_args = instance.args;

let (target_id, target_args) = match fn_args.into_type_list(tcx)[0].kind() {
ty::FnDef(def_id, params) => (def_id, params),
_ => bug!("invalid offload intrinsic arg"),
};

let fn_target = match Instance::try_resolve(tcx, cx.typing_env(), *target_id, target_args) {
Ok(Some(instance)) => instance,
Ok(None) => bug!(
"could not resolve ({:?}, {:?}) to a specific offload instance",
target_id,
target_args
),
Err(_) => {
// An error has already been emitted
return;
}
};

let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE);
let Some(kernel) = cx.get_function(&target_symbol) else {
bug!("could not find target function")
};

let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);

// Build TypeTree (or something similar)
let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
let inputs = sig.inputs();

let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::<Vec<_>>();

// TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory
let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling(
cx,
tcx,
kernel,
offload_entry_ty,
metadata,
&target_symbol,
);

let llfn = bx.llfn();

// TODO(Sa4dUs): this is just to a void lifetime's issues
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
crate::builder::gpu_offload::gen_call_handling(
cx,
bb,
kernel,
&[memtransfer_type],
&[region_id],
llfn,
);
}

fn get_args_from_tuple<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
tuple_op: OperandRef<'tcx, &'ll Value>,
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//!
//! This API is completely unstable and subject to change.
// TODO(Sa4dUs): remove this once we have a great version, just to ignore unused LLVM wrappers
#![allow(unused)]
// tidy-alphabetical-start
#![allow(internal_features)]
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi
| sym::minnumf128
| sym::mul_with_overflow
| sym::needs_drop
| sym::offload
| sym::powf16
| sym::powf32
| sym::powf64
Expand Down Expand Up @@ -310,6 +311,7 @@ pub(crate) fn check_intrinsic_type(
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
(0, 0, vec![type_id, type_id], tcx.types.bool)
}
sym::offload => (2, 0, vec![param(0)], param(1)),
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::arith_offset => (
1,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub mod fast_reject;
pub mod inhabitedness;
pub mod layout;
pub mod normalize_erasing_regions;
pub mod offload_meta;
pub mod pattern;
pub mod print;
pub mod relate;
Expand Down
Loading
Loading