Skip to content

Commit 24333c7

Browse files
committed
LLVMCodeBuilder: Implement repeat loop
1 parent ab5d461 commit 24333c7

File tree

5 files changed

+187
-1
lines changed

5 files changed

+187
-1
lines changed

src/dev/engine/internal/llvmcodebuilder.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
3939
size_t functionIndex = 0;
4040
llvm::Function *currentFunc = beginFunction(functionIndex);
4141
std::vector<IfStatement> ifStatements;
42+
std::vector<Loop> loops;
4243

4344
// Execute recorded steps
4445
for (const Step &step : m_steps) {
@@ -134,6 +135,77 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
134135
ifStatements.pop_back();
135136
break;
136137
}
138+
139+
case Step::Type::BeginRepeatLoop: {
140+
Loop loop;
141+
loop.isRepeatLoop = true;
142+
143+
// index = 0
144+
llvm::Constant *zero = llvm::ConstantInt::get(m_builder.getInt64Ty(), 0, true);
145+
loop.index = m_builder.CreateAlloca(m_builder.getInt64Ty());
146+
m_builder.CreateStore(zero, loop.index);
147+
148+
// Create branches
149+
llvm::BasicBlock *roundBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
150+
loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
151+
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
152+
153+
// Convert last reg to double
154+
assert(step.args.size() == 1);
155+
llvm::Value *count = m_builder.CreateCall(resolve_value_toDouble(), step.args[0]->value);
156+
157+
// Clamp count if <= 0 (we can skip the loop if count is not positive)
158+
llvm::Value *comparison = m_builder.CreateFCmpULE(count, llvm::ConstantFP::get(m_ctx, llvm::APFloat(0.0)));
159+
m_builder.CreateCondBr(comparison, loop.afterLoop, roundBranch);
160+
161+
// Round (Scratch-specific behavior)
162+
m_builder.SetInsertPoint(roundBranch);
163+
llvm::Function *roundFunc = llvm::Intrinsic::getDeclaration(m_module.get(), llvm::Intrinsic::round, { count->getType() });
164+
count = m_builder.CreateCall(roundFunc, { count });
165+
count = m_builder.CreateFPToSI(count, m_builder.getInt64Ty()); // cast to signed integer
166+
167+
// Jump to condition branch
168+
m_builder.CreateBr(loop.conditionBranch);
169+
170+
// Check index
171+
m_builder.SetInsertPoint(loop.conditionBranch);
172+
173+
llvm::BasicBlock *body = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
174+
175+
if (!loop.afterLoop)
176+
loop.afterLoop = llvm::BasicBlock::Create(m_ctx, "", currentFunc);
177+
178+
llvm::Value *currentIndex = m_builder.CreateLoad(m_builder.getInt64Ty(), loop.index);
179+
comparison = m_builder.CreateICmpULT(currentIndex, count);
180+
m_builder.CreateCondBr(comparison, body, loop.afterLoop);
181+
182+
// Switch to body branch
183+
m_builder.SetInsertPoint(body);
184+
185+
loops.push_back(loop);
186+
break;
187+
}
188+
189+
case Step::Type::EndLoop: {
190+
assert(!loops.empty());
191+
Loop &loop = loops.back();
192+
193+
if (loop.isRepeatLoop) {
194+
// Increment index
195+
llvm::Value *currentIndex = m_builder.CreateLoad(m_builder.getInt64Ty(), loop.index);
196+
llvm::Value *incremented = m_builder.CreateAdd(currentIndex, llvm::ConstantInt::get(m_builder.getInt64Ty(), 1, true));
197+
m_builder.CreateStore(incremented, loop.index);
198+
}
199+
200+
// Jump to the condition branch
201+
m_builder.CreateBr(loop.conditionBranch);
202+
203+
// Switch to the branch after the loop
204+
m_builder.SetInsertPoint(loop.afterLoop);
205+
206+
loops.pop_back();
207+
break;
208+
}
137209
}
138210
}
139211

@@ -222,10 +294,16 @@ void LLVMCodeBuilder::endIf()
222294

223295
void LLVMCodeBuilder::beginRepeatLoop()
224296
{
297+
Step step(Step::Type::BeginRepeatLoop);
298+
assert(!m_tmpRegs.empty());
299+
step.args.push_back(m_tmpRegs.back());
300+
m_tmpRegs.pop_back();
301+
m_steps.push_back(step);
225302
}
226303

227304
void LLVMCodeBuilder::endLoop()
228305
{
306+
m_steps.push_back(Step(Step::Type::EndLoop));
229307
}
230308

231309
void LLVMCodeBuilder::yield()
@@ -360,6 +438,11 @@ llvm::FunctionCallee LLVMCodeBuilder::resolve_value_assign_special()
360438
return resolveFunction("value_assign_special", llvm::FunctionType::get(m_builder.getVoidTy(), { m_valueDataType->getPointerTo(), m_builder.getInt32Ty() }, false));
361439
}
362440

441+
llvm::FunctionCallee LLVMCodeBuilder::resolve_value_toDouble()
442+
{
443+
return resolveFunction("value_toDouble", llvm::FunctionType::get(m_builder.getDoubleTy(), m_valueDataType->getPointerTo(), false));
444+
}
445+
363446
llvm::FunctionCallee LLVMCodeBuilder::resolve_value_toBool()
364447
{
365448
return resolveFunction("value_toBool", llvm::FunctionType::get(m_builder.getInt1Ty(), m_valueDataType->getPointerTo(), false));

src/dev/engine/internal/llvmcodebuilder.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ class LLVMCodeBuilder : public ICodeBuilder
5252
Yield,
5353
BeginIf,
5454
BeginElse,
55-
EndIf
55+
EndIf,
56+
BeginRepeatLoop,
57+
EndLoop
5658
};
5759

