Skip to content
Merged
324 changes: 320 additions & 4 deletions lib/Sema/DerivedConformanceTensorArrayProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal,
auto *structDecl = dyn_cast<StructDecl>(nominal);
if (!structDecl)
return false;
// All stored properties must conform to `TensorArrayProtocol`.
// All stored properties must conform to `TensorGroup`.
auto &C = nominal->getASTContext();
auto *tensorArrayProto =
C.getProtocol(KnownProtocolKind::TensorArrayProtocol);
auto *tensorGroupProto =
C.getProtocol(KnownProtocolKind::TensorGroup);
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
if (!v->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(v);
if (!v->hasInterfaceType())
return false;
auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
return (bool)TypeChecker::conformsToProtocol(varType, tensorArrayProto, DC,
return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC,
ConformanceCheckFlags::Used);
});
}
Expand All @@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
return lookup.front();
}

// Return the protocol requirement with the specified name.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) {
auto lookup = proto->lookupDirect(name);
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
[](ValueDecl *v) {
return !isa<ProtocolDecl>(
v->getDeclContext()) ||
!v->isProtocolRequirement();
}),
lookup.end());
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
return lookup.front();
}

// Synthesize body for `_unpackTensorHandles(into:)`.
static void
deriveBodyTensorArrayProtocol_unpackTensorHandles(
Expand Down Expand Up @@ -349,12 +363,314 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
return tensorHandleCountDecl;
}


/// Derive the body for the '_typeList' getter.
static void
deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl) {
auto *parentDC = funcDecl->getParent();
auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();

auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList);

// Concatenate all member `_typeList` arrays.
Type arrayType = BoundGenericType::get(
C.getArrayDecl(), Type(),
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C);
auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
ValueDecl *plusOpDecl = plusOpLookup.front();
auto plusOpDRE = new (C)
DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
auto plusOpExpr = new (C)
DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr);
Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
for (auto member : nominal->getStoredProperties()) {
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
auto *memberTypeListExpr = new (C)
MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq,
DeclNameLoc(), /*Implicit*/ true);
// Create expression `lhsArg + rhsArg`.
auto *plusOpArgs =
TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr},
{}, {}, SourceLoc(), /*HasTrailingClosure*/ false,
/*Implicit*/ true);
typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs,
/*Implicit*/ true);
}

// Return the resulting data types array.
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr);
auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
/*Implicit*/ true);
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
/*Implicit*/ true));
}

/// Derive a '_typeList' implementation.
static ValueDecl *deriveTensorArrayProtocol_typeList(
DerivedConformance &derived) {
auto nominal = derived.Nominal;
auto &TC = derived.TC;
ASTContext &C = TC.Context;

auto parentDC = derived.getConformanceContext();
Type dataTypeArrayType = BoundGenericType::get(
C.getArrayDecl(), Type(),
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType);

// Create `_typeList` property declaration.
VarDecl *typeListDecl;
PatternBindingDecl *patDecl;
std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty(
C.Id_typeList, returnType, returnType, /*isStatic*/ false,
/*isFinal*/ false);

// Add `@inlinable` to the `_typeList` declaration.
if (nominal->getEffectiveAccess() > AccessLevel::Internal)
typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));

// Create `_typeList` getter.
auto *getterDecl = derived.declareDerivedPropertyGetter(
TC, typeListDecl, returnType);
getterDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_typeList);
typeListDecl->setAccessors(StorageImplInfo::getImmutableComputed(),
SourceLoc(), {getterDecl}, SourceLoc());
derived.addMembersToConformanceContext({getterDecl, typeListDecl, patDecl});

return typeListDecl;
}

// Synthesize body for `init(_owning:count:)`.
static void
deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();

// Obtain the address type.
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
auto baseAddressType = BoundGenericType::get(
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
auto addressType = BoundGenericType::get(
C.getOptionalDecl(), Type(), {baseAddressType});
auto *addressTE = TypeExpr::createImplicit(addressType, C);

// Get references to `self` and parameter declarations.
auto *selfDecl = funcDecl->getImplicitSelfDecl();
auto *selfDRE = new (C)
DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
auto *paramDecl = funcDecl->getParameters()->get(0);
auto *paramDRE = new (C)
DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);

