Skip to content

Commit 29859b5

Browse files
committed
LLVMCodeBuilder: Implement equal comparison
1 parent ff36261 commit 29859b5

File tree

5 files changed

+429
-2
lines changed

5 files changed

+429
-2
lines changed

src/dev/engine/internal/icodebuilder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ICodeBuilder
2929
virtual void createMul() = 0;
3030
virtual void createDiv() = 0;
3131

32+
virtual void createCmpEQ() = 0;
3233
virtual void beginIfStatement() = 0;
3334
virtual void beginElseBranch() = 0;
3435
virtual void endIf() = 0;

src/dev/engine/internal/llvmcodebuilder.cpp

Lines changed: 221 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
128128
break;
129129
}
130130

131+
case Step::Type::CmpEQ: {
132+
assert(step.args.size() == 2);
133+
const auto &arg1 = step.args[0].second;
134+
const auto &arg2 = step.args[1].second;
135+
step.functionReturnReg->value = createComparison(arg1, arg2, Comparison::EQ);
136+
break;
137+
}
138+
131139
case Step::Type::Yield:
132140
if (!m_warp) {
133141
freeHeap();
@@ -466,6 +474,11 @@ void LLVMCodeBuilder::createDiv()
466474
createOp(Step::Type::Div, Compiler::StaticType::Number, 2);
467475
}
468476

477+
void LLVMCodeBuilder::createCmpEQ()
478+
{
479+
createOp(Step::Type::CmpEQ, Compiler::StaticType::Bool, 2);
480+
}
481+
469482
void LLVMCodeBuilder::beginIfStatement()
470483
{
471484
Step step(Step::Type::BeginIf);
@@ -831,11 +844,15 @@ llvm::Type *LLVMCodeBuilder::getType(Compiler::StaticType type)
831844
}
832845
}
833846

847+
llvm::Value *LLVMCodeBuilder::isNaN(llvm::Value *num)
848+
{
849+
return m_builder.CreateFCmpUNO(num, num);
850+
}
851+
834852
llvm::Value *LLVMCodeBuilder::removeNaN(llvm::Value *num)
835853
{
836854
// Replace NaN with zero
837-
llvm::Value *isNaN = m_builder.CreateFCmpUNO(num, num);
838-
return m_builder.CreateSelect(isNaN, llvm::ConstantFP::get(m_ctx, llvm::APFloat(0.0)), num);
855+
return m_builder.CreateSelect(isNaN(num), llvm::ConstantFP::get(m_ctx, llvm::APFloat(0.0)), num);
839856
}
840857

841858
void LLVMCodeBuilder::createOp(Step::Type type, Compiler::StaticType retType, size_t argCount)
@@ -859,6 +876,184 @@ void LLVMCodeBuilder::createOp(Step::Type type, Compiler::StaticType retType, si
859876
m_steps.push_back(step);
860877
}
861878

