Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,7 @@ pub(crate) unsafe fn llvm_optimize(
if cgcx.target_is_like_gpu && config.offload.contains(&config::Offload::Enable) {
let cx =
SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size);
// For now we only support up to 10 kernels named kernel_0 ... kernel_9, a follow-up PR is
// introducing a proper offload intrinsic to solve this limitation.

for func in cx.get_functions() {
let offload_kernel = "offload-kernel";
if attributes::has_string_attr(func, offload_kernel) {
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_codegen_llvm/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ use rustc_middle::dep_graph;
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrs, SanitizerFnAttrs};
use rustc_middle::mir::mono::Visibility;
use rustc_middle::ty::TyCtxt;
use rustc_session::config::DebugInfo;
use rustc_session::config::{DebugInfo, Offload};
use rustc_span::Symbol;
use rustc_target::spec::SanitizerSet;

use super::ModuleLlvm;
use crate::attributes;
use crate::builder::Builder;
use crate::builder::gpu_offload::OffloadGlobals;
use crate::context::CodegenCx;
use crate::llvm::{self, Value};

Expand Down Expand Up @@ -85,6 +86,13 @@ pub(crate) fn compile_codegen_unit(
let llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str());
{
let mut cx = CodegenCx::new(tcx, cgu, &llvm_module);

if cx.sess().opts.unstable_opts.offload.contains(&Offload::Enable)
&& !cx.sess().target.is_like_gpu
{
cx.offload_globals.replace(Some(OffloadGlobals::declare(&cx)));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a bit unsure about this location. we could also cache these globals and generate them on the first intrinsic call, but that felt like overloading intrinsic codegen a bit too much

i don't have a strong opinion though, so happy to go with whatever u think is best

}

let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx);
for &(mono_item, data) in &mono_items {
mono_item.predefine::<Builder<'_, '_, '_>>(
Expand Down
186 changes: 126 additions & 60 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,76 @@ use std::ffi::CString;

use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::offload_meta::OffloadMetadata;

use crate::builder::SBuilder;
use crate::builder::Builder;
use crate::common::CodegenCx;
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, BasicBlock, Linkage, Type, Value};
use crate::llvm::{self, Linkage, Type, Value};
use crate::{SimpleCx, attributes};

// LLVM kernel-independent globals required for offloading
pub(crate) struct OffloadGlobals<'ll> {
pub launcher_fn: &'ll llvm::Value,
pub launcher_ty: &'ll llvm::Type,

pub bin_desc: &'ll llvm::Type,

pub kernel_args_ty: &'ll llvm::Type,

pub offload_entry_ty: &'ll llvm::Type,

pub begin_mapper: &'ll llvm::Value,
pub end_mapper: &'ll llvm::Value,
pub mapper_fn_ty: &'ll llvm::Type,

pub ident_t_global: &'ll llvm::Value,

pub register_lib: &'ll llvm::Value,
pub unregister_lib: &'ll llvm::Value,
pub init_rtls: &'ll llvm::Value,
}

impl<'ll> OffloadGlobals<'ll> {
pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
let (launcher_fn, launcher_ty) = generate_launcher(cx);
let kernel_args_ty = KernelArgsTy::new_decl(cx);
let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
let ident_t_global = generate_at_one(cx);

let tptr = cx.type_ptr();
let ti32 = cx.type_i32();
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);

let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
let init_ty = cx.type_func(&[], cx.type_void());
let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);

OffloadGlobals {
launcher_fn,
launcher_ty,
bin_desc,
kernel_args_ty,
offload_entry_ty,
begin_mapper,
end_mapper,
mapper_fn_ty,
ident_t_global,
register_lib,
unregister_lib,
init_rtls,
}
}
}

// ; Function Attrs: nounwind
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
let ti32 = cx.type_i32();
Expand All @@ -30,7 +89,7 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
// offloaded was defined.
fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
let unknown_txt = ";unknown;unknown;0;0;;";
let c_entry_name = CString::new(unknown_txt).unwrap();
let c_val = c_entry_name.as_bytes_with_nul();
Expand Down Expand Up @@ -68,7 +127,7 @@ pub(crate) struct TgtOffloadEntry {
}

