Skip to content
5 changes: 5 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ class RecursiveTypeProperties {
Bits &= ~HasDependentMember;
}

/// Remove the IsUnsafe property from this set.
void removeIsUnsafe() {
Bits &= ~IsUnsafe;
}

/// Test for a particular property in this set.
bool operator&(Property prop) const {
return Bits & prop;
Expand Down
66 changes: 60 additions & 6 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3543,6 +3543,12 @@ TypeAliasType *TypeAliasType::get(TypeAliasDecl *typealias, Type parent,
auto &ctx = underlying->getASTContext();
auto arena = getArena(properties);

// Typealiases can't meaningfully be unsafe; it's the underlying type that
// matters.
properties.removeIsUnsafe();
if (underlying->isUnsafe())
properties |= RecursiveTypeProperties::IsUnsafe;

// Profile the type.
llvm::FoldingSetNodeID id;
TypeAliasType::Profile(id, typealias, parent, genericArgs, underlying);
Expand Down Expand Up @@ -4190,6 +4196,54 @@ void UnboundGenericType::Profile(llvm::FoldingSetNodeID &ID,
ID.AddPointer(Parent.getPointer());
}

/// The safety of a parent type does not have an impact on a nested type within
/// it. This produces the recursive properties of a given type that should
/// be propagated to a nested type, which won't include any "IsUnsafe" bit
/// determined based on the declaration itself.
static RecursiveTypeProperties getRecursivePropertiesAsParent(Type type) {
if (!type)
return RecursiveTypeProperties();

// We only need to do anything interesting at all for unsafe types.
auto properties = type->getRecursiveProperties();
if (!properties.isUnsafe())
return properties;

if (auto nominal = type->getAnyNominal()) {
// If the nominal wasn't itself unsafe, then we got the unsafety from
// something else (e.g., a generic argument), so it won't change.
if (nominal->getExplicitSafety() != ExplicitSafety::Unsafe)
return properties;
}

// Drop the "unsafe" bit. We have to recompute it without considering the
// enclosing nominal type.
properties.removeIsUnsafe();

// Check generic arguments of parent types.
while (type) {
// Merge from the generic arguments.
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
for (auto genericArg : boundGeneric->getGenericArgs())
properties |= genericArg->getRecursiveProperties();
}

if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>()) {
type = nominalOrBound->getParent();
continue;
}

if (auto unbound = type->getAs<UnboundGenericType>()) {
type = unbound->getParent();
continue;
}

break;
};

return properties;
}

