@@ -128,6 +128,14 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
128128 break ;
129129 }
130130
131+ case Step::Type::CmpEQ: {
132+ assert (step.args .size () == 2 );
133+ const auto &arg1 = step.args [0 ].second ;
134+ const auto &arg2 = step.args [1 ].second ;
135+ step.functionReturnReg ->value = createComparison (arg1, arg2, Comparison::EQ);
136+ break ;
137+ }
138+
131139 case Step::Type::Yield:
132140 if (!m_warp) {
133141 freeHeap ();
@@ -466,6 +474,11 @@ void LLVMCodeBuilder::createDiv()
466474 createOp (Step::Type::Div, Compiler::StaticType::Number, 2 );
467475}
468476
477+ void LLVMCodeBuilder::createCmpEQ ()
478+ {
479+ createOp (Step::Type::CmpEQ, Compiler::StaticType::Bool, 2 );
480+ }
481+
469482void LLVMCodeBuilder::beginIfStatement ()
470483{
471484 Step step (Step::Type::BeginIf);
@@ -831,11 +844,15 @@ llvm::Type *LLVMCodeBuilder::getType(Compiler::StaticType type)
831844 }
832845}
833846
847+ llvm::Value *LLVMCodeBuilder::isNaN (llvm::Value *num)
848+ {
849+ return m_builder.CreateFCmpUNO (num, num);
850+ }
851+
834852llvm::Value *LLVMCodeBuilder::removeNaN (llvm::Value *num)
835853{
836854 // Replace NaN with zero
837- llvm::Value *isNaN = m_builder.CreateFCmpUNO (num, num);
838- return m_builder.CreateSelect (isNaN, llvm::ConstantFP::get (m_ctx, llvm::APFloat (0.0 )), num);
855+ return m_builder.CreateSelect (isNaN (num), llvm::ConstantFP::get (m_ctx, llvm::APFloat (0.0 )), num);
839856}
840857
841858void LLVMCodeBuilder::createOp (Step::Type type, Compiler::StaticType retType, size_t argCount)
@@ -859,6 +876,184 @@ void LLVMCodeBuilder::createOp(Step::Type type, Compiler::StaticType retType, si
859876 m_steps.push_back (step);
860877}
861878
879+ llvm::Value *LLVMCodeBuilder::createValue (std::shared_ptr<Register> reg)
880+ {
881+ if (reg->isConstValue ) {
882+ // Create a constant ValueData instance and store it
883+ llvm::Constant *value = castConstValue (reg->constValue , TYPE_MAP[reg->constValue .type ()]);
884+ llvm::Value *ret = m_builder.CreateAlloca (m_valueDataType);
885+
886+ if (reg->constValue .type () == ValueType::String)
887+ value = llvm::ConstantExpr::getPtrToInt (value, m_valueDataType->getElementType (0 ));
888+ else
889+ value = llvm::ConstantExpr::getBitCast (value, m_valueDataType->getElementType (0 ));
890+
891+ llvm::Constant *type = m_builder.getInt32 (static_cast <uint32_t >(reg->constValue .type ()));
892+ llvm::Constant *constValue = llvm::ConstantStruct::get (m_valueDataType, { value, type, m_builder.getInt64 (0 ) });
893+ m_builder.CreateStore (constValue, ret);
894+
895+ return ret;
896+ } else if (reg->isRawValue ) {
897+ llvm::Value *value = castRawValue (reg, reg->type );
898+ llvm::Value *ret = m_builder.CreateAlloca (m_valueDataType);
899+
900+ // Store value
901+ llvm::Value *valueField = m_builder.CreateStructGEP (m_valueDataType, ret, 0 );
902+ m_builder.CreateStore (value, valueField);
903+
904+ auto it = std::find_if (TYPE_MAP.begin (), TYPE_MAP.end (), [®](const std::pair<ValueType, Compiler::StaticType> &pair) { return pair.second == reg->type ; });
905+
906+ if (it == TYPE_MAP.end ()) {
907+ assert (false );
908+ return nullptr ;
909+ }
910+
911+ // Store type
912+ llvm::Value *typeField = m_builder.CreateStructGEP (m_valueDataType, ret, 1 );
913+ ValueType type = it->first ;
914+ m_builder.CreateStore (m_builder.getInt32 (static_cast <uint32_t >(type)), typeField);
915+
916+ return ret;
917+ } else
918+ return reg->value ;
919+ }
920+
921+ llvm::Value *LLVMCodeBuilder::createComparison (std::shared_ptr<Register> arg1, std::shared_ptr<Register> arg2, Comparison type)
922+ {
923+ auto type1 = arg1->type ;
924+ auto type2 = arg2->type ;
925+
926+ if (arg1->isConstValue && arg2->isConstValue ) {
927+ // If both operands are constant, perform the comparison at compile time
928+ bool result = false ;
929+
930+ switch (type) {
931+ case Comparison::EQ:
932+ result = arg1->constValue == arg2->constValue ;
933+ break ;
934+
935+ case Comparison::GT:
936+ result = arg1->constValue > arg2->constValue ;
937+ break ;
938+
939+ case Comparison::LT:
940+ result = arg1->constValue < arg2->constValue ;
941+ break ;
942+
943+ default :
944+ assert (false );
945+ return nullptr ;
946+ }
947+
948+ return m_builder.getInt1 (result);
949+ } else {
950+ // Optimize comparison of constant with number/bool
951+ if (arg1->isConstValue && arg1->constValue .isValidNumber () && (type2 == Compiler::StaticType::Number || type2 == Compiler::StaticType::Bool))
952+ type1 = Compiler::StaticType::Number;
953+
954+ if (arg2->isConstValue && arg2->constValue .isValidNumber () && (type1 == Compiler::StaticType::Number || type1 == Compiler::StaticType::Bool))
955+ type2 = Compiler::StaticType::Number;
956+
957+ // Optimize number and bool comparison
958+ if (type1 == Compiler::StaticType::Number && type2 == Compiler::StaticType::Bool)
959+ type2 = Compiler::StaticType::Number;
960+
961+ if (type1 == Compiler::StaticType::Bool && type2 == Compiler::StaticType::Number)
962+ type1 = Compiler::StaticType::Number;
963+
964+ if (type1 != type2 || type1 == Compiler::StaticType::Unknown || type2 == Compiler::StaticType::Unknown) {
965+ // If the types are different or at least one of them
966+ // is unknown, we must use value functions
967+ llvm::Value *value1 = createValue (arg1);
968+ llvm::Value *value2 = createValue (arg2);
969+
970+ switch (type) {
971+ case Comparison::EQ:
972+ return m_builder.CreateCall (resolve_value_equals (), { value1, value2 });
973+
974+ case Comparison::GT:
975+ return m_builder.CreateCall (resolve_value_greater (), { value1, value2 });
976+
977+ case Comparison::LT:
978+ return m_builder.CreateCall (resolve_value_lower (), { value1, value2 });
979+
980+ default :
981+ assert (false );
982+ return nullptr ;
983+ }
984+ } else {
985+ // Compare raw values
986+ llvm::Value *value1 = castValue (arg1, type1);
987+ llvm::Value *value2 = castValue (arg2, type2);
988+ assert (type1 == type2);
989+
990+ switch (type1) {
991+ case Compiler::StaticType::Number: {
992+ // Compare two numbers
993+ switch (type) {
994+ case Comparison::EQ: {
995+ llvm::Value *nan = m_builder.CreateAnd (isNaN (value1), isNaN (value2)); // NaN == NaN
996+ llvm::Value *cmp = m_builder.CreateFCmpOEQ (value1, value2);
997+ return m_builder.CreateSelect (nan, m_builder.getInt1 (true ), cmp);
998+ }
999+
1000+ case Comparison::GT:
1001+ return m_builder.CreateFCmpOGT (value1, value2);
1002+
1003+ case Comparison::LT:
1004+ return m_builder.CreateFCmpOLT (value1, value2);
1005+
1006+ default :
1007+ assert (false );
1008+ return nullptr ;
1009+ }
1010+ }
1011+
1012+ case Compiler::StaticType::Bool:
1013+ // Compare two booleans
1014+ switch (type) {
1015+ case Comparison::EQ:
1016+ return m_builder.CreateICmpEQ (value1, value2);
1017+
1018+ case Comparison::GT:
1019+ return m_builder.CreateICmpSGT (value1, value2);
1020+
1021+ case Comparison::LT:
1022+ return m_builder.CreateICmpSLT (value1, value2);
1023+
1024+ default :
1025+ assert (false );
1026+ return nullptr ;
1027+ }
1028+
1029+ case Compiler::StaticType::String: {
1030+ // Compare two strings
1031+ llvm::Value *cmpRet = m_builder.CreateCall (resolve_strcasecmp (), { value1, value2 });
1032+
1033+ switch (type) {
1034+ case Comparison::EQ:
1035+ return m_builder.CreateICmpEQ (cmpRet, m_builder.getInt32 (0 ));
1036+
1037+ case Comparison::GT:
1038+ return m_builder.CreateICmpSGT (cmpRet, m_builder.getInt32 (0 ));
1039+
1040+ case Comparison::LT:
1041+ return m_builder.CreateICmpSLT (cmpRet, m_builder.getInt32 (0 ));
1042+
1043+ default :
1044+ assert (false );
1045+ return nullptr ;
1046+ }
1047+ }
1048+
1049+ default :
1050+ assert (false );
1051+ return nullptr ;
1052+ }
1053+ }
1054+ }
1055+ }
1056+
8621057llvm::FunctionCallee LLVMCodeBuilder::resolveFunction (const std::string name, llvm::FunctionType *type)
8631058{
8641059 return m_module->getOrInsertFunction (name, type);
@@ -933,3 +1128,27 @@ llvm::FunctionCallee LLVMCodeBuilder::resolve_value_stringToBool()
9331128{
9341129 return resolveFunction (" value_stringToBool" , llvm::FunctionType::get (m_builder.getInt1Ty (), llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 ), false ));
9351130}
1131+
1132+ llvm::FunctionCallee LLVMCodeBuilder::resolve_value_equals ()
1133+ {
1134+ llvm::Type *valuePtr = m_valueDataType->getPointerTo ();
1135+ return resolveFunction (" value_equals" , llvm::FunctionType::get (m_builder.getInt1Ty (), { valuePtr, valuePtr }, false ));
1136+ }
1137+
1138+ llvm::FunctionCallee LLVMCodeBuilder::resolve_value_greater ()
1139+ {
1140+ llvm::Type *valuePtr = m_valueDataType->getPointerTo ();
1141+ return resolveFunction (" value_greater" , llvm::FunctionType::get (m_builder.getInt1Ty (), { valuePtr, valuePtr }, false ));
1142+ }
1143+
1144+ llvm::FunctionCallee LLVMCodeBuilder::resolve_value_lower ()
1145+ {
1146+ llvm::Type *valuePtr = m_valueDataType->getPointerTo ();
1147+ return resolveFunction (" value_lower" , llvm::FunctionType::get (m_builder.getInt1Ty (), { valuePtr, valuePtr }, false ));
1148+ }
1149+
1150+ llvm::FunctionCallee LLVMCodeBuilder::resolve_strcasecmp ()
1151+ {
1152+ llvm::Type *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 );
1153+ return resolveFunction (" strcasecmp" , llvm::FunctionType::get (m_builder.getInt32Ty (), { pointerType, pointerType }, false ));
1154+ }
0 commit comments