impl TgtOffloadEntry {
pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
Expand All @@ -82,7 +141,7 @@ impl TgtOffloadEntry {
}

fn new<'ll>(
cx: &'ll SimpleCx<'_>,
cx: &CodegenCx<'ll, '_>,
region_id: &'ll Value,
llglobal: &'ll Value,
) -> [&'ll Value; 9] {
Expand Down Expand Up @@ -126,7 +185,7 @@ impl KernelArgsTy {
const OFFLOAD_VERSION: u64 = 3;
const FLAGS: u64 = 0;
const TRIPCOUNT: u64 = 0;
fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
Expand All @@ -140,8 +199,8 @@ impl KernelArgsTy {
kernel_arguments_ty
}

fn new<'ll>(
cx: &'ll SimpleCx<'_>,
fn new<'ll, 'tcx>(
cx: &CodegenCx<'ll, 'tcx>,
num_args: u64,
memtransfer_types: &'ll Value,
geps: [&'ll Value; 3],
Expand Down Expand Up @@ -171,15 +230,16 @@ impl KernelArgsTy {
}

// Contains LLVM values needed to manage offloading for a single kernel.
pub(crate) struct OffloadKernelData<'ll> {
#[derive(Copy, Clone)]
pub(crate) struct OffloadKernelGlobals<'ll> {
pub offload_sizes: &'ll llvm::Value,
pub memtransfer_types: &'ll llvm::Value,
pub region_id: &'ll llvm::Value,
pub offload_entry: &'ll llvm::Value,
}

fn gen_tgt_data_mappers<'ll>(
cx: &'ll SimpleCx<'_>,
cx: &CodegenCx<'ll, '_>,
) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
Expand Down Expand Up @@ -241,12 +301,18 @@ pub(crate) fn add_global<'ll>(
// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
// concatenated into the list of region_ids.
pub(crate) fn gen_define_handling<'ll>(
cx: &SimpleCx<'ll>,
offload_entry_ty: &'ll llvm::Type,
cx: &CodegenCx<'ll, '_>,
metadata: &[OffloadMetadata],
types: &[&Type],
symbol: &str,
) -> OffloadKernelData<'ll> {
types: &[&'ll Type],
symbol: String,
offload_globals: &OffloadGlobals<'ll>,
) -> OffloadKernelGlobals<'ll> {
if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
return *entry;
}

let offload_entry_ty = offload_globals.offload_entry_ty;

// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
// reference) types.
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
Expand Down Expand Up @@ -274,7 +340,7 @@ pub(crate) fn gen_define_handling<'ll>(
let initializer = cx.get_const_i8(0);
let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);

let c_entry_name = CString::new(symbol).unwrap();
let c_entry_name = CString::new(symbol.clone()).unwrap();
let c_val = c_entry_name.as_bytes_with_nul();
let offload_entry_name = format!(".offloading.entry_name.{symbol}");

Expand All @@ -298,11 +364,16 @@ pub(crate) fn gen_define_handling<'ll>(
let c_section_name = CString::new("llvm_offload_entries").unwrap();
llvm::set_section(offload_entry, &c_section_name);

OffloadKernelData { offload_sizes, memtransfer_types, region_id, offload_entry }
let result =
OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };

cx.offload_kernel_cache.borrow_mut().insert(symbol, result);

result
}

fn declare_offload_fn<'ll>(
cx: &'ll SimpleCx<'_>,
cx: &CodegenCx<'ll, '_>,
name: &str,
ty: &'ll llvm::Type,
) -> &'ll llvm::Value {
Expand Down Expand Up @@ -335,28 +406,28 @@ fn declare_offload_fn<'ll>(
// 4. set insert point after kernel call.
// 5. generate all the GEPS and stores, to be used in 6)
// 6. generate __tgt_target_data_end calls to move data from the GPU
pub(crate) fn gen_call_handling<'ll>(
cx: &SimpleCx<'ll>,
bb: &BasicBlock,
offload_data: &OffloadKernelData<'ll>,
pub(crate) fn gen_call_handling<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
offload_data: &OffloadKernelGlobals<'ll>,
args: &[&'ll Value],
types: &[&Type],
metadata: &[OffloadMetadata],
offload_globals: &OffloadGlobals<'ll>,
) {
let OffloadKernelData { offload_sizes, offload_entry, memtransfer_types, region_id } =
let cx = builder.cx;
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
offload_data;
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
let tptr = cx.type_ptr();
let ti32 = cx.type_i32();
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);