// Create an `if var` statement for the current address.
VarDecl *currAddressDecl = new (C) VarDecl(
/*IsStatic*/ false, VarDecl::Specifier::Var, /*IsCaptureList*/ false,
SourceLoc(), C.getIdentifier("currentAddress"), funcDecl);
currAddressDecl->setImplicit();
currAddressDecl->setHasNonPatternBindingInit(true);
currAddressDecl->setInterfaceType(baseAddressType);
currAddressDecl->setValidationToChecked();

Pattern *currAddressPat = new (C)
NamedPattern(currAddressDecl, /*implicit*/ true);
currAddressPat = new (C)
VarPattern(SourceLoc(), /*isLet*/ false, currAddressPat,
/*implicit*/ true);
currAddressPat = new (C)
OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc(),
/*implicit*/ true);
StmtConditionElement cond[] = {
StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)};

// Get the necessary protocol requirements.
auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
auto *tensorArrayProto = C.getProtocol(
KnownProtocolKind::TensorArrayProtocol);
auto initName = DeclName(
C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")});
auto *initReq = getProtocolRequirement(tensorGroupProto, initName);
auto *tensorHandleCountReq = getProtocolRequirement(
tensorArrayProto, C.Id_tensorHandleCount);

Type intType = C.getIntDecl()->getDeclaredType();
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);

// Iterate over members and call `self.t = T(_owning:)`.
llvm::SmallVector<ASTNode, 2> thenMemberExprs;
llvm::SmallVector<ASTNode, 2> elseMemberExprs;
for (auto member : nominal->getStoredProperties()) {
auto memberType = parentDC->mapTypeIntoContext(
member->getValueInterfaceType());
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
auto module = nominal->getModuleContext();
auto confRef = module->lookupConformance(
memberType, tensorGroupProto);
assert(confRef && "Member does not conform to `TensorGroup`");

// Get member type's constructor, e.g. `MemberType.init(_owning:)`.
// Use protocol requirement declaration for the method by default: this
// will be dynamically dispatched.
ValueDecl *memberInitDecl = initReq;
// If conformance reference is concrete, then use concrete witness
// declaration for the constructor.
if (confRef->isConcrete())
memberInitDecl = confRef->getConcrete()->getWitnessDecl(
initReq, C.getLazyResolver());
assert(memberInitDecl && "Member constructor declaration must exist");
auto memberInitDRE = new (C) DeclRefExpr(
memberInitDecl, DeclNameLoc(), /*implicit*/ true);
memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply);

// Create reference to member constructor: `MemberType.init(_owning:)`.
auto *memberInitExpr = new (C) ConstructorRefCallExpr(
memberInitDRE, memberTypeExpr);

auto *addressDRE = new (C) DeclRefExpr(
currAddressDecl, DeclNameLoc(), /*implicit*/ true);
auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType);

// Initialize the member using its TensorGroup constructor.
// Note that, initialization is dependent on the branch of the
// if-statement taken.
auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType);
auto *thenInitCallExpr = CallExpr::createImplicit(
C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")});

// Create a nil expression with type UnsafePointer<CTensorHandle>? for the
// `else` branch.
auto *nilDecl = C.getOptionalNoneDecl();
auto *nilDRE = new (C) DeclRefExpr(
nilDecl, DeclNameLoc(), /*implicit*/ true);
auto *elseInitExpr = new (C) DotSyntaxCallExpr(
nilDRE, SourceLoc(), addressTE);
auto *elseInitCallExpr = CallExpr::createImplicit(
C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});

// Assign the current member to the result of the initializer call.
auto *memberDRE = new (C) MemberRefExpr(
selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);

auto *thenAssignMemberExpr = new (C) AssignExpr(
memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true);
auto *elseAssignMemberExpr = new (C) AssignExpr(
memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true);

thenMemberExprs.push_back(thenAssignMemberExpr);
elseMemberExprs.push_back(elseAssignMemberExpr);

// Advance the current address.
DeclName advancedName(C, C.getIdentifier("advanced"),
{C.getIdentifier("by")});
auto *advancedMethodExpr =
new (C) UnresolvedDotExpr(addressDRE, SourceLoc(),
advancedName, DeclNameLoc(),
/*Implicit*/ true);

