Skip to content

Commit 2b031c7

Browse files
committed
Implement suspendable code
1 parent 6d9a04d commit 2b031c7

15 files changed

+426
-80
lines changed

src/dev/engine/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ std::shared_ptr<libscratchcpp::Block> Compiler::block() const
3838
/*! Compiles the script starting with the given block. */
3939
std::shared_ptr<ExecutableCode> Compiler::compile(std::shared_ptr<Block> startBlock)
4040
{
41-
impl->builder = impl->builderFactory->create(startBlock->id());
41+
impl->builder = impl->builderFactory->create(startBlock->id(), false);
4242
impl->substackTree.clear();
4343
impl->substackHit = false;
4444
impl->warp = false;

src/dev/engine/internal/codebuilderfactory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ std::shared_ptr<CodeBuilderFactory> CodeBuilderFactory::instance()
1212
return m_instance;
1313
}
1414

15-
std::shared_ptr<ICodeBuilder> CodeBuilderFactory::create(const std::string &id) const
15+
std::shared_ptr<ICodeBuilder> CodeBuilderFactory::create(const std::string &id, bool warp) const
1616
{
17-
return std::make_shared<LLVMCodeBuilder>(id);
17+
return std::make_shared<LLVMCodeBuilder>(id, warp);
1818
}

src/dev/engine/internal/codebuilderfactory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class CodeBuilderFactory : public ICodeBuilderFactory
1111
{
1212
public:
1313
static std::shared_ptr<CodeBuilderFactory> instance();
14-
std::shared_ptr<ICodeBuilder> create(const std::string &id) const override;
14+
std::shared_ptr<ICodeBuilder> create(const std::string &id, bool warp) const override;
1515

1616
private:
1717
static std::shared_ptr<CodeBuilderFactory> m_instance;

src/dev/engine/internal/icodebuilderfactory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ICodeBuilderFactory
1414
public:
1515
virtual ~ICodeBuilderFactory() { }
1616

17-
virtual std::shared_ptr<ICodeBuilder> create(const std::string &id) const = 0;
17+
virtual std::shared_ptr<ICodeBuilder> create(const std::string &id, bool warp) const = 0;
1818
};
1919

2020
} // namespace libscratchcpp

src/dev/engine/internal/llvmcodebuilder.cpp

Lines changed: 147 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <llvm/Support/TargetSelect.h>
44
#include <llvm/IR/Verifier.h>
55
#include <llvm/ExecutionEngine/Orc/LLJIT.h>
6+
#include <llvm/Passes/PassBuilder.h>
67

78
#include "llvmcodebuilder.h"
89
#include "llvmexecutablecode.h"
@@ -15,10 +16,12 @@ static std::unordered_map<ValueType, Compiler::StaticType> TYPE_MAP = {
1516
{ ValueType::NegativeInfinity, Compiler::StaticType::Number }, { ValueType::NaN, Compiler::StaticType::Number }
1617
};
1718

18-
LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id) :
19+
LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id, bool warp) :
1920
m_id(id),
2021
m_module(std::make_unique<llvm::Module>(id, m_ctx)),
21-
m_builder(m_ctx)
22+
m_builder(m_ctx),
23+
m_defaultWarp(warp),
24+
m_warp(warp)
2225
{
2326
llvm::InitializeNativeTarget();
2427
llvm::InitializeNativeTargetAsmPrinter();
@@ -32,13 +35,20 @@ LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id) :
3235
std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
3336
{
3437
// Create function
35-
// void f(Target *)
36-
llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder.getVoidTy(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false);
38+
// void *f(Target *)
39+
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0);
40+
llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, pointerType, false);
3741
llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get());
3842

3943
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func);
4044
m_builder.SetInsertPoint(entry);
4145

46+
// Init coroutine
47+
Coroutine coro;
48+
49+
if (!m_warp)
50+
coro = initCoroutine(func);
51+
4252
std::vector<IfStatement> ifStatements;
4353
std::vector<Loop> loops;
4454
m_heap.clear();
@@ -74,8 +84,17 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
7484
}
7585

