Skip to content

Commit 31e656c

Browse files
marcrasiMarc Rasi
authored andcommitted
expression checking for @compilerEvaluable funcs (#17256)
It's mostly just a list of allowed expressions. DeclRefExpr has interesting logic.
1 parent d0f0729 commit 31e656c

File tree

3 files changed

+522
-9
lines changed

3 files changed

+522
-9
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,6 +2541,14 @@ ERROR(compiler_evaluable_bad_context,none,
25412541
"@compilerEvaluable functions not allowed here", ())
25422542
ERROR(compiler_evaluable_loop,none,
25432543
"loops not allowed in @compilerEvaluable functions", ())
2544+
ERROR(compiler_evaluable_forbidden_expression,none,
2545+
"expression not allowed in @compilerEvaluable functions", ())
2546+
ERROR(compiler_evaluable_non_local_mutable,none,
2547+
"referencing non-local mutable variables not allowed in @compilerEvaluable functions", ())
2548+
ERROR(compiler_evaluable_forbidden_type,none,
2549+
"type %0 cannot be used in @compilerEvaluable functions", (Type))
2550+
ERROR(compiler_evaluable_ref_non_compiler_evaluable,none,
2551+
"@compilerEvaluable functions may not reference non-@compilerEvaluable functions", ())
25442552

25452553
//------------------------------------------------------------------------------
25462554
// MARK: Type Check Expressions

lib/Sema/TypeCheckCompilerEvaluable.cpp

Lines changed: 188 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,206 @@ using namespace swift;
2626

2727
namespace {
2828

29-
/// Checks that the body of a function is compiler evaluable.
30-
/// Currently a skeleton implementation that only rejects while loops.
29+
/// Checks that a type is compiler representable.
30+
/// Currently a skeleton implementation that only rejects types named Float,
31+
/// Double and String.
3132
/// TODO(marcrasi): Fill in a real implementation.
33+
static bool checkCompilerRepresentable(const Type &type) {
34+
return type.getString() != "Double" && type.getString() != "Float" &&
35+
type.getString() != "String";
36+
}
37+
38+
/// Checks that the body of a function is compiler evaluable.
3239
class CheckCompilerEvaluableBody : public ASTWalker {
3340
TypeChecker &TC;
34-
bool compilerEvaluable = true;
3541

36-
public:
37-
CheckCompilerEvaluableBody(TypeChecker &TC) : TC(TC) {}
42+
// The function whose body we are checking.
43+
const AbstractFunctionDecl *CheckingFunc;
44+
45+
// Whether the body has passed the check.
46+
bool CompilerEvaluable = true;
47+
48+
public:
49+
CheckCompilerEvaluableBody(TypeChecker &TC,
50+
const AbstractFunctionDecl *CheckingFunc)
51+
: TC(TC), CheckingFunc(CheckingFunc) {}
52+
53+
std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
54+
// If this is the ignored part of a DotSyntaxBaseIgnored, then we can accept
55+
// it without walking it.
56+
if (auto *parentDotSyntaxBaseIgnored =
57+
dyn_cast_or_null<DotSyntaxBaseIgnoredExpr>(Parent.getAsExpr()))
58+
if (parentDotSyntaxBaseIgnored->getLHS() == E)
59+
return {false, E};
60+
61+
if (!checkCompilerRepresentable(E->getType())) {
62+
TC.diagnose(E->getLoc(), diag::compiler_evaluable_forbidden_type,
63+
E->getType())
64+
.highlight(E->getSourceRange());
65+
CompilerEvaluable = false;
66+
return {false, E};
67+
}
68+
69+
switch (E->getKind()) {
70+
#define ALWAYS_ALLOWED(ID) \
71+
case ExprKind::ID: \
72+
return {true, E};
73+
#define SOMETIMES_ALLOWED(ID) \
74+
case ExprKind::ID: \
75+
return checkExpr##ID(cast<ID##Expr>(E));
76+
77+
ALWAYS_ALLOWED(NilLiteral)
78+
ALWAYS_ALLOWED(IntegerLiteral)
79+
ALWAYS_ALLOWED(BooleanLiteral)
80+
ALWAYS_ALLOWED(MagicIdentifierLiteral)
81+
ALWAYS_ALLOWED(DiscardAssignment)
82+
SOMETIMES_ALLOWED(DeclRef)
83+
ALWAYS_ALLOWED(Type)
84+
SOMETIMES_ALLOWED(OtherConstructorDeclRef)
85+
ALWAYS_ALLOWED(DotSyntaxBaseIgnored)
86+
ALWAYS_ALLOWED(MemberRef)
87+
ALWAYS_ALLOWED(Paren)
88+
ALWAYS_ALLOWED(DotSelf)
89+
ALWAYS_ALLOWED(Try)
90+
ALWAYS_ALLOWED(ForceTry)
91+
ALWAYS_ALLOWED(OptionalTry)
92+
ALWAYS_ALLOWED(Tuple)
93+
ALWAYS_ALLOWED(Subscript)
94+
ALWAYS_ALLOWED(TupleElement)
95+
ALWAYS_ALLOWED(CaptureList)
96+
ALWAYS_ALLOWED(Closure)
97+
ALWAYS_ALLOWED(AutoClosure)
98+
ALWAYS_ALLOWED(InOut)
99+
ALWAYS_ALLOWED(DynamicType)
100+
ALWAYS_ALLOWED(RebindSelfInConstructor)
101+
ALWAYS_ALLOWED(BindOptional)
102+
ALWAYS_ALLOWED(OptionalEvaluation)
103+
ALWAYS_ALLOWED(ForceValue)
104+
SOMETIMES_ALLOWED(Call)
105+
ALWAYS_ALLOWED(PrefixUnary)
106+
ALWAYS_ALLOWED(PostfixUnary)
107+
ALWAYS_ALLOWED(Binary)
108+
ALWAYS_ALLOWED(DotSyntaxCall)
109+
ALWAYS_ALLOWED(ConstructorRefCall)
110+
ALWAYS_ALLOWED(Load)
111+
ALWAYS_ALLOWED(TupleShuffle)
112+
ALWAYS_ALLOWED(InjectIntoOptional)
113+
ALWAYS_ALLOWED(Coerce)
114+
ALWAYS_ALLOWED(If)
115+
ALWAYS_ALLOWED(Assign)
116+
ALWAYS_ALLOWED(CodeCompletion)
117+
ALWAYS_ALLOWED(EditorPlaceholder)
118+
119+
// Allow all errors and unchecked expressions so that we don't put errors
120+
// on top of expressions that alrady have errors.
121+
ALWAYS_ALLOWED(Error)
122+
ALWAYS_ALLOWED(UnresolvedTypeConversion)
123+
#define UNCHECKED_EXPR(ID, PARENT) ALWAYS_ALLOWED(ID)
124+
#include "swift/AST/ExprNodes.def"
125+
126+
default:
127+
TC.diagnose(E->getStartLoc(),
128+
diag::compiler_evaluable_forbidden_expression)
129+
.highlight(E->getSourceRange());
130+
CompilerEvaluable = false;
131+
return {false, E};
132+
133+
#undef ALWAYS_ALLOWED
134+
#undef SOMETIMES_ALLOWED
135+
}
136+
}
137+
138+
std::pair<bool, Expr *> checkExprCall(CallExpr *call) {
139+
// TODO(SR-8035): Eliminate this special case.
140+
// Allow calls to some stdlib assertion functions without walking them
141+
// further, because the calls do currently-forbidden things. (They use
142+
// Strings and they call functions imported from C).
143+
if (auto *calleeRef = dyn_cast<DeclRefExpr>(call->getDirectCallee()))
144+
if (auto *callee = dyn_cast<AbstractFunctionDecl>(calleeRef->getDecl()))
145+
if (callee->isChildContextOf(TC.Context.TheStdlibModule) &&
146+
(callee->getNameStr() == "_precondition" ||
147+
callee->getNameStr() == "_preconditionFailure" ||
148+
callee->getNameStr() == "_sanityCheck" ||
149+
callee->getNameStr() == "fatalError"))
150+
return {false, call};
151+
152+
// Otherwise, walk everything in the expression.
153+
return {true, call};
154+
}
155+
156+
std::pair<bool, Expr *> checkExprDeclRef(DeclRefExpr *declRef) {
157+
auto *decl = declRef->getDeclRef().getDecl();
158+
if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
159+
// DeclRefs to immutable variables are always allowed.
160+
if (varDecl->isImmutable())
161+
return {true, declRef};
162+
163+
// DeclRefs to mutable variables are only allowed if they are declared
164+
// within the @compilerEvaluable function.
165+
if (varDecl->getDeclContext() == CheckingFunc ||
166+
varDecl->getDeclContext()->isChildContextOf(CheckingFunc))
167+
return {true, declRef};
168+
169+
TC.diagnose(declRef->getLoc(),
170+
diag::compiler_evaluable_non_local_mutable);
171+
CompilerEvaluable = false;
172+
return {false, declRef};
173+
} else if (auto *functionDecl = dyn_cast<AbstractFunctionDecl>(decl)) {
174+
return checkAbstractFunctionDeclRef(declRef, functionDecl);
175+
} else if (isa<EnumElementDecl>(decl)) {
176+
return {true, declRef};
177+
} else {
178+
TC.diagnose(declRef->getLoc(),
179+
diag::compiler_evaluable_forbidden_expression)
180+
.highlight(declRef->getSourceRange());
181+
CompilerEvaluable = false;
182+
return {false, declRef};
183+
}
184+
}
185+
186+
std::pair<bool, Expr *>
187+
checkExprOtherConstructorDeclRef(OtherConstructorDeclRefExpr *declRef) {
188+
return checkAbstractFunctionDeclRef(declRef, declRef->getDecl());
189+
}
190+
191+
std::pair<bool, Expr *>
192+
checkAbstractFunctionDeclRef(Expr *declRef, AbstractFunctionDecl *decl) {
193+
// If the function is @compilerEvaluable, allow it.
194+
if (decl->getAttrs().hasAttribute<CompilerEvaluableAttr>(
195+
/*AllowInvalid=*/true))
196+
return {true, declRef};
197+
198+
// If the function is nested within the function that we are checking, allow
199+
// it.
200+
if (decl->isChildContextOf(CheckingFunc))
201+
return {true, declRef};
202+
203+
// For now, allow all builtins.
204+
// TODO: Mark which builtins are actually compiler evaluable.
205+
if (decl->isChildContextOf(TC.Context.TheBuiltinModule))
206+
return {true, declRef};
207+
208+
// Allow all protocol methods. Later, the interpreter looks up the actual
209+
// function and emits an error when it is not @compilerEvaluable.
210+
if (isa<ProtocolDecl>(decl->getDeclContext()))
211+
return {true, declRef};
212+
213+
TC.diagnose(declRef->getLoc(),
214+
diag::compiler_evaluable_ref_non_compiler_evaluable);
215+
CompilerEvaluable = false;
216+
return {false, declRef};
217+
}
38218

39219
std::pair<bool, Stmt *> walkToStmtPre(Stmt *S) override {
40220
if (S->getKind() == StmtKind::While) {
41221
TC.diagnose(S->getStartLoc(), diag::compiler_evaluable_loop);
42-
compilerEvaluable = false;
222+
CompilerEvaluable = false;
43223
return {false, S};
44224
}
45225
return {true, S};
46226
}
47227

48-
bool getCompilerEvaluable() const { return compilerEvaluable; }
228+
bool getCompilerEvaluable() const { return CompilerEvaluable; }
49229
};
50230

51231
} // namespace
@@ -62,7 +242,7 @@ void TypeChecker::checkFunctionBodyCompilerEvaluable(AbstractFunctionDecl *D) {
62242
assert(D->getBodyKind() == AbstractFunctionDecl::BodyKind::TypeChecked &&
63243
"cannot check @compilerEvaluable body that is not type checked");
64244

65-
CheckCompilerEvaluableBody Checker(*this);
245+
CheckCompilerEvaluableBody Checker(*this, D);
66246
D->getBody()->walk(Checker);
67247
if (!Checker.getCompilerEvaluable()) {
68248
compilerEvaluableAttr->setInvalid();

0 commit comments

Comments
 (0)