From e969b513135795a66db8f47f5b0896545b26b5a3 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 23 Apr 2019 19:17:25 -0400 Subject: [PATCH 01/11] Changed 'TensorArrayProtocol' such that it can be used to support output tensor arrays in raw ops. --- .../DerivedConformanceTensorArrayProtocol.cpp | 256 +++++++++++++++++- lib/Sema/DerivedConformances.cpp | 7 + stdlib/public/TensorFlow/CMakeLists.txt | 4 +- stdlib/public/TensorFlow/TensorGroup.swift | 23 +- .../tensor_array_protocol.swift | 6 +- test/TensorFlowRuntime/tracer.swift | 12 + utils/build-script-impl | 4 +- utils/build_swift/driver_arguments.py | 3 +- 8 files changed, 302 insertions(+), 13 deletions(-) diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index 990d65c5532a1..f9998b896a3f6 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal, auto *structDecl = dyn_cast(nominal); if (!structDecl) return false; - // All stored properties must conform to `TensorArrayProtocol`. + // All stored properties must conform to `TensorGroup`. auto &C = nominal->getASTContext(); - auto *tensorArrayProto = - C.getProtocol(KnownProtocolKind::TensorArrayProtocol); + auto *tensorGroupProto = + C.getProtocol(KnownProtocolKind::TensorGroup); return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) { if (!v->hasInterfaceType()) C.getLazyResolver()->resolveDeclSignature(v); if (!v->hasInterfaceType()) return false; auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); - return (bool)TypeChecker::conformsToProtocol(varType, tensorArrayProto, DC, + return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC, ConformanceCheckFlags::Used); }); } @@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { return lookup.front(); } +// Return the protocol requirement with the specified name. +static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) { + auto lookup = proto->lookupDirect(name); + lookup.erase(std::remove_if(lookup.begin(), lookup.end(), + [](ValueDecl *v) { + return !isa( + v->getDeclContext()) || + !v->isProtocolRequirement(); + }), + lookup.end()); + assert(lookup.size() == 1 && "Ambiguous protocol requirement"); + return lookup.front(); +} + // Synthesize body for `_unpackTensorHandles(into:)`. static void deriveBodyTensorArrayProtocol_unpackTensorHandles( @@ -349,12 +363,246 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount( return tensorHandleCountDecl; } +// Synthesize body for `init(_owning:count:)`. +static void +deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { + auto *parentDC = funcDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + // Obtain the address type. + auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType(); + auto baseAddressType = BoundGenericType::get( + C.getUnsafePointerDecl(), Type(), {cTensorHandleType}); + auto addressType = BoundGenericType::get( + C.getOptionalDecl(), Type(), {baseAddressType}); + auto *addressTE = TypeExpr::createImplicit(addressType, C); + + // Get references to `self` and parameter declarations. + auto *selfDecl = funcDecl->getImplicitSelfDecl(); + auto *selfDRE = new (C) + DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); + auto *paramDecl = funcDecl->getParameters()->get(0); + auto *paramDRE = new (C) + DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true); + + // Create an `if var` statement for the current address. + VarDecl *currAddressDecl = new (C) VarDecl( + /*IsStatic*/ false, VarDecl::Specifier::Var, /*IsCaptureList*/ false, + SourceLoc(), C.getIdentifier("currentAddress"), funcDecl); + currAddressDecl->setImplicit(); + currAddressDecl->setHasNonPatternBindingInit(true); + currAddressDecl->setInterfaceType(baseAddressType); + currAddressDecl->setValidationToChecked(); + + Pattern *currAddressPat = new (C) + NamedPattern(currAddressDecl, /*implicit*/ true); + currAddressPat = new (C) + VarPattern(SourceLoc(), /*isLet*/ false, currAddressPat, + /*implicit*/ true); + currAddressPat = new (C) + OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc(), + /*implicit*/ true); + StmtConditionElement cond[] = { + StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)}; + + // Get the necessary protocol requirements. + auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup); + auto *tensorArrayProto = C.getProtocol( + KnownProtocolKind::TensorArrayProtocol); + auto initName = DeclName( + C, DeclBaseName::createConstructor(), + {C.getIdentifier("_owning"), C.getIdentifier("count")}); + auto *initReq = getProtocolRequirement(tensorArrayProto, initName); + auto *tensorHandleCountReq = getProtocolRequirement( + tensorArrayProto, C.Id_tensorHandleCount); + + Type intType = C.getIntDecl()->getDeclaredType(); + TypeExpr *intTE = TypeExpr::createImplicit(intType, C); + + // Goes through the member TensorGroups and call + // `self.t = T(_owning:count:)`. + llvm::SmallVector thenMemberExprs; + llvm::SmallVector elseMemberExprs; + for (auto member : nominal->getStoredProperties()) { + auto memberType = parentDC->mapTypeIntoContext( + member->getValueInterfaceType()); + auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C); + auto module = nominal->getModuleContext(); + auto confRef = module->lookupConformance( + memberType, tensorGroupProto); + assert(confRef && "Member does not conform to `TensorGroup`"); + + // Get member type's constructor, e.g. `MemberType.init(_owning:)`. + // Use protocol requirement declaration for the method by default: this + // will be dynamically dispatched. + ValueDecl *memberInitDecl = initReq; + // If conformance reference is concrete, then use concrete witness + // declaration for the constructor. + if (confRef->isConcrete()) + memberInitDecl = confRef->getConcrete()->getWitnessDecl( + initReq, C.getLazyResolver()); + assert(memberInitDecl && "Member constructor declaration must exist"); + auto memberInitDRE = new (C) DeclRefExpr( + memberInitDecl, DeclNameLoc(), /*implicit*/ true); + memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + + // Create reference to member constructor: `MemberType.init(_owning:)`. + auto *memberInitExpr = new (C) ConstructorRefCallExpr( + memberInitDRE, memberTypeExpr); + + auto *addressDRE = new (C) DeclRefExpr( + currAddressDecl, DeclNameLoc(), /*implicit*/ true); + auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType); + + // Initialize the member using its TensorGroup constructor. + // Note that, initialization is dependent on the branch of the + // if-statement taken. + auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType); + auto *thenInitCallExpr = CallExpr::createImplicit( + C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")}); + + // Create a nil expression with type UnsafePointer? for the + // `else` branch. + auto *nilDecl = C.getOptionalNoneDecl(); + auto *nilDRE = new (C) DeclRefExpr( + nilDecl, DeclNameLoc(), /*implicit*/ true); + auto *elseInitExpr = new (C) DotSyntaxCallExpr( + nilDRE, SourceLoc(), addressTE); + auto *elseInitCallExpr = CallExpr::createImplicit( + C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")}); + + // Assign the current member to the result of the initializer call. + auto *memberDRE = new (C) MemberRefExpr( + selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true); + + auto *thenAssignMemberExpr = new (C) AssignExpr( + memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true); + auto *elseAssignMemberExpr = new (C) AssignExpr( + memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true); + + thenMemberExprs.push_back(thenAssignMemberExpr); + elseMemberExprs.push_back(elseAssignMemberExpr); + + // Advance the current address. + DeclName advancedName(C, C.getIdentifier("advanced"), + {C.getIdentifier("by")}); + auto *advancedMethodExpr = + new (C) UnresolvedDotExpr(addressDRE, SourceLoc(), + advancedName, DeclNameLoc(), + /*Implicit*/ true); + + // Obtain `MemberType._tensorHandleCount`. + auto *memberCountMRE = new (C) MemberRefExpr( + memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(), + /*Implicit*/ true); + + // Cast the tensor handle count to Int. + auto intInitName = DeclName(C, DeclBaseName::createConstructor(), + {Identifier()}); + auto *intInitExpr = + new (C) UnresolvedDotExpr(intTE, SourceLoc(), intInitName, + DeclNameLoc(), /*Implicit*/ true); + auto *intInitCallExpr = CallExpr::createImplicit( + C, intInitExpr, {memberCountMRE}, {Identifier()}); + + // Assign the new address. + auto *assignAddrCallExpr = CallExpr::createImplicit( + C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")}); + auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(), + assignAddrCallExpr, + /*Implicit*/ true); + + thenMemberExprs.push_back(assignAddrExpr); + } + + auto *thenBody = BraceStmt::create( + C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(), + /*implicit*/ true); + + auto *elseBody = BraceStmt::create( + C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(), + /*implicit*/ true); + + auto *ifStmt = new (C) + IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(), + /*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody, + /*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true); + + funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(), + /*implicit*/ true)); +} + +// Synthesize a constructor declaration for a `TensorArrayProtocol` +// method requirement. +static ValueDecl *deriveTensorArrayProtocol_constructor( + DerivedConformance &derived, Identifier argument1Name, + Identifier parameter1Name, Type parameter1Type, + Identifier parameter2Name, Type parameter2Type, Type returnType, + AbstractFunctionDecl::BodySynthesizer bodySynthesizer) { + auto nominal = derived.Nominal; + auto &C = derived.TC.Context; + auto parentDC = derived.getConformanceContext(); + + auto *param1 = + new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + argument1Name, SourceLoc(), parameter1Name, parentDC); + param1->setInterfaceType(parameter1Type); + auto *param2 = + new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + parameter2Name, SourceLoc(), parameter2Name, parentDC); + param2->setInterfaceType(parameter2Type); + ParameterList *params = ParameterList::create(C, {param1, param2}); + + DeclName name(C, DeclBaseName::createConstructor(), params); + auto *initDecl = + new (C) ConstructorDecl(name, SourceLoc(), OTK_None, SourceLoc(), + /*Throws*/ false, SourceLoc(), params, + /*GenericParams*/ nullptr, parentDC); + initDecl->setImplicit(); + initDecl->setSynthesized(); + initDecl->setBodySynthesizer(bodySynthesizer); + + if (auto env = parentDC->getGenericEnvironmentOfContext()) + initDecl->setGenericEnvironment(env); + initDecl->computeType(AnyFunctionType::ExtInfo().withThrows(false)); + initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); + initDecl->setValidationToChecked(); + + derived.addMembersToConformanceContext({initDecl}); + C.addSynthesizedDecl(initDecl); + + return initDecl; +} + +// Synthesize the `init(_owning:count:)` function declaration. +static ValueDecl +*deriveTensorArrayProtocol_init(DerivedConformance &derived) { + auto &C = derived.TC.Context; + + // Obtain the address type. + auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType(); + Type baseAddressType = BoundGenericType::get( + C.getUnsafePointerDecl(), Type(), {cTensorHandleType}); + Type addressType = BoundGenericType::get( + C.getOptionalDecl(), Type(), {baseAddressType}); + Type intType = C.getIntDecl()->getDeclaredType(); + Type voidType = C.getVoidDecl()->getDeclaredInterfaceType(); + + return deriveTensorArrayProtocol_constructor( + derived, C.getIdentifier("_owning"), C.getIdentifier("tensorHandles"), + addressType, C.getIdentifier("count"), intType, voidType, + deriveBodyTensorArrayProtocol_init); +} + ValueDecl *DerivedConformance::deriveTensorArrayProtocol( ValueDecl *requirement) { if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles) return deriveTensorArrayProtocol_unpackTensorHandles(*this); if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount) return deriveTensorArrayProtocol_tensorHandleCount(*this); + if (requirement->getBaseName() == DeclBaseName::createConstructor()) + return deriveTensorArrayProtocol_init(*this); TC.diagnose(requirement->getLoc(), diag::broken_tensor_array_protocol_requirement); return nullptr; diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 9f16124c59127..8fdecb20a0397 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -340,6 +340,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, if (argumentNames[0] == ctx.getIdentifier("_owning")) { return getRequirement(KnownProtocolKind::TensorGroup); } + } else if (argumentNames.size() == 2) { + // SWIFT_ENABLE_TENSORFLOW + // TensorArrayProtocol.init(_owning:count) + if (argumentNames[0] == ctx.getIdentifier("_owning") && + argumentNames[0] == ctx.getIdentifier("count")) { + return getRequirement(KnownProtocolKind::TensorArrayProtocol); + } } return nullptr; diff --git a/stdlib/public/TensorFlow/CMakeLists.txt b/stdlib/public/TensorFlow/CMakeLists.txt index e3bb11fff4ceb..a40d900f24f45 100644 --- a/stdlib/public/TensorFlow/CMakeLists.txt +++ b/stdlib/public/TensorFlow/CMakeLists.txt @@ -56,7 +56,9 @@ set(SOURCES # Copy TensorFlow bindings file, if it exists. if (TENSORFLOW_SWIFT_BINDINGS) - list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS}") + file(GLOB_RECURSE TENSORFLOW_SWIFT_BINDINGS_SOURCES + "${TENSORFLOW_SWIFT_BINDINGS}/*.swift") + list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS_SOURCES}") endif() # Copy TensorFlow high-level API sources, if they exist. diff --git a/stdlib/public/TensorFlow/TensorGroup.swift b/stdlib/public/TensorFlow/TensorGroup.swift index af2b4f0dbe457..0982a3d765f56 100644 --- a/stdlib/public/TensorFlow/TensorGroup.swift +++ b/stdlib/public/TensorFlow/TensorGroup.swift @@ -21,6 +21,13 @@ import CTensorFlow /// This protocol is defined separately from `TensorGroup` in order for the /// number of tensors to be determined at runtime. For example, /// `[Tensor]` may have an unknown number of elements at compile time. +/// +/// This protocol can be derived automatically for structs whose stored +/// properties all conform to the `TensorGroup` protocol. It cannot be derived +/// automatically for structs whose properties all conform to +/// `TensorArrayProtocol` due to the constructor requirement (i.e., in such +/// cases it would be impossible to know how to break down `count` among the +/// stored properties). public protocol TensorArrayProtocol { /// Writes the tensor handles to `address`, which must be allocated /// with enough capacity to hold `_tensorHandleCount` handles. The tensor @@ -29,6 +36,8 @@ public protocol TensorArrayProtocol { func _unpackTensorHandles(into address: UnsafeMutablePointer?) var _tensorHandleCount: Int32 { get } + + init(_owning tensorHandles: UnsafePointer?, count: Int) } /// A protocol representing types that can be mapped to and from @@ -67,6 +76,11 @@ public extension TensorGroup { static var _unknownShapeList: [TensorShape?] { return Array(repeating: nil, count: _typeList.count) } + + init(_owning tensorHandles: UnsafePointer?, count: Int) { + precondition(count == Self._typeList.count) + self.init(_owning: tensorHandles) + } } //===----------------------------------------------------------------------===// @@ -199,7 +213,7 @@ extension StringTensor : TensorGroup { } } -extension Array : TensorArrayProtocol where Element : TensorArrayProtocol { +extension Array : TensorArrayProtocol where Element : TensorGroup { public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { var ptr = address for elem in self { @@ -213,4 +227,11 @@ extension Array : TensorArrayProtocol where Element : TensorArrayProtocol { for elem in self { count += elem._tensorHandleCount } return count } + + public init(_owning tensorHandles: UnsafePointer?, count: Int) { + let size = count / Int(Element._tensorHandleCount) + self = Array((0.. } @@ -32,7 +32,7 @@ struct Nested : TensorArrayProtocol { var mixed: Mixed } -struct Generic : TensorArrayProtocol { +struct Generic : TensorArrayProtocol { var t: T var u: U } diff --git a/test/TensorFlowRuntime/tracer.swift b/test/TensorFlowRuntime/tracer.swift index f3cfffe2b44b1..3aa5ba4aae8e3 100644 --- a/test/TensorFlowRuntime/tracer.swift +++ b/test/TensorFlowRuntime/tracer.swift @@ -168,6 +168,17 @@ TracerTests.testAllBackends("Advanced") { var model: Model = [Tensor(1.0), Tensor(2.0)] var optimizer: Optimizer = [Tensor(1.0), Tensor(2.0)] + public init() {} + + public init(_owning tensorHandles: UnsafePointer?, count: Int) { + self.model = [ + Tensor(_owning: tensorHandles), + Tensor(_owning: tensorHandles?.advanced(by: 1))] + self.optimizer = [ + Tensor(_owning: tensorHandles?.advanced(by: 2)), + Tensor(_owning: tensorHandles?.advanced(by: 3))] + } + public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { print("Calling State._unpackTensorHandles().") var ptr = address @@ -175,6 +186,7 @@ TracerTests.testAllBackends("Advanced") { ptr = ptr!.advanced(by: Int(model._tensorHandleCount)) optimizer._unpackTensorHandles(into: ptr) } + public var _tensorHandleCount: Int32 { return model._tensorHandleCount + optimizer._tensorHandleCount } diff --git a/utils/build-script-impl b/utils/build-script-impl index 4bf140f028519..414f8cb611d77 100755 --- a/utils/build-script-impl +++ b/utils/build-script-impl @@ -280,7 +280,7 @@ KNOWN_SETTINGS=( tensorflow-host-include-dir "" "Path to host TensorFlow headers" tensorflow-target-include-dir "" "Path to target Tensorflow headers" tensorflow-target-lib-dir "" "Path to target TensorFlow libraries" - tensorflow-swift-bindings "" "Path to TensorFlow Swift bindings file" + tensorflow-swift-bindings "" "Path to TensorFlow Swift bindings repository" tensorflow-swift-apis "" "Path to TensorFlow deep learning library repository" ) @@ -2476,7 +2476,7 @@ for host in "${ALL_HOSTS[@]}"; do # Handle TensorFlow Swift bindings file. if [[ ! "${TENSORFLOW_SWIFT_BINDINGS}" && -d "${TENSORFLOW_SWIFT_BINDINGS_DIR}" ]] ; then - TENSORFLOW_SWIFT_BINDINGS="${TENSORFLOW_SWIFT_BINDINGS_DIR}/RawOpsGenerated.swift" + TENSORFLOW_SWIFT_BINDINGS="${TENSORFLOW_SWIFT_BINDINGS_DIR}" fi if [[ "${TENSORFLOW_SWIFT_BINDINGS}" ]] ; then cmake_options=( diff --git a/utils/build_swift/driver_arguments.py b/utils/build_swift/driver_arguments.py index 3c927ef0a2b22..be01ca8e29de4 100644 --- a/utils/build_swift/driver_arguments.py +++ b/utils/build_swift/driver_arguments.py @@ -974,8 +974,7 @@ def create_argument_parser(): 'Used for linking Swift programs.') option('--tensorflow-swift-bindings', store_path, default=None, - help='Path to a TensorFlow Swift bindings file ' - '(RawOpsGenerated.swift).') + help='Path to a TensorFlow Swift bindings repository.') option('--tensorflow-swift-apis', store_path, default=None, help='Path to a TensorFlow deep learning library repository.') From 6765b945fcb06c698a4a72d89096d57f25bea034 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 23 Apr 2019 21:14:45 -0400 Subject: [PATCH 02/11] Bug fix. --- lib/Sema/DerivedConformanceTensorArrayProtocol.cpp | 5 ++--- lib/Sema/DerivedConformances.cpp | 2 +- .../dynamic_compilation_tensor_group.swift | 2 +- test/TensorFlowRuntime/tensor_array_protocol.swift | 10 +++++----- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index f9998b896a3f6..f92bb4cc320c0 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -411,9 +411,8 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { auto *tensorArrayProto = C.getProtocol( KnownProtocolKind::TensorArrayProtocol); auto initName = DeclName( - C, DeclBaseName::createConstructor(), - {C.getIdentifier("_owning"), C.getIdentifier("count")}); - auto *initReq = getProtocolRequirement(tensorArrayProto, initName); + C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")}); + auto *initReq = getProtocolRequirement(tensorGroupProto, initName); auto *tensorHandleCountReq = getProtocolRequirement( tensorArrayProto, C.Id_tensorHandleCount); diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 8fdecb20a0397..78d821285aa0c 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -344,7 +344,7 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, // SWIFT_ENABLE_TENSORFLOW // TensorArrayProtocol.init(_owning:count) if (argumentNames[0] == ctx.getIdentifier("_owning") && - argumentNames[0] == ctx.getIdentifier("count")) { + argumentNames[1] == ctx.getIdentifier("count")) { return getRequirement(KnownProtocolKind::TensorArrayProtocol); } } diff --git a/test/TensorFlowRuntime/dynamic_compilation_tensor_group.swift b/test/TensorFlowRuntime/dynamic_compilation_tensor_group.swift index e16ff52b94479..59939eccdd04c 100644 --- a/test/TensorFlowRuntime/dynamic_compilation_tensor_group.swift +++ b/test/TensorFlowRuntime/dynamic_compilation_tensor_group.swift @@ -142,7 +142,7 @@ func some_tf_op(n : Int) { } let actual: Tensor = #tfop("Pack", - [arr], T$dtype: Float.tensorFlowDataType, axis: Int64(0)) + arr, T$dtype: Float.tensorFlowDataType, axis: Int64(0)) let expected = ShapedArray(shape: [n, 3], scalars: arr_exp) expectEqual(expected, actual.array) } diff --git a/test/TensorFlowRuntime/tensor_array_protocol.swift b/test/TensorFlowRuntime/tensor_array_protocol.swift index 9669c27b05051..64200dcabc524 100644 --- a/test/TensorFlowRuntime/tensor_array_protocol.swift +++ b/test/TensorFlowRuntime/tensor_array_protocol.swift @@ -17,7 +17,7 @@ struct Simple : TensorGroup { var w, b: Tensor } -struct Mixed : TensorArrayProtocol { +struct Mixed : TensorGroup { // Mutable. var string: StringTensor var float: Tensor @@ -25,14 +25,14 @@ struct Mixed : TensorArrayProtocol { let int: Tensor } -struct Nested : TensorArrayProtocol { +struct Nested : TensorGroup { // Immutable. let simple: Simple // Mutable. var mixed: Mixed } -struct Generic : TensorArrayProtocol { +struct Generic : TensorGroup { var t: T var u: U } @@ -157,7 +157,7 @@ TensorArrayProtocolTests.test("GenericUnpackTensorHandles") { TensorArrayProtocolTests.test("NestedGenericTensorHandleCount") { struct NestedGeneric { func function() { - struct UltraNested : TensorArrayProtocol { + struct UltraNested : TensorArrayProtocol { var a: Generic var b: Generic } @@ -181,7 +181,7 @@ TensorArrayProtocolTests.test("NestedGenericTensorHandleCount") { TensorArrayProtocolTests.test("NestedGenericUnpackTensorHandles") { struct NestedGeneric { func function() { - struct UltraNested : TensorArrayProtocol { + struct UltraNested : TensorArrayProtocol { var a: Generic var b: Generic } From 82431d4838070bc8ce58f9d1d2bf8a6c1c593ef6 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Wed, 24 Apr 2019 13:33:52 -0400 Subject: [PATCH 03/11] Addressed Dan's comments. --- .../DerivedConformanceTensorArrayProtocol.cpp | 62 +++++++------------ 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index f92bb4cc320c0..304efd5795c47 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -419,8 +419,7 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { Type intType = C.getIntDecl()->getDeclaredType(); TypeExpr *intTE = TypeExpr::createImplicit(intType, C); - // Goes through the member TensorGroups and call - // `self.t = T(_owning:count:)`. + // Iterate over members and call `self.t = T(_owning:)`. llvm::SmallVector thenMemberExprs; llvm::SmallVector elseMemberExprs; for (auto member : nominal->getStoredProperties()) { @@ -532,25 +531,32 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { /*implicit*/ true)); } -// Synthesize a constructor declaration for a `TensorArrayProtocol` -// method requirement. -static ValueDecl *deriveTensorArrayProtocol_constructor( - DerivedConformance &derived, Identifier argument1Name, - Identifier parameter1Name, Type parameter1Type, - Identifier parameter2Name, Type parameter2Type, Type returnType, - AbstractFunctionDecl::BodySynthesizer bodySynthesizer) { +// Synthesize the `init(_owning:count:)` function declaration. +static ValueDecl +*deriveTensorArrayProtocol_init(DerivedConformance &derived) { + auto &C = derived.TC.Context; + + // Obtain the address type. + auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType(); + Type baseAddressType = BoundGenericType::get( + C.getUnsafePointerDecl(), Type(), {cTensorHandleType}); + Type addressType = BoundGenericType::get( + C.getOptionalDecl(), Type(), {baseAddressType}); + Type intType = C.getIntDecl()->getDeclaredType(); + auto nominal = derived.Nominal; auto &C = derived.TC.Context; auto parentDC = derived.getConformanceContext(); - auto *param1 = - new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), - argument1Name, SourceLoc(), parameter1Name, parentDC); - param1->setInterfaceType(parameter1Type); - auto *param2 = - new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), - parameter2Name, SourceLoc(), parameter2Name, parentDC); - param2->setInterfaceType(parameter2Type); + auto *param1 = new (C) ParamDecl( + VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"), + parentDC); + param1->setInterfaceType(addressType); + auto *param2 = new (C) ParamDecl( + VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC); + param2->setInterfaceType(intType); ParameterList *params = ParameterList::create(C, {param1, param2}); DeclName name(C, DeclBaseName::createConstructor(), params); @@ -560,7 +566,7 @@ static ValueDecl *deriveTensorArrayProtocol_constructor( /*GenericParams*/ nullptr, parentDC); initDecl->setImplicit(); initDecl->setSynthesized(); - initDecl->setBodySynthesizer(bodySynthesizer); + initDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_init); if (auto env = parentDC->getGenericEnvironmentOfContext()) initDecl->setGenericEnvironment(env); @@ -574,26 +580,6 @@ static ValueDecl *deriveTensorArrayProtocol_constructor( return initDecl; } -// Synthesize the `init(_owning:count:)` function declaration. -static ValueDecl -*deriveTensorArrayProtocol_init(DerivedConformance &derived) { - auto &C = derived.TC.Context; - - // Obtain the address type. - auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType(); - Type baseAddressType = BoundGenericType::get( - C.getUnsafePointerDecl(), Type(), {cTensorHandleType}); - Type addressType = BoundGenericType::get( - C.getOptionalDecl(), Type(), {baseAddressType}); - Type intType = C.getIntDecl()->getDeclaredType(); - Type voidType = C.getVoidDecl()->getDeclaredInterfaceType(); - - return deriveTensorArrayProtocol_constructor( - derived, C.getIdentifier("_owning"), C.getIdentifier("tensorHandles"), - addressType, C.getIdentifier("count"), intType, voidType, - deriveBodyTensorArrayProtocol_init); -} - ValueDecl *DerivedConformance::deriveTensorArrayProtocol( ValueDecl *requirement) { if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles) From 2e4487c4b46ad91fc87b94835445a7905d291433 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Wed, 24 Apr 2019 13:40:04 -0400 Subject: [PATCH 04/11] Minor fix. --- lib/Sema/DerivedConformanceTensorArrayProtocol.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index 304efd5795c47..ef7a79d188739 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -535,6 +535,8 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { static ValueDecl *deriveTensorArrayProtocol_init(DerivedConformance &derived) { auto &C = derived.TC.Context; + auto nominal = derived.Nominal; + auto parentDC = derived.getConformanceContext(); // Obtain the address type. auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType(); @@ -544,10 +546,6 @@ static ValueDecl C.getOptionalDecl(), Type(), {baseAddressType}); Type intType = C.getIntDecl()->getDeclaredType(); - auto nominal = derived.Nominal; - auto &C = derived.TC.Context; - auto parentDC = derived.getConformanceContext(); - auto *param1 = new (C) ParamDecl( VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"), From 0647ce2f99b52a242b16e0d0e90da09b9849532a Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Wed, 24 Apr 2019 13:58:46 -0400 Subject: [PATCH 05/11] Minor edit. --- .../DerivedConformanceTensorArrayProtocol.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index ef7a79d188739..751ea6640d0d1 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -515,12 +515,12 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { } auto *thenBody = BraceStmt::create( - C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(), - /*implicit*/ true); + C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(), + /*implicit*/ true); auto *elseBody = BraceStmt::create( - C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(), - /*implicit*/ true); + C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(), + /*implicit*/ true); auto *ifStmt = new (C) IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(), @@ -547,13 +547,13 @@ static ValueDecl Type intType = C.getIntDecl()->getDeclaredType(); auto *param1 = new (C) ParamDecl( - VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), - C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"), - parentDC); + VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"), + parentDC); param1->setInterfaceType(addressType); auto *param2 = new (C) ParamDecl( - VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), - C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC); + VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC); param2->setInterfaceType(intType); ParameterList *params = ParameterList::create(C, {param1, param2}); From bde610a62f5062c8180c1aadd49d896e02336e5f Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Wed, 24 Apr 2019 18:53:19 -0400 Subject: [PATCH 06/11] Added a '_typeList' property to 'TensorArrayProtocol'. --- .../DerivedConformanceTensorArrayProtocol.cpp | 85 +++++++++++++++++++ lib/Sema/DerivedConformances.cpp | 5 ++ stdlib/public/TensorFlow/TensorGroup.swift | 16 +++- test/TensorFlowRuntime/tracer.swift | 4 + 4 files changed, 106 insertions(+), 4 deletions(-) diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index 751ea6640d0d1..7447a4aa59f1e 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -363,6 +363,89 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount( return tensorHandleCountDecl; } + +/// Derive the body for the '_typeList' getter. +static void +deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl) { + auto *parentDC = funcDecl->getParent(); + auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup); + auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList); + + // Concatenate all member `_typeList` arrays. + Type arrayType = BoundGenericType::get( + C.getArrayDecl(), Type(), + {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()}); + auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C); + auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+")); + assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator."); + ValueDecl *plusOpDecl = plusOpLookup.front(); + auto plusOpDRE = new (C) + DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true); + auto plusOpExpr = new (C) + DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr); + Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc()); + for (auto member : nominal->getStoredProperties()) { + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C); + auto *memberTypeListExpr = new (C) + MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq, + DeclNameLoc(), /*Implicit*/ true); + // Create expression `lhsArg + rhsArg`. + auto *plusOpArgs = + TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr}, + {}, {}, SourceLoc(), /*HasTrailingClosure*/ false, + /*Implicit*/ true); + typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs, + /*Implicit*/ true); + } + + // Return the resulting data types array. + auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr); + auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(), + /*Implicit*/ true); + funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(), + /*Implicit*/ true)); +} + +/// Derive a '_typeList' implementation. +static ValueDecl *deriveTensorArrayProtocol_typeList( + DerivedConformance &derived) { + auto nominal = derived.Nominal; + auto &TC = derived.TC; + ASTContext &C = TC.Context; + + auto parentDC = derived.getConformanceContext(); + Type dataTypeArrayType = BoundGenericType::get( + C.getArrayDecl(), Type(), + {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()}); + auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType); + + // Create `_typeList` property declaration. + VarDecl *typeListDecl; + PatternBindingDecl *patDecl; + std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty( + C.Id_typeList, returnType, returnType, /*isStatic*/ false, + /*isFinal*/ false); + + // Add `@inlinable` to the `_typeList` declaration. + if (nominal->getEffectiveAccess() > AccessLevel::Internal) + typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true)); + + // Create `_typeList` getter. + auto *getterDecl = derived.declareDerivedPropertyGetter( + TC, typeListDecl, returnType); + getterDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_typeList); + typeListDecl->setAccessors(StorageImplInfo::getImmutableComputed(), + SourceLoc(), {getterDecl}, SourceLoc()); + derived.addMembersToConformanceContext({getterDecl, typeListDecl, patDecl}); + + return typeListDecl; +} + // Synthesize body for `init(_owning:count:)`. static void deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { @@ -584,6 +667,8 @@ ValueDecl *DerivedConformance::deriveTensorArrayProtocol( return deriveTensorArrayProtocol_unpackTensorHandles(*this); if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount) return deriveTensorArrayProtocol_tensorHandleCount(*this); + if (requirement->getBaseName() == TC.Context.Id_typeList) + return deriveTensorArrayProtocol_typeList(*this); if (requirement->getBaseName() == DeclBaseName::createConstructor()) return deriveTensorArrayProtocol_init(*this); TC.diagnose(requirement->getLoc(), diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 78d821285aa0c..9aac23e0e4298 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -230,6 +230,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, // TensorArrayProtocol._tensorHandleCount if (name.isSimpleName(ctx.Id_tensorHandleCount)) return getRequirement(KnownProtocolKind::TensorArrayProtocol); + + // SWIFT_ENABLE_TENSORFLOW + // TensorArrayProtocol._typeList + if (name.isSimpleName(ctx.Id_typeList) && !requirement->isStatic()) + return getRequirement(KnownProtocolKind::TensorArrayProtocol); // SWIFT_ENABLE_TENSORFLOW // TensorGroup._typeList diff --git a/stdlib/public/TensorFlow/TensorGroup.swift b/stdlib/public/TensorFlow/TensorGroup.swift index 0982a3d765f56..02577eb73d8c4 100644 --- a/stdlib/public/TensorFlow/TensorGroup.swift +++ b/stdlib/public/TensorFlow/TensorGroup.swift @@ -36,6 +36,7 @@ public protocol TensorArrayProtocol { func _unpackTensorHandles(into address: UnsafeMutablePointer?) var _tensorHandleCount: Int32 { get } + var _typeList: [TensorDataType] { get } init(_owning tensorHandles: UnsafePointer?, count: Int) } @@ -69,13 +70,16 @@ public protocol TensorGroup : TensorArrayProtocol { public extension TensorGroup { /// The number of tensor fields in this type. static var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) } - var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) } /// An array of `nil`s with the same number of elements as `_outputTypeList`. /// The `nil` represents unknown shape. static var _unknownShapeList: [TensorShape?] { return Array(repeating: nil, count: _typeList.count) } + + // The following instance properties are from `TensorArrayProtocol`. + var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) } + var _typeList: [TensorDataType] { return Self._typeList } init(_owning tensorHandles: UnsafePointer?, count: Int) { precondition(count == Self._typeList.count) @@ -223,9 +227,13 @@ extension Array : TensorArrayProtocol where Element : TensorGroup { } public var _tensorHandleCount: Int32 { - var count: Int32 = 0 - for elem in self { count += elem._tensorHandleCount } - return count + return Element._tensorHandleCount * Int32(count) + } + + public var _typeList: [TensorDataType] { + return Array([[TensorDataType]]( + repeating: Element._typeList, + count: Int(Element._tensorHandleCount)).joined()) } public init(_owning tensorHandles: UnsafePointer?, count: Int) { diff --git a/test/TensorFlowRuntime/tracer.swift b/test/TensorFlowRuntime/tracer.swift index 3aa5ba4aae8e3..d5b6c1b30a3e6 100644 --- a/test/TensorFlowRuntime/tracer.swift +++ b/test/TensorFlowRuntime/tracer.swift @@ -191,6 +191,10 @@ TracerTests.testAllBackends("Advanced") { return model._tensorHandleCount + optimizer._tensorHandleCount } + public var _typeList: [TensorDataType] { + return model._typeList + optimizer._typeList + } + func _makeInstance(owning inputs: C) -> State where C.Element == CTensorHandle { assert(inputs.count == 4) From e0681db652968aa2bd04ef5b28f23931a08ee41b Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Thu, 25 Apr 2019 21:08:28 -0400 Subject: [PATCH 07/11] Minor formatting edit. --- lib/Sema/DerivedConformanceTensorArrayProtocol.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index 7447a4aa59f1e..f86257e85bfe1 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -420,8 +420,8 @@ static ValueDecl *deriveTensorArrayProtocol_typeList( auto parentDC = derived.getConformanceContext(); Type dataTypeArrayType = BoundGenericType::get( - C.getArrayDecl(), Type(), - {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()}); + C.getArrayDecl(), Type(), + {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()}); auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType); // Create `_typeList` property declaration. From 87e3f96e639b8dfba0a40791dad9f1109275ce5f Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 26 Apr 2019 10:29:55 -0400 Subject: [PATCH 08/11] Reverted a change that's part of a different PR. --- stdlib/public/TensorFlow/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stdlib/public/TensorFlow/CMakeLists.txt b/stdlib/public/TensorFlow/CMakeLists.txt index a40d900f24f45..e3bb11fff4ceb 100644 --- a/stdlib/public/TensorFlow/CMakeLists.txt +++ b/stdlib/public/TensorFlow/CMakeLists.txt @@ -56,9 +56,7 @@ set(SOURCES # Copy TensorFlow bindings file, if it exists. if (TENSORFLOW_SWIFT_BINDINGS) - file(GLOB_RECURSE TENSORFLOW_SWIFT_BINDINGS_SOURCES - "${TENSORFLOW_SWIFT_BINDINGS}/*.swift") - list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS_SOURCES}") + list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS}") endif() # Copy TensorFlow high-level API sources, if they exist. From 13a9fd687d0fc92cfb3306c0ed27fdf68a88009b Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 26 Apr 2019 10:32:37 -0400 Subject: [PATCH 09/11] Revert "Reverted a change that's part of a different PR." This reverts commit 87e3f96e639b8dfba0a40791dad9f1109275ce5f. --- stdlib/public/TensorFlow/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stdlib/public/TensorFlow/CMakeLists.txt b/stdlib/public/TensorFlow/CMakeLists.txt index e3bb11fff4ceb..a40d900f24f45 100644 --- a/stdlib/public/TensorFlow/CMakeLists.txt +++ b/stdlib/public/TensorFlow/CMakeLists.txt @@ -56,7 +56,9 @@ set(SOURCES # Copy TensorFlow bindings file, if it exists. if (TENSORFLOW_SWIFT_BINDINGS) - list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS}") + file(GLOB_RECURSE TENSORFLOW_SWIFT_BINDINGS_SOURCES + "${TENSORFLOW_SWIFT_BINDINGS}/*.swift") + list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS_SOURCES}") endif() # Copy TensorFlow high-level API sources, if they exist. From fb7eff398f48f7554d7c0297015e733c8dde5bef Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 26 Apr 2019 10:34:30 -0400 Subject: [PATCH 10/11] Formatting change. --- stdlib/public/TensorFlow/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/public/TensorFlow/CMakeLists.txt b/stdlib/public/TensorFlow/CMakeLists.txt index a40d900f24f45..a85b487cf1075 100644 --- a/stdlib/public/TensorFlow/CMakeLists.txt +++ b/stdlib/public/TensorFlow/CMakeLists.txt @@ -57,7 +57,7 @@ set(SOURCES # Copy TensorFlow bindings file, if it exists. if (TENSORFLOW_SWIFT_BINDINGS) file(GLOB_RECURSE TENSORFLOW_SWIFT_BINDINGS_SOURCES - "${TENSORFLOW_SWIFT_BINDINGS}/*.swift") + "${TENSORFLOW_SWIFT_BINDINGS}/*.swift") list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS_SOURCES}") endif() From c6bd0ee9569ca641d239a0f120781a59622f2a47 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Fri, 26 Apr 2019 17:54:05 -0400 Subject: [PATCH 11/11] Reverted the 'swift-bindings' changes. --- stdlib/public/TensorFlow/CMakeLists.txt | 4 +--- utils/build-script-impl | 4 ++-- utils/build_swift/driver_arguments.py | 3 ++- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/stdlib/public/TensorFlow/CMakeLists.txt b/stdlib/public/TensorFlow/CMakeLists.txt index a85b487cf1075..e3bb11fff4ceb 100644 --- a/stdlib/public/TensorFlow/CMakeLists.txt +++ b/stdlib/public/TensorFlow/CMakeLists.txt @@ -56,9 +56,7 @@ set(SOURCES # Copy TensorFlow bindings file, if it exists. if (TENSORFLOW_SWIFT_BINDINGS) - file(GLOB_RECURSE TENSORFLOW_SWIFT_BINDINGS_SOURCES - "${TENSORFLOW_SWIFT_BINDINGS}/*.swift") - list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS_SOURCES}") + list(APPEND SOURCES "${TENSORFLOW_SWIFT_BINDINGS}") endif() # Copy TensorFlow high-level API sources, if they exist. diff --git a/utils/build-script-impl b/utils/build-script-impl index 414f8cb611d77..4bf140f028519 100755 --- a/utils/build-script-impl +++ b/utils/build-script-impl @@ -280,7 +280,7 @@ KNOWN_SETTINGS=( tensorflow-host-include-dir "" "Path to host TensorFlow headers" tensorflow-target-include-dir "" "Path to target Tensorflow headers" tensorflow-target-lib-dir "" "Path to target TensorFlow libraries" - tensorflow-swift-bindings "" "Path to TensorFlow Swift bindings repository" + tensorflow-swift-bindings "" "Path to TensorFlow Swift bindings file" tensorflow-swift-apis "" "Path to TensorFlow deep learning library repository" ) @@ -2476,7 +2476,7 @@ for host in "${ALL_HOSTS[@]}"; do # Handle TensorFlow Swift bindings file. if [[ ! "${TENSORFLOW_SWIFT_BINDINGS}" && -d "${TENSORFLOW_SWIFT_BINDINGS_DIR}" ]] ; then - TENSORFLOW_SWIFT_BINDINGS="${TENSORFLOW_SWIFT_BINDINGS_DIR}" + TENSORFLOW_SWIFT_BINDINGS="${TENSORFLOW_SWIFT_BINDINGS_DIR}/RawOpsGenerated.swift" fi if [[ "${TENSORFLOW_SWIFT_BINDINGS}" ]] ; then cmake_options=( diff --git a/utils/build_swift/driver_arguments.py b/utils/build_swift/driver_arguments.py index be01ca8e29de4..3c927ef0a2b22 100644 --- a/utils/build_swift/driver_arguments.py +++ b/utils/build_swift/driver_arguments.py @@ -974,7 +974,8 @@ def create_argument_parser(): 'Used for linking Swift programs.') option('--tensorflow-swift-bindings', store_path, default=None, - help='Path to a TensorFlow Swift bindings repository.') + help='Path to a TensorFlow Swift bindings file ' + '(RawOpsGenerated.swift).') option('--tensorflow-swift-apis', store_path, default=None, help='Path to a TensorFlow deep learning library repository.')