Skip to content

Commit e98d355

Browse files
committed
wip, works
1 parent 8eebc80 commit e98d355

File tree

3 files changed

+120
-11
lines changed

3 files changed

+120
-11
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,11 +767,36 @@ pub(crate) unsafe fn llvm_optimize(
767767
};
768768

769769
if cgcx.target_is_like_gpu && config.offload.contains(&config::Offload::Enable) {
770+
let lib_bc_c = CString::new("/p/lustre1/drehwald1/prog/offload/r/lib.bc").unwrap();
771+
let host_out_c = CString::new("/p/lustre1/drehwald1/prog/offload/r/host.out").unwrap();
772+
let out_obj_c = CString::new("/p/lustre1/drehwald1/prog/offload/r/host.o").unwrap();
773+
770774
unsafe {
771-
llvm::LLVMRustBundleImages(module.module_llvm.llmod(), module.module_llvm.tm.raw());
775+
llvm::LLVMRustBundleImages(
776+
module.module_llvm.llmod(),
777+
module.module_llvm.tm.raw(),
778+
host_out_c.as_ptr(),
779+
);
772780
}
773-
}
781+
unsafe {
782+
// 1) Bundle device module into offload image host.out (device TM)
783+
let ok = llvm::LLVMRustBundleImages(
784+
module.module_llvm.llmod(),
785+
module.module_llvm.tm.raw(),
786+
host_out_c.as_ptr(),
787+
);
788+
assert!(ok, "LLVMRustBundleImages (device -> host.out) failed");
774789

790+
// 2) Finalize host: lib.bc + host.out -> host.offload.o (host TM created in C++)
791+
let ok = llvm::LLVMRustFinalizeOffload(
792+
lib_bc_c.as_ptr(),
793+
host_out_c.as_ptr(),
794+
out_obj_c.as_ptr(),
795+
);
796+
assert!(ok, "LLVMRustFinalizeOffload (host finalize) failed");
797+
}
798+
dbg!("done");
799+
}
775800
result.into_result().unwrap_or_else(|()| llvm_err(dcx, LlvmError::RunLlvmPasses))
776801
}
777802

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,8 +1726,17 @@ mod Offload {
17261726
use super::*;
17271727
unsafe extern "C" {
17281728
/// Processes the module and writes it in an offload compatible way into a "host.out" file.
1729-
pub(crate) fn LLVMRustBundleImages<'a>(M: &'a Module, TM: &'a TargetMachine) -> bool;
17301729
pub(crate) fn LLVMRustOffloadMapper<'a>(OldFn: &'a Value, NewFn: &'a Value);
1730+
pub(crate) fn LLVMRustBundleImages<'a>(
1731+
M: &'a Module,
1732+
TM: &'a TargetMachine,
1733+
host_out: *const c_char,
1734+
) -> bool;
1735+
pub(crate) fn LLVMRustFinalizeOffload(
1736+
lib_bc_path: *const c_char,
1737+
host_out_path: *const c_char,
1738+
out_obj_path: *const c_char,
1739+
) -> bool;
17311740
}
17321741
}
17331742