UnboundGenericType *UnboundGenericType::
get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
llvm::FoldingSetNodeID ID;
Expand All @@ -4198,7 +4252,7 @@ get(GenericTypeDecl *TheDecl, Type Parent, const ASTContext &C) {
RecursiveTypeProperties properties;
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
properties |= RecursiveTypeProperties::IsUnsafe;
if (Parent) properties |= Parent->getRecursiveProperties();
properties |= getRecursivePropertiesAsParent(Parent);

auto arena = getArena(properties);

Expand Down Expand Up @@ -4252,7 +4306,7 @@ BoundGenericType *BoundGenericType::get(NominalTypeDecl *TheDecl,
RecursiveTypeProperties properties;
if (TheDecl->getExplicitSafety() == ExplicitSafety::Unsafe)
properties |= RecursiveTypeProperties::IsUnsafe;
if (Parent) properties |= Parent->getRecursiveProperties();
properties |= getRecursivePropertiesAsParent(Parent);
for (Type Arg : GenericArgs) {
properties |= Arg->getRecursiveProperties();
}
Expand Down Expand Up @@ -4335,7 +4389,7 @@ EnumType *EnumType::get(EnumDecl *D, Type Parent, const ASTContext &C) {
RecursiveTypeProperties properties;
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
properties |= RecursiveTypeProperties::IsUnsafe;
if (Parent) properties |= Parent->getRecursiveProperties();
properties |= getRecursivePropertiesAsParent(Parent);
auto arena = getArena(properties);

auto *&known = C.getImpl().getArena(arena).EnumTypes[{D, Parent}];
Expand All @@ -4353,7 +4407,7 @@ StructType *StructType::get(StructDecl *D, Type Parent, const ASTContext &C) {
RecursiveTypeProperties properties;
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
properties |= RecursiveTypeProperties::IsUnsafe;
if (Parent) properties |= Parent->getRecursiveProperties();
properties |= getRecursivePropertiesAsParent(Parent);
auto arena = getArena(properties);

auto *&known = C.getImpl().getArena(arena).StructTypes[{D, Parent}];
Expand All @@ -4371,7 +4425,7 @@ ClassType *ClassType::get(ClassDecl *D, Type Parent, const ASTContext &C) {
RecursiveTypeProperties properties;
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
properties |= RecursiveTypeProperties::IsUnsafe;
if (Parent) properties |= Parent->getRecursiveProperties();
properties |= getRecursivePropertiesAsParent(Parent);
auto arena = getArena(properties);

auto *&known = C.getImpl().getArena(arena).ClassTypes[{D, Parent}];
Expand Down Expand Up @@ -5538,7 +5592,7 @@ ProtocolType *ProtocolType::get(ProtocolDecl *D, Type Parent,
RecursiveTypeProperties properties;
if (D->getExplicitSafety() == ExplicitSafety::Unsafe)
properties |= RecursiveTypeProperties::IsUnsafe;
if (Parent) properties |= Parent->getRecursiveProperties();
properties |= getRecursivePropertiesAsParent(Parent);
auto arena = getArena(properties);

auto *&known = C.getImpl().getArena(arena).ProtocolTypes[{D, Parent}];
Expand Down
14 changes: 8 additions & 6 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1212,12 +1212,14 @@ ExplicitSafety Decl::getExplicitSafety() const {
ExplicitSafety::Unspecified);
}

// Inference: Check the enclosing context.
if (auto enclosingDC = getDeclContext()) {
// Is this an extension with @safe or @unsafe on it?
if (auto ext = dyn_cast<ExtensionDecl>(enclosingDC)) {
if (auto extSafety = getExplicitSafetyFromAttrs(ext))
return *extSafety;
// Inference: Check the enclosing context, unless this is a type.
if (!isa<TypeDecl>(this)) {
if (auto enclosingDC = getDeclContext()) {
// Is this an extension with @safe or @unsafe on it?
if (auto ext = dyn_cast<ExtensionDecl>(enclosingDC)) {
if (auto extSafety = getExplicitSafetyFromAttrs(ext))
return *extSafety;
}
}
}

Expand Down
25 changes: 25 additions & 0 deletions lib/Sema/TypeCheckEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,8 @@ class EffectsHandlingWalker : public ASTWalker {
recurse = asImpl().checkForEach(forEach);
} else if (auto labeled = dyn_cast<LabeledConditionalStmt>(S)) {
asImpl().noteLabeledConditionalStmt(labeled);
} else if (auto defer = dyn_cast<DeferStmt>(S)) {
recurse = asImpl().checkDefer(defer);
}

if (!recurse)
Expand Down Expand Up @@ -2110,6 +2112,10 @@ class ApplyClassifier {
return ShouldRecurse;
}

ShouldRecurse_t checkDefer(DeferStmt *S) {
return ShouldNotRecurse;
}

ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
return ShouldRecurse;
}
Expand Down Expand Up @@ -2255,6 +2261,10 @@ class ApplyClassifier {
return ShouldRecurse;
}

ShouldRecurse_t checkDefer(DeferStmt *S) {
return ShouldNotRecurse;
}

ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
return ShouldRecurse;
}
Expand Down Expand Up @@ -2354,6 +2364,10 @@ class ApplyClassifier {
return ShouldNotRecurse;
}

ShouldRecurse_t checkDefer(DeferStmt *S) {
return ShouldNotRecurse;
}

ShouldRecurse_t checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
return ShouldRecurse;
}
Expand Down Expand Up @@ -4398,6 +4412,17 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
return ShouldRecurse;
}

ShouldRecurse_t checkDefer(DeferStmt *S) {
// Pretend we're in an 'unsafe'.
ContextScope scope(*this, std::nullopt);
scope.enterUnsafe(S->getDeferLoc());

// Walk the call expression. We don't care about the rest.
S->getCallExpr()->walk(*this);

return ShouldNotRecurse;
}

void diagnoseRedundantTry(AnyTryExpr *E) const {
if (auto *SVE = SingleValueStmtExpr::tryDigOutSingleValueStmtExpr(E)) {
// For an if/switch expression, produce a tailored warning.
Expand Down
79 changes: 68 additions & 11 deletions lib/Sema/TypeCheckUnsafe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,12 @@ bool swift::enumerateUnsafeUses(ArrayRef<ProtocolConformanceRef> conformances,
bool swift::enumerateUnsafeUses(SubstitutionMap subs,
SourceLoc loc,
llvm::function_ref<bool(UnsafeUse)> fn) {
// FIXME: Check replacement types?
// Replacement types.
for (auto replacementType : subs.getReplacementTypes()) {
if (replacementType->isUnsafe() &&
fn(UnsafeUse::forReferenceToUnsafe(nullptr, false, replacementType, loc)))
return true;
}

// Check conformances.
if (enumerateUnsafeUses(subs.getConformances(), loc, fn))
Expand Down Expand Up @@ -375,21 +380,73 @@ void swift::diagnoseUnsafeType(ASTContext &ctx, SourceLoc loc, Type type,
if (!ctx.LangOpts.hasFeature(Feature::StrictMemorySafety))
return;

if (!type->isUnsafe() && !type->getCanonicalType()->isUnsafe())
if (!type->isUnsafe())
return;

// Look for a specific @unsafe nominal type.
Type specificType;
type.findIf([&specificType](Type type) {
if (auto typeDecl = type->getAnyNominal()) {
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
specificType = type;
return false;
// Look for a specific @unsafe nominal type along the way.
class Walker : public TypeWalker {
public:
Type specificType;

Action walkToTypePre(Type type) override {
if (specificType)
return Action::Stop;

// If this refers to a nominal type that is @unsafe, store that.
if (auto typeDecl = type->getAnyNominal()) {
if (typeDecl->getExplicitSafety() == ExplicitSafety::Unsafe) {
specificType = type;
return Action::Stop;
}
}

// Do not recurse into nominal types, because we do not want to visit
// their "parent" types.
if (isa<NominalOrBoundGenericNominalType>(type.getPointer()) ||
isa<UnboundGenericType>(type.getPointer())) {
// Recurse into the generic arguments. This operation is recursive,
// because we also need to see the generic arguments of parent types.
walkGenericArguments(type);

return Action::SkipNode;
}

return Action::Continue;
}

private:
/// Recursively walk the generic arguments of this type and its parent
/// types.
void walkGenericArguments(Type type) {
if (!type)
return;

// Walk the generic arguments.
if (auto boundGeneric = type->getAs<BoundGenericType>()) {
for (auto genericArg : boundGeneric->getGenericArgs())
genericArg.walk(*this);
}

if (auto nominalOrBound = type->getAs<NominalOrBoundGenericNominalType>())
return walkGenericArguments(nominalOrBound->getParent());

if (auto unbound = type->getAs<UnboundGenericType>())
return walkGenericArguments(unbound->getParent());
}
};

return false;
});
// Look for a canonical unsafe type.
Walker walker;
type->getCanonicalType().walk(walker);
Type specificType = walker.specificType;

// Look for an unsafe type in the non-canonical type, which is a better answer
// if we can find it.
walker.specificType = Type();
type.walk(walker);
if (specificType && walker.specificType &&
specificType->isEqual(walker.specificType))
specificType = walker.specificType;

diagnose(specificType ? specificType : type);
}
Expand Down
Loading