Skip to content

Commit 6d9a04d

Browse files
committed
Drop multiple functions in compiled code
1 parent d94bdc7 commit 6d9a04d

File tree

9 files changed

+76
-155
lines changed

9 files changed

+76
-155
lines changed

src/dev/engine/internal/llvmcodebuilder.cpp

Lines changed: 38 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,14 @@ LLVMCodeBuilder::LLVMCodeBuilder(const std::string &id) :
3131

3232
std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
3333
{
34-
size_t functionIndex = 0;
35-
llvm::Function *currentFunc = beginFunction(functionIndex);
34+
// Create function
35+
// void f(Target *)
36+
llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder.getVoidTy(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false);
37+
llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get());
38+
39+
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func);
40+
m_builder.SetInsertPoint(entry);
41+
3642
std::vector<IfStatement> ifStatements;
3743
std::vector<Loop> loops;
3844
m_heap.clear();
@@ -45,9 +51,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
4551
std::vector<llvm::Value *> args;
4652

4753
// Add target pointer arg
48-
assert(currentFunc->arg_size() == 1);
54+
assert(func->arg_size() == 1);
4955
types.push_back(llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0));
50-
args.push_back(currentFunc->getArg(0));
56+
args.push_back(func->getArg(0));
5157

5258
// Args
5359
for (auto &arg : step.args) {
@@ -69,14 +75,13 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
6975

7076
case Step::Type::Yield:
7177
freeHeap();
72-
endFunction(currentFunc, functionIndex);
73-
currentFunc = beginFunction(++functionIndex);
78+
// TODO: Implement yielding
7479
break;
7580

7681
case Step::Type::BeginIf: {
7782
IfStatement statement;
7883
statement.beforeIf = m_builder.GetInsertBlock();
79-
statement.body = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
84+
statement.body = llvm::BasicBlock::Create(m_ctx, "", func);
8085

8186
// Use last reg
8287
assert(step.args.size() == 1);
@@ -98,13 +103,13 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
98103

99104
// Jump to the branch after the if statement
100105
assert(!statement.afterIf);
101-
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
106+
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func);
102107
freeHeap();
103108
m_builder.CreateBr(statement.afterIf);
104109

105110
// Create else branch
106111
assert(!statement.elseBranch);
107-
statement.elseBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
112+
statement.elseBranch = llvm::BasicBlock::Create(m_ctx, "", func);
108113

109114
// Since there's an else branch, the conditional instruction should jump to it
110115
m_builder.SetInsertPoint(statement.beforeIf);
@@ -121,7 +126,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
121126

122127
// Jump to the branch after the if statement
123128
if (!statement.afterIf)
124-
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
129+
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func);
125130

126131
freeHeap();
127132
m_builder.CreateBr(statement.afterIf);
@@ -150,9 +155,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
150155
m_builder.CreateStore(zero, loop.index);
151156

152157
// Create branches
153-
llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
154-
loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
155-
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
158+
llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_ctx, "", func);
159+
loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", func);
160+
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func);
156161

157162
// Use last reg for count
158163
assert(step.args.size() == 1);
@@ -177,10 +182,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
177182
// Check index
178183
m_builder.SetInsertPoint(loop.conditionBranch);
179184

180-
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
185+
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", func);
181186

182187
if (!loop.afterLoop)
183-
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
188+
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func);
184189

185190
llvm::Value *currentIndex = m_builder.CreateLoad(m_builder.getInt64Ty(), loop.index);
186191
comparison = m_builder.CreateICmpULT(currentIndex, count);
@@ -198,8 +203,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
198203
Loop &loop = loops.back();
199204

200205
// Create branches
201-
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
202-
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
206+
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", func);
207+
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func);
203208

204209
// Use last reg
205210
assert(step.args.size() == 1);
@@ -219,8 +224,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
219224
Loop &loop = loops.back();
220225

221226
// Create branches
222-
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
223-
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
227+
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", func);
228+
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", func);
224229

