Skip to content

Commit

Permalink
Support the inc and dec operators for pointers in unsafe context (#514)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcauberer committed Apr 7, 2024
1 parent dec0730 commit c4c4acd
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .run/spice.run.xml
@@ -1,5 +1,5 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="spice" type="CMakeRunConfiguration" factoryName="Application" PROGRAM_PARAMS="build -O2 -d ../../media/test-project/test.spice" REDIRECT_INPUT="false" ELEVATE="false" USE_EXTERNAL_CONSOLE="false" EMULATE_TERMINAL="false" PASS_PARENT_ENVS_2="true" PROJECT_NAME="Spice" TARGET_NAME="spice" CONFIG_NAME="Debug" RUN_TARGET_PROJECT_NAME="Spice" RUN_TARGET_NAME="spice">
<configuration default="false" name="spice" type="CMakeRunConfiguration" factoryName="Application" PROGRAM_PARAMS="run -O0 -d -ir ../../media/test-project/test.spice" REDIRECT_INPUT="false" ELEVATE="false" USE_EXTERNAL_CONSOLE="false" EMULATE_TERMINAL="false" PASS_PARENT_ENVS_2="true" PROJECT_NAME="Spice" TARGET_NAME="spice" CONFIG_NAME="Debug" RUN_TARGET_PROJECT_NAME="Spice" RUN_TARGET_NAME="spice">
<envs>
<env name="LLVM_ADDITIONAL_FLAGS" value="-lole32 -lws2_32" />
<env name="LLVM_BUILD_INCLUDE_DIR" value="$PROJECT_DIR$/../llvm-project-latest/build/include" />
Expand Down
18 changes: 6 additions & 12 deletions media/test-project/test.spice
@@ -1,17 +1,11 @@
import "std/os/env";
import "bootstrap/lexer/lexer";

f<int> main() {
String filePath = getEnv("SPICE_STD_DIR") + "/../test/test-files/bootstrap-compiler/standalone-lexer-test/test-file.spice";
Lexer lexer = Lexer(filePath.getRaw());
unsigned long tokenCount = 0l;
while (!lexer.isEOF()) {
Token token = lexer.getToken();
token.print();
lexer.advance();
tokenCount++;
int[3] a = [1, 2, 3];
int* aPtr = a;
printf("%d\n", *aPtr);
unsafe {
aPtr++;
}
printf("\nLexed tokens: %d\n", tokenCount);
printf("%d\n", *aPtr);
}

/*import "bootstrap/util/block-allocator";
Expand Down
25 changes: 22 additions & 3 deletions src/irgenerator/OpRuleConversionManager.cpp
Expand Up @@ -62,7 +62,8 @@ LLVMExprResult OpRuleConversionManager::getPlusEqualInst(const ASTNode *node, LL
case COMB(TY_PTR, TY_SHORT): // fallthrough
case COMB(TY_PTR, TY_LONG): {
llvm::Type *elementTy = lhsSTy.getContainedTy().toLLVMType(context, accessScope);
return {.value = builder.CreateGEP(elementTy, lhsV(), rhsV())};
llvm::Value* rhsVExt = builder.CreateSExt(rhsV(), builder.getInt64Ty());
return {.value = builder.CreateGEP(elementTy, lhsV(), rhsVExt)};
}
default: // GCOV_EXCL_LINE
throw CompilerError(UNHANDLED_BRANCH, "Operator fallthrough: +="); // GCOV_EXCL_LINE
Expand Down Expand Up @@ -116,8 +117,10 @@ LLVMExprResult OpRuleConversionManager::getMinusEqualInst(const ASTNode *node, L
case COMB(TY_PTR, TY_INT): // fallthrough
case COMB(TY_PTR, TY_SHORT): // fallthrough
case COMB(TY_PTR, TY_LONG): {
llvm::Type *elementType = lhsSTy.getContainedTy().toLLVMType(context, accessScope);
return {.value = builder.CreateGEP(elementType, lhsV(), rhsV())};
llvm::Type *elementTy = lhsSTy.getContainedTy().toLLVMType(context, accessScope);
llvm::Value* rhsVExt = builder.CreateSExt(rhsV(), builder.getInt64Ty());
llvm::Value* rhsVNeg = builder.CreateNeg(rhsVExt);
return {.value = builder.CreateGEP(elementTy, lhsV(), rhsVNeg)};
}
default: // GCOV_EXCL_LINE
throw CompilerError(UNHANDLED_BRANCH, "Operator fallthrough: -="); // GCOV_EXCL_LINE
Expand Down Expand Up @@ -1481,6 +1484,10 @@ LLVMExprResult OpRuleConversionManager::getPrefixPlusPlusInst(const ASTNode *nod
return {.value = builder.CreateAdd(lhsV(), builder.getInt16(1), "", false, lhsSTy.isSigned())};
case TY_LONG:
return {.value = builder.CreateAdd(lhsV(), builder.getInt64(1), "", false, lhsSTy.isSigned())};
case TY_PTR: {
llvm::Type *elementTy = lhsSTy.getContainedTy().toLLVMType(context, accessScope);
return {.value = builder.CreateGEP(elementTy, lhsV(), builder.getInt64(1))};
}
default:
break;
}
Expand All @@ -1499,6 +1506,10 @@ LLVMExprResult OpRuleConversionManager::getPrefixMinusMinusInst(const ASTNode *n
return {.value = builder.CreateSub(lhsV(), builder.getInt16(1), "", false, lhsSTy.isSigned())};
case TY_LONG:
return {.value = builder.CreateSub(lhsV(), builder.getInt64(1), "", false, lhsSTy.isSigned())};
case TY_PTR: {
llvm::Type *elementTy = lhsSTy.getContainedTy().toLLVMType(context, accessScope);
return {.value = builder.CreateGEP(elementTy, lhsV(), builder.getInt64(-1))};
}
default:
break;
}
Expand Down Expand Up @@ -1552,6 +1563,10 @@ LLVMExprResult OpRuleConversionManager::getPostfixPlusPlusInst(const ASTNode *no
return {.value = builder.CreateAdd(lhsV(), builder.getInt16(1), "", false, lhsSTy.isSigned())};
case TY_LONG:
return {.value = builder.CreateAdd(lhsV(), builder.getInt64(1), "", false, lhsSTy.isSigned())};
case TY_PTR: {
llvm::Type *elementTy = lhsSTy.getContainedTy().toLLVMType(context, accessScope);
return {.value = builder.CreateGEP(elementTy, lhsV(), builder.getInt64(1))};
}
default:
break;
}
Expand All @@ -1575,6 +1590,10 @@ LLVMExprResult OpRuleConversionManager::getPostfixMinusMinusInst(const ASTNode *
return {.value = builder.CreateSub(lhsV(), builder.getInt16(1), "", false, lhsSTy.isSigned())};
case TY_LONG:
return {.value = builder.CreateSub(lhsV(), builder.getInt64(1), "", false, lhsSTy.isSigned())};
case TY_PTR: {
llvm::Type *elementTy = lhsSTy.getContainedTy().toLLVMType(context, accessScope);
return {.value = builder.CreateGEP(elementTy, lhsV(), builder.getInt64(-1))};
}
default:
break;
}
Expand Down
37 changes: 35 additions & 2 deletions src/typechecker/OpRuleManager.cpp
Expand Up @@ -512,6 +512,12 @@ SymbolType OpRuleManager::getPrefixPlusPlusResultType(const ASTNode *node, const
// Remove reference wrappers
SymbolType lhsType = lhs.type.removeReferenceWrapper();

// Check if this is an unsafe operation
if (lhsType.isPtr()) {
ensureUnsafeAllowed(node, "++", lhsType);
return lhsType;
}

return validateUnaryOperation(node, PREFIX_PLUS_PLUS_OP_RULES, ARRAY_LENGTH(PREFIX_PLUS_PLUS_OP_RULES), "++", lhsType);
}

Expand All @@ -522,6 +528,12 @@ SymbolType OpRuleManager::getPrefixMinusMinusResultType(const ASTNode *node, con
// Remove reference wrappers
SymbolType lhsType = lhs.type.removeReferenceWrapper();

// Check if this is an unsafe operation
if (lhsType.isPtr()) {
ensureUnsafeAllowed(node, "--", lhsType);
return lhsType;
}

return validateUnaryOperation(node, PREFIX_MINUS_MINUS_OP_RULES, ARRAY_LENGTH(PREFIX_MINUS_MINUS_OP_RULES), "--", lhsType);
}

Expand Down Expand Up @@ -567,8 +579,13 @@ ExprResult OpRuleManager::getPostfixPlusPlusResultType(ASTNode *node, const Expr
// Remove reference wrappers
SymbolType lhsType = lhs.type.removeReferenceWrapper();

return ExprResult(
validateUnaryOperation(node, POSTFIX_PLUS_PLUS_OP_RULES, ARRAY_LENGTH(POSTFIX_PLUS_PLUS_OP_RULES), "++", lhsType));
// Check if this is an unsafe operation
if (lhsType.isPtr()) {
ensureUnsafeAllowed(node, "++", lhsType);
return {lhs};
}

return {validateUnaryOperation(node, POSTFIX_PLUS_PLUS_OP_RULES, ARRAY_LENGTH(POSTFIX_PLUS_PLUS_OP_RULES), "++", lhsType)};
}

ExprResult OpRuleManager::getPostfixMinusMinusResultType(ASTNode *node, const ExprResult &lhs, size_t opIdx) {
Expand All @@ -583,6 +600,12 @@ ExprResult OpRuleManager::getPostfixMinusMinusResultType(ASTNode *node, const Ex
// Remove reference wrappers
SymbolType lhsType = lhs.type.removeReferenceWrapper();

// Check if this is an unsafe operation
if (lhsType.isPtr()) {
ensureUnsafeAllowed(node, "--", lhsType);
return {lhs};
}

return {validateUnaryOperation(node, POSTFIX_MINUS_MINUS_OP_RULES, ARRAY_LENGTH(POSTFIX_MINUS_MINUS_OP_RULES), "--", lhsType)};
}

Expand Down Expand Up @@ -713,6 +736,16 @@ SemanticError OpRuleManager::getExceptionBinary(const ASTNode *node, const char
return {node, OPERATOR_WRONG_DATA_TYPE, errorMsg.str()};
}

void OpRuleManager::ensureUnsafeAllowed(const ASTNode *node, const char *name, const SymbolType &lhs) const {
if (typeChecker->currentScope->doesAllowUnsafeOperations())
return;
// Print error message
const std::string lhsName = lhs.getName(true);
const std::string errorMsg = "Cannot apply '" + std::string(name) + "' operator on type " + lhsName +
" as this is an unsafe operation. Please use unsafe blocks if you know what you are doing.";
SOFT_ERROR_VOID(node, UNSAFE_OPERATION_IN_SAFE_CONTEXT, errorMsg)
}

void OpRuleManager::ensureUnsafeAllowed(const ASTNode *node, const char *name, const SymbolType &lhs,
const SymbolType &rhs) const {
if (typeChecker->currentScope->doesAllowUnsafeOperations())
Expand Down
5 changes: 3 additions & 2 deletions src/typechecker/OpRuleManager.h
Expand Up @@ -648,8 +648,8 @@ class OpRuleManager {
ExprResult getDivResultType(ASTNode *node, const ExprResult &lhs, const ExprResult &rhs, size_t opIdx);
static ExprResult getRemResultType(const ASTNode *node, const ExprResult &lhs, const ExprResult &rhs);
static SymbolType getPrefixMinusResultType(const ASTNode *node, const ExprResult &lhs);
static SymbolType getPrefixPlusPlusResultType(const ASTNode *node, const ExprResult &lhs);
static SymbolType getPrefixMinusMinusResultType(const ASTNode *node, const ExprResult &lhs);
SymbolType getPrefixPlusPlusResultType(const ASTNode *node, const ExprResult &lhs);
SymbolType getPrefixMinusMinusResultType(const ASTNode *node, const ExprResult &lhs);
static SymbolType getPrefixNotResultType(const ASTNode *node, const ExprResult &lhs);
static SymbolType getPrefixBitwiseNotResultType(const ASTNode *node, const ExprResult &lhs);
static SymbolType getPrefixMulResultType(const ASTNode *node, const ExprResult &lhs);
Expand All @@ -676,6 +676,7 @@ class OpRuleManager {
static SemanticError getExceptionUnary(const ASTNode *node, const char *name, const SymbolType &lhs);
static SemanticError getExceptionBinary(const ASTNode *node, const char *name, const SymbolType &lhs, const SymbolType &rhs,
const char *messagePrefix);
void ensureUnsafeAllowed(const ASTNode *node, const char *name, const SymbolType &lhs) const;
void ensureUnsafeAllowed(const ASTNode *node, const char *name, const SymbolType &lhs, const SymbolType &rhs) const;
static void ensureNoConstAssign(const ASTNode *node, const SymbolType &lhs);
};
Expand Down
4 changes: 2 additions & 2 deletions src/typechecker/TypeChecker.cpp
Expand Up @@ -1250,7 +1250,7 @@ std::any TypeChecker::visitPrefixUnaryExpr(PrefixUnaryExprNode *node) {
operandType = OpRuleManager::getPrefixMinusResultType(node, operand);
break;
case PrefixUnaryExprNode::OP_PLUS_PLUS:
operandType = OpRuleManager::getPrefixPlusPlusResultType(node, operand);
operandType = opRuleManager.getPrefixPlusPlusResultType(node, operand);

if (operandEntry) {
// In case the lhs is captured, notify the capture about the write access
Expand All @@ -1263,7 +1263,7 @@ std::any TypeChecker::visitPrefixUnaryExpr(PrefixUnaryExprNode *node) {

break;
case PrefixUnaryExprNode::OP_MINUS_MINUS:
operandType = OpRuleManager::getPrefixMinusMinusResultType(node, operand);
operandType = opRuleManager.getPrefixMinusMinusResultType(node, operand);

if (operandEntry) {
// In case the lhs is captured, notify the capture about the write access
Expand Down
Expand Up @@ -14,7 +14,7 @@ Find the GDB manual and other documentation resources online at:
For help, type "help".
Type "apropos word" to search for commands related to "word"...
Reading symbols from ./source...
Haltepunkt 1: file source.spice, line 42.
Breakpoint 1: file source.spice, line 42.
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".

Expand All @@ -27,4 +27,4 @@ pair = {first = 2, second = }
$1 = {contents = , capacity = 5, size = 5}
$2 = 5
All assertions passed!
[Inferior 1 (process 685913) exited normally]
[Inferior 1 (process 75215) exited normally]
Expand Up @@ -496,7 +496,7 @@ attributes #3 = { cold noreturn nounwind }
!0 = !DIGlobalVariableExpression(var: !1, expr: !DIExpression())
!1 = distinct !DIGlobalVariable(name: "printf.str.0", linkageName: "printf.str.0", scope: !2, file: !5, line: 68, type: !6, isLocal: true, isDefinition: true)
!2 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !3, producer: "spice version dev (https://github.com/spicelang/spice)", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, globals: !4, splitDebugInlining: false, nameTableKind: None)
!3 = !DIFile(filename: "/home/marc/Dokumente/Dev/spice/cmake-build-release/test/./test-files/irgenerator/debug-info/success-dbg-info-complex/source.spice", directory: "./test-files/irgenerator/debug-info/success-dbg-info-complex")
!3 = !DIFile(filename: "/home/marc/Dokumente/Dev/spice/cmake-build-debug/test/./test-files/irgenerator/debug-info/success-dbg-info-complex/source.spice", directory: "./test-files/irgenerator/debug-info/success-dbg-info-complex")
!4 = !{!0}
!5 = !DIFile(filename: "source.spice", directory: "./test-files/irgenerator/debug-info/success-dbg-info-complex")
!6 = !DIStringType(name: "printf.str.0", size: 192)
Expand Down
Expand Up @@ -92,7 +92,7 @@ attributes #3 = { nofree nounwind }
!0 = !DIGlobalVariableExpression(var: !1, expr: !DIExpression())
!1 = distinct !DIGlobalVariable(name: "anon.string.0", linkageName: "anon.string.0", scope: !2, file: !7, line: 8, type: !15, isLocal: true, isDefinition: true)
!2 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !3, producer: "spice version dev (https://github.com/spicelang/spice)", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, globals: !4, splitDebugInlining: false, nameTableKind: None)
!3 = !DIFile(filename: "/home/marc/Dokumente/Dev/spice/cmake-build-release/test/./test-files/irgenerator/debug-info/success-dbg-info-simple/source.spice", directory: "./test-files/irgenerator/debug-info/success-dbg-info-simple")
!3 = !DIFile(filename: "/home/marc/Dokumente/Dev/spice/cmake-build-debug/test/./test-files/irgenerator/debug-info/success-dbg-info-simple/source.spice", directory: "./test-files/irgenerator/debug-info/success-dbg-info-simple")
!4 = !{!0, !5, !9, !12}
!5 = !DIGlobalVariableExpression(var: !6, expr: !DIExpression())
!6 = distinct !DIGlobalVariable(name: "printf.str.0", linkageName: "printf.str.0", scope: !2, file: !7, line: 15, type: !8, isLocal: true, isDefinition: true)
Expand Down
@@ -0,0 +1 @@
All assertions passed!
@@ -0,0 +1,103 @@
; ModuleID = 'source.spice'
source_filename = "source.spice"

@anon.array.0 = private unnamed_addr constant [3 x i32] [i32 1, i32 2, i32 3]
@anon.string.0 = private unnamed_addr constant [61 x i8] c"Assertion failed: Condition '*aPtr == 1' evaluated to false.\00", align 1
@anon.string.1 = private unnamed_addr constant [61 x i8] c"Assertion failed: Condition '*aPtr == 2' evaluated to false.\00", align 1
@anon.string.2 = private unnamed_addr constant [61 x i8] c"Assertion failed: Condition '*aPtr == 1' evaluated to false.\00", align 1
@anon.string.3 = private unnamed_addr constant [61 x i8] c"Assertion failed: Condition '*aPtr == 3' evaluated to false.\00", align 1
@anon.string.4 = private unnamed_addr constant [61 x i8] c"Assertion failed: Condition '*aPtr == 1' evaluated to false.\00", align 1
@printf.str.0 = private unnamed_addr constant [23 x i8] c"All assertions passed!\00", align 1

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() #0 {
%result = alloca i32, align 4
%a = alloca [3 x i32], align 4
%aPtr = alloca ptr, align 8
store i32 0, ptr %result, align 4
store [3 x i32] [i32 1, i32 2, i32 3], ptr %a, align 4
%1 = getelementptr inbounds [3 x i32], ptr %a, i32 0, i32 0
store ptr %1, ptr %aPtr, align 8
%2 = load ptr, ptr %aPtr, align 8
%3 = load i32, ptr %2, align 4
%4 = icmp eq i32 %3, 1
br i1 %4, label %assert.exit.L4, label %assert.then.L4, !prof !0

assert.then.L4: ; preds = %0
%5 = call i32 (ptr, ...) @printf(ptr @anon.string.0)
call void @exit(i32 1)
unreachable

assert.exit.L4: ; preds = %0
%6 = load ptr, ptr %aPtr, align 8
%7 = getelementptr i32, ptr %6, i64 1
store ptr %7, ptr %aPtr, align 8
%8 = load ptr, ptr %aPtr, align 8
%9 = load i32, ptr %8, align 4
%10 = icmp eq i32 %9, 2
br i1 %10, label %assert.exit.L6, label %assert.then.L6, !prof !0

assert.then.L6: ; preds = %assert.exit.L4
%11 = call i32 (ptr, ...) @printf(ptr @anon.string.1)
call void @exit(i32 1)
unreachable

assert.exit.L6: ; preds = %assert.exit.L4
%12 = load ptr, ptr %aPtr, align 8
%13 = getelementptr i32, ptr %12, i64 -1
store ptr %13, ptr %aPtr, align 8
%14 = load ptr, ptr %aPtr, align 8
%15 = load i32, ptr %14, align 4
%16 = icmp eq i32 %15, 1
br i1 %16, label %assert.exit.L8, label %assert.then.L8, !prof !0

assert.then.L8: ; preds = %assert.exit.L6
%17 = call i32 (ptr, ...) @printf(ptr @anon.string.2)
call void @exit(i32 1)
unreachable

assert.exit.L8: ; preds = %assert.exit.L6
%18 = load ptr, ptr %aPtr, align 8
%19 = getelementptr i32, ptr %18, i64 2
store ptr %19, ptr %aPtr, align 8
%20 = load ptr, ptr %aPtr, align 8
%21 = load i32, ptr %20, align 4
%22 = icmp eq i32 %21, 3
br i1 %22, label %assert.exit.L10, label %assert.then.L10, !prof !0

assert.then.L10: ; preds = %assert.exit.L8
%23 = call i32 (ptr, ...) @printf(ptr @anon.string.3)
call void @exit(i32 1)
unreachable

assert.exit.L10: ; preds = %assert.exit.L8
%24 = load ptr, ptr %aPtr, align 8
%25 = getelementptr i32, ptr %24, i64 -2
store ptr %25, ptr %aPtr, align 8
%26 = load ptr, ptr %aPtr, align 8
%27 = load i32, ptr %26, align 4
%28 = icmp eq i32 %27, 1
br i1 %28, label %assert.exit.L12, label %assert.then.L12, !prof !0

assert.then.L12: ; preds = %assert.exit.L10
%29 = call i32 (ptr, ...) @printf(ptr @anon.string.4)
call void @exit(i32 1)
unreachable

assert.exit.L12: ; preds = %assert.exit.L10
%30 = call i32 (ptr, ...) @printf(ptr noundef @printf.str.0)
%31 = load i32, ptr %result, align 4
ret i32 %31
}

; Function Attrs: nofree nounwind
declare noundef i32 @printf(ptr nocapture noundef readonly, ...) #1

; Function Attrs: cold noreturn nounwind
declare void @exit(i32) #2

attributes #0 = { noinline nounwind optnone uwtable }
attributes #1 = { nofree nounwind }
attributes #2 = { cold noreturn nounwind }

!0 = !{!"branch_weights", i32 2000, i32 1}
@@ -0,0 +1,14 @@
f<int> main() {
int[3] a = [1, 2, 3];
int* aPtr = &a[0];
assert *aPtr == 1;
unsafe { aPtr++; }
assert *aPtr == 2;
unsafe { aPtr--; }
assert *aPtr == 1;
unsafe { aPtr += 2; }
assert *aPtr == 3;
unsafe { aPtr -= 2; }
assert *aPtr == 1;
printf("All assertions passed!");
}

0 comments on commit c4c4acd

Please sign in to comment.