33#include < llvm/Support/TargetSelect.h>
44#include < llvm/IR/Verifier.h>
55#include < llvm/ExecutionEngine/Orc/LLJIT.h>
6+ #include < llvm/Passes/PassBuilder.h>
67
78#include " llvmcodebuilder.h"
89#include " llvmexecutablecode.h"
@@ -15,10 +16,12 @@ static std::unordered_map<ValueType, Compiler::StaticType> TYPE_MAP = {
1516 { ValueType::NegativeInfinity, Compiler::StaticType::Number }, { ValueType::NaN, Compiler::StaticType::Number }
1617};
1718
18- LLVMCodeBuilder::LLVMCodeBuilder (const std::string &id) :
19+ LLVMCodeBuilder::LLVMCodeBuilder (const std::string &id, bool warp ) :
1920 m_id(id),
2021 m_module(std::make_unique<llvm::Module>(id, m_ctx)),
21- m_builder(m_ctx)
22+ m_builder(m_ctx),
23+ m_defaultWarp(warp),
24+ m_warp(warp)
2225{
2326 llvm::InitializeNativeTarget ();
2427 llvm::InitializeNativeTargetAsmPrinter ();
@@ -32,13 +35,20 @@ LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id) :
3235std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize ()
3336{
3437 // Create function
35- // void f(Target *)
36- llvm::FunctionType *funcType = llvm::FunctionType::get (m_builder.getVoidTy (), llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 ), false );
38+ // void *f(Target *)
39+ llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 );
40+ llvm::FunctionType *funcType = llvm::FunctionType::get (pointerType, pointerType, false );
3741 llvm::Function *func = llvm::Function::Create (funcType, llvm::Function::ExternalLinkage, " f" , m_module.get ());
3842
3943 llvm::BasicBlock *entry = llvm::BasicBlock::Create (m_ctx, " entry" , func);
4044 m_builder.SetInsertPoint (entry);
4145
46+ // Init coroutine
47+ Coroutine coro;
48+
49+ if (!m_warp)
50+ coro = initCoroutine (func);
51+
4252 std::vector<IfStatement> ifStatements;
4353 std::vector<Loop> loops;
4454 m_heap.clear ();
@@ -74,8 +84,17 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
7484 }
7585
7686 case Step::Type::Yield:
77- freeHeap ();
78- // TODO: Implement yielding
87+ if (!m_warp) {
88+ freeHeap ();
89+ llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create (m_ctx, " " , func);
90+ llvm::Value *noneToken = llvm::ConstantTokenNone::get (m_ctx);
91+ llvm::Value *suspendResult = m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1 (false ) });
92+ llvm::SwitchInst *sw = m_builder.CreateSwitch (suspendResult, coro.suspend , 2 );
93+ sw->addCase (m_builder.getInt8 (0 ), resumeBranch);
94+ sw->addCase (m_builder.getInt8 (1 ), coro.cleanup );
95+ m_builder.SetInsertPoint (resumeBranch);
96+ }
97+
7998 break ;
8099
81100 case Step::Type::BeginIf: {
@@ -277,20 +296,59 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
277296
278297 freeHeap ();
279298
299+ // Add final suspend point
300+ if (!m_warp) {
301+ llvm::BasicBlock *endBranch = llvm::BasicBlock::Create (m_ctx, " end" , func);
302+ llvm::Value *suspendResult =
303+ m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get (m_ctx), m_builder.getInt1 (true ) });
304+ llvm::SwitchInst *sw = m_builder.CreateSwitch (suspendResult, coro.suspend , 2 );
305+ sw->addCase (m_builder.getInt8 (0 ), endBranch);
306+ sw->addCase (m_builder.getInt8 (1 ), coro.cleanup );
307+ m_builder.SetInsertPoint (endBranch);
308+ }
309+
280310 // End and verify the function
281311 if (!m_tmpRegs.empty ()) {
282312 std::cout
283313 << " warning: " << m_tmpRegs.size () << " registers were leaked by script '" << m_module->getName ().str () << " ', function '" << func->getName ().str ()
284314 << " ' (if you see this as a regular user, this is a bug and should be reported)" << std::endl;
285315 }
286316
287- m_builder.CreateRetVoid ();
317+ if (m_warp)
318+ m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
319+ else
320+ m_builder.CreateBr (coro.cleanup );
288321
289- if (llvm::verifyFunction (*func, &llvm::errs ())) {
290- llvm::errs () << " error: LLVM function verficiation failed!\n " ;
291- llvm::errs () << " script hat ID: " << m_id << " \n " ;
322+ verifyFunction (func);
323+
324+ // Create resume function
325+ // bool resume(void *)
326+ funcType = llvm::FunctionType::get (m_builder.getInt1Ty (), pointerType, false );
327+ func = llvm::Function::Create (funcType, llvm::Function::ExternalLinkage, " resume" , m_module.get ());
328+
329+ entry = llvm::BasicBlock::Create (m_ctx, " entry" , func);
330+ m_builder.SetInsertPoint (entry);
331+
332+ if (m_warp)
333+ m_builder.CreateRet (m_builder.getInt1 (true ));
334+ else {
335+ llvm::Value *coroHandle = func->getArg (0 );
336+ m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_resume), { coroHandle });
337+ llvm::Value *done = m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_done), { coroHandle });
338+ m_builder.CreateRet (done);
292339 }
293340
341+ verifyFunction (func);
342+
343+ #ifdef PRINT_LLVM_IR
344+ std::cout << std::endl << " === LLVM IR (" << m_module->getName ().str () << " ) ===" << std::endl;
345+ m_module->print (llvm::outs (), nullptr );
346+ std::cout << " ==============" << std::endl << std::endl;
347+ #endif
348+
349+ // Optimize
350+ optimize ();
351+
294352#ifdef PRINT_LLVM_IR
295353 std::cout << std::endl << " === LLVM IR (" << m_module->getName ().str () << " ) ===" << std::endl;
296354 m_module->print (llvm::outs (), nullptr );
@@ -396,6 +454,9 @@ void LLVMCodeBuilder::beginLoopCondition()
396454
397455void LLVMCodeBuilder::endLoop ()
398456{
457+ if (!m_warp)
458+ m_steps.push_back (Step (Step::Type::Yield));
459+
399460 m_steps.push_back (Step (Step::Type::EndLoop));
400461}
401462
@@ -431,6 +492,82 @@ void LLVMCodeBuilder::initTypes()
431492 m_valueDataType->setBody ({ unionType, valueType, sizeType });
432493}
433494
495+ LLVMCodeBuilder::Coroutine LLVMCodeBuilder::initCoroutine (llvm::Function *func)
496+ {
497+ // Set presplitcoroutine attribute
498+ func->setPresplitCoroutine ();
499+
500+ // Coroutine intrinsics
501+ llvm::Function *coroId = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_id);
502+ llvm::Function *coroSize = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_size, m_builder.getInt64Ty ());
503+ llvm::Function *coroBegin = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_begin);
504+ llvm::Function *coroEnd = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_end);
505+ llvm::Function *coroFree = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_free);
506+
507+ // Init coroutine
508+ Coroutine coro;
509+ llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 );
510+ llvm::Constant *nullPointer = llvm::ConstantPointerNull::get (pointerType);
511+ llvm::Value *coroIdRet = m_builder.CreateCall (coroId, { m_builder.getInt32 (8 ), nullPointer, nullPointer, nullPointer });
512+
513+ // Allocate memory
514+ llvm::Value *coroSizeRet = m_builder.CreateCall (coroSize, std::nullopt , " size" );
515+ llvm::Function *mallocFunc = llvm::Function::Create (llvm::FunctionType::get (pointerType, { m_builder.getInt64Ty () }, false ), llvm::Function::ExternalLinkage, " malloc" , m_module.get ());
516+ llvm::Value *alloc = m_builder.CreateCall (mallocFunc, coroSizeRet, " mem" );
517+
518+ // Begin
519+ coro.handle = m_builder.CreateCall (coroBegin, { coroIdRet, alloc });
520+ llvm::BasicBlock *entry = m_builder.GetInsertBlock ();
521+
522+ // Create suspend branch
523+ coro.suspend = llvm::BasicBlock::Create (m_ctx, " suspend" , func);
524+ m_builder.SetInsertPoint (coro.suspend );
525+ m_builder.CreateCall (coroEnd, { coro.handle , m_builder.getInt1 (false ), llvm::ConstantTokenNone::get (m_ctx) });
526+ m_builder.CreateRet (coro.handle );
527+
528+ // Create free branch
529+ llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create (m_ctx, " free" , func);
530+ m_builder.SetInsertPoint (freeBranch);
531+ m_builder.CreateFree (alloc);
532+ m_builder.CreateBr (coro.suspend );
533+
534+ // Create cleanup branch
535+ coro.cleanup = llvm::BasicBlock::Create (m_ctx, " cleanup" , func);
536+ m_builder.SetInsertPoint (coro.cleanup );
537+ llvm::Value *mem = m_builder.CreateCall (coroFree, { coroIdRet, coro.handle });
538+ llvm::Value *needFree = m_builder.CreateIsNotNull (mem);
539+ m_builder.CreateCondBr (needFree, freeBranch, coro.suspend );
540+
541+ m_builder.SetInsertPoint (entry);
542+ return coro;
543+ }
544+
545+ void LLVMCodeBuilder::verifyFunction (llvm::Function *func)
546+ {
547+ if (llvm::verifyFunction (*func, &llvm::errs ())) {
548+ llvm::errs () << " error: LLVM function verficiation failed!\n " ;
549+ llvm::errs () << " script hat ID: " << m_id << " \n " ;
550+ }
551+ }
552+
553+ void LLVMCodeBuilder::optimize ()
554+ {
555+ llvm::PassBuilder passBuilder;
556+ llvm::LoopAnalysisManager loopAnalysisManager;
557+ llvm::FunctionAnalysisManager functionAnalysisManager;
558+ llvm::CGSCCAnalysisManager cGSCCAnalysisManager;
559+ llvm::ModuleAnalysisManager moduleAnalysisManager;
560+
561+ passBuilder.registerModuleAnalyses (moduleAnalysisManager);
562+ passBuilder.registerCGSCCAnalyses (cGSCCAnalysisManager);
563+ passBuilder.registerFunctionAnalyses (functionAnalysisManager);
564+ passBuilder.registerLoopAnalyses (loopAnalysisManager);
565+ passBuilder.crossRegisterProxies (loopAnalysisManager, functionAnalysisManager, cGSCCAnalysisManager, moduleAnalysisManager);
566+
567+ llvm::ModulePassManager modulePassManager = passBuilder.buildPerModuleDefaultPipeline (llvm::OptimizationLevel::O3);
568+ modulePassManager.run (*m_module, moduleAnalysisManager);
569+ }
570+
434571void LLVMCodeBuilder::freeHeap ()
435572{
436573 // Free dynamically allocated memory
0 commit comments