diff --git a/stdlib/public/TensorFlow/CompilerRuntime.swift b/stdlib/public/TensorFlow/CompilerRuntime.swift index 07d1c1b17b0aa..26bb0fc055d50 100644 --- a/stdlib/public/TensorFlow/CompilerRuntime.swift +++ b/stdlib/public/TensorFlow/CompilerRuntime.swift @@ -100,7 +100,19 @@ private class TraceContext { /// (TF_Function) upon finalizing. let graph = TF_NewGraph() - /// The list of inputs to the trace graph function. + /// The list of inputs to the trace graph function. It starts with the inputs + /// to the function that we trace (referred to as the "tracee function" or + /// "tracee"), followed by possible additional inputs that correspond to + /// concrete tensors produced within the trace function. + /// + /// For example, if the tracee is: + /// func foo(x: TensorPair) -> Tensor { + /// let y = Tensor(1.0) + /// return x.first + x.second + y + /// } + /// + /// Then the generated trace graph function has 3 input tensors: x.first, + /// x.second, and y. /// /// These symbolic tensors corresond to PlaceHolder nodes in the trace graph, /// and will be filled in when we execute the trace graph function. @@ -116,6 +128,11 @@ private class TraceContext { /// The trace graph function created by `finalize()`. var traceGraphFn: CTFFunction? + /// The number of additional input tensors to the trace graph function, + /// created from concrete intermediate tensors in the tracee, such as `y` in + /// the code snippet above. + var additionalInputTensorCount: Int32 = -1 + /// `inputValueCount` is the length of the (flattened) list of input tensors /// to the trace function. init(inputValueCount: Int) { @@ -150,18 +167,41 @@ private class TraceContext { outputs: [CTensorHandle]) { internalConsistencyCheck(traceGraphFn == nil) var symbolicOutputs: [TF_Output] = [] - for (i, output) in outputs.enumerated() { + // Only add symbolic output tensors as the outputs of the trace graph function. + // For example, let the tracee be: + // func foo(x: Tensor) -> (Tensor, Tensor) { + // let y = Tensor(1.0) + // return (x + x, y) + // } + // + // Here foo() returns 2 tensors, but only the first one (as computed by x + + // x) is symbolic. The second one for y is concrete, and is computed at + // trace creation time, not trace execution time. + // Also see the comment block above finalizeAndExecuteTraceFn(). + for (i, output) in outputs.enumerated() + where TFE_TensorHandleIsConcrete(output) == 0 { debugLog("Adding symbolic output \(i) as a trace graph func output.") symbolicOutputs.append(TFE_GetTFOutputFromTensorHandle(output ,status)) checkOk(status) } + let traceeInputCount = symbolicInputs.count + // Append concrete tensors created within the tracee as symbolic inputs to + // the generated trace graph function. + additionalInputTensorCount = TFE_FinalizeInputTensorsFromTraceContext( + cTraceContext) + for i in 0..], + func execute(traceeInputs: [Tensor], outputs: [CTensorHandle]) -> [CTensorHandle] { // We must be in the `notTracing` enum mode. internalConsistencyCheck(_RuntimeConfig.traceState.context == nil) @@ -210,13 +252,28 @@ private class TraceContext { checkOk(status) } - debugLog("Adding \(inputs.count) tracee input tensors.") - internalConsistencyCheck(symbolicInputs.count == inputs.count) - for input in inputs { + debugLog("Adding \(traceeInputs.count) tracee input tensors.") + internalConsistencyCheck(symbolicInputs.count == traceeInputs.count + + Int(additionalInputTensorCount)) + for input in traceeInputs { _TFCOpAddInputFromTensorHandle(op, input.handle, status) checkOk(status) } + debugLog("Adding \(additionalInputTensorCount) additional input tensors.") + for i in 0.., data: Data) -> (Tensor, Result) { + // Create an intermediate tensor value, which the tracing infra needs to + // convert into a placeholder input into the generated trace graph function. + let tmp = Tensor(1.0) + return (tmp, tmp + data) + } + + let state = Tensor(2.0) + let data = Tensor(3.0) + let tracedFn = _graph(with: state, in: tracee) + let (newState, result) = tracedFn(state, data) + + _hostOp(newState) + expectNearlyEqualWithScalarTensor(1.0, newState) + + _hostOp(result) + expectNearlyEqualWithScalarTensor(4.0, result) +} + +TracerTests.testAllBackends("Advanced") { + typealias Model = [Tensor] + + typealias Optimizer = [Tensor] + + struct State : _TensorArrayProtocolEnhanced { + var model: Model = [Tensor(1.0), Tensor(2.0)] + var optimizer: Optimizer = [Tensor(1.0), Tensor(2.0)] + + public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { + print("Calling State._unpackTensorHandles().") + var ptr = address + model._unpackTensorHandles(into: ptr) + ptr = ptr!.advanced(by: Int(model._tensorHandleCount)) + optimizer._unpackTensorHandles(into: ptr) + } + public var _tensorHandleCount: Int32 { + return model._tensorHandleCount + optimizer._tensorHandleCount + } + + func _makeInstance(owning inputs: C) -> State + where C.Element == CTensorHandle { + assert(inputs.count == 4) + var abstractState = State() + let index0 = inputs.startIndex + let index1 = inputs.index(after: index0) + abstractState.model = [Tensor(handle: TensorHandle(_owning: inputs[index0])), + Tensor(handle: TensorHandle(_owning: inputs[index1]))] + let index2 = inputs.index(after: index1) + let index3 = inputs.index(after: index2) + abstractState.optimizer = [Tensor(handle: TensorHandle(_owning: inputs[index2])), + Tensor(handle: TensorHandle(_owning: inputs[index3]))] + return abstractState + } + } + + func tracee(state: State, data: Data) -> (State, Result) { + print("Running tracee()") + var tmp = Tensor(0.0) + for i in 0..(3.0) + let tracedFn = _graph(with: state, in: tracee) + let (newState, result) = tracedFn(state, data) + + _hostOp(newState) // should be State(model: [3.0, 2.0], optimizer: [1.0, 2.0]) + expectNearlyEqualWithScalarTensor(3.0, newState.model[0]) + expectNearlyEqualWithScalarTensor(2.0, newState.model[1]) + expectNearlyEqualWithScalarTensor(1.0, newState.optimizer[0]) + expectNearlyEqualWithScalarTensor(2.0, newState.optimizer[1]) + + _hostOp(result) // should be 8.0 + expectNearlyEqualWithScalarTensor(8.0, result) +} + runAllTests()