225230
// Use last reg
226231
assert(step.args.size() == 1);
@@ -238,7 +243,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
238243
case Step::Type::BeginLoopCondition: {
239244
Loop loop;
240245
loop.isRepeatLoop = false;
241-
loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
246+
loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", func);
242247
freeHeap();
243248
m_builder.CreateBr(loop.conditionBranch);
244249
m_builder.SetInsertPoint(loop.conditionBranch);
@@ -272,7 +277,19 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
272277

273278
freeHeap();
274279

275-
endFunction(currentFunc, functionIndex);
280+
// End and verify the function
281+
if (!m_tmpRegs.empty()) {
282+
std::cout
283+
<< "warning: " << m_tmpRegs.size() << " registers were leaked by script '" << m_module->getName().str() << "', function '" << func->getName().str()
284+
<< "' (if you see this as a regular user, this is a bug and should be reported)" << std::endl;
285+
}
286+
287+
m_builder.CreateRetVoid();
288+
289+
if (llvm::verifyFunction(*func, &llvm::errs())) {
290+
llvm::errs() << "error: LLVM function verficiation failed!\n";
291+
llvm::errs() << "script hat ID: " << m_id << "\n";
292+
}
276293

277294
#ifdef PRINT_LLVM_IR
278295
std::cout << std::endl << "=== LLVM IR (" << m_module->getName().str() << ") ===" << std::endl;
@@ -414,36 +431,6 @@ void LLVMCodeBuilder::initTypes()
414431
m_valueDataType->setBody({ unionType, valueType, sizeType });
415432
}
416433

417-
llvm::Function *LLVMCodeBuilder::beginFunction(size_t index)
418-
{
419-
// size_t f#(Target *)
420-
llvm::FunctionType *funcType = llvm::FunctionType::get(m_builder.getInt64Ty(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false);
421-
llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f" + std::to_string(index), m_module.get());
422-
423-
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func);
424-
m_builder.SetInsertPoint(entry);
425-
426-
return func;
427-
}
428-
429-
void LLVMCodeBuilder::endFunction(llvm::Function *func, size_t index)
430-
{
431-
if (!m_tmpRegs.empty()) {
432-
std::cout
433-
<< "warning: " << m_tmpRegs.size() << " registers were leaked by script '" << m_module->getName().str() << "', function '" << func->getName().str()
434-
<< "' (if you see this as a regular user, this is a bug and should be reported)" << std::endl;
435-
}
436-
437-
// Return next function index
438-
m_builder.CreateRet(m_builder.getInt64(index + 1));
439-
440-
if (llvm::verifyFunction(*func, &llvm::errs())) {
441-
llvm::errs() << "error: LLVM function verficiation failed!\n";
442-
llvm::errs() << "script hat ID: " << m_id << "\n";
443-
llvm::errs() << "function name: " << func->getName().data() << "\n";
444-
}
445-
}
446-
447434
void LLVMCodeBuilder::freeHeap()
448435
{
449436
// Free dynamically allocated memory

src/dev/engine/internal/llvmcodebuilder.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ class LLVMCodeBuilder : public ICodeBuilder
9999
};
100100

101101
void initTypes();
102-
llvm::Function *beginFunction(size_t index);
103-
void endFunction(llvm::Function *func, size_t index);
104102

105103
void freeHeap();
106104
llvm::Value *castValue(std::shared_ptr<Register> reg, Compiler::StaticType targetType);

src/dev/engine/internal/llvmexecutablecode.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,44 +31,33 @@ LLVMExecutableCode::LLVMExecutableCode(std::unique_ptr<llvm::Module> module) :
3131
}
3232

3333
// Lookup functions
34-
size_t i = 0;
35-
36-
while (true) {
37-
auto func = m_jit->get()->lookup("f" + std::to_string(i));
38-
39-
if (func)
40-
m_functions.push_back((FunctionType)(func->getValue()));
41-
else {
42-
// Ignore error
43-
llvm::consumeError(func.takeError());
44-
break;
45-
}
46-
47-
i++;
48-
}
34+
m_mainFunction = (MainFunctionType)lookupFunction("f");
35+
assert(m_mainFunction);
4936
}
5037