7686
case Step::Type::Yield:
77-
freeHeap();
78-
// TODO: Implement yielding
87+
if (!m_warp) {
88+
freeHeap();
89+
llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create(m_ctx, "", func);
90+
llvm::Value *noneToken = llvm::ConstantTokenNone::get(m_ctx);
91+
llvm::Value *suspendResult = m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1(false) });
92+
llvm::SwitchInst *sw = m_builder.CreateSwitch(suspendResult, coro.suspend, 2);
93+
sw->addCase(m_builder.getInt8(0), resumeBranch);
94+
sw->addCase(m_builder.getInt8(1), coro.cleanup);
95+
m_builder.SetInsertPoint(resumeBranch);
96+
}
97+
7998
break;
8099

81100
case Step::Type::BeginIf: {
@@ -277,20 +296,59 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
277296

278297
freeHeap();
279298

299+
// Add final suspend point
300+
if (!m_warp) {
301+
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_ctx, "end", func);
302+
llvm::Value *suspendResult =
303+
m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get(m_ctx), m_builder.getInt1(true) });
304+
llvm::SwitchInst *sw = m_builder.CreateSwitch(suspendResult, coro.suspend, 2);
305+
sw->addCase(m_builder.getInt8(0), endBranch);
306+
sw->addCase(m_builder.getInt8(1), coro.cleanup);
307+
m_builder.SetInsertPoint(endBranch);
308+
}
309+
280310
// End and verify the function
281311
if (!m_tmpRegs.empty()) {
282312
std::cout
283313
<< "warning: " << m_tmpRegs.size() << " registers were leaked by script '" << m_module->getName().str() << "', function '" << func->getName().str()
284314
<< "' (if you see this as a regular user, this is a bug and should be reported)" << std::endl;
285315
}
286316

287-
m_builder.CreateRetVoid();
317+
if (m_warp)
318+
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
319+
else
320+
m_builder.CreateBr(coro.cleanup);
288321

289-
if (llvm::verifyFunction(*func, &llvm::errs())) {
290-
llvm::errs() << "error: LLVM function verficiation failed!\n";
291-
llvm::errs() << "script hat ID: " << m_id << "\n";
322+
verifyFunction(func);
323+
324+
// Create resume function
325+
// bool resume(void *)
326+
funcType = llvm::FunctionType::get(m_builder.getInt1Ty(), pointerType, false);
327+
func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "resume", m_module.get());
328+
329+
entry = llvm::BasicBlock::Create(m_ctx, "entry", func);
330+
m_builder.SetInsertPoint(entry);
331+
332+
if (m_warp)
333+
m_builder.CreateRet(m_builder.getInt1(true));
334+
else {
335+
llvm::Value *coroHandle = func->getArg(0);
336+
m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_resume), { coroHandle });
337+
llvm::Value *done = m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_done), { coroHandle });
338+
m_builder.CreateRet(done);
292339
}
293340

341+
verifyFunction(func);
342+
343+
#ifdef PRINT_LLVM_IR
344+
std::cout << std::endl << "=== LLVM IR (" << m_module->getName().str() << ") ===" << std::endl;
345+
m_module->print(llvm::outs(), nullptr);
346+
std::cout << "==============" << std::endl << std::endl;
347+
#endif
348+
349+
// Optimize
350+
optimize();
351+
294352
#ifdef PRINT_LLVM_IR
295353
std::cout << std::endl << "=== LLVM IR (" << m_module->getName().str() << ") ===" << std::endl;
296354
m_module->print(llvm::outs(), nullptr);
@@ -396,6 +454,9 @@ void LLVMCodeBuilder::beginLoopCondition()
396454

397455
void LLVMCodeBuilder::endLoop()
398456
{
457+
if (!m_warp)
458+
m_steps.push_back(Step(Step::Type::Yield));
459+
399460
m_steps.push_back(Step(Step::Type::EndLoop));
400461
}
401462

@@ -431,6 +492,82 @@ void LLVMCodeBuilder::initTypes()
431492
m_valueDataType->setBody({ unionType, valueType, sizeType });
432493
}
433494