// Obtain `MemberType._tensorHandleCount`.
auto *memberCountMRE = new (C) MemberRefExpr(
memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(),
/*Implicit*/ true);

// Cast the tensor handle count to Int.
auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
{Identifier()});
auto *intInitExpr =
new (C) UnresolvedDotExpr(intTE, SourceLoc(), intInitName,
DeclNameLoc(), /*Implicit*/ true);
auto *intInitCallExpr = CallExpr::createImplicit(
C, intInitExpr, {memberCountMRE}, {Identifier()});

// Assign the new address.
auto *assignAddrCallExpr = CallExpr::createImplicit(
C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")});
auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(),
assignAddrCallExpr,
/*Implicit*/ true);

thenMemberExprs.push_back(assignAddrExpr);
}

auto *thenBody = BraceStmt::create(
C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(),
/*implicit*/ true);

auto *elseBody = BraceStmt::create(
C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(),
/*implicit*/ true);

auto *ifStmt = new (C)
IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(),
/*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody,
/*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true);

funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(),
/*implicit*/ true));
}

// Synthesize the `init(_owning:count:)` function declaration.
static ValueDecl
*deriveTensorArrayProtocol_init(DerivedConformance &derived) {
auto &C = derived.TC.Context;
auto nominal = derived.Nominal;
auto parentDC = derived.getConformanceContext();

// Obtain the address type.
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
Type baseAddressType = BoundGenericType::get(
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
Type addressType = BoundGenericType::get(
C.getOptionalDecl(), Type(), {baseAddressType});
Type intType = C.getIntDecl()->getDeclaredType();

auto *param1 = new (C) ParamDecl(
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"),
parentDC);
param1->setInterfaceType(addressType);
auto *param2 = new (C) ParamDecl(
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC);
param2->setInterfaceType(intType);
ParameterList *params = ParameterList::create(C, {param1, param2});

DeclName name(C, DeclBaseName::createConstructor(), params);
auto *initDecl =
new (C) ConstructorDecl(name, SourceLoc(), OTK_None, SourceLoc(),
/*Throws*/ false, SourceLoc(), params,
/*GenericParams*/ nullptr, parentDC);
initDecl->setImplicit();
initDecl->setSynthesized();
initDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_init);

if (auto env = parentDC->getGenericEnvironmentOfContext())
initDecl->setGenericEnvironment(env);
initDecl->computeType(AnyFunctionType::ExtInfo().withThrows(false));
initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
initDecl->setValidationToChecked();

derived.addMembersToConformanceContext({initDecl});
C.addSynthesizedDecl(initDecl);

return initDecl;
}

ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
ValueDecl *requirement) {
if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles)
return deriveTensorArrayProtocol_unpackTensorHandles(*this);
if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount)
return deriveTensorArrayProtocol_tensorHandleCount(*this);
if (requirement->getBaseName() == TC.Context.Id_typeList)
return deriveTensorArrayProtocol_typeList(*this);
if (requirement->getBaseName() == DeclBaseName::createConstructor())
return deriveTensorArrayProtocol_init(*this);
TC.diagnose(requirement->getLoc(),
diag::broken_tensor_array_protocol_requirement);
return nullptr;
Expand Down
12 changes: 12 additions & 0 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
// TensorArrayProtocol._tensorHandleCount
if (name.isSimpleName(ctx.Id_tensorHandleCount))
return getRequirement(KnownProtocolKind::TensorArrayProtocol);

// SWIFT_ENABLE_TENSORFLOW
// TensorArrayProtocol._typeList
if (name.isSimpleName(ctx.Id_typeList) && !requirement->isStatic())
return getRequirement(KnownProtocolKind::TensorArrayProtocol);

// SWIFT_ENABLE_TENSORFLOW
// TensorGroup._typeList
Expand Down Expand Up @@ -340,6 +345,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
if (argumentNames[0] == ctx.getIdentifier("_owning")) {
return getRequirement(KnownProtocolKind::TensorGroup);
}
} else if (argumentNames.size() == 2) {
// SWIFT_ENABLE_TENSORFLOW
// TensorArrayProtocol.init(_owning:count)
if (argumentNames[0] == ctx.getIdentifier("_owning") &&
argumentNames[1] == ctx.getIdentifier("count")) {
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
}
}

return nullptr;
Expand Down
Loading