Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,12 +543,16 @@ class alignas(1 << DeclAlignInBits) Decl {
HasUnreferenceableStorage : 1
);

SWIFT_INLINE_BITFIELD(EnumDecl, NominalTypeDecl, 2+2,
SWIFT_INLINE_BITFIELD(EnumDecl, NominalTypeDecl, 2+2+1,
/// The stage of the raw type circularity check for this class.
Circularity : 2,

/// True if the enum has cases and at least one case has associated values.
HasAssociatedValues : 2
HasAssociatedValues : 2,
/// True if the enum has at least one case that has some availability
/// attribute. A single bit because it's lazily computed along with the
/// HasAssociatedValues bit.
HasAnyUnavailableValues : 1
);

SWIFT_INLINE_BITFIELD(PrecedenceGroupDecl, Decl, 1+2,
Expand Down Expand Up @@ -3220,6 +3224,11 @@ class EnumDecl final : public NominalTypeDecl {
/// Note that this is true for enums with absolutely no cases.
bool hasOnlyCasesWithoutAssociatedValues() const;

/// True if any of the enum cases have availability annotations.
///
/// Note that this is false for enums with absolutely no cases.
bool hasPotentiallyUnavailableCaseValue() const;

/// True if the enum has cases.
bool hasCases() const {
return !getAllElements().empty();
Expand Down
3 changes: 2 additions & 1 deletion include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2096,10 +2096,11 @@ NOTE(construct_raw_representable_from_unwrapped_value,none,
"construct %0 from unwrapped %1 value", (Type, Type))

// Derived conformances

ERROR(cannot_synthesize_in_extension,none,
"implementation of %0 cannot be automatically synthesized in an extension", (Type))

ERROR(broken_case_iterable_requirement,none,
"CaseIterable protocol is broken: unexpected requirement", ())
ERROR(broken_raw_representable_requirement,none,
"RawRepresentable protocol is broken: unexpected requirement", ())
ERROR(broken_equatable_requirement,none,
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#define IDENTIFIER(name) IDENTIFIER_WITH_NAME(name, #name)
#define IDENTIFIER_(name) IDENTIFIER_WITH_NAME(name, "_" #name)

IDENTIFIER(AllCases)
IDENTIFIER(allCases)
IDENTIFIER(alloc)
IDENTIFIER(allocWithZone)
IDENTIFIER(allZeros)
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ PROTOCOL(Comparable)
PROTOCOL(Error)
PROTOCOL_(ErrorCodeProtocol)
PROTOCOL(OptionSet)
PROTOCOL(CaseIterable)

PROTOCOL_(BridgedNSError)
PROTOCOL_(BridgedStoredNSError)
Expand Down
22 changes: 22 additions & 0 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2689,6 +2689,8 @@ EnumDecl::EnumDecl(SourceLoc EnumLoc,
= static_cast<unsigned>(CircularityCheck::Unchecked);
Bits.EnumDecl.HasAssociatedValues
= static_cast<unsigned>(AssociatedValueCheck::Unchecked);
Bits.EnumDecl.HasAnyUnavailableValues
= false;
}

StructDecl::StructDecl(SourceLoc StructLoc, Identifier Name, SourceLoc NameLoc,
Expand Down Expand Up @@ -3041,6 +3043,17 @@ EnumElementDecl *EnumDecl::getElement(Identifier Name) const {
return nullptr;
}

bool EnumDecl::hasPotentiallyUnavailableCaseValue() const {
switch (static_cast<AssociatedValueCheck>(Bits.EnumDecl.HasAssociatedValues)) {
case AssociatedValueCheck::Unchecked:
// Compute below
this->hasOnlyCasesWithoutAssociatedValues();
LLVM_FALLTHROUGH;
default:
return static_cast<bool>(Bits.EnumDecl.HasAnyUnavailableValues);
}
}

bool EnumDecl::hasOnlyCasesWithoutAssociatedValues() const {
// Check whether we already have a cached answer.
switch (static_cast<AssociatedValueCheck>(
Expand All @@ -3056,6 +3069,15 @@ bool EnumDecl::hasOnlyCasesWithoutAssociatedValues() const {
return false;
}
for (auto elt : getAllElements()) {
for (auto Attr : elt->getAttrs()) {
if (auto AvAttr = dyn_cast<AvailableAttr>(Attr)) {
if (!AvAttr->isInvalid()) {
const_cast<EnumDecl*>(this)->Bits.EnumDecl.HasAnyUnavailableValues
= true;
}
}
}

if (elt->hasAssociatedValues()) {
const_cast<EnumDecl*>(this)->Bits.EnumDecl.HasAssociatedValues
= static_cast<unsigned>(AssociatedValueCheck::HasAssociatedValues);
Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5856,6 +5856,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::RawRepresentable:
case KnownProtocolKind::Equatable:
case KnownProtocolKind::Hashable:
case KnownProtocolKind::CaseIterable:
case KnownProtocolKind::Comparable:
case KnownProtocolKind::ObjectiveCBridgeable:
case KnownProtocolKind::DestructorSafeContainer:
Expand Down
1 change: 1 addition & 0 deletions lib/Sema/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_swift_library(swiftSema STATIC
ConstraintGraph.cpp
ConstraintLocator.cpp
ConstraintSystem.cpp
DerivedConformanceCaseIterable.cpp
DerivedConformanceCodable.cpp
DerivedConformanceCodingKey.cpp
DerivedConformanceEquatableHashable.cpp
Expand Down
164 changes: 164 additions & 0 deletions lib/Sema/DerivedConformanceCaseIterable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2016 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See http://swift.org/LICENSE.txt for license information
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements implicit derivation of the CaseIterable protocol.
//
//===----------------------------------------------------------------------===//

#include "TypeChecker.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Expr.h"
#include "swift/AST/Types.h"
#include "llvm/Support/raw_ostream.h"
#include "DerivedConformances.h"

using namespace swift;
using namespace DerivedConformance;

/// Common preconditions for CaseIterable.
static bool canDeriveConformance(NominalTypeDecl *type) {
// The type must be an enum.
auto enumDecl = dyn_cast<EnumDecl>(type);
if (!enumDecl)
return false;

// "Simple" enums without availability attributes can derive
// a CaseIterable conformance.
//
// FIXME: Lift the availability restriction.
return !enumDecl->hasPotentiallyUnavailableCaseValue()
&& enumDecl->hasOnlyCasesWithoutAssociatedValues();
}

/// Derive the implementation of allCases for a "simple" no-payload enum.
void deriveCaseIterable_enum_getter(AbstractFunctionDecl *funcDecl) {
auto *parentDC = funcDecl->getDeclContext();
auto *parentEnum = parentDC->getAsEnumOrEnumExtensionContext();
auto enumTy = parentEnum->getDeclaredTypeInContext();
auto &C = parentDC->getASTContext();

SmallVector<Expr *, 8> elExprs;
for (EnumElementDecl *elt : parentEnum->getAllElements()) {
auto *ref = new (C) DeclRefExpr(elt, DeclNameLoc(), /*implicit*/true);
auto *base = TypeExpr::createImplicit(enumTy, C);
auto *apply = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base);
elExprs.push_back(apply);
}
auto *arrayExpr = ArrayExpr::create(C, SourceLoc(), elExprs, {}, SourceLoc());

auto *returnStmt = new (C) ReturnStmt(SourceLoc(), arrayExpr);
auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt),
SourceLoc());
funcDecl->setBody(body);
}

static ArraySliceType *computeAllCasesType(NominalTypeDecl *enumType) {
auto metaTy = enumType->getDeclaredInterfaceType();
if (!metaTy || metaTy->hasError())
return nullptr;

return ArraySliceType::get(metaTy->getRValueInstanceType());
}

static Type deriveCaseIterable_AllCases(TypeChecker &tc, Decl *parentDecl,
EnumDecl *enumDecl) {
// enum SomeEnum : CaseIterable {
// @derived
// typealias AllCases = [SomeEnum]
// }
auto *rawInterfaceType = computeAllCasesType(enumDecl);
return cast<DeclContext>(parentDecl)->mapTypeIntoContext(rawInterfaceType);
}

ValueDecl *DerivedConformance::deriveCaseIterable(TypeChecker &tc,
Decl *parentDecl,
NominalTypeDecl *targetDecl,
ValueDecl *requirement) {
// Conformance can't be synthesized in an extension.
auto caseIterableProto
= tc.Context.getProtocol(KnownProtocolKind::CaseIterable);
auto caseIterableType = caseIterableProto->getDeclaredType();
if (targetDecl != parentDecl) {
tc.diagnose(parentDecl->getLoc(), diag::cannot_synthesize_in_extension,
caseIterableType);
return nullptr;
}

// Check that we can actually derive CaseIterable for this type.
if (!canDeriveConformance(targetDecl))
return nullptr;

// Build the necessary decl.
if (requirement->getBaseName() != tc.Context.Id_allCases) {
tc.diagnose(requirement->getLoc(),
diag::broken_case_iterable_requirement);
return nullptr;
}

auto enumDecl = cast<EnumDecl>(targetDecl);
ASTContext &C = tc.Context;


// Define the property.
auto *returnTy = computeAllCasesType(targetDecl);

VarDecl *propDecl;
PatternBindingDecl *pbDecl;
std::tie(propDecl, pbDecl)
= declareDerivedProperty(tc, parentDecl, enumDecl, C.Id_allCases,
returnTy, returnTy,
/*isStatic=*/true, /*isFinal=*/true);

// Define the getter.
auto *getterDecl = addGetterToReadOnlyDerivedProperty(tc, propDecl, returnTy);

getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);

auto dc = cast<IterableDeclContext>(parentDecl);
dc->addMember(getterDecl);
dc->addMember(propDecl);
dc->addMember(pbDecl);

return propDecl;
}

Type DerivedConformance::deriveCaseIterable(TypeChecker &tc, Decl *parentDecl,
NominalTypeDecl *targetDecl,
AssociatedTypeDecl *assocType) {
// Conformance can't be synthesized in an extension.
auto caseIterableProto
= tc.Context.getProtocol(KnownProtocolKind::CaseIterable);
auto caseIterableType = caseIterableProto->getDeclaredType();
if (targetDecl != parentDecl) {
tc.diagnose(parentDecl->getLoc(), diag::cannot_synthesize_in_extension,
caseIterableType);
return nullptr;
}

// We can only synthesize CaseIterable for enums.
auto enumDecl = dyn_cast<EnumDecl>(targetDecl);
if (!enumDecl)
return nullptr;

// Check that we can actually derive CaseIterable for this type.
if (!canDeriveConformance(targetDecl))
return nullptr;

if (assocType->getName() == tc.Context.Id_AllCases) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the caller already check this?

return deriveCaseIterable_AllCases(tc, parentDecl, enumDecl);
}

tc.diagnose(assocType->getLoc(),
diag::broken_case_iterable_requirement);
return nullptr;
}

17 changes: 16 additions & 1 deletion lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ bool DerivedConformance::derivesProtocolConformance(TypeChecker &tc,
return enumDecl->hasRawType();

// Enums without associated values can implicitly derive Equatable and
// Hashable conformance.
// Hashable conformances.
case KnownProtocolKind::Equatable:
return canDeriveEquatable(tc, enumDecl, protocol);
case KnownProtocolKind::Hashable:
return canDeriveHashable(tc, enumDecl, protocol);
// "Simple" enums without availability attributes can explicitly derive
// a CaseIterable conformance.
//
// FIXME: Lift the availability restriction.
case KnownProtocolKind::CaseIterable:
return !enumDecl->hasPotentiallyUnavailableCaseValue()
&& enumDecl->hasOnlyCasesWithoutAssociatedValues();

// @objc enums can explicitly derive their _BridgedNSError conformance.
case KnownProtocolKind::BridgedNSError:
Expand Down Expand Up @@ -135,6 +142,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
if (name.isSimpleName(ctx.Id_hashValue))
return getRequirement(KnownProtocolKind::Hashable);

// CaseIterable.allValues
if (name.isSimpleName(ctx.Id_allCases))
return getRequirement(KnownProtocolKind::CaseIterable);

// _BridgedNSError._nsErrorDomain
if (name.isSimpleName(ctx.Id_nsErrorDomain))
return getRequirement(KnownProtocolKind::BridgedNSError);
Expand Down Expand Up @@ -192,6 +203,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
if (name.isSimpleName(ctx.Id_RawValue))
return getRequirement(KnownProtocolKind::RawRepresentable);

// CaseIterable.AllCases
if (name.isSimpleName(ctx.Id_AllCases))
return getRequirement(KnownProtocolKind::CaseIterable);

return nullptr;
}

Expand Down
19 changes: 19 additions & 0 deletions lib/Sema/DerivedConformances.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ ValueDecl *getDerivableRequirement(TypeChecker &tc,
NominalTypeDecl *nominal,
ValueDecl *requirement);


/// Derive a CaseIterable requirement for an enum if it has no associated
/// values for any of its cases.
///
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveCaseIterable(TypeChecker &tc,
Decl *parentDecl,
NominalTypeDecl *type,
ValueDecl *requirement);

/// Derive a CaseIterable type witness for an enum if it has no associated
/// values for any of its cases.
///
/// \returns the derived member, which will also be added to the type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the comments on these two functions supposed to be identical? (I also find it a bit strange that the functions are named identically, but I realize that's a pattern that already exists in this file — nonetheless, did you consider something like deriveCaseIterable_property_allCases and deriveCaseIterable_associatedtype_AllCases? Is there some motivation I'm missing here for this naming convention?)

Type deriveCaseIterable(TypeChecker &tc,
Decl *parentDecl,
NominalTypeDecl *type,
AssociatedTypeDecl *assocType);

/// Derive a RawRepresentable requirement for an enum, if it has a valid
/// raw type and raw values for all of its cases.
///
Expand Down
7 changes: 7 additions & 0 deletions lib/Sema/TypeCheckDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9012,6 +9012,13 @@ void TypeChecker::synthesizeMemberForLookup(NominalTypeDecl *target,
auto *encodableProto = Context.getProtocol(KnownProtocolKind::Encodable);
if (!evaluateTargetConformanceTo(decodableProto))
(void)evaluateTargetConformanceTo(encodableProto);
} else if (baseName.getIdentifier() == Context.Id_allCases ||
baseName.getIdentifier() == Context.Id_AllCases) {
// If the target should conform to the CaseIterable protocol, check the
// conformance here to attempt synthesis.
auto *caseIterableProto
= Context.getProtocol(KnownProtocolKind::CaseIterable);
(void)evaluateTargetConformanceTo(caseIterableProto);
}
} else {
auto argumentNames = member.getArgumentNames();
Expand Down
14 changes: 11 additions & 3 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4712,11 +4712,17 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
return DerivedConformance::deriveRawRepresentable(*this, Decl,
TypeDecl, Requirement);

case KnownProtocolKind::CaseIterable:
return DerivedConformance::deriveCaseIterable(*this, Decl,
TypeDecl, Requirement);

case KnownProtocolKind::Equatable:
return DerivedConformance::deriveEquatable(*this, Decl, TypeDecl, Requirement);
return DerivedConformance::deriveEquatable(*this, Decl, TypeDecl,
Requirement);

case KnownProtocolKind::Hashable:
return DerivedConformance::deriveHashable(*this, Decl, TypeDecl, Requirement);
return DerivedConformance::deriveHashable(*this, Decl, TypeDecl,
Requirement);

case KnownProtocolKind::BridgedNSError:
return DerivedConformance::deriveBridgedNSError(*this, Decl, TypeDecl,
Expand Down Expand Up @@ -4752,7 +4758,9 @@ Type TypeChecker::deriveTypeWitness(DeclContext *DC,
case KnownProtocolKind::RawRepresentable:
return DerivedConformance::deriveRawRepresentable(*this, Decl,
TypeDecl, AssocType);

case KnownProtocolKind::CaseIterable:
return DerivedConformance::deriveCaseIterable(*this, Decl,
TypeDecl, AssocType);
default:
return nullptr;
}
Expand Down
Loading