879+
llvm::Value *LLVMCodeBuilder::createValue(std::shared_ptr<Register> reg)
880+
{
881+
if (reg->isConstValue) {
882+
// Create a constant ValueData instance and store it
883+
llvm::Constant *value = castConstValue(reg->constValue, TYPE_MAP[reg->constValue.type()]);
884+
llvm::Value *ret = m_builder.CreateAlloca(m_valueDataType);
885+
886+
if (reg->constValue.type() == ValueType::String)
887+
value = llvm::ConstantExpr::getPtrToInt(value, m_valueDataType->getElementType(0));
888+
else
889+
value = llvm::ConstantExpr::getBitCast(value, m_valueDataType->getElementType(0));
890+
891+
llvm::Constant *type = m_builder.getInt32(static_cast<uint32_t>(reg->constValue.type()));
892+
llvm::Constant *constValue = llvm::ConstantStruct::get(m_valueDataType, { value, type, m_builder.getInt64(0) });
893+
m_builder.CreateStore(constValue, ret);
894+
895+
return ret;
896+
} else if (reg->isRawValue) {
897+
llvm::Value *value = castRawValue(reg, reg->type);
898+
llvm::Value *ret = m_builder.CreateAlloca(m_valueDataType);
899+
900+
// Store value
901+
llvm::Value *valueField = m_builder.CreateStructGEP(m_valueDataType, ret, 0);
902+
m_builder.CreateStore(value, valueField);
903+
904+
auto it = std::find_if(TYPE_MAP.begin(), TYPE_MAP.end(), [&reg](const std::pair<ValueType, Compiler::StaticType> &pair) { return pair.second == reg->type; });
905+
906+
if (it == TYPE_MAP.end()) {
907+
assert(false);
908+
return nullptr;
909+
}
910+
911+
// Store type
912+
llvm::Value *typeField = m_builder.CreateStructGEP(m_valueDataType, ret, 1);
913+
ValueType type = it->first;
914+
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(type)), typeField);
915+
916+
return ret;
917+
} else
918+
return reg->value;
919+
}
920+
921+
llvm::Value *LLVMCodeBuilder::createComparison(std::shared_ptr<Register> arg1, std::shared_ptr<Register> arg2, Comparison type)
922+
{
923+
auto type1 = arg1->type;
924+
auto type2 = arg2->type;
925+
926+
if (arg1->isConstValue && arg2->isConstValue) {
927+
// If both operands are constant, perform the comparison at compile time
928+
bool result = false;
929+
930+
switch (type) {
931+
case Comparison::EQ:
932+
result = arg1->constValue == arg2->constValue;
933+
break;
934+
935+
case Comparison::GT:
936+
result = arg1->constValue > arg2->constValue;
937+
break;
938+
939+
case Comparison::LT:
940+
result = arg1->constValue < arg2->constValue;
941+
break;
942+
943+
default:
944+
assert(false);
945+
return nullptr;
946+
}
947+
948+
return m_builder.getInt1(result);
949+
} else {
950+
// Optimize comparison of constant with number/bool
951+
if (arg1->isConstValue && arg1->constValue.isValidNumber() && (type2 == Compiler::StaticType::Number || type2 == Compiler::StaticType::Bool))
952+
type1 = Compiler::StaticType::Number;
953+
954+
if (arg2->isConstValue && arg2->constValue.isValidNumber() && (type1 == Compiler::StaticType::Number || type1 == Compiler::StaticType::Bool))
955+
type2 = Compiler::StaticType::Number;
956+
957+
// Optimize number and bool comparison
958+
if (type1 == Compiler::StaticType::Number && type2 == Compiler::StaticType::Bool)
959+
type2 = Compiler::StaticType::Number;
960+
961+
if (type1 == Compiler::StaticType::Bool && type2 == Compiler::StaticType::Number)
962+
type1 = Compiler::StaticType::Number;
963+
964+
if (type1 != type2 || type1 == Compiler::StaticType::Unknown || type2 == Compiler::StaticType::Unknown) {
965+
// If the types are different or at least one of them
966+
// is unknown, we must use value functions
967+
llvm::Value *value1 = createValue(arg1);
968+
llvm::Value *value2 = createValue(arg2);
969+
970+
switch (type) {
971+
case Comparison::EQ:
972+
return m_builder.CreateCall(resolve_value_equals(), { value1, value2 });
973+
974+
case Comparison::GT:
975+
return m_builder.CreateCall(resolve_value_greater(), { value1, value2 });
976+
977+
case Comparison::LT:
978+
return m_builder.CreateCall(resolve_value_lower(), { value1, value2 });
979+
980+
default:
981+
assert(false);
982+
return nullptr;
983+
}
984+
} else {
985+
// Compare raw values
986+
llvm::Value *value1 = castValue(arg1, type1);
987+
llvm::Value *value2 = castValue(arg2, type2);
988+
assert(type1 == type2);
989+
990+
switch (type1) {
991+
case Compiler::StaticType::Number: {
992+
// Compare two numbers
993+
switch (type) {
994+
case Comparison::EQ: {
995+
llvm::Value *nan = m_builder.CreateAnd(isNaN(value1), isNaN(value2)); // NaN == NaN
996+
llvm::Value *cmp = m_builder.CreateFCmpOEQ(value1, value2);
997+
return m_builder.CreateSelect(nan, m_builder.getInt1(true), cmp);
998+
}
999+
1000+
case Comparison::GT:
1001+
return m_builder.CreateFCmpOGT(value1, value2);
1002+
1003+
case Comparison::LT:
1004+
return m_builder.CreateFCmpOLT(value1, value2);
1005+
1006+
default:
1007+
assert(false);
1008+
return nullptr;
1009+
}
1010+
}
1011+
1012+
case Compiler::StaticType::Bool:
1013+
// Compare two booleans
1014+
switch (type) {
1015+
case Comparison::EQ:
1016+
return m_builder.CreateICmpEQ(value1, value2);
1017+
1018+
case Comparison::GT:
1019+
return m_builder.CreateICmpSGT(value1, value2);
1020+
1021+
case Comparison::LT:
1022+
return m_builder.CreateICmpSLT(value1, value2);
1023+
1024+
default:
1025+
assert(false);
1026+
return nullptr;
1027+
}
1028+
1029+
case Compiler::StaticType::String: {
1030+
// Compare two strings
1031+
llvm::Value *cmpRet = m_builder.CreateCall(resolve_strcasecmp(), { value1, value2 });
1032+
1033+
switch (type) {
1034+
case Comparison::EQ:
1035+
return m_builder.CreateICmpEQ(cmpRet, m_builder.getInt32(0));
1036+
1037+
case Comparison::GT:
1038+
return m_builder.CreateICmpSGT(cmpRet, m_builder.getInt32(0));
1039+
1040+
case Comparison::LT:
1041+
return m_builder.CreateICmpSLT(cmpRet, m_builder.getInt32(0));
1042+
1043+
default:
1044+
assert(false);
1045+
return nullptr;
1046+
}
1047+
}
1048+
1049+
default:
1050+
assert(false);
1051+
return nullptr;
1052+
}
1053+
}
1054+
}
1055+
}
1056+
8621057
llvm::FunctionCallee LLVMCodeBuilder::resolveFunction(const std::string name, llvm::FunctionType *type)
8631058
{
8641059
return m_module->getOrInsertFunction(name, type);
@@ -933,3 +1128,27 @@ llvm::FunctionCallee LLVMCodeBuilder::resolve_value_stringToBool()
9331128
{
9341129
return resolveFunction("value_stringToBool", llvm::FunctionType::get(m_builder.getInt1Ty(), llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0), false));
9351130
}
1131+
1132+
llvm::FunctionCallee LLVMCodeBuilder::resolve_value_equals()
1133+
{
1134+
llvm::Type *valuePtr = m_valueDataType->getPointerTo();
1135+
return resolveFunction("value_equals", llvm::FunctionType::get(m_builder.getInt1Ty(), { valuePtr, valuePtr }, false));
1136+
}
1137+
1138+
llvm::FunctionCallee LLVMCodeBuilder::resolve_value_greater()
1139+
{
1140+
llvm::Type *valuePtr = m_valueDataType->getPointerTo();
1141+
return resolveFunction("value_greater", llvm::FunctionType::get(m_builder.getInt1Ty(), { valuePtr, valuePtr }, false));
1142+
}
1143+
1144+
llvm::FunctionCallee LLVMCodeBuilder::resolve_value_lower()
1145+
{
1146+
llvm::Type *valuePtr = m_valueDataType->getPointerTo();
1147+
return resolveFunction("value_lower", llvm::FunctionType::get(m_builder.getInt1Ty(), { valuePtr, valuePtr }, false));
1148+
}
1149+
1150+
llvm::FunctionCallee LLVMCodeBuilder::resolve_strcasecmp()
1151+
{
1152+
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0);
1153+
return resolveFunction("strcasecmp", llvm::FunctionType::get(m_builder.getInt32Ty(), { pointerType, pointerType }, false));
1154+
}