495+
LLVMCodeBuilder::Coroutine LLVMCodeBuilder::initCoroutine(llvm::Function *func)
496+
{
497+
// Set presplitcoroutine attribute
498+
func->setPresplitCoroutine();
499+
500+
// Coroutine intrinsics
501+
llvm::Function *coroId = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_id);
502+
llvm::Function *coroSize = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_size, m_builder.getInt64Ty());
503+
llvm::Function *coroBegin = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_begin);
504+
llvm::Function *coroEnd = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_end);
505+
llvm::Function *coroFree = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_free);
506+
507+
// Init coroutine
508+
Coroutine coro;
509+
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0);
510+
llvm::Constant *nullPointer = llvm::ConstantPointerNull::get(pointerType);
511+
llvm::Value *coroIdRet = m_builder.CreateCall(coroId, { m_builder.getInt32(8), nullPointer, nullPointer, nullPointer });
512+
513+
// Allocate memory
514+
llvm::Value *coroSizeRet = m_builder.CreateCall(coroSize, std::nullopt, "size");
515+
llvm::Function *mallocFunc = llvm::Function::Create(llvm::FunctionType::get(pointerType, { m_builder.getInt64Ty() }, false), llvm::Function::ExternalLinkage, "malloc", m_module.get());
516+
llvm::Value *alloc = m_builder.CreateCall(mallocFunc, coroSizeRet, "mem");
517+
518+
// Begin
519+
coro.handle = m_builder.CreateCall(coroBegin, { coroIdRet, alloc });
520+
llvm::BasicBlock *entry = m_builder.GetInsertBlock();
521+
522+
// Create suspend branch
523+
coro.suspend = llvm::BasicBlock::Create(m_ctx, "suspend", func);
524+
m_builder.SetInsertPoint(coro.suspend);
525+
m_builder.CreateCall(coroEnd, { coro.handle, m_builder.getInt1(false), llvm::ConstantTokenNone::get(m_ctx) });
526+
m_builder.CreateRet(coro.handle);
527+
528+
// Create free branch
529+
llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create(m_ctx, "free", func);
530+
m_builder.SetInsertPoint(freeBranch);
531+
m_builder.CreateFree(alloc);
532+
m_builder.CreateBr(coro.suspend);
533+
534+
// Create cleanup branch
535+
coro.cleanup = llvm::BasicBlock::Create(m_ctx, "cleanup", func);
536+
m_builder.SetInsertPoint(coro.cleanup);
537+
llvm::Value *mem = m_builder.CreateCall(coroFree, { coroIdRet, coro.handle });
538+
llvm::Value *needFree = m_builder.CreateIsNotNull(mem);
539+
m_builder.CreateCondBr(needFree, freeBranch, coro.suspend);
540+
541+
m_builder.SetInsertPoint(entry);
542+
return coro;
543+
}
544+
545+
void LLVMCodeBuilder::verifyFunction(llvm::Function *func)
546+
{
547+
if (llvm::verifyFunction(*func, &llvm::errs())) {
548+
llvm::errs() << "error: LLVM function verficiation failed!\n";
549+
llvm::errs() << "script hat ID: " << m_id << "\n";
550+
}
551+
}
552+
553+
void LLVMCodeBuilder::optimize()
554+
{
555+
llvm::PassBuilder passBuilder;
556+
llvm::LoopAnalysisManager loopAnalysisManager;
557+
llvm::FunctionAnalysisManager functionAnalysisManager;
558+
llvm::CGSCCAnalysisManager cGSCCAnalysisManager;
559+
llvm::ModuleAnalysisManager moduleAnalysisManager;
560+
561+
passBuilder.registerModuleAnalyses(moduleAnalysisManager);
562+
passBuilder.registerCGSCCAnalyses(cGSCCAnalysisManager);
563+
passBuilder.registerFunctionAnalyses(functionAnalysisManager);
564+
passBuilder.registerLoopAnalyses(loopAnalysisManager);
565+
passBuilder.crossRegisterProxies(loopAnalysisManager, functionAnalysisManager, cGSCCAnalysisManager, moduleAnalysisManager);
566+
567+
llvm::ModulePassManager modulePassManager = passBuilder.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3);
568+
modulePassManager.run(*m_module, moduleAnalysisManager);
569+
}
570+
434571
void LLVMCodeBuilder::freeHeap()
435572
{
436573
// Free dynamically allocated memory

src/dev/engine/internal/llvmcodebuilder.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Target;
1717
class LLVMCodeBuilder : public ICodeBuilder
1818
{
1919
public:
20-
LLVMCodeBuilder(const std::string &id);
20+
LLVMCodeBuilder(const std::string &id, bool warp);
2121

2222
std::shared_ptr<ExecutableCode> finalize() override;
2323

@@ -98,8 +98,25 @@ class LLVMCodeBuilder : public ICodeBuilder
9898
llvm::BasicBlock *afterLoop = nullptr;
9999
};
100100

101+
struct Coroutine
102+
{
103+
llvm::Value *handle = nullptr;
104+
llvm::BasicBlock *suspend = nullptr;
105+
llvm::BasicBlock *cleanup = nullptr;
106+
};
107+
108+
struct Procedure
109+
{
110+
// TODO: Implement procedures
111+
bool warp = false;
112+
};
113+
101114
void initTypes();
102115

116+
Coroutine initCoroutine(llvm::Function *func);
117+
void verifyFunction(llvm::Function *func);
118+
void optimize();
119+
103120
void freeHeap();
104121
llvm::Value *castValue(std::shared_ptr<Register> reg, Compiler::StaticType targetType);
105122
llvm::Value *castRawValue(std::shared_ptr<Register> reg, Compiler::StaticType targetType);
@@ -134,6 +151,8 @@ class LLVMCodeBuilder : public ICodeBuilder
134151
std::vector<Value> m_constValues;
135152
std::vector<std::vector<std::shared_ptr<Register>>> m_regs;
136153
std::vector<std::shared_ptr<Register>> m_tmpRegs;
154+
bool m_defaultWarp = false;
155+
bool m_warp = false;
137156

138157
std::vector<llvm::Value *> m_heap;
139158

src/dev/engine/internal/llvmexecutablecode.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,51 @@ LLVMExecutableCode::LLVMExecutableCode(std::unique_ptr<llvm::Module> module) :
3333
// Lookup functions
3434
m_mainFunction = (MainFunctionType)lookupFunction("f");
3535
assert(m_mainFunction);
36+
m_resumeFunction = (ResumeFunctionType)lookupFunction("resume");
37+
assert(m_resumeFunction);
3638
}
3739