@@ -1740,11 +1749,23 @@ mod Offload_fallback {
17401749
/// Processes the module and writes it in an offload compatible way into a "host.out" file.
17411750
/// Marked as unsafe to match the real offload wrapper which is unsafe due to FFI.
17421751
#[allow(unused_unsafe)]
1743-
pub(crate) unsafe fn LLVMRustBundleImages<'a>(_M: &'a Module, _TM: &'a TargetMachine) -> bool {
1752+
pub(crate) unsafe fn LLVMRustOffloadMapper<'a>(_OldFn: &'a Value, _NewFn: &'a Value) {
17441753
unimplemented!("This rustc version was not built with LLVM Offload support!");
17451754
}
17461755
#[allow(unused_unsafe)]
1747-
pub(crate) unsafe fn LLVMRustOffloadMapper<'a>(_OldFn: &'a Value, _NewFn: &'a Value) {
1756+
pub(crate) unsafe fn LLVMRustBundleImages<'a>(
1757+
M: &'a Module,
1758+
TM: &'a TargetMachine,
1759+
host_out: *const c_char,
1760+
) -> bool {
1761+
unimplemented!("This rustc version was not built with LLVM Offload support!");
1762+
}
1763+
#[allow(unused_unsafe)]
1764+
pub(crate) unsafe fn LLVMRustFinalizeOffload(
1765+
lib_bc_path: *const c_char,
1766+
host_out_path: *const c_char,
1767+
out_obj_path: *const c_char,
1768+
) -> bool {
17481769
unimplemented!("This rustc version was not built with LLVM Offload support!");
17491770
}
17501771
}

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ static Error writeFile(StringRef Filename, StringRef Data) {
186186
// --image=file=device.bc,triple=amdgcn-amd-amdhsa,arch=gfx90a,kind=openmp
187187
// The input module is the rust code compiled for a gpu target like amdgpu.
188188
// Based on clang/tools/clang-offload-packager/ClangOffloadPackager.cpp
189-
extern "C" bool LLVMRustBundleImages(LLVMModuleRef M, TargetMachine &TM) {
189+
extern "C" bool LLVMRustBundleImages(LLVMModuleRef M, TargetMachine &TM, const char *HostOutPath) {
190190
std::string Storage;
191191
llvm::raw_string_ostream OS1(Storage);
192192
llvm::WriteBitcodeToFile(*unwrap(M), OS1);
@@ -211,16 +211,16 @@ extern "C" bool LLVMRustBundleImages(LLVMModuleRef M, TargetMachine &TM) {
211211
// Offload binary has invalid size alignment
212212
return false;
213213
OS2 << Buffer;
214-
if (Error E = writeFile("host.out",
214+
if (Error E = writeFile(HostOutPath,
215215
StringRef(BinaryData.begin(), BinaryData.size())))
216216
return false;
217217
return true;
218218
}
219219

220220
#include "llvm/Bitcode/BitcodeReader.h"
221221
Expected<std::unique_ptr<Module>>
222-
loadHostModuleFromBitcode(LLVMContext &Ctx) {
223-
auto MBOrErr = MemoryBuffer::getFile("/g/todo");
222+
loadHostModuleFromBitcode(LLVMContext &Ctx, StringRef LibBCPath) {
223+
auto MBOrErr = MemoryBuffer::getFile(LibBCPath);
224224
if (!MBOrErr)
225225
return errorCodeToError(MBOrErr.getError());
226226

@@ -252,11 +252,15 @@ extern "C" void embedBufferInModule(Module &M, MemoryBufferRef Buf) {
252252
}
253253

254254
Error embedHostOutIntoHostModule(Module &HostM, StringRef HostOutPath) {
255+
llvm::errs() << "embedHostOutIntoHostModule step 1:\n";
255256
auto MBOrErr = MemoryBuffer::getFile(HostOutPath);
257+
llvm::errs() << "embedHostOutIntoHostModule step 2:\n";
256258
if (!MBOrErr)
257259
return errorCodeToError(MBOrErr.getError());
258260

261+
llvm::errs() << "embedHostOutIntoHostModule step 3:\n";
259262
MemoryBufferRef Buf = (*MBOrErr)->getMemBufferRef();
263+
llvm::errs() << "embedHostOutIntoHostModule step 4:\n";
260264
embedBufferInModule(HostM, Buf);
261265
return Error::success();
262266
}
@@ -276,8 +280,8 @@ Error emitHostObjectWithTM(Module &HostM,
276280
TargetMachine &TM,
277281
StringRef OutObjPath) {
278282
// Make sure module matches the TM
279-
HostM.setDataLayout(TM.createDataLayout());
280-
HostM.setTargetTriple(TM.getTargetTriple().str());
283+
//HostM.setDataLayout(TM.createDataLayout());
284+
//HostM.setTargetTriple(TM.getTargetTriple().str());
281285

282286
legacy::PassManager PM;
283287
std::error_code EC;
@@ -314,6 +318,65 @@ extern "C" void LLVMRustOffloadMapper(LLVMValueRef OldFn, LLVMValueRef NewFn) {
314318
}
315319
#endif
316320

321+
// Create a host TargetMachine with HARDCODED triple/CPU
322+
static std::unique_ptr<TargetMachine> createHostTargetMachine() {
323+
static bool Initialized = false;
324+
if (!Initialized) {
325+
InitializeAllTargets();
326+
InitializeAllTargetMCs();
327+
InitializeAllAsmPrinters();
328+
InitializeAllAsmParsers();
329+
Initialized = true;
330+
}
331+
332+
// Hardcoded host triple + CPU (adapt if your CI/host differs)
333+
std::string TripleStr = "x86_64-unknown-linux-gnu";
334+
std::string CPU = "x86-64"; // OK for X86
335+
336+
std::string Err;
337+
const Target *T = TargetRegistry::lookupTarget(TripleStr, Err);
338+
if (!T) {
339+
// Could log Err here
340+
return nullptr;
341+
}
342+
343+
TargetOptions Opts;
344+
auto RM = std::optional<Reloc::Model>(Reloc::PIC_);
345+
346+
std::unique_ptr<TargetMachine> TM(
347+
T->createTargetMachine(TripleStr, CPU, /*Features*/"", Opts, RM));
348+
349+
return TM;
350+
}
351+
352+
// Top-level entry: host finalize in second rustc invocation
353+
// lib.bc (from first rustc) + host.out (from LLVMRustBundleImages) => host.offload.o
354+
extern "C" bool LLVMRustFinalizeOffload(const char *LibBCPath,
355+
const char *HostOutPath,
356+
const char *OutObjPath) {
357+
LLVMContext Ctx;
358+
359+
// 1. Load host lib.bc
360+
auto ModOrErr = loadHostModuleFromBitcode(Ctx, LibBCPath);
361+
if (!ModOrErr)
362+
return !errorToBool(ModOrErr.takeError());
363+
std::unique_ptr<Module> HostM = std::move(*ModOrErr);
364+
365+
// 2. Embed host.out
366+
if (Error E = embedHostOutIntoHostModule(*HostM, HostOutPath))
367+
return !errorToBool(std::move(E));
368+
369+
// 3. Create host TM and emit host object
370+
auto HostTM = createHostTargetMachine();
371+
if (!HostTM)
372+
return false;
373+
374+
if (Error E = emitHostObjectWithTM(*HostM, *HostTM, OutObjPath))
375+
return !errorToBool(std::move(E));
376+
377+
return true;
378+
}
379+
317380
extern "C" LLVMValueRef LLVMRustGetNamedValue(LLVMModuleRef M, const char *Name,
318381
size_t NameLen) {
319382
return wrap(unwrap(M)->getNamedValue(StringRef(Name, NameLen)));

0 commit comments

Comments
 (0)