Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ IDENTIFIER_(tensorHandleCount)
IDENTIFIER_(typeList)
// AdditiveArithmetic, VectorProtocol
IDENTIFIER(zero)
IDENTIFIER(Scalar)
IDENTIFIER(VectorSpaceScalar)
// Differentiable
IDENTIFIER(AllDifferentiableVariables)
IDENTIFIER(TangentVector)
Expand Down
52 changes: 28 additions & 24 deletions lib/Sema/DerivedConformanceAdditiveArithmeticVectorProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ enum MathOperator {
Add,
// `-(Self, Self)`, `AdditiveArithmetic` requirement
Subtract,
// `*(Scalar, Self)`, `VectorProtocol` requirement
// `*(VectorSpaceScalar, Self)`, `VectorProtocol` requirement
ScalarMultiply
};

Expand Down Expand Up @@ -76,52 +76,54 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
return lookup.front();
}

// Return the `Scalar` associated type for the given `ValueDecl` if it conforms
// to `VectorProtocol` in the given context. Otherwise, return `nullptr`.
static Type getVectorProtocolScalarAssocType(VarDecl *varDecl, DeclContext *DC) {
// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if
// it conforms to `VectorProtocol` in the given context. Otherwise, return
// `nullptr`.
static Type getVectorProtocolVectorSpaceScalarAssocType(
VarDecl *varDecl, DeclContext *DC) {
auto &C = varDecl->getASTContext();
auto *vectorNumericProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
if (!varDecl->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(varDecl);
if (!varDecl->hasInterfaceType())
return nullptr;
auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
auto conf = TypeChecker::conformsToProtocol(varType, vectorNumericProto, DC,
None);
auto conf = TypeChecker::conformsToProtocol(varType, vectorProto, DC, None);
if (!conf)
return nullptr;
Type scalarType = conf->getTypeWitnessByName(varType, C.Id_Scalar);
assert(scalarType && "'Scalar' associated type not found");
return scalarType;
conf->dump();
return conf->getTypeWitnessByName(varType, C.Id_VectorSpaceScalar);
}

// Return the `Scalar` associated type for the given nominal type in the given
// context, or `nullptr` if `Scalar` cannot be derived.
static Type deriveVectorProtocol_Scalar(NominalTypeDecl *nominal,
DeclContext *DC) {
// Return the `VectorSpaceScalar` associated type for the given nominal type in
// the given context, or `nullptr` if `VectorSpaceScalar` cannot be derived.
static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal,
DeclContext *DC) {
auto &C = DC->getASTContext();
// Nominal type must be a struct. (Zero stored properties is okay.)
if (!isa<StructDecl>(nominal))
return nullptr;
// If all stored properties conform to `VectorProtocol` and have the same
// `Scalar` associated type, return that `Scalar` associated type.
// Otherwise, the `Scalar` type cannot be derived.
// `VectorSpaceScalar` associated type, return that `VectorSpaceScalar`
// associated type. Otherwise, the `VectorSpaceScalar` type cannot be derived.
Type sameScalarType;
for (auto member : nominal->getStoredProperties()) {
if (!member->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(member);
if (!member->hasInterfaceType())
return nullptr;
auto scalarType = getVectorProtocolScalarAssocType(member, DC);
auto scalarType = getVectorProtocolVectorSpaceScalarAssocType(member, DC);
// If stored property does not conform to `VectorProtocol`, return nullptr.
if (!scalarType)
return nullptr;
// If same `Scalar` type has not been set, set it for the first time.
// If same `VectorSpaceScalar` type has not been set, set it for the first
// time.
if (!sameScalarType) {
sameScalarType = scalarType;
continue;
}
// If stored property `Scalar` types do not match, return nullptr.
// If stored property `VectorSpaceScalar` types do not match, return
// nullptr.
if (!scalarType->isEqual(sameScalarType))
return nullptr;
}
Expand Down Expand Up @@ -169,8 +171,8 @@ bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal,
// value information.
if (hasLetStoredPropertyWithInitialValue(nominal))
return false;
// Must be able to derive `Scalar` associated type.
return bool(deriveVectorProtocol_Scalar(nominal, DC));
// Must be able to derive `VectorSpaceScalar` associated type.
return bool(deriveVectorProtocol_VectorSpaceScalar(nominal, DC));
}

// Synthesize body for the given math operator.
Expand Down Expand Up @@ -303,7 +305,8 @@ static ValueDecl *deriveMathOperator(DerivedConformance &derived,
return std::make_pair(selfInterfaceType, selfInterfaceType);
case ScalarMultiply:
return std::make_pair(
deriveVectorProtocol_Scalar(nominal, parentDC)->mapTypeOutOfContext(),
deriveVectorProtocol_VectorSpaceScalar(nominal, parentDC)
->mapTypeOutOfContext(),
selfInterfaceType);
}
};
Expand Down Expand Up @@ -480,8 +483,9 @@ Type DerivedConformance::deriveVectorProtocol(AssociatedTypeDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
if (requirement->getBaseName() == TC.Context.Id_Scalar)
return deriveVectorProtocol_Scalar(Nominal, getConformanceContext());
if (requirement->getBaseName() == TC.Context.Id_VectorSpaceScalar)
return deriveVectorProtocol_VectorSpaceScalar(
Nominal, getConformanceContext());
TC.diagnose(requirement->getLoc(), diag::broken_vector_protocol_requirement);
return nullptr;
}
15 changes: 8 additions & 7 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,8 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType());
auto *vecNumProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
auto vecNumType = TypeLoc::withoutLoc(vecNumProto->getDeclaredType());
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType());
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());