5860
Step(Type type) :
@@ -76,6 +78,14 @@ class LLVMCodeBuilder : public ICodeBuilder
7678
llvm::BasicBlock *afterIf = nullptr;
7779
};
7880

81+
struct Loop
82+
{
83+
bool isRepeatLoop = false;
84+
llvm::Value *index = nullptr;
85+
llvm::BasicBlock *conditionBranch = nullptr;
86+
llvm::BasicBlock *afterLoop = nullptr;
87+
};
88+
7989
void initTypes();
8090
llvm::Function *beginFunction(size_t index);
8191
void endFunction(llvm::Function *func, size_t index);
@@ -88,6 +98,7 @@ class LLVMCodeBuilder : public ICodeBuilder
8898
llvm::FunctionCallee resolve_value_assign_bool();
8999
llvm::FunctionCallee resolve_value_assign_cstring();
90100
llvm::FunctionCallee resolve_value_assign_special();
101+
llvm::FunctionCallee resolve_value_toDouble();
91102
llvm::FunctionCallee resolve_value_toBool();
92103

93104
std::string m_id;

test/dev/llvm/llvmcodebuilder_test.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,89 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
274274
code->run(ctx.get());
275275
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected);
276276
}
277+
278+
TEST_F(LLVMCodeBuilderTest, RepeatLoop)
279+
{
280+
// Const count
281+
m_builder->addConstValue("-5");
282+
m_builder->beginRepeatLoop();
283+
m_builder->addFunctionCall("test_function_no_args", 0, false);
284+
m_builder->endLoop();
285+
286+
m_builder->addConstValue(0);
287+
m_builder->beginRepeatLoop();
288+
m_builder->addFunctionCall("test_function_no_args", 0, false);
289+
m_builder->endLoop();
290+
291+
m_builder->addConstValue(3);
292+
m_builder->beginRepeatLoop();
293+
m_builder->addFunctionCall("test_function_no_args", 0, false);
294+
m_builder->endLoop();
295+
296+
m_builder->addConstValue("2");
297+
m_builder->beginRepeatLoop();
298+
m_builder->addConstValue(0);
299+
m_builder->addFunctionCall("test_function_1_arg", 1, false);
300+
m_builder->endLoop();
301+
302+
// Count returned by function
303+
m_builder->addConstValue(2);
304+
m_builder->addFunctionCall("test_const", 1, true);
305+
m_builder->beginRepeatLoop();
306+
m_builder->addFunctionCall("test_function_no_args", 0, false);
307+
m_builder->endLoop();
308+
309+
// Nested
310+
m_builder->addConstValue(2);
311+
m_builder->beginRepeatLoop();
312+
{
313+
m_builder->addConstValue(2);
314+
m_builder->beginRepeatLoop();
315+
{
316+
m_builder->addConstValue(1);
317+
m_builder->addFunctionCall("test_function_1_arg", 1, false);
318+
}
319+
m_builder->endLoop();
320+
321+
m_builder->addConstValue(2);
322+
m_builder->addFunctionCall("test_function_1_arg", 1, false);
323+
324+
m_builder->addConstValue(3);
325+
m_builder->beginRepeatLoop();
326+
{
327+
m_builder->addConstValue(3);
328+
m_builder->addFunctionCall("test_function_1_arg", 1, false);
329+
}
330+
m_builder->endLoop();
331+
}
332+
m_builder->endLoop();
333+
334+
auto code = m_builder->finalize();
335+
auto ctx = code->createExecutionContext(&m_target);
336+
337+
static const std::string expected =
338+
"no_args\n"
339+
"no_args\n"
340+
"no_args\n"
341+
"1_arg 0\n"
342+
"1_arg 0\n"
343+
"no_args\n"
344+
"no_args\n"
345+
"1_arg 1\n"
346+
"1_arg 1\n"
347+
"1_arg 2\n"
348+
"1_arg 3\n"
349+
"1_arg 3\n"
350+
"1_arg 3\n"
351+
"1_arg 1\n"
352+
"1_arg 1\n"
353+
"1_arg 2\n"
354+
"1_arg 3\n"
355+
"1_arg 3\n"
356+
"1_arg 3\n";
357+
358+
EXPECT_CALL(m_target, isStage).WillRepeatedly(Return(false));
359+
testing::internal::CaptureStdout();
360+
code->run(ctx.get());
361+
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected);
362+
}

test/dev/llvm/testfunctions.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,9 @@ extern "C"
7878
{
7979
value_assign_bool(ret, value_equals(a, b));
8080
}
81+
82+
void test_const(Target *target, ValueData *ret, ValueData *v)
83+
{
84+
value_assign_copy(ret, v);
85+
}
8186
}

test/dev/llvm/testfunctions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ extern "C"
2020
void test_function_3_args_ret(Target *target, ValueData *ret, const ValueData *arg1, const ValueData *arg2, const ValueData *arg3);
2121

2222
void test_equals(Target *target, ValueData *ret, ValueData *a, ValueData *b);
23+
void test_const(Target *target, ValueData *ret, ValueData *v);
2324
}
2425

2526
} // namespace libscratchcpp

0 commit comments

Comments
 (0)