@@ -26,26 +26,206 @@ using namespace swift;
2626
2727namespace {
2828
29- // / Checks that the body of a function is compiler evaluable.
30- // / Currently a skeleton implementation that only rejects while loops.
29+ // / Checks that a type is compiler representable.
30+ // / Currently a skeleton implementation that only rejects types named Float,
31+ // / Double and String.
3132// / TODO(marcrasi): Fill in a real implementation.
33+ static bool checkCompilerRepresentable (const Type &type) {
34+ return type.getString () != " Double" && type.getString () != " Float" &&
35+ type.getString () != " String" ;
36+ }
37+
38+ // / Checks that the body of a function is compiler evaluable.
3239class CheckCompilerEvaluableBody : public ASTWalker {
3340 TypeChecker &TC;
34- bool compilerEvaluable = true ;
3541
36- public:
37- CheckCompilerEvaluableBody (TypeChecker &TC) : TC(TC) {}
42+ // The function whose body we are checking.
43+ const AbstractFunctionDecl *CheckingFunc;
44+
45+ // Whether the body has passed the check.
46+ bool CompilerEvaluable = true ;
47+
48+ public:
49+ CheckCompilerEvaluableBody (TypeChecker &TC,
50+ const AbstractFunctionDecl *CheckingFunc)
51+ : TC(TC), CheckingFunc(CheckingFunc) {}
52+
53+ std::pair<bool , Expr *> walkToExprPre (Expr *E) override {
54+ // If this is the ignored part of a DotSyntaxBaseIgnored, then we can accept
55+ // it without walking it.
56+ if (auto *parentDotSyntaxBaseIgnored =
57+ dyn_cast_or_null<DotSyntaxBaseIgnoredExpr>(Parent.getAsExpr ()))
58+ if (parentDotSyntaxBaseIgnored->getLHS () == E)
59+ return {false , E};
60+
61+ if (!checkCompilerRepresentable (E->getType ())) {
62+ TC.diagnose (E->getLoc (), diag::compiler_evaluable_forbidden_type,
63+ E->getType ())
64+ .highlight (E->getSourceRange ());
65+ CompilerEvaluable = false ;
66+ return {false , E};
67+ }
68+
69+ switch (E->getKind ()) {
70+ #define ALWAYS_ALLOWED (ID ) \
71+ case ExprKind::ID: \
72+ return {true , E};
73+ #define SOMETIMES_ALLOWED (ID ) \
74+ case ExprKind::ID: \
75+ return checkExpr##ID (cast<ID##Expr>(E));
76+
77+ ALWAYS_ALLOWED (NilLiteral)
78+ ALWAYS_ALLOWED (IntegerLiteral)
79+ ALWAYS_ALLOWED (BooleanLiteral)
80+ ALWAYS_ALLOWED (MagicIdentifierLiteral)
81+ ALWAYS_ALLOWED (DiscardAssignment)
82+ SOMETIMES_ALLOWED (DeclRef)
83+ ALWAYS_ALLOWED (Type)
84+ SOMETIMES_ALLOWED (OtherConstructorDeclRef)
85+ ALWAYS_ALLOWED (DotSyntaxBaseIgnored)
86+ ALWAYS_ALLOWED (MemberRef)
87+ ALWAYS_ALLOWED (Paren)
88+ ALWAYS_ALLOWED (DotSelf)
89+ ALWAYS_ALLOWED (Try)
90+ ALWAYS_ALLOWED (ForceTry)
91+ ALWAYS_ALLOWED (OptionalTry)
92+ ALWAYS_ALLOWED (Tuple)
93+ ALWAYS_ALLOWED (Subscript)
94+ ALWAYS_ALLOWED (TupleElement)
95+ ALWAYS_ALLOWED (CaptureList)
96+ ALWAYS_ALLOWED (Closure)
97+ ALWAYS_ALLOWED (AutoClosure)
98+ ALWAYS_ALLOWED (InOut)
99+ ALWAYS_ALLOWED (DynamicType)
100+ ALWAYS_ALLOWED (RebindSelfInConstructor)
101+ ALWAYS_ALLOWED (BindOptional)
102+ ALWAYS_ALLOWED (OptionalEvaluation)
103+ ALWAYS_ALLOWED (ForceValue)
104+ SOMETIMES_ALLOWED (Call)
105+ ALWAYS_ALLOWED (PrefixUnary)
106+ ALWAYS_ALLOWED (PostfixUnary)
107+ ALWAYS_ALLOWED (Binary)
108+ ALWAYS_ALLOWED (DotSyntaxCall)
109+ ALWAYS_ALLOWED (ConstructorRefCall)
110+ ALWAYS_ALLOWED (Load)
111+ ALWAYS_ALLOWED (TupleShuffle)
112+ ALWAYS_ALLOWED (InjectIntoOptional)
113+ ALWAYS_ALLOWED (Coerce)
114+ ALWAYS_ALLOWED (If)
115+ ALWAYS_ALLOWED (Assign)
116+ ALWAYS_ALLOWED (CodeCompletion)
117+ ALWAYS_ALLOWED (EditorPlaceholder)
118+
119+ // Allow all errors and unchecked expressions so that we don't put errors
120+ // on top of expressions that alrady have errors.
121+ ALWAYS_ALLOWED (Error)
122+ ALWAYS_ALLOWED (UnresolvedTypeConversion)
123+ #define UNCHECKED_EXPR (ID, PARENT ) ALWAYS_ALLOWED(ID)
124+ #include " swift/AST/ExprNodes.def"
125+
126+ default :
127+ TC.diagnose (E->getStartLoc (),
128+ diag::compiler_evaluable_forbidden_expression)
129+ .highlight (E->getSourceRange ());
130+ CompilerEvaluable = false ;
131+ return {false , E};
132+
133+ #undef ALWAYS_ALLOWED
134+ #undef SOMETIMES_ALLOWED
135+ }
136+ }
137+
138+ std::pair<bool , Expr *> checkExprCall (CallExpr *call) {
139+ // TODO(SR-8035): Eliminate this special case.
140+ // Allow calls to some stdlib assertion functions without walking them
141+ // further, because the calls do currently-forbidden things. (They use
142+ // Strings and they call functions imported from C).
143+ if (auto *calleeRef = dyn_cast<DeclRefExpr>(call->getDirectCallee ()))
144+ if (auto *callee = dyn_cast<AbstractFunctionDecl>(calleeRef->getDecl ()))
145+ if (callee->isChildContextOf (TC.Context .TheStdlibModule ) &&
146+ (callee->getNameStr () == " _precondition" ||
147+ callee->getNameStr () == " _preconditionFailure" ||
148+ callee->getNameStr () == " _sanityCheck" ||
149+ callee->getNameStr () == " fatalError" ))
150+ return {false , call};
151+
152+ // Otherwise, walk everything in the expression.
153+ return {true , call};
154+ }
155+
156+ std::pair<bool , Expr *> checkExprDeclRef (DeclRefExpr *declRef) {
157+ auto *decl = declRef->getDeclRef ().getDecl ();
158+ if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
159+ // DeclRefs to immutable variables are always allowed.
160+ if (varDecl->isImmutable ())
161+ return {true , declRef};
162+
163+ // DeclRefs to mutable variables are only allowed if they are declared
164+ // within the @compilerEvaluable function.
165+ if (varDecl->getDeclContext () == CheckingFunc ||
166+ varDecl->getDeclContext ()->isChildContextOf (CheckingFunc))
167+ return {true , declRef};
168+
169+ TC.diagnose (declRef->getLoc (),
170+ diag::compiler_evaluable_non_local_mutable);
171+ CompilerEvaluable = false ;
172+ return {false , declRef};
173+ } else if (auto *functionDecl = dyn_cast<AbstractFunctionDecl>(decl)) {
174+ return checkAbstractFunctionDeclRef (declRef, functionDecl);
175+ } else if (isa<EnumElementDecl>(decl)) {
176+ return {true , declRef};
177+ } else {
178+ TC.diagnose (declRef->getLoc (),
179+ diag::compiler_evaluable_forbidden_expression)
180+ .highlight (declRef->getSourceRange ());
181+ CompilerEvaluable = false ;
182+ return {false , declRef};
183+ }
184+ }
185+
186+ std::pair<bool , Expr *>
187+ checkExprOtherConstructorDeclRef (OtherConstructorDeclRefExpr *declRef) {
188+ return checkAbstractFunctionDeclRef (declRef, declRef->getDecl ());
189+ }
190+
191+ std::pair<bool , Expr *>
192+ checkAbstractFunctionDeclRef (Expr *declRef, AbstractFunctionDecl *decl) {
193+ // If the function is @compilerEvaluable, allow it.
194+ if (decl->getAttrs ().hasAttribute <CompilerEvaluableAttr>(
195+ /* AllowInvalid=*/ true ))
196+ return {true , declRef};
197+
198+ // If the function is nested within the function that we are checking, allow
199+ // it.
200+ if (decl->isChildContextOf (CheckingFunc))
201+ return {true , declRef};
202+
203+ // For now, allow all builtins.
204+ // TODO: Mark which builtins are actually compiler evaluable.
205+ if (decl->isChildContextOf (TC.Context .TheBuiltinModule ))
206+ return {true , declRef};
207+
208+ // Allow all protocol methods. Later, the interpreter looks up the actual
209+ // function and emits an error when it is not @compilerEvaluable.
210+ if (isa<ProtocolDecl>(decl->getDeclContext ()))
211+ return {true , declRef};
212+
213+ TC.diagnose (declRef->getLoc (),
214+ diag::compiler_evaluable_ref_non_compiler_evaluable);
215+ CompilerEvaluable = false ;
216+ return {false , declRef};
217+ }
38218
39219 std::pair<bool , Stmt *> walkToStmtPre (Stmt *S) override {
40220 if (S->getKind () == StmtKind::While) {
41221 TC.diagnose (S->getStartLoc (), diag::compiler_evaluable_loop);
42- compilerEvaluable = false ;
222+ CompilerEvaluable = false ;
43223 return {false , S};
44224 }
45225 return {true , S};
46226 }
47227
48- bool getCompilerEvaluable () const { return compilerEvaluable ; }
228+ bool getCompilerEvaluable () const { return CompilerEvaluable ; }
49229};
50230
51231} // namespace
@@ -62,7 +242,7 @@ void TypeChecker::checkFunctionBodyCompilerEvaluable(AbstractFunctionDecl *D) {
62242 assert (D->getBodyKind () == AbstractFunctionDecl::BodyKind::TypeChecked &&
63243 " cannot check @compilerEvaluable body that is not type checked" );
64244
65- CheckCompilerEvaluableBody Checker (*this );
245+ CheckCompilerEvaluableBody Checker (*this , D );
66246 D->getBody ()->walk (Checker);
67247 if (!Checker.getCompilerEvaluable ()) {
68248 compilerEvaluableAttr->setInvalid ();
0 commit comments