let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
let tgt_decl = offload_globals.launcher_fn;
let tgt_target_kernel_ty = offload_globals.launcher_ty;

let mut builder = SBuilder::build(cx, bb);
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
let tgt_bin_desc = offload_globals.bin_desc;

let tgt_kernel_decl = offload_globals.kernel_args_ty;
let begin_mapper_decl = offload_globals.begin_mapper;
let end_mapper_decl = offload_globals.end_mapper;
let fn_ty = offload_globals.mapper_fn_ty;

let num_args = types.len() as u64;
let ip = unsafe { llvm::LLVMRustGetInsertPoint(&builder.llbuilder) };
Expand All @@ -378,9 +449,8 @@ pub(crate) fn gen_call_handling<'ll>(
// Step 0)
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
// %6 = alloca %struct.__tgt_bin_desc, align 8
let llfn = unsafe { llvm::LLVMGetBasicBlockParent(bb) };
unsafe {
llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, llfn);
llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
}
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");

Expand Down Expand Up @@ -413,16 +483,16 @@ pub(crate) fn gen_call_handling<'ll>(
}

let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
let register_lib_decl = offload_globals.register_lib;
let unregister_lib_decl = offload_globals.unregister_lib;
let init_ty = cx.type_func(&[], cx.type_void());
let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
let init_rtls_decl = offload_globals.init_rtls;

// FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
// call void @__tgt_register_lib(ptr noundef %6)
builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None);
// call void @__tgt_init_all_rtls()
builder.call(init_ty, init_rtls_decl, &[], None);
builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);

for i in 0..num_args {
let idx = cx.get_const_i32(i);
Expand All @@ -437,15 +507,15 @@ pub(crate) fn gen_call_handling<'ll>(

// For now we have a very simplistic indexing scheme into our
// offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
fn get_geps<'a, 'll>(
builder: &mut SBuilder<'a, 'll>,
cx: &'ll SimpleCx<'ll>,
fn get_geps<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
ty: &'ll Type,
ty2: &'ll Type,
a1: &'ll Value,
a2: &'ll Value,
a4: &'ll Value,
) -> [&'ll Value; 3] {
let cx = builder.cx;
let i32_0 = cx.get_const_i32(0);

let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
Expand All @@ -454,30 +524,29 @@ pub(crate) fn gen_call_handling<'ll>(
[gep1, gep2, gep3]
}

fn generate_mapper_call<'a, 'll>(
builder: &mut SBuilder<'a, 'll>,
cx: &'ll SimpleCx<'ll>,
fn generate_mapper_call<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
geps: [&'ll Value; 3],
o_type: &'ll Value,
fn_to_call: &'ll Value,
fn_ty: &'ll Type,
num_args: u64,
s_ident_t: &'ll Value,
) {
let cx = builder.cx;
let nullptr = cx.const_null(cx.type_ptr());
let i64_max = cx.get_const_i64(u64::MAX);
let num_args = cx.get_const_i32(num_args);
let args =
vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
builder.call(fn_ty, fn_to_call, &args, None);
builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
}

// Step 2)
let s_ident_t = generate_at_one(&cx);
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
let s_ident_t = offload_globals.ident_t_global;
let geps = get_geps(builder, ty, ty2, a1, a2, a4);
generate_mapper_call(
&mut builder,
&cx,
builder,
geps,
memtransfer_types,
begin_mapper_decl,
Expand All @@ -504,14 +573,13 @@ pub(crate) fn gen_call_handling<'ll>(
region_id,
a5,
];
builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)

// Step 4)
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
let geps = get_geps(builder, ty, ty2, a1, a2, a4);
generate_mapper_call(
&mut builder,
&cx,
builder,
geps,
memtransfer_types,
end_mapper_decl,
Expand All @@ -520,7 +588,5 @@ pub(crate) fn gen_call_handling<'ll>(
s_ident_t,
);

builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);

drop(builder);
builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None);
}
Loading
Loading