src/dev/engine/internal/llvmcodebuilder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class LLVMCodeBuilder : public ICodeBuilder
3131
void createMul() override;
3232
void createDiv() override;
3333

34+
void createCmpEQ() override;
3435
void beginIfStatement() override;
3536
void beginElseBranch() override;
3637
void endIf() override;
@@ -67,6 +68,7 @@ class LLVMCodeBuilder : public ICodeBuilder
6768
Sub,
6869
Mul,
6970
Div,
71+
CmpEQ,
7072
Yield,
7173
BeginIf,
7274
BeginElse,
@@ -122,6 +124,13 @@ class LLVMCodeBuilder : public ICodeBuilder
122124
bool warp = false;
123125
};
124126

127+
enum class Comparison
128+
{
129+
EQ,
130+
GT,
131+
LT
132+
};
133+
125134
void initTypes();
126135

127136
Coroutine initCoroutine(llvm::Function *func);
@@ -133,10 +142,12 @@ class LLVMCodeBuilder : public ICodeBuilder
133142
llvm::Value *castRawValue(std::shared_ptr<Register> reg, Compiler::StaticType targetType);
134143
llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType);
135144
llvm::Type *getType(Compiler::StaticType type);
145+
llvm::Value *isNaN(llvm::Value *num);
136146
llvm::Value *removeNaN(llvm::Value *num);
137147

138148
void createOp(Step::Type type, Compiler::StaticType retType, size_t argCount);
139149
llvm::Value *createValue(std::shared_ptr<Register> reg);
150+
llvm::Value *createComparison(std::shared_ptr<Register> arg1, std::shared_ptr<Register> arg2, Comparison type);
140151

141152
llvm::FunctionCallee resolveFunction(const std::string name, llvm::FunctionType *type);
142153
llvm::FunctionCallee resolve_value_init();
@@ -153,6 +164,10 @@ class LLVMCodeBuilder : public ICodeBuilder
153164
llvm::FunctionCallee resolve_value_boolToCString();
154165
llvm::FunctionCallee resolve_value_stringToDouble();
155166
llvm::FunctionCallee resolve_value_stringToBool();
167+
llvm::FunctionCallee resolve_value_equals();
168+
llvm::FunctionCallee resolve_value_greater();
169+
llvm::FunctionCallee resolve_value_lower();
170+
llvm::FunctionCallee resolve_strcasecmp();
156171

157172
std::string m_id;
158173
llvm::LLVMContext m_ctx;

0 commit comments

Comments
 (0)