Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use string objects for extending strings #199

Merged
merged 7 commits into from
Sep 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions media/test-project/os-test.spice
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@ f<int> main() {
import "std/runtime/string_rt" as _rt_str;

f<int> main() {
_rt_str::String s1 = _rt_str::String('H');
s1.append("ello");

printf("Equals: %d", s1.opEquals("Hell2"));
// Plus
printf("String: %s\n", "Hello " + "World!");
string s = "Hello " + "World!";
printf("String: %s\n", s);
// Equals
printf("String: %d\n", "Hello World!" == "Hello Programmers!");
printf("String: %d\n", "Hello" == "Hell2");
// Not equals
printf("String: %d\n", "Hello World!" != "Hello Programmers!");
printf("String: %d\n", "Hello" != "Hell2");
}
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ set(SOURCES
analyzer/OpRuleManager.h
generator/GeneratorVisitor.cpp
generator/GeneratorVisitor.h
generator/StdFunctionManager.cpp
generator/StdFunctionManager.h
generator/OpRuleConversionsManager.cpp
generator/OpRuleConversionsManager.h
linker/LinkerInterface.cpp
Expand Down
6 changes: 3 additions & 3 deletions src/analyzer/AnalyzerVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ std::any AnalyzerVisitor::visitStructDef(StructDefNode *node) {
}

// Build struct specifiers
SymbolType symbolType = SymbolType(TY_STRUCT, node->structName, templateTypes);
SymbolType symbolType = SymbolType(TY_STRUCT, node->structName, {}, templateTypes);
auto structSymbolSpecifiers = SymbolSpecifiers(symbolType);
if (SpecifierLstNode *specifierLst = node->specifierLst(); specifierLst) {
for (const auto &specifier : specifierLst->specifiers()) {
Expand Down Expand Up @@ -1324,7 +1324,7 @@ std::any AnalyzerVisitor::visitAssignExpr(AssignExprNode *node) {
"The variable '" + variableName + "' was referenced before defined");

// Perform type inference
if (lhsTy.is(TY_DYN))
if (lhsTy.is(TY_DYN) || (lhsTy.is(TY_STRING) && rhsTy.is(TY_STRING)))
currentEntry->updateType(rhsTy, false);

// Update state in symbol table
Expand Down Expand Up @@ -2296,7 +2296,7 @@ std::any AnalyzerVisitor::visitCustomDataType(CustomDataTypeNode *node) {
throw SemanticError(node->codeLoc, UNKNOWN_DATATYPE, "Unknown datatype '" + identifier + "'");
structSymbol->setUsed();

return node->setEvaluatedSymbolType(SymbolType(TY_STRUCT, identifier, concreteTemplateTypes));
return node->setEvaluatedSymbolType(SymbolType(TY_STRUCT, identifier, {.arraySize = 0}, concreteTemplateTypes));
}

void AnalyzerVisitor::insertDestructorCall(const CodeLoc &codeLoc, SymbolTableEntry *varEntry) {
Expand Down
7 changes: 7 additions & 0 deletions src/analyzer/OpRuleManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ SymbolType OpRuleManager::getAssignResultType(const CodeLoc &codeLoc, const Symb
// Allow char* = string
if (lhs.isPointerOf(TY_CHAR) && rhs.is(TY_STRING))
return lhs;
// Allow string = string_object
if (lhs.is(TY_STRING) && rhs.is(TY_STRING))
return rhs;
// Check primitive type combinations
return validateBinaryOperation(codeLoc, ASSIGN_OP_RULES, "=", lhs, rhs);
}
Expand Down Expand Up @@ -155,6 +158,10 @@ SymbolType OpRuleManager::getPlusResultType(const CodeLoc &codeLoc, const Symbol
throw printErrorMessageUnsafe(codeLoc, "+", lhs, rhs);
}

// Allow string + string
if (lhs.is(TY_STRING) && rhs.is(TY_STRING))
return SymbolType(TY_STRING, "", { .isStringStruct = true }, {});

return validateBinaryOperation(codeLoc, PLUS_OP_RULES, "+", lhs, rhs);
}

Expand Down
2 changes: 0 additions & 2 deletions src/analyzer/OpRuleManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ const std::vector<BinaryOpRule> ASSIGN_OP_RULES = {
BinaryOpRule(TY_LONG, TY_LONG, TY_LONG, false), // long = long -> long
BinaryOpRule(TY_BYTE, TY_BYTE, TY_BYTE, false), // byte = byte -> byte
BinaryOpRule(TY_CHAR, TY_CHAR, TY_CHAR, false), // char = char -> char
BinaryOpRule(TY_STRING, TY_STRING, TY_STRING, false), // string = string -> string
BinaryOpRule(TY_BOOL, TY_BOOL, TY_BOOL, false), // bool = bool -> bool
};

Expand Down Expand Up @@ -412,7 +411,6 @@ const std::vector<BinaryOpRule> PLUS_OP_RULES = {
BinaryOpRule(TY_LONG, TY_SHORT, TY_LONG, false), // long + short -> long
BinaryOpRule(TY_LONG, TY_LONG, TY_LONG, false), // long + long -> long
BinaryOpRule(TY_BYTE, TY_BYTE, TY_BYTE, false), // byte + byte -> byte
BinaryOpRule(TY_STRING, TY_STRING, TY_STRING, false), // string + string -> string
};

// Minus op rules
Expand Down
73 changes: 22 additions & 51 deletions src/generator/GeneratorVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <symbol/Function.h>
#include <symbol/Struct.h>
#include <symbol/SymbolTable.h>
#include <util/CommonUtil.h>
#include <util/FileUtil.h>
#include <util/ThreadFactory.h>

Expand Down Expand Up @@ -42,7 +41,10 @@ GeneratorVisitor::GeneratorVisitor(const std::shared_ptr<llvm::LLVMContext> &con

// Create LLVM base components
module = std::make_unique<llvm::Module>(FileUtil::getFileName(sourceFile.filePath), *context);
conversionsManager = std::make_unique<OpRuleConversionsManager>(context, builder);

// Initialize generator helper objects
stdFunctionManager = std::make_unique<StdFunctionManager>(this);
conversionsManager = std::make_unique<OpRuleConversionsManager>(this);

// Initialize LLVM
llvm::InitializeAllTargetInfos();
Expand Down Expand Up @@ -286,7 +288,7 @@ std::any GeneratorVisitor::visitMainFctDef(MainFctDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(retrieveStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -473,7 +475,7 @@ std::any GeneratorVisitor::visitFctDef(FctDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(retrieveStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -660,7 +662,7 @@ std::any GeneratorVisitor::visitProcDef(ProcDefNode *node) {

// Restore stack if necessary
if (stackState != nullptr) {
builder->CreateCall(retrieveStackRestoreFct(), {stackState});
builder->CreateCall(stdFunctionManager->getStackRestoreFct(), {stackState});
stackState = nullptr;
}

Expand Down Expand Up @@ -1300,12 +1302,12 @@ std::any GeneratorVisitor::visitAssertStmt(AssertStmtNode *node) {
parentFct->getBasicBlockList().push_back(bThen);
moveInsertPointToBlock(bThen);
// Generate IR for assertion error
llvm::Function *printfFct = retrievePrintfFct();
llvm::Function *printfFct = stdFunctionManager->getPrintfFct();
std::string errorMsg = "Assertion failed: Condition '" + node->expressionString + "' evaluated to false.";
llvm::Value *templateString = builder->CreateGlobalStringPtr(errorMsg);
builder->CreateCall(printfFct, templateString);
// Generate call to exit
llvm::Function *exitFct = retrieveExitFct();
llvm::Function *exitFct = stdFunctionManager->getExitFct();
builder->CreateCall(exitFct, builder->getInt32(1));
// Create unreachable instruction
builder->CreateUnreachable();
Expand Down Expand Up @@ -1456,7 +1458,7 @@ std::any GeneratorVisitor::visitContinueStmt(ContinueStmtNode *node) {

std::any GeneratorVisitor::visitPrintfCall(PrintfCallNode *node) {
// Declare if not declared already
llvm::Function *printfFct = retrievePrintfFct();
llvm::Function *printfFct = stdFunctionManager->getPrintfFct();

std::vector<llvm::Value *> printfArgs;
printfArgs.push_back(builder->CreateGlobalStringPtr(node->templatedString));
Expand All @@ -1471,6 +1473,9 @@ std::any GeneratorVisitor::visitPrintfCall(PrintfCallNode *node) {
if (argSymbolType.isArray()) { // Convert array type to pointer type
llvm::Value *indices[2] = {builder->getInt32(0), builder->getInt32(0)};
argVal = builder->CreateInBoundsGEP(targetType, argValPtr, indices);
} else if (argSymbolType.isStringStruct()) {
argValPtr = materializeString(argValPtr);
argVal = builder->CreateLoad(targetType, argValPtr);
} else {
argVal = builder->CreateLoad(targetType, argValPtr);
}
Expand Down Expand Up @@ -1982,6 +1987,10 @@ std::any GeneratorVisitor::visitAdditiveExpr(AdditiveExprNode *node) {

switch (opQueue.front().first) {
case AdditiveExprNode::OP_PLUS:
/*if (lhsSymbolType.isStringStruct())
lhs = materializeString(lhsPtr);
if (rhsSymbolType.isStringStruct())
rhs = materializeString(rhsPtr);*/
lhs = conversionsManager->getPlusInst(lhs, rhs, lhsSymbolType, rhsSymbolType, currentScope, node->codeLoc);
break;
case AdditiveExprNode::OP_MINUS:
Expand Down Expand Up @@ -2976,7 +2985,7 @@ llvm::Value *GeneratorVisitor::insertAlloca(llvm::Type *llvmType, const std::str

llvm::Value *GeneratorVisitor::allocateDynamicallySizedArray(llvm::Type *itemType) {
// Call llvm.stacksave intrinsic
llvm::Function *stackSaveFct = retrieveStackSaveFct();
llvm::Function *stackSaveFct = stdFunctionManager->getStackSaveFct();
if (stackState == nullptr)
stackState = builder->CreateCall(stackSaveFct);
// Allocate array
Expand Down Expand Up @@ -3043,48 +3052,10 @@ bool GeneratorVisitor::insertDestructorCall(const CodeLoc &codeLoc, SymbolTableE
return true;
}

llvm::Function *GeneratorVisitor::retrievePrintfFct() {
std::string printfFctName = "printf";
llvm::Function *printfFct = module->getFunction(printfFctName);
if (printfFct)
return printfFct;
// Not found -> declare it for linkage
llvm::FunctionType *printfFctTy = llvm::FunctionType::get(builder->getInt32Ty(), builder->getInt8PtrTy(), true);
module->getOrInsertFunction(printfFctName, printfFctTy);
return module->getFunction(printfFctName);
}

llvm::Function *GeneratorVisitor::retrieveExitFct() {
std::string exitFctName = "exit";
llvm::Function *exitFct = module->getFunction(exitFctName);
if (exitFct)
return exitFct;
// Not found -> declare it for linkage
llvm::FunctionType *exitFctTy = llvm::FunctionType::get(builder->getVoidTy(), builder->getInt32Ty(), false);
module->getOrInsertFunction(exitFctName, exitFctTy);
return module->getFunction(exitFctName);
}

llvm::Function *GeneratorVisitor::retrieveStackSaveFct() {
std::string stackSaveFctName = "llvm.stacksave";
llvm::Function *stackSaveFct = module->getFunction(stackSaveFctName);
if (stackSaveFct)
return stackSaveFct;
// Not found -> declare it for linkage
llvm::FunctionType *stackSaveFctTy = llvm::FunctionType::get(builder->getInt8PtrTy(), {}, false);
module->getOrInsertFunction(stackSaveFctName, stackSaveFctTy);
return module->getFunction(stackSaveFctName);
}

llvm::Function *GeneratorVisitor::retrieveStackRestoreFct() {
std::string stackRestoreFctName = "llvm.stackrestore";
llvm::Function *stackRestoreFct = module->getFunction(stackRestoreFctName);
if (stackRestoreFct)
return stackRestoreFct;
// Not found -> declare it for linkage
llvm::FunctionType *stackRestoreFctTy = llvm::FunctionType::get(builder->getVoidTy(), builder->getInt8PtrTy(), false);
module->getOrInsertFunction(stackRestoreFctName, stackRestoreFctTy);
return module->getFunction(stackRestoreFctName);
llvm::Value *GeneratorVisitor::materializeString(llvm::Value *stringStructPtr) {
assert(stringStructPtr->getType()->isPointerTy());
llvm::Value *rawStringValue = builder->CreateCall(stdFunctionManager->getStringRawFct(), stringStructPtr);
return rawStringValue;
}

llvm::Constant *GeneratorVisitor::getDefaultValueForSymbolType(const SymbolType &symbolType) {
Expand Down
13 changes: 8 additions & 5 deletions src/generator/GeneratorVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ast/AstNodes.h>
#include <ast/AstVisitor.h>

#include <generator/StdFunctionManager.h>
#include <generator/OpRuleConversionsManager.h>
#include <symbol/ScopePath.h>
#include <symbol/SymbolType.h>
Expand All @@ -22,7 +23,6 @@ class ThreadFactory;
class LinkerInterface;
struct CliOptions;
class LinkerInterface;
class OpRuleConversionsManager;
class SymbolTable;
class SymbolTableEntry;
class Function;
Expand All @@ -43,6 +43,10 @@ class GeneratorVisitor : public AstVisitor {
ThreadFactory &threadFactory, const LinkerInterface &linker, const CliOptions &cliOptions,
const SourceFile &sourceFile, const std::string &objectFile);

// Friend classes
friend class StdFunctionManager;
friend class OpRuleConversionsManager;

// Public methods
void optimize();
void emit();
Expand Down Expand Up @@ -101,6 +105,7 @@ class GeneratorVisitor : public AstVisitor {

private:
// Members
std::unique_ptr<StdFunctionManager> stdFunctionManager;
std::unique_ptr<OpRuleConversionsManager> conversionsManager;
const std::string &objectFile;
llvm::TargetMachine *targetMachine{};
Expand Down Expand Up @@ -163,10 +168,8 @@ class GeneratorVisitor : public AstVisitor {
llvm::Value *allocateDynamicallySizedArray(llvm::Type *itemType);
llvm::Value *createGlobalArray(llvm::Constant *constArray);
bool insertDestructorCall(const CodeLoc &codeLoc, SymbolTableEntry *varEntry);
llvm::Function *retrievePrintfFct();
llvm::Function *retrieveExitFct();
llvm::Function *retrieveStackSaveFct();
llvm::Function *retrieveStackRestoreFct();

llvm::Value *materializeString(llvm::Value *stringStruct);
llvm::Constant *getDefaultValueForSymbolType(const SymbolType &symbolType);
SymbolTableEntry *initExtGlobal(const std::string &globalName, const std::string &fqGlobalName);
llvm::Value *doImplicitCast(llvm::Value *src, llvm::Type *dstTy, SymbolType srcType);
Expand Down
37 changes: 28 additions & 9 deletions src/generator/OpRuleConversionsManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
#include <stdexcept>

#include <exception/IRError.h>
#include <generator/GeneratorVisitor.h>
#include <generator/StdFunctionManager.h>
#include <util/CodeLoc.h>

OpRuleConversionsManager::OpRuleConversionsManager(GeneratorVisitor *generator) : generator(generator) {
builder = generator->builder.get();
context = generator->context.get();
stdFunctionManager = generator->stdFunctionManager.get();
}

llvm::Value *OpRuleConversionsManager::getPlusEqualInst(llvm::Value *lhs, llvm::Value *rhs, const SymbolType &lhsSTy,
const SymbolType &rhsSTy, SymbolTable *accessScope,
const CodeLoc &codeLoc) {
Expand Down Expand Up @@ -465,9 +473,12 @@ llvm::Value *OpRuleConversionsManager::getEqualInst(llvm::Value *lhs, llvm::Valu
case COMB(TY_BYTE, TY_BYTE): // fallthrough
case COMB(TY_CHAR, TY_CHAR):
return builder->CreateICmpEQ(lhs, rhs);
case COMB(TY_STRING, TY_STRING):
// ToDo(@marcauberer): Insert call to opEquals in the runtime lib
throw IRError(codeLoc, COMING_SOON_IR, "The compiler does not support the '==' operator for lhs=string and rhs=string yet");
case COMB(TY_STRING, TY_STRING): {
// Generate call to the function isRawEqual(string, string) of the string std
llvm::Function *opFct = stdFunctionManager->getStringLitEqualsOpStringLitFct();
llvm::Value *result = builder->CreateCall(opFct, {lhs, rhs});
return result;
}
case COMB(TY_BOOL, TY_BOOL):
return builder->CreateICmpEQ(lhs, rhs);
}
Expand Down Expand Up @@ -570,9 +581,13 @@ llvm::Value *OpRuleConversionsManager::getNotEqualInst(llvm::Value *lhs, llvm::V
case COMB(TY_BYTE, TY_BYTE): // fallthrough
case COMB(TY_CHAR, TY_CHAR):
return builder->CreateICmpNE(lhs, rhs);
case COMB(TY_STRING, TY_STRING):
// ToDo(@marcauberer): Insert call to opNotEquals in the runtime lib
throw IRError(codeLoc, COMING_SOON_IR, "The compiler does not support the '!=' operator for lhs=string and rhs=string yet");
case COMB(TY_STRING, TY_STRING): {
// Generate call to the function isRawEqual(string, string) of the string std
llvm::Function *opFct = stdFunctionManager->getStringLitEqualsOpStringLitFct();
llvm::Value *result = builder->CreateCall(opFct, {lhs, rhs});
// Negate the result
return builder->CreateNot(result);
}
case COMB(TY_BOOL, TY_BOOL):
return builder->CreateICmpNE(lhs, rhs);
}
Expand Down Expand Up @@ -943,9 +958,13 @@ llvm::Value *OpRuleConversionsManager::getPlusInst(llvm::Value *lhs, llvm::Value
case COMB(TY_BYTE, TY_BYTE): // fallthrough
case COMB(TY_CHAR, TY_CHAR):
return builder->CreateAdd(lhs, rhs);
case COMB(TY_STRING, TY_STRING):
// ToDo(@marcauberer): Insert call to append in the runtime lib
throw IRError(codeLoc, COMING_SOON_IR, "The compiler does not support the '+' operator for lhs=string and rhs=string yet");
case COMB(TY_STRING, TY_STRING): {
// Generate call to the constructor ctor(string, string) of the String struct
llvm::Function *opFct = stdFunctionManager->getStringLitPlusOpStringLitFct();
llvm::Value *thisPtr = generator->insertAlloca(stdFunctionManager->getStringStructType());
builder->CreateCall(opFct, {thisPtr, lhs, rhs});
return thisPtr;
}
case COMB(TY_PTR, TY_INT): // fallthrough
case COMB(TY_PTR, TY_SHORT): // fallthrough
case COMB(TY_PTR, TY_LONG):
Expand Down
13 changes: 8 additions & 5 deletions src/generator/OpRuleConversionsManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
#include <llvm/IR/Value.h>

// Forward declarations
class ErrorFactory;
class GeneratorVisitor;
class StdFunctionManager;
struct CodeLoc;

#define COMB(en1, en2) ((en1) | ((en2) << 16))

class OpRuleConversionsManager {
public:
explicit OpRuleConversionsManager(const std::shared_ptr<llvm::LLVMContext> &context, std::shared_ptr<llvm::IRBuilder<>> builder)
: context(context), builder(std::move(builder)) {}
// Constructors
explicit OpRuleConversionsManager(GeneratorVisitor *generator);

// Public methods
llvm::Value *getPlusEqualInst(llvm::Value *lhs, llvm::Value *rhs, const SymbolType &lhsTy, const SymbolType &rhsTy,
Expand Down Expand Up @@ -64,6 +65,8 @@ class OpRuleConversionsManager {

private:
// Members
const std::shared_ptr<llvm::LLVMContext> &context;
std::shared_ptr<llvm::IRBuilder<>> builder;
GeneratorVisitor *generator;
llvm::LLVMContext *context;
llvm::IRBuilder<> *builder;
const StdFunctionManager *stdFunctionManager;
};
Loading