diff --git a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp index 990d65c5532a1..f86257e85bfe1 100644 --- a/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp +++ b/lib/Sema/DerivedConformanceTensorArrayProtocol.cpp @@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal, auto *structDecl = dyn_cast(nominal); if (!structDecl) return false; - // All stored properties must conform to `TensorArrayProtocol`. + // All stored properties must conform to `TensorGroup`. auto &C = nominal->getASTContext(); - auto *tensorArrayProto = - C.getProtocol(KnownProtocolKind::TensorArrayProtocol); + auto *tensorGroupProto = + C.getProtocol(KnownProtocolKind::TensorGroup); return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) { if (!v->hasInterfaceType()) C.getLazyResolver()->resolveDeclSignature(v); if (!v->hasInterfaceType()) return false; auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); - return (bool)TypeChecker::conformsToProtocol(varType, tensorArrayProto, DC, + return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC, ConformanceCheckFlags::Used); }); } @@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { return lookup.front(); } +// Return the protocol requirement with the specified name. +static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) { + auto lookup = proto->lookupDirect(name); + lookup.erase(std::remove_if(lookup.begin(), lookup.end(), + [](ValueDecl *v) { + return !isa( + v->getDeclContext()) || + !v->isProtocolRequirement(); + }), + lookup.end()); + assert(lookup.size() == 1 && "Ambiguous protocol requirement"); + return lookup.front(); +} + // Synthesize body for `_unpackTensorHandles(into:)`. static void deriveBodyTensorArrayProtocol_unpackTensorHandles( @@ -349,12 +363,314 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount( return tensorHandleCountDecl; } + +/// Derive the body for the '_typeList' getter. +static void +deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl) { + auto *parentDC = funcDecl->getParent(); + auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup); + auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList); + + // Concatenate all member `_typeList` arrays. + Type arrayType = BoundGenericType::get( + C.getArrayDecl(), Type(), + {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()}); + auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C); + auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+")); + assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator."); + ValueDecl *plusOpDecl = plusOpLookup.front(); + auto plusOpDRE = new (C) + DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true); + auto plusOpExpr = new (C) + DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr); + Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc()); + for (auto member : nominal->getStoredProperties()) { + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C); + auto *memberTypeListExpr = new (C) + MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq, + DeclNameLoc(), /*Implicit*/ true); + // Create expression `lhsArg + rhsArg`. + auto *plusOpArgs = + TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr}, + {}, {}, SourceLoc(), /*HasTrailingClosure*/ false, + /*Implicit*/ true); + typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs, + /*Implicit*/ true); + } + + // Return the resulting data types array. + auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr); + auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(), + /*Implicit*/ true); + funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(), + /*Implicit*/ true)); +} + +/// Derive a '_typeList' implementation. +static ValueDecl *deriveTensorArrayProtocol_typeList( + DerivedConformance &derived) { + auto nominal = derived.Nominal; + auto &TC = derived.TC; + ASTContext &C = TC.Context; + + auto parentDC = derived.getConformanceContext(); + Type dataTypeArrayType = BoundGenericType::get( + C.getArrayDecl(), Type(), + {C.getTensorDataTypeDecl()->getDeclaredInterfaceType()}); + auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType); + + // Create `_typeList` property declaration. + VarDecl *typeListDecl; + PatternBindingDecl *patDecl; + std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty( + C.Id_typeList, returnType, returnType, /*isStatic*/ false, + /*isFinal*/ false); + + // Add `@inlinable` to the `_typeList` declaration. + if (nominal->getEffectiveAccess() > AccessLevel::Internal) + typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true)); + + // Create `_typeList` getter. + auto *getterDecl = derived.declareDerivedPropertyGetter( + TC, typeListDecl, returnType); + getterDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_typeList); + typeListDecl->setAccessors(StorageImplInfo::getImmutableComputed(), + SourceLoc(), {getterDecl}, SourceLoc()); + derived.addMembersToConformanceContext({getterDecl, typeListDecl, patDecl}); + + return typeListDecl; +} + +// Synthesize body for `init(_owning:count:)`. +static void +deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) { + auto *parentDC = funcDecl->getParent(); + auto *nominal = parentDC->getSelfNominalTypeDecl(); + auto &C = nominal->getASTContext(); + + // Obtain the address type. + auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType(); + auto baseAddressType = BoundGenericType::get( + C.getUnsafePointerDecl(), Type(), {cTensorHandleType}); + auto addressType = BoundGenericType::get( + C.getOptionalDecl(), Type(), {baseAddressType}); + auto *addressTE = TypeExpr::createImplicit(addressType, C); + + // Get references to `self` and parameter declarations. + auto *selfDecl = funcDecl->getImplicitSelfDecl(); + auto *selfDRE = new (C) + DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); + auto *paramDecl = funcDecl->getParameters()->get(0); + auto *paramDRE = new (C) + DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true); + + // Create an `if var` statement for the current address. + VarDecl *currAddressDecl = new (C) VarDecl( + /*IsStatic*/ false, VarDecl::Specifier::Var, /*IsCaptureList*/ false, + SourceLoc(), C.getIdentifier("currentAddress"), funcDecl); + currAddressDecl->setImplicit(); + currAddressDecl->setHasNonPatternBindingInit(true); + currAddressDecl->setInterfaceType(baseAddressType); + currAddressDecl->setValidationToChecked(); + + Pattern *currAddressPat = new (C) + NamedPattern(currAddressDecl, /*implicit*/ true); + currAddressPat = new (C) + VarPattern(SourceLoc(), /*isLet*/ false, currAddressPat, + /*implicit*/ true); + currAddressPat = new (C) + OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc(), + /*implicit*/ true); + StmtConditionElement cond[] = { + StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)}; + + // Get the necessary protocol requirements. + auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup); + auto *tensorArrayProto = C.getProtocol( + KnownProtocolKind::TensorArrayProtocol); + auto initName = DeclName( + C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")}); + auto *initReq = getProtocolRequirement(tensorGroupProto, initName); + auto *tensorHandleCountReq = getProtocolRequirement( + tensorArrayProto, C.Id_tensorHandleCount); + + Type intType = C.getIntDecl()->getDeclaredType(); + TypeExpr *intTE = TypeExpr::createImplicit(intType, C); + + // Iterate over members and call `self.t = T(_owning:)`. + llvm::SmallVector thenMemberExprs; + llvm::SmallVector elseMemberExprs; + for (auto member : nominal->getStoredProperties()) { + auto memberType = parentDC->mapTypeIntoContext( + member->getValueInterfaceType()); + auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C); + auto module = nominal->getModuleContext(); + auto confRef = module->lookupConformance( + memberType, tensorGroupProto); + assert(confRef && "Member does not conform to `TensorGroup`"); + + // Get member type's constructor, e.g. `MemberType.init(_owning:)`. + // Use protocol requirement declaration for the method by default: this + // will be dynamically dispatched. + ValueDecl *memberInitDecl = initReq; + // If conformance reference is concrete, then use concrete witness + // declaration for the constructor. + if (confRef->isConcrete()) + memberInitDecl = confRef->getConcrete()->getWitnessDecl( + initReq, C.getLazyResolver()); + assert(memberInitDecl && "Member constructor declaration must exist"); + auto memberInitDRE = new (C) DeclRefExpr( + memberInitDecl, DeclNameLoc(), /*implicit*/ true); + memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply); + + // Create reference to member constructor: `MemberType.init(_owning:)`. + auto *memberInitExpr = new (C) ConstructorRefCallExpr( + memberInitDRE, memberTypeExpr); + + auto *addressDRE = new (C) DeclRefExpr( + currAddressDecl, DeclNameLoc(), /*implicit*/ true); + auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType); + + // Initialize the member using its TensorGroup constructor. + // Note that, initialization is dependent on the branch of the + // if-statement taken. + auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType); + auto *thenInitCallExpr = CallExpr::createImplicit( + C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")}); + + // Create a nil expression with type UnsafePointer? for the + // `else` branch. + auto *nilDecl = C.getOptionalNoneDecl(); + auto *nilDRE = new (C) DeclRefExpr( + nilDecl, DeclNameLoc(), /*implicit*/ true); + auto *elseInitExpr = new (C) DotSyntaxCallExpr( + nilDRE, SourceLoc(), addressTE); + auto *elseInitCallExpr = CallExpr::createImplicit( + C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")}); + + // Assign the current member to the result of the initializer call. + auto *memberDRE = new (C) MemberRefExpr( + selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true); + + auto *thenAssignMemberExpr = new (C) AssignExpr( + memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true); + auto *elseAssignMemberExpr = new (C) AssignExpr( + memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true); + + thenMemberExprs.push_back(thenAssignMemberExpr); + elseMemberExprs.push_back(elseAssignMemberExpr); + + // Advance the current address. + DeclName advancedName(C, C.getIdentifier("advanced"), + {C.getIdentifier("by")}); + auto *advancedMethodExpr = + new (C) UnresolvedDotExpr(addressDRE, SourceLoc(), + advancedName, DeclNameLoc(), + /*Implicit*/ true); + + // Obtain `MemberType._tensorHandleCount`. + auto *memberCountMRE = new (C) MemberRefExpr( + memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(), + /*Implicit*/ true); + + // Cast the tensor handle count to Int. + auto intInitName = DeclName(C, DeclBaseName::createConstructor(), + {Identifier()}); + auto *intInitExpr = + new (C) UnresolvedDotExpr(intTE, SourceLoc(), intInitName, + DeclNameLoc(), /*Implicit*/ true); + auto *intInitCallExpr = CallExpr::createImplicit( + C, intInitExpr, {memberCountMRE}, {Identifier()}); + + // Assign the new address. + auto *assignAddrCallExpr = CallExpr::createImplicit( + C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")}); + auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(), + assignAddrCallExpr, + /*Implicit*/ true); + + thenMemberExprs.push_back(assignAddrExpr); + } + + auto *thenBody = BraceStmt::create( + C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(), + /*implicit*/ true); + + auto *elseBody = BraceStmt::create( + C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(), + /*implicit*/ true); + + auto *ifStmt = new (C) + IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(), + /*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody, + /*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true); + + funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(), + /*implicit*/ true)); +} + +// Synthesize the `init(_owning:count:)` function declaration. +static ValueDecl +*deriveTensorArrayProtocol_init(DerivedConformance &derived) { + auto &C = derived.TC.Context; + auto nominal = derived.Nominal; + auto parentDC = derived.getConformanceContext(); + + // Obtain the address type. + auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType(); + Type baseAddressType = BoundGenericType::get( + C.getUnsafePointerDecl(), Type(), {cTensorHandleType}); + Type addressType = BoundGenericType::get( + C.getOptionalDecl(), Type(), {baseAddressType}); + Type intType = C.getIntDecl()->getDeclaredType(); + + auto *param1 = new (C) ParamDecl( + VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"), + parentDC); + param1->setInterfaceType(addressType); + auto *param2 = new (C) ParamDecl( + VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), + C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC); + param2->setInterfaceType(intType); + ParameterList *params = ParameterList::create(C, {param1, param2}); + + DeclName name(C, DeclBaseName::createConstructor(), params); + auto *initDecl = + new (C) ConstructorDecl(name, SourceLoc(), OTK_None, SourceLoc(), + /*Throws*/ false, SourceLoc(), params, + /*GenericParams*/ nullptr, parentDC); + initDecl->setImplicit(); + initDecl->setSynthesized(); + initDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_init); + + if (auto env = parentDC->getGenericEnvironmentOfContext()) + initDecl->setGenericEnvironment(env); + initDecl->computeType(AnyFunctionType::ExtInfo().withThrows(false)); + initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); + initDecl->setValidationToChecked(); + + derived.addMembersToConformanceContext({initDecl}); + C.addSynthesizedDecl(initDecl); + + return initDecl; +} + ValueDecl *DerivedConformance::deriveTensorArrayProtocol( ValueDecl *requirement) { if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles) return deriveTensorArrayProtocol_unpackTensorHandles(*this); if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount) return deriveTensorArrayProtocol_tensorHandleCount(*this); + if (requirement->getBaseName() == TC.Context.Id_typeList) + return deriveTensorArrayProtocol_typeList(*this); + if (requirement->getBaseName() == DeclBaseName::createConstructor()) + return deriveTensorArrayProtocol_init(*this); TC.diagnose(requirement->getLoc(), diag::broken_tensor_array_protocol_requirement); return nullptr; diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 9f16124c59127..9aac23e0e4298 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -230,6 +230,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, // TensorArrayProtocol._tensorHandleCount if (name.isSimpleName(ctx.Id_tensorHandleCount)) return getRequirement(KnownProtocolKind::TensorArrayProtocol); + + // SWIFT_ENABLE_TENSORFLOW + // TensorArrayProtocol._typeList + if (name.isSimpleName(ctx.Id_typeList) && !requirement->isStatic()) + return getRequirement(KnownProtocolKind::TensorArrayProtocol); // SWIFT_ENABLE_TENSORFLOW // TensorGroup._typeList @@ -340,6 +345,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, if (argumentNames[0] == ctx.getIdentifier("_owning")) { return getRequirement(KnownProtocolKind::TensorGroup); } + } else if (argumentNames.size() == 2) { + // SWIFT_ENABLE_TENSORFLOW + // TensorArrayProtocol.init(_owning:count) + if (argumentNames[0] == ctx.getIdentifier("_owning") && + argumentNames[1] == ctx.getIdentifier("count")) { + return getRequirement(KnownProtocolKind::TensorArrayProtocol); + } } return nullptr; diff --git a/stdlib/public/TensorFlow/TensorGroup.swift b/stdlib/public/TensorFlow/TensorGroup.swift index af2b4f0dbe457..02577eb73d8c4 100644 --- a/stdlib/public/TensorFlow/TensorGroup.swift +++ b/stdlib/public/TensorFlow/TensorGroup.swift @@ -21,6 +21,13 @@ import CTensorFlow /// This protocol is defined separately from `TensorGroup` in order for the /// number of tensors to be determined at runtime. For example, /// `[Tensor]` may have an unknown number of elements at compile time. +/// +/// This protocol can be derived automatically for structs whose stored +/// properties all conform to the `TensorGroup` protocol. It cannot be derived +/// automatically for structs whose properties all conform to +/// `TensorArrayProtocol` due to the constructor requirement (i.e., in such +/// cases it would be impossible to know how to break down `count` among the +/// stored properties). public protocol TensorArrayProtocol { /// Writes the tensor handles to `address`, which must be allocated /// with enough capacity to hold `_tensorHandleCount` handles. The tensor @@ -29,6 +36,9 @@ public protocol TensorArrayProtocol { func _unpackTensorHandles(into address: UnsafeMutablePointer?) var _tensorHandleCount: Int32 { get } + var _typeList: [TensorDataType] { get } + + init(_owning tensorHandles: UnsafePointer?, count: Int) } /// A protocol representing types that can be mapped to and from @@ -60,13 +70,21 @@ public protocol TensorGroup : TensorArrayProtocol { public extension TensorGroup { /// The number of tensor fields in this type. static var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) } - var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) } /// An array of `nil`s with the same number of elements as `_outputTypeList`. /// The `nil` represents unknown shape. static var _unknownShapeList: [TensorShape?] { return Array(repeating: nil, count: _typeList.count) } + + // The following instance properties are from `TensorArrayProtocol`. + var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) } + var _typeList: [TensorDataType] { return Self._typeList } + + init(_owning tensorHandles: UnsafePointer?, count: Int) { + precondition(count == Self._typeList.count) + self.init(_owning: tensorHandles) + } } //===----------------------------------------------------------------------===// @@ -199,7 +217,7 @@ extension StringTensor : TensorGroup { } } -extension Array : TensorArrayProtocol where Element : TensorArrayProtocol { +extension Array : TensorArrayProtocol where Element : TensorGroup { public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { var ptr = address for elem in self { @@ -209,8 +227,19 @@ extension Array : TensorArrayProtocol where Element : TensorArrayProtocol { } public var _tensorHandleCount: Int32 { - var count: Int32 = 0 - for elem in self { count += elem._tensorHandleCount } - return count + return Element._tensorHandleCount * Int32(count) + } + + public var _typeList: [TensorDataType] { + return Array([[TensorDataType]]( + repeating: Element._typeList, + count: Int(Element._tensorHandleCount)).joined()) + } + + public init(_owning tensorHandles: UnsafePointer?, count: Int) { + let size = count / Int(Element._tensorHandleCount) + self = Array((0.. = #tfop("Pack", - [arr], T$dtype: Float.tensorFlowDataType, axis: Int64(0)) + arr, T$dtype: Float.tensorFlowDataType, axis: Int64(0)) let expected = ShapedArray(shape: [n, 3], scalars: arr_exp) expectEqual(expected, actual.array) } diff --git a/test/TensorFlowRuntime/tensor_array_protocol.swift b/test/TensorFlowRuntime/tensor_array_protocol.swift index 15cf6014d0a56..64200dcabc524 100644 --- a/test/TensorFlowRuntime/tensor_array_protocol.swift +++ b/test/TensorFlowRuntime/tensor_array_protocol.swift @@ -11,13 +11,13 @@ import StdlibUnittest var TensorArrayProtocolTests = TestSuite("TensorArrayProtocol") -struct Empty : TensorArrayProtocol {} +struct Empty : TensorGroup {} -struct Simple : TensorArrayProtocol { +struct Simple : TensorGroup { var w, b: Tensor } -struct Mixed : TensorArrayProtocol { +struct Mixed : TensorGroup { // Mutable. var string: StringTensor var float: Tensor @@ -25,14 +25,14 @@ struct Mixed : TensorArrayProtocol { let int: Tensor } -struct Nested : TensorArrayProtocol { +struct Nested : TensorGroup { // Immutable. let simple: Simple // Mutable. var mixed: Mixed } -struct Generic : TensorArrayProtocol { +struct Generic : TensorGroup { var t: T var u: U } @@ -157,7 +157,7 @@ TensorArrayProtocolTests.test("GenericUnpackTensorHandles") { TensorArrayProtocolTests.test("NestedGenericTensorHandleCount") { struct NestedGeneric { func function() { - struct UltraNested : TensorArrayProtocol { + struct UltraNested : TensorArrayProtocol { var a: Generic var b: Generic } @@ -181,7 +181,7 @@ TensorArrayProtocolTests.test("NestedGenericTensorHandleCount") { TensorArrayProtocolTests.test("NestedGenericUnpackTensorHandles") { struct NestedGeneric { func function() { - struct UltraNested : TensorArrayProtocol { + struct UltraNested : TensorArrayProtocol { var a: Generic var b: Generic } diff --git a/test/TensorFlowRuntime/tracer.swift b/test/TensorFlowRuntime/tracer.swift index f3cfffe2b44b1..d5b6c1b30a3e6 100644 --- a/test/TensorFlowRuntime/tracer.swift +++ b/test/TensorFlowRuntime/tracer.swift @@ -168,6 +168,17 @@ TracerTests.testAllBackends("Advanced") { var model: Model = [Tensor(1.0), Tensor(2.0)] var optimizer: Optimizer = [Tensor(1.0), Tensor(2.0)] + public init() {} + + public init(_owning tensorHandles: UnsafePointer?, count: Int) { + self.model = [ + Tensor(_owning: tensorHandles), + Tensor(_owning: tensorHandles?.advanced(by: 1))] + self.optimizer = [ + Tensor(_owning: tensorHandles?.advanced(by: 2)), + Tensor(_owning: tensorHandles?.advanced(by: 3))] + } + public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { print("Calling State._unpackTensorHandles().") var ptr = address @@ -175,10 +186,15 @@ TracerTests.testAllBackends("Advanced") { ptr = ptr!.advanced(by: Int(model._tensorHandleCount)) optimizer._unpackTensorHandles(into: ptr) } + public var _tensorHandleCount: Int32 { return model._tensorHandleCount + optimizer._tensorHandleCount } + public var _typeList: [TensorDataType] { + return model._typeList + optimizer._typeList + } + func _makeInstance(owning inputs: C) -> State where C.Element == CTensorHandle { assert(inputs.count == 4)