Expand All @@ -651,17 +651,18 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
});

// Associated struct can derive `VectorProtocol` if the associated types of
// all members conform to `VectorProtocol` and share the same scalar type.
// all members conform to `VectorProtocol` and share the same
// `VectorSpaceScalar` type.
Type sameScalarType;
bool canDeriveVectorProtocol =
canDeriveAdditiveArithmetic && !diffProperties.empty() &&
llvm::all_of(diffProperties, [&](VarDecl *vd) {
auto conf = TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
vecNumProto, nominal,
None);
vectorProto, nominal, None);
if (!conf)
return false;
Type scalarType = conf->getTypeWitnessByName(vd->getType(), C.Id_Scalar);
auto scalarType =
conf->getTypeWitnessByName(vd->getType(), C.Id_VectorSpaceScalar);
if (!sameScalarType) {
sameScalarType = scalarType;
return true;
Expand All @@ -685,7 +686,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
// type, make the associated struct conform to `VectorProtocol` instead of
// just `AdditiveArithmetic`.
if (canDeriveVectorProtocol)
inherited.push_back(vecNumType);
inherited.push_back(vectorType);

auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(),
/*Inherited*/ C.AllocateCopy(inherited),
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
return getRequirement(KnownProtocolKind::Differentiable);

// SWIFT_ENABLE_TENSORFLOW
// VectorProtocol.Scalar
if (name.isSimpleName(ctx.Id_Scalar))
// VectorProtocol.VectorSpaceScalar
if (name.isSimpleName(ctx.Id_VectorSpaceScalar))
return getRequirement(KnownProtocolKind::VectorProtocol);

return nullptr;
Expand Down
10 changes: 5 additions & 5 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@
/// elements in this vector space and have either no shape or a static shape.
public protocol VectorProtocol : AdditiveArithmetic {
/// The type of scalars in the vector space.
associatedtype Scalar : AdditiveArithmetic
associatedtype VectorSpaceScalar : AdditiveArithmetic

static func * (lhs: Scalar, rhs: Self) -> Self
static func *= (lhs: inout Self, rhs: Scalar)
static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self
static func *= (lhs: inout Self, rhs: VectorSpaceScalar)
}

public extension VectorProtocol {
static func * (lhs: Self, rhs: Scalar) -> Self {
static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self {
return rhs * lhs
}

static func *= (lhs: inout Self, rhs: Scalar) {
static func *= (lhs: inout Self, rhs: VectorSpaceScalar) {
lhs = rhs * lhs
}
}
Expand Down
2 changes: 1 addition & 1 deletion stdlib/public/core/FloatingPointTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,7 @@ extension ${Self} : Strideable {
//===----------------------------------------------------------------------===//

extension ${Self} : VectorProtocol {
public typealias Scalar = ${Self}
public typealias VectorSpaceScalar = ${Self}
}

extension ${Self} : Differentiable {
Expand Down
2 changes: 1 addition & 1 deletion utils/update_checkout/update-checkout-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@
"clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a",
"libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a",
"tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86",
"tensorflow-swift-apis": "b9bcebc7dfd1497e324c24f4c0d8c0b212580d1f",
"tensorflow-swift-apis": "eea90b71a7bf01ad031e11bb4a753c40e19e55d9",
"indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a",
"sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a"
}
Expand Down