3840
void LLVMExecutableCode::run(ExecutionContext *context)
3941
{
4042
LLVMExecutionContext *ctx = getContext(context);
4143

42-
if (!ctx->finished) {
43-
m_mainFunction(context->target());
44-
ctx->finished = true;
44+
if (ctx->finished())
45+
return;
46+
47+
if (ctx->coroutineHandle()) {
48+
bool done = m_resumeFunction(ctx->coroutineHandle());
49+
50+
if (done)
51+
ctx->setCoroutineHandle(nullptr);
52+
53+
ctx->setFinished(done);
54+
} else {
55+
void *handle = m_mainFunction(context->target());
56+
57+
if (!handle)
58+
ctx->setFinished(true);
59+
60+
ctx->setCoroutineHandle(handle);
4561
}
4662
}
4763

4864
void LLVMExecutableCode::kill(ExecutionContext *context)
4965
{
50-
getContext(context)->finished = true;
66+
LLVMExecutionContext *ctx = getContext(context);
67+
ctx->setCoroutineHandle(nullptr);
68+
ctx->setFinished(true);
5169
}
5270

5371
void LLVMExecutableCode::reset(ExecutionContext *context)
5472
{
55-
getContext(context)->finished = false;
73+
LLVMExecutionContext *ctx = getContext(context);
74+
ctx->setCoroutineHandle(nullptr);
75+
ctx->setFinished(false);
5676
}
5777

5878
bool LLVMExecutableCode::isFinished(ExecutionContext *context) const
5979
{
60-
return getContext(context)->finished;
80+
return getContext(context)->finished();
6181
}
6282

6383
void LLVMExecutableCode::promise()

src/dev/engine/internal/llvmexecutablecode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ class LLVMExecutableCode : public ExecutableCode
3232
private:
3333
uint64_t lookupFunction(const std::string &name);
3434

35-
using MainFunctionType = size_t (*)(Target *);
35+
using MainFunctionType = void *(*)(Target *);
36+
using ResumeFunctionType = bool (*)(void *);
3637

3738
static LLVMExecutionContext *getContext(ExecutionContext *context);
3839

3940
std::unique_ptr<llvm::LLVMContext> m_ctx;
4041
llvm::Expected<std::unique_ptr<llvm::orc::LLJIT>> m_jit;
4142

4243
MainFunctionType m_mainFunction;
44+
ResumeFunctionType m_resumeFunction;
4345
};
4446

4547
} // namespace libscratchcpp

0 commit comments

Comments
 (0)