From bfa5c83f2a728e4b1022e61dd53257404043ce57 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Thu, 7 Jun 2018 19:42:18 -0700 Subject: [PATCH] expression checking for @compilerEvaluable funcs --- include/swift/AST/DiagnosticsSema.def | 8 + lib/Sema/TypeCheckCompilerEvaluable.cpp | 196 +++++++++++++- test/Sema/compiler_evaluable.swift | 327 +++++++++++++++++++++++- 3 files changed, 522 insertions(+), 9 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index bea059a687e8b..0f42727977290 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2529,6 +2529,14 @@ ERROR(compiler_evaluable_bad_context,none, "@compilerEvaluable functions not allowed here", ()) ERROR(compiler_evaluable_loop,none, "loops not allowed in @compilerEvaluable functions", ()) +ERROR(compiler_evaluable_forbidden_expression,none, + "expression not allowed in @compilerEvaluable functions", ()) +ERROR(compiler_evaluable_non_local_mutable,none, + "referencing non-local mutable variables not allowed in @compilerEvaluable functions", ()) +ERROR(compiler_evaluable_forbidden_type,none, + "type %0 cannot be used in @compilerEvaluable functions", (Type)) +ERROR(compiler_evaluable_ref_non_compiler_evaluable,none, + "@compilerEvaluable functions may not reference non-@compilerEvaluable functions", ()) //------------------------------------------------------------------------------ // Type Check Expressions diff --git a/lib/Sema/TypeCheckCompilerEvaluable.cpp b/lib/Sema/TypeCheckCompilerEvaluable.cpp index 4c7c51dcc7a6e..3a5257674855d 100644 --- a/lib/Sema/TypeCheckCompilerEvaluable.cpp +++ b/lib/Sema/TypeCheckCompilerEvaluable.cpp @@ -26,26 +26,206 @@ using namespace swift; namespace { -/// Checks that the body of a function is compiler evaluable. -/// Currently a skeleton implementation that only rejects while loops. +/// Checks that a type is compiler representable. +/// Currently a skeleton implementation that only rejects types named Float, +/// Double and String. /// TODO(marcrasi): Fill in a real implementation. +static bool checkCompilerRepresentable(const Type &type) { + return type.getString() != "Double" && type.getString() != "Float" && + type.getString() != "String"; +} + +/// Checks that the body of a function is compiler evaluable. class CheckCompilerEvaluableBody : public ASTWalker { TypeChecker &TC; - bool compilerEvaluable = true; - public: - CheckCompilerEvaluableBody(TypeChecker &TC) : TC(TC) {} + // The function whose body we are checking. + const AbstractFunctionDecl *CheckingFunc; + + // Whether the body has passed the check. + bool CompilerEvaluable = true; + +public: + CheckCompilerEvaluableBody(TypeChecker &TC, + const AbstractFunctionDecl *CheckingFunc) + : TC(TC), CheckingFunc(CheckingFunc) {} + + std::pair walkToExprPre(Expr *E) override { + // If this is the ignored part of a DotSyntaxBaseIgnored, then we can accept + // it without walking it. + if (auto *parentDotSyntaxBaseIgnored = + dyn_cast_or_null(Parent.getAsExpr())) + if (parentDotSyntaxBaseIgnored->getLHS() == E) + return {false, E}; + + if (!checkCompilerRepresentable(E->getType())) { + TC.diagnose(E->getLoc(), diag::compiler_evaluable_forbidden_type, + E->getType()) + .highlight(E->getSourceRange()); + CompilerEvaluable = false; + return {false, E}; + } + + switch (E->getKind()) { + #define ALWAYS_ALLOWED(ID) \ + case ExprKind::ID: \ + return {true, E}; + #define SOMETIMES_ALLOWED(ID) \ + case ExprKind::ID: \ + return checkExpr##ID(cast(E)); + + ALWAYS_ALLOWED(NilLiteral) + ALWAYS_ALLOWED(IntegerLiteral) + ALWAYS_ALLOWED(BooleanLiteral) + ALWAYS_ALLOWED(MagicIdentifierLiteral) + ALWAYS_ALLOWED(DiscardAssignment) + SOMETIMES_ALLOWED(DeclRef) + ALWAYS_ALLOWED(Type) + SOMETIMES_ALLOWED(OtherConstructorDeclRef) + ALWAYS_ALLOWED(DotSyntaxBaseIgnored) + ALWAYS_ALLOWED(MemberRef) + ALWAYS_ALLOWED(Paren) + ALWAYS_ALLOWED(DotSelf) + ALWAYS_ALLOWED(Try) + ALWAYS_ALLOWED(ForceTry) + ALWAYS_ALLOWED(OptionalTry) + ALWAYS_ALLOWED(Tuple) + ALWAYS_ALLOWED(Subscript) + ALWAYS_ALLOWED(TupleElement) + ALWAYS_ALLOWED(CaptureList) + ALWAYS_ALLOWED(Closure) + ALWAYS_ALLOWED(AutoClosure) + ALWAYS_ALLOWED(InOut) + ALWAYS_ALLOWED(DynamicType) + ALWAYS_ALLOWED(RebindSelfInConstructor) + ALWAYS_ALLOWED(BindOptional) + ALWAYS_ALLOWED(OptionalEvaluation) + ALWAYS_ALLOWED(ForceValue) + SOMETIMES_ALLOWED(Call) + ALWAYS_ALLOWED(PrefixUnary) + ALWAYS_ALLOWED(PostfixUnary) + ALWAYS_ALLOWED(Binary) + ALWAYS_ALLOWED(DotSyntaxCall) + ALWAYS_ALLOWED(ConstructorRefCall) + ALWAYS_ALLOWED(Load) + ALWAYS_ALLOWED(TupleShuffle) + ALWAYS_ALLOWED(InjectIntoOptional) + ALWAYS_ALLOWED(Coerce) + ALWAYS_ALLOWED(If) + ALWAYS_ALLOWED(Assign) + ALWAYS_ALLOWED(CodeCompletion) + ALWAYS_ALLOWED(EditorPlaceholder) + + // Allow all errors and unchecked expressions so that we don't put errors + // on top of expressions that alrady have errors. + ALWAYS_ALLOWED(Error) + ALWAYS_ALLOWED(UnresolvedTypeConversion) + #define UNCHECKED_EXPR(ID, PARENT) ALWAYS_ALLOWED(ID) + #include "swift/AST/ExprNodes.def" + + default: + TC.diagnose(E->getStartLoc(), + diag::compiler_evaluable_forbidden_expression) + .highlight(E->getSourceRange()); + CompilerEvaluable = false; + return {false, E}; + + #undef ALWAYS_ALLOWED + #undef SOMETIMES_ALLOWED + } + } + + std::pair checkExprCall(CallExpr *call) { + // TODO(SR-8035): Eliminate this special case. + // Allow calls to some stdlib assertion functions without walking them + // further, because the calls do currently-forbidden things. (They use + // Strings and they call functions imported from C). + if (auto *calleeRef = dyn_cast(call->getDirectCallee())) + if (auto *callee = dyn_cast(calleeRef->getDecl())) + if (callee->isChildContextOf(TC.Context.TheStdlibModule) && + (callee->getNameStr() == "_precondition" || + callee->getNameStr() == "_preconditionFailure" || + callee->getNameStr() == "_sanityCheck" || + callee->getNameStr() == "fatalError")) + return {false, call}; + + // Otherwise, walk everything in the expression. + return {true, call}; + } + + std::pair checkExprDeclRef(DeclRefExpr *declRef) { + auto *decl = declRef->getDeclRef().getDecl(); + if (auto *varDecl = dyn_cast(decl)) { + // DeclRefs to immutable variables are always allowed. + if (varDecl->isImmutable()) + return {true, declRef}; + + // DeclRefs to mutable variables are only allowed if they are declared + // within the @compilerEvaluable function. + if (varDecl->getDeclContext() == CheckingFunc || + varDecl->getDeclContext()->isChildContextOf(CheckingFunc)) + return {true, declRef}; + + TC.diagnose(declRef->getLoc(), + diag::compiler_evaluable_non_local_mutable); + CompilerEvaluable = false; + return {false, declRef}; + } else if (auto *functionDecl = dyn_cast(decl)) { + return checkAbstractFunctionDeclRef(declRef, functionDecl); + } else if (isa(decl)) { + return {true, declRef}; + } else { + TC.diagnose(declRef->getLoc(), + diag::compiler_evaluable_forbidden_expression) + .highlight(declRef->getSourceRange()); + CompilerEvaluable = false; + return {false, declRef}; + } + } + + std::pair + checkExprOtherConstructorDeclRef(OtherConstructorDeclRefExpr *declRef) { + return checkAbstractFunctionDeclRef(declRef, declRef->getDecl()); + } + + std::pair + checkAbstractFunctionDeclRef(Expr *declRef, AbstractFunctionDecl *decl) { + // If the function is @compilerEvaluable, allow it. + if (decl->getAttrs().hasAttribute( + /*AllowInvalid=*/true)) + return {true, declRef}; + + // If the function is nested within the function that we are checking, allow + // it. + if (decl->isChildContextOf(CheckingFunc)) + return {true, declRef}; + + // For now, allow all builtins. + // TODO: Mark which builtins are actually compiler evaluable. + if (decl->isChildContextOf(TC.Context.TheBuiltinModule)) + return {true, declRef}; + + // Allow all protocol methods. Later, the interpreter looks up the actual + // function and emits an error when it is not @compilerEvaluable. + if (isa(decl->getDeclContext())) + return {true, declRef}; + + TC.diagnose(declRef->getLoc(), + diag::compiler_evaluable_ref_non_compiler_evaluable); + CompilerEvaluable = false; + return {false, declRef}; + } std::pair walkToStmtPre(Stmt *S) override { if (S->getKind() == StmtKind::While) { TC.diagnose(S->getStartLoc(), diag::compiler_evaluable_loop); - compilerEvaluable = false; + CompilerEvaluable = false; return {false, S}; } return {true, S}; } - bool getCompilerEvaluable() const { return compilerEvaluable; } + bool getCompilerEvaluable() const { return CompilerEvaluable; } }; } // namespace @@ -62,7 +242,7 @@ void TypeChecker::checkFunctionBodyCompilerEvaluable(AbstractFunctionDecl *D) { assert(D->getBodyKind() == AbstractFunctionDecl::BodyKind::TypeChecked && "cannot check @compilerEvaluable body that is not type checked"); - CheckCompilerEvaluableBody Checker(*this); + CheckCompilerEvaluableBody Checker(*this, D); D->getBody()->walk(Checker); if (!Checker.getCompilerEvaluable()) { compilerEvaluableAttr->setInvalid(); diff --git a/test/Sema/compiler_evaluable.swift b/test/Sema/compiler_evaluable.swift index 6a40940ae4460..8258117d91b5b 100644 --- a/test/Sema/compiler_evaluable.swift +++ b/test/Sema/compiler_evaluable.swift @@ -1,4 +1,8 @@ -// RUN: %target-typecheck-verify-swift +// RUN: %target-typecheck-verify-swift -module-name ThisModule + +// ---------------------------------------------------------------------------- +// Test what contexts @compilerEvaluable is allowed in. +// ---------------------------------------------------------------------------- protocol AProtocol { @compilerEvaluable @@ -106,6 +110,327 @@ func aGenericFunction(t: T) { @compilerEvaluable func funcTopLevel() {} +// ---------------------------------------------------------------------------- +// Test the AST expression checker. +// ---------------------------------------------------------------------------- + +// ---------------------------------------------------------------------------- +// Helper decls for expression tests. +// ---------------------------------------------------------------------------- + +let globalLet = 1 +var globalVar = 1 +let globalStringLet = "global string" +var globalStringVar = "global string" + +@compilerEvaluable +func compilerEvaluable() {} + +func nonCompilerEvaluable() {} + +@compilerEvaluable +func genericCompilerEvaluable(t: T) {} + +func genericNonCompilerEvaluable(t: T) {} + +protocol ProtocolWithAFunction { + func protocolFunction() -> Int +} + +@compilerEvaluable +func autocloses(_ x: @autoclosure () -> Int) {} + +struct SimpleStruct { + let field: Int + + @compilerEvaluable + init() { + field = 1 + } + + @compilerEvaluable + static postfix func ... (x: SimpleStruct) -> SimpleStruct { + return x + } + + @compilerEvaluable + func foo() -> Int { + return 1 + } +} + +struct SubscriptableStruct { + @compilerEvaluable + init() {} + + @compilerEvaluable + subscript(x: Int) -> Int { return x } +} + +enum SimpleEnum { + case case1 + case case2 +} + +enum PayloadEnum { + case case1(x: Int) + case case2(x: Int) +} + +@compilerEvaluable +func throwingFunction() throws {} + +@compilerEvaluable +func mutate(_ x: inout Int) { + x = 3 +} + +// ---------------------------------------------------------------------------- +// Actual expression tests. +// ---------------------------------------------------------------------------- + +@compilerEvaluable +func literals() { + // NilLiteral + let _: Int? = nil + + // IntegerLiteral + let _ = 1 + + // FloatLiteral + let _ = 1.0 // expected-error{{type 'Double' cannot be used in @compilerEvaluable functions}} + + // BooleanLiteral + let _ = true + + // StringLiteral + let _ = "hello world" // expected-error{{type 'String' cannot be used in @compilerEvaluable functions}} + + // InterpolatedStringLiteral + let _ = "hello world \(1)" // expected-error{{type 'String' cannot be used in @compilerEvaluable functions}} + + // MagicIdentifierLiteral + let _ = #line +} + +@compilerEvaluable +func declRef(arg1: Int, arg2: inout Int) { + let _ = arg1 + let _ = arg2 + arg2 = 1 + + let x = 1 + let _ = x + + var y = 1 + let _ = y + y = 2 + + let _ = globalLet + let _ = globalVar // expected-error{{referencing non-local mutable variables not allowed in @compilerEvaluable functions}} + globalVar = 2 // expected-error{{referencing non-local mutable variables not allowed in @compilerEvaluable functions}} + + let _ = globalStringLet // expected-error{{type 'String' cannot be used in @compilerEvaluable functions}} + let _ = globalStringVar // expected-error{{type 'String' cannot be used in @compilerEvaluable functions}} + + let _ = compilerEvaluable + let _ = nonCompilerEvaluable // expected-error{{@compilerEvaluable functions may not reference non-@compilerEvaluable functions}} + + compilerEvaluable() + nonCompilerEvaluable() // expected-error{{@compilerEvaluable functions may not reference non-@compilerEvaluable functions}} + + let _: SimpleEnum = .case1 + let _: PayloadEnum = .case1(x: 2) + + func inner() { + let _ = arg1 + let _ = arg2 + arg2 = 1 + + let _ = x + let _ = y + + let _ = globalLet + let _ = globalVar // expected-error{{referencing non-local mutable variables not allowed in @compilerEvaluable functions}} + } + + inner() +} + +func declRef2(outerArg1: Int, outerArg2: inout Int) { + let outerLet = 1 + var outerVar = 1 + + @compilerEvaluable + func inner() { + let _ = outerArg1 + let _ = outerArg2 // expected-error{{referencing non-local mutable variables not allowed in @compilerEvaluable functions}} + outerArg2 = 2 // expected-error{{referencing non-local mutable variables not allowed in @compilerEvaluable functions}} + + let _ = outerLet + let _ = outerVar // expected-error{{referencing non-local mutable variables not allowed in @compilerEvaluable functions}} + outerVar = 2 // expected-error{{referencing non-local mutable variables not allowed in @compilerEvaluable functions}} + } +} + +struct DeclRefInStruct { + let field: Int + + @compilerEvaluable + init() { + self.init(field: 1) + } + + @compilerEvaluable + init(field: Int) { + self.field = field + } + + @compilerEvaluable + init(b: Bool) { + self.init(field: 1, b: b) // expected-error{{@compilerEvaluable functions may not reference non-@compilerEvaluable functions}} + } + + init(field: Int, b: Bool) { + self.init(field: 1) + } + + @compilerEvaluable + func method() -> Int { + return field + } +} + +enum DeclRefInEnum { + case thing1 + case thing2 + + @compilerEvaluable + init() { + self = .thing1 + } + + @compilerEvaluable + func value() -> Int { + switch(self) { + case .thing1: + return 5 + case .thing2: + return 10 + } + } +} + +@compilerEvaluable +func declRefToProtocolFunction(t: T) -> Int { + return t.protocolFunction() +} + +@compilerEvaluable +func closures() { + let x = 2 + let closure = { [x] (i: Int) -> (Int, Int) in + return (x, i) + } + let _ = closure(2) + + autocloses(1) + + var mutable = 3 + func mutate() { + mutable += 1 + } + mutate() +} + +@compilerEvaluable +func dynamicTypeExpr(t: T) { + let _ = type(of: t) +} + +@compilerEvaluable +func miscAllowedExpressions() throws { + // DiscardAssignment + _ = 1 + + // DotSyntaxBaseIgnored + let _ = ThisModule.funcTopLevel + + // MemberRef + let _ = SimpleStruct().field + + // Try, ForceTry, and OptionalTry + let _ = try throwingFunction() + let _ = try! throwingFunction() + let _ = try? throwingFunction() + + // Paren + let _ = (1) + + // DotSelf + let _ = AStruct.self + + // Tuple + let _ = (1, 2) + let _ = (a: 1, 2) + + // Subscript + let _ = SubscriptableStruct()[1] + + // TupleElement + let _ = (1, 2).0 + + // InOut + var x = 1 + mutate(&x) + + // InjectIntoOptional, BindOptional, OptionalEvaluation, and ForceValue + let opt: SimpleStruct? = SimpleStruct() + let _ = opt?.field + let _ = opt! + + // ConstructorRefCall, Call, PrefixUnary, PostfixUnary, Binary, and + // DotSyntaxCall + let simpleStruct = SimpleStruct() + let _ = compilerEvaluable() + let _ = !true + let _ = simpleStruct... + let _ = 1 + 1 + let _ = simpleStruct.foo() + + // Load + let _ = x + + // TupleShuffle + let _: (y: Int, x: Int) = (x: 1, y: 2) + + // Coerce + let _ = 2 as Int + + // If + let _ = true ? 1 : 2 + + // Assign + let _ = (x = 2) +} + +@compilerEvaluable +func miscForbiddenExpressions() { + // Array + let _ = [1] // expected-error{{expression not allowed in @compilerEvaluable functions}} + + // Dictionary + let _ = ["a": 1] // expected-error{{expression not allowed in @compilerEvaluable functions}} + + // KeyPath and KeyPathApplication + let keyPath = \SimpleStruct.field // expected-error{{expression not allowed in @compilerEvaluable functions}} + let _ = SimpleStruct()[keyPath: keyPath] // expected-error{{expression not allowed in @compilerEvaluable functions}} +} + +// ---------------------------------------------------------------------------- +// Test the AST statement checker. +// ---------------------------------------------------------------------------- + @compilerEvaluable func funcWithLoop() -> Int { var x = 1