From 6d9a04d3353ed54defad315b4a400efbd1328f83 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Fri, 11 Oct 2024 22:32:46 +0200 Subject: [PATCH 1/2] Drop multiple functions in compiled code --- src/dev/engine/internal/llvmcodebuilder.cpp | 89 ++++++++----------- src/dev/engine/internal/llvmcodebuilder.h | 2 - .../engine/internal/llvmexecutablecode.cpp | 41 ++++----- src/dev/engine/internal/llvmexecutablecode.h | 6 +- .../engine/internal/llvmexecutioncontext.cpp | 10 --- .../engine/internal/llvmexecutioncontext.h | 6 +- test/dev/llvm/llvmcodebuilder_test.cpp | 4 +- test/dev/llvm/llvmexecutablecode_test.cpp | 61 +++---------- test/dev/llvm/llvmexecutioncontext_test.cpp | 12 --- 9 files changed, 76 insertions(+), 155 deletions(-) diff --git a/src/dev/engine/internal/llvmcodebuilder.cpp b/src/dev/engine/internal/llvmcodebuilder.cpp index 5a2c0955..47626f1d 100644 --- a/src/dev/engine/internal/llvmcodebuilder.cpp +++ b/src/dev/engine/internal/llvmcodebuilder.cpp @@ -31,8 +31,14 @@ LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id) : std::shared_ptr LLVMCodeBuilder::finalize() { - size_t functionIndex = 0; - llvm::Function *currentFunc = beginFunction(functionIndex); + // Create function + // void f(Target *) + llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder.getVoidTy(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false); + llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get()); + + llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); + m_builder.SetInsertPoint(entry); + std::vector ifStatements; std::vector loops; m_heap.clear(); @@ -45,9 +51,9 @@ std::shared_ptr LLVMCodeBuilder::finalize() std::vector args; // Add target pointer arg - assert(currentFunc->arg_size() == 1); + assert(func->arg_size() == 1); types.push_back(llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0)); - args.push_back(currentFunc->getArg(0)); + args.push_back(func->getArg(0)); // Args for (auto &arg : step.args) { @@ -69,14 +75,13 @@ std::shared_ptr LLVMCodeBuilder::finalize() case Step::Type::Yield: freeHeap(); - endFunction(currentFunc, functionIndex); - currentFunc = beginFunction(++functionIndex); + // TODO: Implement yielding break; case Step::Type::BeginIf: { IfStatement statement; statement.beforeIf = m_builder.GetInsertBlock(); - statement.body = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + statement.body = llvm::BasicBlock::Create(m_ctx, "", func); // Use last reg assert(step.args.size() == 1); @@ -98,13 +103,13 @@ std::shared_ptr LLVMCodeBuilder::finalize() // Jump to the branch after the if statement assert(!statement.afterIf); - statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func); freeHeap(); m_builder.CreateBr(statement.afterIf); // Create else branch assert(!statement.elseBranch); - statement.elseBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + statement.elseBranch = llvm::BasicBlock::Create(m_ctx, "", func); // Since there's an else branch, the conditional instruction should jump to it m_builder.SetInsertPoint(statement.beforeIf); @@ -121,7 +126,7 @@ std::shared_ptr LLVMCodeBuilder::finalize() // Jump to the branch after the if statement if (!statement.afterIf) - statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func); freeHeap(); m_builder.CreateBr(statement.afterIf); @@ -150,9 +155,9 @@ std::shared_ptr LLVMCodeBuilder::finalize() m_builder.CreateStore(zero, loop.index); // Create branches - llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc); - loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc); - loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_ctx, "", func); + loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", func); + loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func); // Use last reg for count assert(step.args.size() == 1); @@ -177,10 +182,10 @@ std::shared_ptr LLVMCodeBuilder::finalize() // Check index m_builder.SetInsertPoint(loop.conditionBranch); - llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", func); if (!loop.afterLoop) - loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func); llvm::Value *currentIndex = m_builder.CreateLoad(m_builder.getInt64Ty(), loop.index); comparison = m_builder.CreateICmpULT(currentIndex, count); @@ -198,8 +203,8 @@ std::shared_ptr LLVMCodeBuilder::finalize() Loop &loop = loops.back(); // Create branches - llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", currentFunc); - loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", func); + loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func); // Use last reg assert(step.args.size() == 1); @@ -219,8 +224,8 @@ std::shared_ptr LLVMCodeBuilder::finalize() Loop &loop = loops.back(); // Create branches - llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", currentFunc); - loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", func); + loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func); // Use last reg assert(step.args.size() == 1); @@ -238,7 +243,7 @@ std::shared_ptr LLVMCodeBuilder::finalize() case Step::Type::BeginLoopCondition: { Loop loop; loop.isRepeatLoop = false; - loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc); + loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", func); freeHeap(); m_builder.CreateBr(loop.conditionBranch); m_builder.SetInsertPoint(loop.conditionBranch); @@ -272,7 +277,19 @@ std::shared_ptr LLVMCodeBuilder::finalize() freeHeap(); - endFunction(currentFunc, functionIndex); + // End and verify the function + if (!m_tmpRegs.empty()) { + std::cout + << "warning: " << m_tmpRegs.size() << " registers were leaked by script '" << m_module->getName().str() << "', function '" << func->getName().str() + << "' (if you see this as a regular user, this is a bug and should be reported)" << std::endl; + } + + m_builder.CreateRetVoid(); + + if (llvm::verifyFunction(*func, &llvm::errs())) { + llvm::errs() << "error: LLVM function verficiation failed!\n"; + llvm::errs() << "script hat ID: " << m_id << "\n"; + } #ifdef PRINT_LLVM_IR std::cout << std::endl << "=== LLVM IR (" << m_module->getName().str() << ") ===" << std::endl; @@ -414,36 +431,6 @@ void LLVMCodeBuilder::initTypes() m_valueDataType->setBody({ unionType, valueType, sizeType }); } -llvm::Function *LLVMCodeBuilder::beginFunction(size_t index) -{ - // size_t f#(Target *) - llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder.getInt64Ty(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false); - llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f" + std::to_string(index), m_module.get()); - - llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); - m_builder.SetInsertPoint(entry); - - return func; -} - -void LLVMCodeBuilder::endFunction(llvm::Function *func, size_t index) -{ - if (!m_tmpRegs.empty()) { - std::cout - << "warning: " << m_tmpRegs.size() << " registers were leaked by script '" << m_module->getName().str() << "', function '" << func->getName().str() - << "' (if you see this as a regular user, this is a bug and should be reported)" << std::endl; - } - - // Return next function index - m_builder.CreateRet(m_builder.getInt64(index + 1)); - - if (llvm::verifyFunction(*func, &llvm::errs())) { - llvm::errs() << "error: LLVM function verficiation failed!\n"; - llvm::errs() << "script hat ID: " << m_id << "\n"; - llvm::errs() << "function name: " << func->getName().data() << "\n"; - } -} - void LLVMCodeBuilder::freeHeap() { // Free dynamically allocated memory diff --git a/src/dev/engine/internal/llvmcodebuilder.h b/src/dev/engine/internal/llvmcodebuilder.h index 78923bcb..0eba59f2 100644 --- a/src/dev/engine/internal/llvmcodebuilder.h +++ b/src/dev/engine/internal/llvmcodebuilder.h @@ -99,8 +99,6 @@ class LLVMCodeBuilder : public ICodeBuilder }; void initTypes(); - llvm::Function *beginFunction(size_t index); - void endFunction(llvm::Function *func, size_t index); void freeHeap(); llvm::Value *castValue(std::shared_ptr reg, Compiler::StaticType targetType); diff --git a/src/dev/engine/internal/llvmexecutablecode.cpp b/src/dev/engine/internal/llvmexecutablecode.cpp index c38fe784..6e40cb4d 100644 --- a/src/dev/engine/internal/llvmexecutablecode.cpp +++ b/src/dev/engine/internal/llvmexecutablecode.cpp @@ -31,44 +31,33 @@ LLVMExecutableCode::LLVMExecutableCode(std::unique_ptr module) : } // Lookup functions - size_t i = 0; - - while (true) { - auto func = m_jit->get()->lookup("f" + std::to_string(i)); - - if (func) - m_functions.push_back((FunctionType)(func->getValue())); - else { - // Ignore error - llvm::consumeError(func.takeError()); - break; - } - - i++; - } + m_mainFunction = (MainFunctionType)lookupFunction("f"); + assert(m_mainFunction); } void LLVMExecutableCode::run(ExecutionContext *context) { LLVMExecutionContext *ctx = getContext(context); - if (ctx->pos() < m_functions.size()) - ctx->setPos(m_functions[ctx->pos()](context->target())); + if (!ctx->finished) { + m_mainFunction(context->target()); + ctx->finished = true; + } } void LLVMExecutableCode::kill(ExecutionContext *context) { - getContext(context)->setPos(m_functions.size()); + getContext(context)->finished = true; } void LLVMExecutableCode::reset(ExecutionContext *context) { - getContext(context)->setPos(0); + getContext(context)->finished = false; } bool LLVMExecutableCode::isFinished(ExecutionContext *context) const { - return getContext(context)->pos() >= m_functions.size(); + return getContext(context)->finished; } void LLVMExecutableCode::promise() @@ -84,6 +73,18 @@ std::shared_ptr LLVMExecutableCode::createExecutionContext(Tar return std::make_shared(target); } +uint64_t LLVMExecutableCode::lookupFunction(const std::string &name) +{ + auto func = m_jit->get()->lookup(name); + + if (func) + return func->getValue(); + else { + llvm::errs() << "error: failed to lookup LLVM function: " << toString(func.takeError()) << "\n"; + return 0; + } +} + LLVMExecutionContext *LLVMExecutableCode::getContext(ExecutionContext *context) { assert(dynamic_cast(context)); diff --git a/src/dev/engine/internal/llvmexecutablecode.h b/src/dev/engine/internal/llvmexecutablecode.h index 312ef597..ce9fd371 100644 --- a/src/dev/engine/internal/llvmexecutablecode.h +++ b/src/dev/engine/internal/llvmexecutablecode.h @@ -30,14 +30,16 @@ class LLVMExecutableCode : public ExecutableCode std::shared_ptr createExecutionContext(Target *target) const override; private: - using FunctionType = size_t (*)(Target *); + uint64_t lookupFunction(const std::string &name); + + using MainFunctionType = size_t (*)(Target *); static LLVMExecutionContext *getContext(ExecutionContext *context); std::unique_ptr m_ctx; llvm::Expected> m_jit; - std::vector m_functions; + MainFunctionType m_mainFunction; }; } // namespace libscratchcpp diff --git a/src/dev/engine/internal/llvmexecutioncontext.cpp b/src/dev/engine/internal/llvmexecutioncontext.cpp index 5200d9fb..b399541d 100644 --- a/src/dev/engine/internal/llvmexecutioncontext.cpp +++ b/src/dev/engine/internal/llvmexecutioncontext.cpp @@ -8,13 +8,3 @@ LLVMExecutionContext::LLVMExecutionContext(Target *target) : ExecutionContext(target) { } - -size_t LLVMExecutionContext::pos() const -{ - return m_pos; -} - -void LLVMExecutionContext::setPos(size_t newPos) -{ - m_pos = newPos; -} diff --git a/src/dev/engine/internal/llvmexecutioncontext.h b/src/dev/engine/internal/llvmexecutioncontext.h index 98f91554..e789aeea 100644 --- a/src/dev/engine/internal/llvmexecutioncontext.h +++ b/src/dev/engine/internal/llvmexecutioncontext.h @@ -12,11 +12,7 @@ class LLVMExecutionContext : public ExecutionContext public: LLVMExecutionContext(Target *target); - size_t pos() const; - void setPos(size_t newPos); - - private: - size_t m_pos = 0; + bool finished = false; // TODO: Remove this }; } // namespace libscratchcpp diff --git a/test/dev/llvm/llvmcodebuilder_test.cpp b/test/dev/llvm/llvmcodebuilder_test.cpp index 642a0cb2..1c2048e4 100644 --- a/test/dev/llvm/llvmcodebuilder_test.cpp +++ b/test/dev/llvm/llvmcodebuilder_test.cpp @@ -189,7 +189,7 @@ TEST_F(LLVMCodeBuilderTest, RawValueCasting) ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); } -TEST_F(LLVMCodeBuilderTest, Yield) +/*TEST_F(LLVMCodeBuilderTest, Yield) { m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); @@ -234,7 +234,7 @@ TEST_F(LLVMCodeBuilderTest, Yield) code->run(ctx.get()); ASSERT_EQ(testing::internal::GetCapturedStdout(), expected2); ASSERT_TRUE(code->isFinished(ctx.get())); -} +}*/ TEST_F(LLVMCodeBuilderTest, IfStatement) { diff --git a/test/dev/llvm/llvmexecutablecode_test.cpp b/test/dev/llvm/llvmexecutablecode_test.cpp index 95993e0b..7c5e818b 100644 --- a/test/dev/llvm/llvmexecutablecode_test.cpp +++ b/test/dev/llvm/llvmexecutablecode_test.cpp @@ -25,22 +25,18 @@ class LLVMExecutableCodeTest : public testing::Test llvm::InitializeNativeTargetAsmParser(); } - llvm::Function *beginFunction(size_t index) + llvm::Function *beginFunction() { - // size_t f#(Target *) - llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder->getInt64Ty(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false); - llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f" + std::to_string(index), m_module.get()); + // void f(Target *) + llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder->getVoidTy(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false); + llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get()); llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); m_builder->SetInsertPoint(entry); return func; } - void endFunction(size_t index) - { - // Return next function index - m_builder->CreateRet(m_builder->getInt64(index + 1)); - } + void endFunction() { m_builder->CreateRetVoid(); } void addTestFunction(llvm::Function *mainFunc) { @@ -69,6 +65,8 @@ class LLVMExecutableCodeTest : public testing::Test TEST_F(LLVMExecutableCodeTest, CreateExecutionContext) { + beginFunction(); + endFunction(); LLVMExecutableCode code(std::move(m_module)); auto ctx = code.createExecutionContext(&m_target); ASSERT_TRUE(ctx); @@ -76,27 +74,11 @@ TEST_F(LLVMExecutableCodeTest, CreateExecutionContext) ASSERT_TRUE(dynamic_cast(ctx.get())); } -TEST_F(LLVMExecutableCodeTest, NoFunctions) -{ - LLVMExecutableCode code(std::move(m_module)); - auto ctx = code.createExecutionContext(&m_target); - ASSERT_TRUE(code.isFinished(ctx.get())); - - code.run(ctx.get()); - ASSERT_TRUE(code.isFinished(ctx.get())); - - code.kill(ctx.get()); - ASSERT_TRUE(code.isFinished(ctx.get())); - - code.reset(ctx.get()); - ASSERT_TRUE(code.isFinished(ctx.get())); -} - -TEST_F(LLVMExecutableCodeTest, SingleFunction) +TEST_F(LLVMExecutableCodeTest, MainFunction) { - auto f = beginFunction(0); + auto f = beginFunction(); addTestFunction(f); - endFunction(0); + endFunction(); LLVMExecutableCode code(std::move(m_module)); auto ctx = code.createExecutionContext(&m_target); @@ -134,26 +116,3 @@ TEST_F(LLVMExecutableCodeTest, SingleFunction) ASSERT_TRUE(code.isFinished(anotherCtx.get())); ASSERT_FALSE(code.isFinished(ctx.get())); } - -TEST_F(LLVMExecutableCodeTest, MultipleFunctions) -{ - static const int count = 5; - - for (int i = 0; i < count; i++) { - auto f = beginFunction(i); - addTestFunction(f); - endFunction(i); - } - - LLVMExecutableCode code(std::move(m_module)); - auto ctx = code.createExecutionContext(&m_target); - ASSERT_FALSE(code.isFinished(ctx.get())); - - for (int i = 0; i < count; i++) { - ASSERT_FALSE(code.isFinished(ctx.get())); - EXPECT_CALL(m_mock, f(&m_target)); - code.run(ctx.get()); - } - - ASSERT_TRUE(code.isFinished(ctx.get())); -} diff --git a/test/dev/llvm/llvmexecutioncontext_test.cpp b/test/dev/llvm/llvmexecutioncontext_test.cpp index 4b715c2c..8a9d3e4b 100644 --- a/test/dev/llvm/llvmexecutioncontext_test.cpp +++ b/test/dev/llvm/llvmexecutioncontext_test.cpp @@ -10,15 +10,3 @@ TEST(LLVMExecutionContextTest, Constructor) LLVMExecutionContext ctx(&target); ASSERT_EQ(ctx.target(), &target); } - -TEST(LLVMExecutionContextTest, Pos) -{ - LLVMExecutionContext ctx(nullptr); - ASSERT_EQ(ctx.pos(), 0); - - ctx.setPos(1); - ASSERT_EQ(ctx.pos(), 1); - - ctx.setPos(356); - ASSERT_EQ(ctx.pos(), 356); -} From 2b031c7c383705cfddcb1c351d2b7d2caed15f27 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sat, 12 Oct 2024 17:31:16 +0200 Subject: [PATCH 2/2] Implement suspendable code --- src/dev/engine/compiler.cpp | 2 +- .../engine/internal/codebuilderfactory.cpp | 4 +- src/dev/engine/internal/codebuilderfactory.h | 2 +- src/dev/engine/internal/icodebuilderfactory.h | 2 +- src/dev/engine/internal/llvmcodebuilder.cpp | 157 ++++++++++++- src/dev/engine/internal/llvmcodebuilder.h | 21 +- .../engine/internal/llvmexecutablecode.cpp | 32 ++- src/dev/engine/internal/llvmexecutablecode.h | 4 +- .../engine/internal/llvmexecutioncontext.cpp | 20 ++ .../engine/internal/llvmexecutioncontext.h | 10 +- test/dev/compiler/compiler_test.cpp | 2 +- test/dev/llvm/llvmcodebuilder_test.cpp | 209 ++++++++++++++---- test/dev/llvm/llvmexecutablecode_test.cpp | 37 +++- test/mocks/codebuilderfactorymock.h | 2 +- test/scratch_classes/inputvalue_test.cpp | 2 +- 15 files changed, 426 insertions(+), 80 deletions(-) diff --git a/src/dev/engine/compiler.cpp b/src/dev/engine/compiler.cpp index b3f4613f..355fec48 100644 --- a/src/dev/engine/compiler.cpp +++ b/src/dev/engine/compiler.cpp @@ -38,7 +38,7 @@ std::shared_ptr Compiler::block() const /*! Compiles the script starting with the given block. */ std::shared_ptr Compiler::compile(std::shared_ptr startBlock) { - impl->builder = impl->builderFactory->create(startBlock->id()); + impl->builder = impl->builderFactory->create(startBlock->id(), false); impl->substackTree.clear(); impl->substackHit = false; impl->warp = false; diff --git a/src/dev/engine/internal/codebuilderfactory.cpp b/src/dev/engine/internal/codebuilderfactory.cpp index fbda1929..5f955bdc 100644 --- a/src/dev/engine/internal/codebuilderfactory.cpp +++ b/src/dev/engine/internal/codebuilderfactory.cpp @@ -12,7 +12,7 @@ std::shared_ptr CodeBuilderFactory::instance() return m_instance; } -std::shared_ptr CodeBuilderFactory::create(const std::string &id) const +std::shared_ptr CodeBuilderFactory::create(const std::string &id, bool warp) const { - return std::make_shared(id); + return std::make_shared(id, warp); } diff --git a/src/dev/engine/internal/codebuilderfactory.h b/src/dev/engine/internal/codebuilderfactory.h index ea12b288..8de574ab 100644 --- a/src/dev/engine/internal/codebuilderfactory.h +++ b/src/dev/engine/internal/codebuilderfactory.h @@ -11,7 +11,7 @@ class CodeBuilderFactory : public ICodeBuilderFactory { public: static std::shared_ptr instance(); - std::shared_ptr create(const std::string &id) const override; + std::shared_ptr create(const std::string &id, bool warp) const override; private: static std::shared_ptr m_instance; diff --git a/src/dev/engine/internal/icodebuilderfactory.h b/src/dev/engine/internal/icodebuilderfactory.h index 9bfb203b..4438e5f8 100644 --- a/src/dev/engine/internal/icodebuilderfactory.h +++ b/src/dev/engine/internal/icodebuilderfactory.h @@ -14,7 +14,7 @@ class ICodeBuilderFactory public: virtual ~ICodeBuilderFactory() { } - virtual std::shared_ptr create(const std::string &id) const = 0; + virtual std::shared_ptr create(const std::string &id, bool warp) const = 0; }; } // namespace libscratchcpp diff --git a/src/dev/engine/internal/llvmcodebuilder.cpp b/src/dev/engine/internal/llvmcodebuilder.cpp index 47626f1d..70b60963 100644 --- a/src/dev/engine/internal/llvmcodebuilder.cpp +++ b/src/dev/engine/internal/llvmcodebuilder.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "llvmcodebuilder.h" #include "llvmexecutablecode.h" @@ -15,10 +16,12 @@ static std::unordered_map TYPE_MAP = { { ValueType::NegativeInfinity, Compiler::StaticType::Number }, { ValueType::NaN, Compiler::StaticType::Number } }; -LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id) : +LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id, bool warp) : m_id(id), m_module(std::make_unique(id, m_ctx)), - m_builder(m_ctx) + m_builder(m_ctx), + m_defaultWarp(warp), + m_warp(warp) { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -32,13 +35,20 @@ LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id) : std::shared_ptr LLVMCodeBuilder::finalize() { // Create function - // void f(Target *) - llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder.getVoidTy(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false); + // void *f(Target *) + llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0); + llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, pointerType, false); llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get()); llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); m_builder.SetInsertPoint(entry); + // Init coroutine + Coroutine coro; + + if (!m_warp) + coro = initCoroutine(func); + std::vector ifStatements; std::vector loops; m_heap.clear(); @@ -74,8 +84,17 @@ std::shared_ptr LLVMCodeBuilder::finalize() } case Step::Type::Yield: - freeHeap(); - // TODO: Implement yielding + if (!m_warp) { + freeHeap(); + llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create(m_ctx, "", func); + llvm::Value *noneToken = llvm::ConstantTokenNone::get(m_ctx); + llvm::Value *suspendResult = m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1(false) }); + llvm::SwitchInst *sw = m_builder.CreateSwitch(suspendResult, coro.suspend, 2); + sw->addCase(m_builder.getInt8(0), resumeBranch); + sw->addCase(m_builder.getInt8(1), coro.cleanup); + m_builder.SetInsertPoint(resumeBranch); + } + break; case Step::Type::BeginIf: { @@ -277,6 +296,17 @@ std::shared_ptr LLVMCodeBuilder::finalize() freeHeap(); + // Add final suspend point + if (!m_warp) { + llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_ctx, "end", func); + llvm::Value *suspendResult = + m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get(m_ctx), m_builder.getInt1(true) }); + llvm::SwitchInst *sw = m_builder.CreateSwitch(suspendResult, coro.suspend, 2); + sw->addCase(m_builder.getInt8(0), endBranch); + sw->addCase(m_builder.getInt8(1), coro.cleanup); + m_builder.SetInsertPoint(endBranch); + } + // End and verify the function if (!m_tmpRegs.empty()) { std::cout @@ -284,13 +314,41 @@ std::shared_ptr LLVMCodeBuilder::finalize() << "' (if you see this as a regular user, this is a bug and should be reported)" << std::endl; } - m_builder.CreateRetVoid(); + if (m_warp) + m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType)); + else + m_builder.CreateBr(coro.cleanup); - if (llvm::verifyFunction(*func, &llvm::errs())) { - llvm::errs() << "error: LLVM function verficiation failed!\n"; - llvm::errs() << "script hat ID: " << m_id << "\n"; + verifyFunction(func); + + // Create resume function + // bool resume(void *) + funcType = llvm::FunctionType::get(m_builder.getInt1Ty(), pointerType, false); + func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "resume", m_module.get()); + + entry = llvm::BasicBlock::Create(m_ctx, "entry", func); + m_builder.SetInsertPoint(entry); + + if (m_warp) + m_builder.CreateRet(m_builder.getInt1(true)); + else { + llvm::Value *coroHandle = func->getArg(0); + m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_resume), { coroHandle }); + llvm::Value *done = m_builder.CreateCall(llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_done), { coroHandle }); + m_builder.CreateRet(done); } + verifyFunction(func); + +#ifdef PRINT_LLVM_IR + std::cout << std::endl << "=== LLVM IR (" << m_module->getName().str() << ") ===" << std::endl; + m_module->print(llvm::outs(), nullptr); + std::cout << "==============" << std::endl << std::endl; +#endif + + // Optimize + optimize(); + #ifdef PRINT_LLVM_IR std::cout << std::endl << "=== LLVM IR (" << m_module->getName().str() << ") ===" << std::endl; m_module->print(llvm::outs(), nullptr); @@ -396,6 +454,9 @@ void LLVMCodeBuilder::beginLoopCondition() void LLVMCodeBuilder::endLoop() { + if (!m_warp) + m_steps.push_back(Step(Step::Type::Yield)); + m_steps.push_back(Step(Step::Type::EndLoop)); } @@ -431,6 +492,82 @@ void LLVMCodeBuilder::initTypes() m_valueDataType->setBody({ unionType, valueType, sizeType }); } +LLVMCodeBuilder::Coroutine LLVMCodeBuilder::initCoroutine(llvm::Function *func) +{ + // Set presplitcoroutine attribute + func->setPresplitCoroutine(); + + // Coroutine intrinsics + llvm::Function *coroId = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_id); + llvm::Function *coroSize = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_size, m_builder.getInt64Ty()); + llvm::Function *coroBegin = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_begin); + llvm::Function *coroEnd = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_end); + llvm::Function *coroFree = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::coro_free); + + // Init coroutine + Coroutine coro; + llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0); + llvm::Constant *nullPointer = llvm::ConstantPointerNull::get(pointerType); + llvm::Value *coroIdRet = m_builder.CreateCall(coroId, { m_builder.getInt32(8), nullPointer, nullPointer, nullPointer }); + + // Allocate memory + llvm::Value *coroSizeRet = m_builder.CreateCall(coroSize, std::nullopt, "size"); + llvm::Function *mallocFunc = llvm::Function::Create(llvm::FunctionType::get(pointerType, { m_builder.getInt64Ty() }, false), llvm::Function::ExternalLinkage, "malloc", m_module.get()); + llvm::Value *alloc = m_builder.CreateCall(mallocFunc, coroSizeRet, "mem"); + + // Begin + coro.handle = m_builder.CreateCall(coroBegin, { coroIdRet, alloc }); + llvm::BasicBlock *entry = m_builder.GetInsertBlock(); + + // Create suspend branch + coro.suspend = llvm::BasicBlock::Create(m_ctx, "suspend", func); + m_builder.SetInsertPoint(coro.suspend); + m_builder.CreateCall(coroEnd, { coro.handle, m_builder.getInt1(false), llvm::ConstantTokenNone::get(m_ctx) }); + m_builder.CreateRet(coro.handle); + + // Create free branch + llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create(m_ctx, "free", func); + m_builder.SetInsertPoint(freeBranch); + m_builder.CreateFree(alloc); + m_builder.CreateBr(coro.suspend); + + // Create cleanup branch + coro.cleanup = llvm::BasicBlock::Create(m_ctx, "cleanup", func); + m_builder.SetInsertPoint(coro.cleanup); + llvm::Value *mem = m_builder.CreateCall(coroFree, { coroIdRet, coro.handle }); + llvm::Value *needFree = m_builder.CreateIsNotNull(mem); + m_builder.CreateCondBr(needFree, freeBranch, coro.suspend); + + m_builder.SetInsertPoint(entry); + return coro; +} + +void LLVMCodeBuilder::verifyFunction(llvm::Function *func) +{ + if (llvm::verifyFunction(*func, &llvm::errs())) { + llvm::errs() << "error: LLVM function verficiation failed!\n"; + llvm::errs() << "script hat ID: " << m_id << "\n"; + } +} + +void LLVMCodeBuilder::optimize() +{ + llvm::PassBuilder passBuilder; + llvm::LoopAnalysisManager loopAnalysisManager; + llvm::FunctionAnalysisManager functionAnalysisManager; + llvm::CGSCCAnalysisManager cGSCCAnalysisManager; + llvm::ModuleAnalysisManager moduleAnalysisManager; + + passBuilder.registerModuleAnalyses(moduleAnalysisManager); + passBuilder.registerCGSCCAnalyses(cGSCCAnalysisManager); + passBuilder.registerFunctionAnalyses(functionAnalysisManager); + passBuilder.registerLoopAnalyses(loopAnalysisManager); + passBuilder.crossRegisterProxies(loopAnalysisManager, functionAnalysisManager, cGSCCAnalysisManager, moduleAnalysisManager); + + llvm::ModulePassManager modulePassManager = passBuilder.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3); + modulePassManager.run(*m_module, moduleAnalysisManager); +} + void LLVMCodeBuilder::freeHeap() { // Free dynamically allocated memory diff --git a/src/dev/engine/internal/llvmcodebuilder.h b/src/dev/engine/internal/llvmcodebuilder.h index 0eba59f2..c61404fb 100644 --- a/src/dev/engine/internal/llvmcodebuilder.h +++ b/src/dev/engine/internal/llvmcodebuilder.h @@ -17,7 +17,7 @@ class Target; class LLVMCodeBuilder : public ICodeBuilder { public: - LLVMCodeBuilder(const std::string &id); + LLVMCodeBuilder(const std::string &id, bool warp); std::shared_ptr finalize() override; @@ -98,8 +98,25 @@ class LLVMCodeBuilder : public ICodeBuilder llvm::BasicBlock *afterLoop = nullptr; }; + struct Coroutine + { + llvm::Value *handle = nullptr; + llvm::BasicBlock *suspend = nullptr; + llvm::BasicBlock *cleanup = nullptr; + }; + + struct Procedure + { + // TODO: Implement procedures + bool warp = false; + }; + void initTypes(); + Coroutine initCoroutine(llvm::Function *func); + void verifyFunction(llvm::Function *func); + void optimize(); + void freeHeap(); llvm::Value *castValue(std::shared_ptr reg, Compiler::StaticType targetType); llvm::Value *castRawValue(std::shared_ptr reg, Compiler::StaticType targetType); @@ -134,6 +151,8 @@ class LLVMCodeBuilder : public ICodeBuilder std::vector m_constValues; std::vector>> m_regs; std::vector> m_tmpRegs; + bool m_defaultWarp = false; + bool m_warp = false; std::vector m_heap; diff --git a/src/dev/engine/internal/llvmexecutablecode.cpp b/src/dev/engine/internal/llvmexecutablecode.cpp index 6e40cb4d..6c2059ba 100644 --- a/src/dev/engine/internal/llvmexecutablecode.cpp +++ b/src/dev/engine/internal/llvmexecutablecode.cpp @@ -33,31 +33,51 @@ LLVMExecutableCode::LLVMExecutableCode(std::unique_ptr module) : // Lookup functions m_mainFunction = (MainFunctionType)lookupFunction("f"); assert(m_mainFunction); + m_resumeFunction = (ResumeFunctionType)lookupFunction("resume"); + assert(m_resumeFunction); } void LLVMExecutableCode::run(ExecutionContext *context) { LLVMExecutionContext *ctx = getContext(context); - if (!ctx->finished) { - m_mainFunction(context->target()); - ctx->finished = true; + if (ctx->finished()) + return; + + if (ctx->coroutineHandle()) { + bool done = m_resumeFunction(ctx->coroutineHandle()); + + if (done) + ctx->setCoroutineHandle(nullptr); + + ctx->setFinished(done); + } else { + void *handle = m_mainFunction(context->target()); + + if (!handle) + ctx->setFinished(true); + + ctx->setCoroutineHandle(handle); } } void LLVMExecutableCode::kill(ExecutionContext *context) { - getContext(context)->finished = true; + LLVMExecutionContext *ctx = getContext(context); + ctx->setCoroutineHandle(nullptr); + ctx->setFinished(true); } void LLVMExecutableCode::reset(ExecutionContext *context) { - getContext(context)->finished = false; + LLVMExecutionContext *ctx = getContext(context); + ctx->setCoroutineHandle(nullptr); + ctx->setFinished(false); } bool LLVMExecutableCode::isFinished(ExecutionContext *context) const { - return getContext(context)->finished; + return getContext(context)->finished(); } void LLVMExecutableCode::promise() diff --git a/src/dev/engine/internal/llvmexecutablecode.h b/src/dev/engine/internal/llvmexecutablecode.h index ce9fd371..4094fee8 100644 --- a/src/dev/engine/internal/llvmexecutablecode.h +++ b/src/dev/engine/internal/llvmexecutablecode.h @@ -32,7 +32,8 @@ class LLVMExecutableCode : public ExecutableCode private: uint64_t lookupFunction(const std::string &name); - using MainFunctionType = size_t (*)(Target *); + using MainFunctionType = void *(*)(Target *); + using ResumeFunctionType = bool (*)(void *); static LLVMExecutionContext *getContext(ExecutionContext *context); @@ -40,6 +41,7 @@ class LLVMExecutableCode : public ExecutableCode llvm::Expected> m_jit; MainFunctionType m_mainFunction; + ResumeFunctionType m_resumeFunction; }; } // namespace libscratchcpp diff --git a/src/dev/engine/internal/llvmexecutioncontext.cpp b/src/dev/engine/internal/llvmexecutioncontext.cpp index b399541d..5934a90a 100644 --- a/src/dev/engine/internal/llvmexecutioncontext.cpp +++ b/src/dev/engine/internal/llvmexecutioncontext.cpp @@ -8,3 +8,23 @@ LLVMExecutionContext::LLVMExecutionContext(Target *target) : ExecutionContext(target) { } + +void *LLVMExecutionContext::coroutineHandle() const +{ + return m_coroutineHandle; +} + +void LLVMExecutionContext::setCoroutineHandle(void *newCoroutineHandle) +{ + m_coroutineHandle = newCoroutineHandle; +} + +bool LLVMExecutionContext::finished() const +{ + return m_finished; +} + +void LLVMExecutionContext::setFinished(bool newFinished) +{ + m_finished = newFinished; +} diff --git a/src/dev/engine/internal/llvmexecutioncontext.h b/src/dev/engine/internal/llvmexecutioncontext.h index e789aeea..c6357819 100644 --- a/src/dev/engine/internal/llvmexecutioncontext.h +++ b/src/dev/engine/internal/llvmexecutioncontext.h @@ -12,7 +12,15 @@ class LLVMExecutionContext : public ExecutionContext public: LLVMExecutionContext(Target *target); - bool finished = false; // TODO: Remove this + void *coroutineHandle() const; + void setCoroutineHandle(void *newCoroutineHandle); + + bool finished() const; + void setFinished(bool newFinished); + + private: + void *m_coroutineHandle = nullptr; + bool m_finished = false; }; } // namespace libscratchcpp diff --git a/test/dev/compiler/compiler_test.cpp b/test/dev/compiler/compiler_test.cpp index 9a461bb5..e96d2cc5 100644 --- a/test/dev/compiler/compiler_test.cpp +++ b/test/dev/compiler/compiler_test.cpp @@ -41,7 +41,7 @@ class CompilerTest : public testing::Test { ASSERT_EQ(compiler.block(), nullptr); // TODO: Test warp - EXPECT_CALL(m_builderFactory, create(block->id())).WillOnce(Return(m_builder)); + EXPECT_CALL(m_builderFactory, create(block->id(), false)).WillOnce(Return(m_builder)); EXPECT_CALL(*m_builder, finalize()).WillOnce(Return(m_code)); ASSERT_EQ(compiler.compile(block), m_code); ASSERT_EQ(compiler.block(), nullptr); diff --git a/test/dev/llvm/llvmcodebuilder_test.cpp b/test/dev/llvm/llvmcodebuilder_test.cpp index 1c2048e4..656acdeb 100644 --- a/test/dev/llvm/llvmcodebuilder_test.cpp +++ b/test/dev/llvm/llvmcodebuilder_test.cpp @@ -15,52 +15,60 @@ class LLVMCodeBuilderTest : public testing::Test public: void SetUp() override { - m_builder = std::make_unique("test"); test_function(nullptr, nullptr); // force dependency } + void createBuilder(bool warp) { m_builder = std::make_unique("test", warp); } + std::unique_ptr m_builder; TargetMock m_target; // NOTE: isStage() is used for call expectations }; TEST_F(LLVMCodeBuilderTest, FunctionCalls) { - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); + static const std::vector warpList = { false, true }; - m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); + for (bool warp : warpList) { + createBuilder(warp); + m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); - m_builder->addConstValue("1"); - m_builder->addFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }); - m_builder->addConstValue("2"); - m_builder->addConstValue("3"); - m_builder->addFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); + m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}); + m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); - m_builder->addConstValue("test"); - m_builder->addConstValue("4"); - m_builder->addConstValue("5"); - m_builder->addFunctionCall("test_function_3_args_ret", Compiler::StaticType::String, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); - auto code = m_builder->finalize(); - auto ctx = code->createExecutionContext(&m_target); + m_builder->addConstValue("1"); + m_builder->addFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }); + m_builder->addConstValue("2"); + m_builder->addConstValue("3"); + m_builder->addFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); - static const std::string expected = - "no_args\n" - "no_args_ret\n" - "1_arg no_args_output\n" - "1_arg_ret 1\n" - "3_args 1_arg_output 2 3\n" - "3_args test 4 5\n" - "1_arg 3_args_output\n"; - - EXPECT_CALL(m_target, isStage()).Times(7); - testing::internal::CaptureStdout(); - code->run(ctx.get()); - ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); + m_builder->addConstValue("test"); + m_builder->addConstValue("4"); + m_builder->addConstValue("5"); + m_builder->addFunctionCall("test_function_3_args_ret", Compiler::StaticType::String, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); + m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); + auto code = m_builder->finalize(); + auto ctx = code->createExecutionContext(&m_target); + + static const std::string expected = + "no_args\n" + "no_args_ret\n" + "1_arg no_args_output\n" + "1_arg_ret 1\n" + "3_args 1_arg_output 2 3\n" + "3_args test 4 5\n" + "1_arg 3_args_output\n"; + + EXPECT_CALL(m_target, isStage()).Times(7); + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); + } } TEST_F(LLVMCodeBuilderTest, ConstCasting) { + createBuilder(true); + m_builder->addConstValue(5.2); m_builder->addFunctionCall("test_print_number", Compiler::StaticType::Void, { Compiler::StaticType::Number }); m_builder->addConstValue("-24.156"); @@ -100,6 +108,8 @@ TEST_F(LLVMCodeBuilderTest, ConstCasting) TEST_F(LLVMCodeBuilderTest, RawValueCasting) { + createBuilder(true); + // Number -> number m_builder->addConstValue(5.2); m_builder->addFunctionCall("test_const_number", Compiler::StaticType::Number, { Compiler::StaticType::Number }); @@ -189,26 +199,33 @@ TEST_F(LLVMCodeBuilderTest, RawValueCasting) ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); } -/*TEST_F(LLVMCodeBuilderTest, Yield) +TEST_F(LLVMCodeBuilderTest, Yield) { - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); + auto build = [this]() { + m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); - m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); + m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}); + m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); - m_builder->yield(); + m_builder->yield(); - m_builder->addConstValue("1"); - m_builder->addFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }); - m_builder->addConstValue("2"); - m_builder->addConstValue(3); - m_builder->addFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); + m_builder->addConstValue("1"); + m_builder->addFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }); + m_builder->addConstValue("2"); + m_builder->addConstValue(3); + m_builder->addFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); + + m_builder->addConstValue("test"); + m_builder->addConstValue("4"); + m_builder->addConstValue("5"); + m_builder->addFunctionCall("test_function_3_args_ret", Compiler::StaticType::String, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); + m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); + }; + + // Without warp + createBuilder(false); + build(); - m_builder->addConstValue("test"); - m_builder->addConstValue("4"); - m_builder->addConstValue("5"); - m_builder->addFunctionCall("test_function_3_args_ret", Compiler::StaticType::String, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }); auto code = m_builder->finalize(); auto ctx = code->createExecutionContext(&m_target); @@ -234,10 +251,32 @@ TEST_F(LLVMCodeBuilderTest, RawValueCasting) code->run(ctx.get()); ASSERT_EQ(testing::internal::GetCapturedStdout(), expected2); ASSERT_TRUE(code->isFinished(ctx.get())); -}*/ + + // With warp + createBuilder(true); + build(); + code = m_builder->finalize(); + ctx = code->createExecutionContext(&m_target); + + static const std::string expected = + "no_args\n" + "no_args_ret\n" + "1_arg no_args_output\n" + "1_arg_ret 1\n" + "3_args 1_arg_output 2 3\n" + "3_args test 4 5\n" + "1_arg 3_args_output\n"; + + EXPECT_CALL(m_target, isStage()).Times(7); + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); +} TEST_F(LLVMCodeBuilderTest, IfStatement) { + createBuilder(true); + // Without else branch (const condition) m_builder->addConstValue("true"); m_builder->beginIfStatement(); @@ -405,8 +444,11 @@ TEST_F(LLVMCodeBuilderTest, IfStatement) ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); } +// TODO: Write a test for count rounding TEST_F(LLVMCodeBuilderTest, RepeatLoop) { + createBuilder(true); + // Const count m_builder->addConstValue("-5"); m_builder->beginRepeatLoop(); @@ -489,10 +531,39 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop) testing::internal::CaptureStdout(); code->run(ctx.get()); ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); + + // Yield + createBuilder(false); + + m_builder->addConstValue(3); + m_builder->beginRepeatLoop(); + m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); + m_builder->endLoop(); + + code = m_builder->finalize(); + ctx = code->createExecutionContext(&m_target); + + static const std::string expected1 = "no_args\n"; + + EXPECT_CALL(m_target, isStage).WillRepeatedly(Return(false)); + + for (int i = 0; i < 3; i++) { + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected1); + ASSERT_FALSE(code->isFinished(ctx.get())); + } + + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_TRUE(testing::internal::GetCapturedStdout().empty()); + ASSERT_TRUE(code->isFinished(ctx.get())); } TEST_F(LLVMCodeBuilderTest, WhileLoop) { + createBuilder(true); + // Const condition m_builder->beginLoopCondition(); m_builder->addConstValue("false"); @@ -566,10 +637,35 @@ TEST_F(LLVMCodeBuilderTest, WhileLoop) testing::internal::CaptureStdout(); code->run(ctx.get()); ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); + + // Yield + createBuilder(false); + + m_builder->beginLoopCondition(); + m_builder->addConstValue(true); + m_builder->beginWhileLoop(); + m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); + m_builder->endLoop(); + + code = m_builder->finalize(); + ctx = code->createExecutionContext(&m_target); + + static const std::string expected1 = "no_args\n"; + + EXPECT_CALL(m_target, isStage).WillRepeatedly(Return(false)); + + for (int i = 0; i < 10; i++) { + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected1); + ASSERT_FALSE(code->isFinished(ctx.get())); + } } TEST_F(LLVMCodeBuilderTest, RepeatUntilLoop) { + createBuilder(true); + // Const condition m_builder->beginLoopCondition(); m_builder->addConstValue("true"); @@ -646,4 +742,27 @@ TEST_F(LLVMCodeBuilderTest, RepeatUntilLoop) testing::internal::CaptureStdout(); code->run(ctx.get()); ASSERT_EQ(testing::internal::GetCapturedStdout(), expected); + + // Yield + createBuilder(false); + + m_builder->beginLoopCondition(); + m_builder->addConstValue(false); + m_builder->beginRepeatUntilLoop(); + m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}); + m_builder->endLoop(); + + code = m_builder->finalize(); + ctx = code->createExecutionContext(&m_target); + + static const std::string expected1 = "no_args\n"; + + EXPECT_CALL(m_target, isStage).WillRepeatedly(Return(false)); + + for (int i = 0; i < 10; i++) { + testing::internal::CaptureStdout(); + code->run(ctx.get()); + ASSERT_EQ(testing::internal::GetCapturedStdout(), expected1); + ASSERT_FALSE(code->isFinished(ctx.get())); + } } diff --git a/test/dev/llvm/llvmexecutablecode_test.cpp b/test/dev/llvm/llvmexecutablecode_test.cpp index 7c5e818b..fea3f778 100644 --- a/test/dev/llvm/llvmexecutablecode_test.cpp +++ b/test/dev/llvm/llvmexecutablecode_test.cpp @@ -25,10 +25,13 @@ class LLVMExecutableCodeTest : public testing::Test llvm::InitializeNativeTargetAsmParser(); } - llvm::Function *beginFunction() + inline llvm::Constant *nullPointer() { return llvm::ConstantPointerNull::get(llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0)); } + + llvm::Function *beginMainFunction() { - // void f(Target *) - llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder->getVoidTy(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false); + // void *f(Target *) + llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0); + llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, pointerType, false); llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get()); llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); @@ -36,7 +39,18 @@ class LLVMExecutableCodeTest : public testing::Test return func; } - void endFunction() { m_builder->CreateRetVoid(); } + llvm::Function *beginResumeFunction() + { + // bool f(void *) + llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder->getInt1Ty(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false); + llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "resume", m_module.get()); + + llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); + m_builder->SetInsertPoint(entry); + return func; + } + + void endFunction(llvm::Value *ret) { m_builder->CreateRet(ret); } void addTestFunction(llvm::Function *mainFunc) { @@ -65,8 +79,12 @@ class LLVMExecutableCodeTest : public testing::Test TEST_F(LLVMExecutableCodeTest, CreateExecutionContext) { - beginFunction(); - endFunction(); + beginMainFunction(); + endFunction(nullPointer()); + + beginResumeFunction(); + endFunction(m_builder->getInt1(true)); + LLVMExecutableCode code(std::move(m_module)); auto ctx = code.createExecutionContext(&m_target); ASSERT_TRUE(ctx); @@ -76,9 +94,12 @@ TEST_F(LLVMExecutableCodeTest, CreateExecutionContext) TEST_F(LLVMExecutableCodeTest, MainFunction) { - auto f = beginFunction(); + auto f = beginMainFunction(); addTestFunction(f); - endFunction(); + endFunction(nullPointer()); + + beginResumeFunction(); + endFunction(m_builder->getInt1(true)); LLVMExecutableCode code(std::move(m_module)); auto ctx = code.createExecutionContext(&m_target); diff --git a/test/mocks/codebuilderfactorymock.h b/test/mocks/codebuilderfactorymock.h index d8a030c5..7bf04a5b 100644 --- a/test/mocks/codebuilderfactorymock.h +++ b/test/mocks/codebuilderfactorymock.h @@ -8,5 +8,5 @@ using namespace libscratchcpp; class CodeBuilderFactoryMock : public ICodeBuilderFactory { public: - MOCK_METHOD(std::shared_ptr, create, (const std::string &), (const, override)); + MOCK_METHOD(std::shared_ptr, create, (const std::string &, bool), (const, override)); }; diff --git a/test/scratch_classes/inputvalue_test.cpp b/test/scratch_classes/inputvalue_test.cpp index e58e6b49..d26dd9d5 100644 --- a/test/scratch_classes/inputvalue_test.cpp +++ b/test/scratch_classes/inputvalue_test.cpp @@ -130,7 +130,7 @@ TEST(InputValueTest, Compile) Compiler compiler(&engine, &target); auto block = std::make_shared("", ""); - EXPECT_CALL(builderFactory, create(_)).WillOnce(Return(builder)); + EXPECT_CALL(builderFactory, create(_, _)).WillOnce(Return(builder)); EXPECT_CALL(*builder, finalize); compiler.compile(block); CompilerPrivate::builderFactory = nullptr;