5138
void LLVMExecutableCode::run(ExecutionContext *context)
5239
{
5340
LLVMExecutionContext *ctx = getContext(context);
5441

55-
if (ctx->pos() < m_functions.size())
56-
ctx->setPos(m_functions[ctx->pos()](context->target()));
42+
if (!ctx->finished) {
43+
m_mainFunction(context->target());
44+
ctx->finished = true;
45+
}
5746
}
5847

5948
void LLVMExecutableCode::kill(ExecutionContext *context)
6049
{
61-
getContext(context)->setPos(m_functions.size());
50+
getContext(context)->finished = true;
6251
}
6352

6453
void LLVMExecutableCode::reset(ExecutionContext *context)
6554
{
66-
getContext(context)->setPos(0);
55+
getContext(context)->finished = false;
6756
}
6857

6958
bool LLVMExecutableCode::isFinished(ExecutionContext *context) const
7059
{
71-
return getContext(context)->pos() >= m_functions.size();
60+
return getContext(context)->finished;
7261
}
7362

7463
void LLVMExecutableCode::promise()
@@ -84,6 +73,18 @@ std::shared_ptr<ExecutionContext> LLVMExecutableCode::createExecutionContext(Tar
8473
return std::make_shared<LLVMExecutionContext>(target);
8574
}
8675

76+
uint64_t LLVMExecutableCode::lookupFunction(const std::string &name)
77+
{
78+
auto func = m_jit->get()->lookup(name);
79+
80+
if (func)
81+
return func->getValue();
82+
else {
83+
llvm::errs() << "error: failed to lookup LLVM function: " << toString(func.takeError()) << "\n";
84+
return 0;
85+
}
86+
}
87+
8788
LLVMExecutionContext *LLVMExecutableCode::getContext(ExecutionContext *context)
8889
{
8990
assert(dynamic_cast<LLVMExecutionContext *>(context));

src/dev/engine/internal/llvmexecutablecode.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,16 @@ class LLVMExecutableCode : public ExecutableCode
3030
std::shared_ptr<ExecutionContext> createExecutionContext(Target *target) const override;
3131

3232
private:
33-
using FunctionType = size_t (*)(Target *);
33+
uint64_t lookupFunction(const std::string &name);
34+
35+
using MainFunctionType = size_t (*)(Target *);
3436

3537
static LLVMExecutionContext *getContext(ExecutionContext *context);
3638

3739
std::unique_ptr<llvm::LLVMContext> m_ctx;
3840
llvm::Expected<std::unique_ptr<llvm::orc::LLJIT>> m_jit;
3941

40-
std::vector<FunctionType> m_functions;
42+
MainFunctionType m_mainFunction;
4143
};
4244

4345
} // namespace libscratchcpp

src/dev/engine/internal/llvmexecutioncontext.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,3 @@ LLVMExecutionContext::LLVMExecutionContext(Target *target) :
88
ExecutionContext(target)
99
{
1010
}
11-
12-
size_t LLVMExecutionContext::pos() const
13-
{
14-
return m_pos;
15-
}
16-
17-
void LLVMExecutionContext::setPos(size_t newPos)
18-
{
19-
m_pos = newPos;
20-
}

src/dev/engine/internal/llvmexecutioncontext.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@ class LLVMExecutionContext : public ExecutionContext
1212
public:
1313
LLVMExecutionContext(Target *target);
1414

15-
size_t pos() const;
16-
void setPos(size_t newPos);
17-
18-
private:
19-
size_t m_pos = 0;
15+
bool finished = false; // TODO: Remove this
2016
};
2117

2218
} // namespace libscratchcpp

test/dev/llvm/llvmcodebuilder_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ TEST_F(LLVMCodeBuilderTest, RawValueCasting)
189189
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected);
190190
}
191191

192-
TEST_F(LLVMCodeBuilderTest, Yield)
192+
/*TEST_F(LLVMCodeBuilderTest, Yield)
193193
{
194194
m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {});
195195
@@ -234,7 +234,7 @@ TEST_F(LLVMCodeBuilderTest, Yield)
234234
code->run(ctx.get());
235235
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected2);
236236
ASSERT_TRUE(code->isFinished(ctx.get()));
237-
}
237+
}*/
238238

239239
TEST_F(LLVMCodeBuilderTest, IfStatement)
240240
{

0 commit comments

Comments
 (0)