-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Taught tracer to support concrete, intermediate tensors created within tracee #22100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -100,7 +100,19 @@ private class TraceContext { | |||||
| /// (TF_Function) upon finalizing. | ||||||
| let graph = TF_NewGraph() | ||||||
|
|
||||||
| /// The list of inputs to the trace graph function. | ||||||
| /// The list of inputs to the trace graph function. It starts with the inputs | ||||||
| /// to the function that we trace (referred to as the "tracee function" or | ||||||
| /// "tracee"), followed by possible additional inputs that correspond to | ||||||
| /// concrete tensors produced within the trace function. | ||||||
| /// | ||||||
| /// For example, if the tracee is: | ||||||
| /// func foo(x: TensorPair) -> Tensor { | ||||||
| /// let y = Tensor<Float>(1.0) | ||||||
| /// return x.first + x.second + y | ||||||
| /// } | ||||||
| /// | ||||||
| /// Then the generated trace graph function has 3 input tensors: x.first, | ||||||
| /// x.second, and y. | ||||||
| /// | ||||||
| /// These symbolic tensors corresond to PlaceHolder nodes in the trace graph, | ||||||
| /// and will be filled in when we execute the trace graph function. | ||||||
|
|
@@ -116,6 +128,11 @@ private class TraceContext { | |||||
| /// The trace graph function created by `finalize()`. | ||||||
| var traceGraphFn: CTFFunction? | ||||||
|
|
||||||
| /// The number of additional input tensors to the trace graph function, | ||||||
| /// created from concrete intermediate tensors in the tracee, such as `y` in | ||||||
| /// the code snippet above. | ||||||
| var additionalInputTensorCount: Int32 = -1 | ||||||
|
|
||||||
| /// `inputValueCount` is the length of the (flattened) list of input tensors | ||||||
| /// to the trace function. | ||||||
| init(inputValueCount: Int) { | ||||||
|
|
@@ -150,18 +167,41 @@ private class TraceContext { | |||||
| outputs: [CTensorHandle]) { | ||||||
| internalConsistencyCheck(traceGraphFn == nil) | ||||||
| var symbolicOutputs: [TF_Output] = [] | ||||||
| for (i, output) in outputs.enumerated() { | ||||||
| // Only add symbolic output tensors as the outputs of the trace graph function. | ||||||
| // For example, let the tracee be: | ||||||
| // func foo(x: Tensor) -> (Tensor, Tensor) { | ||||||
| // let y = Tensor<Float>(1.0) | ||||||
| // return (x + x, y) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test added here does not exercise this scenario?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed in person, and |
||||||
| // } | ||||||
| // | ||||||
| // Here foo() returns 2 tensors, but only the first one (as computed by x + | ||||||
| // x) is symbolic. The second one for y is concrete, and is computed at | ||||||
| // trace creation time, not trace execution time. | ||||||
| // Also see the comment block above finalizeAndExecuteTraceFn(). | ||||||
| for (i, output) in outputs.enumerated() | ||||||
| where TFE_TensorHandleIsConcrete(output) == 0 { | ||||||
| debugLog("Adding symbolic output \(i) as a trace graph func output.") | ||||||
| symbolicOutputs.append(TFE_GetTFOutputFromTensorHandle(output ,status)) | ||||||
| checkOk(status) | ||||||
| } | ||||||
|
|
||||||
| let traceeInputCount = symbolicInputs.count | ||||||
| // Append concrete tensors created within the tracee as symbolic inputs to | ||||||
| // the generated trace graph function. | ||||||
| additionalInputTensorCount = TFE_FinalizeInputTensorsFromTraceContext( | ||||||
| cTraceContext) | ||||||
| for i in 0..<additionalInputTensorCount { | ||||||
| symbolicInputs.append(TFE_GetInputGraphNodeFromTraceContext( | ||||||
| cTraceContext, UInt32(i))) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I misread. |
||||||
| } | ||||||
|
|
||||||
| let tracedFunctionName = | ||||||
| "\(traceeBasicName)_\(_RuntimeConfig.traceGraphFunctionCounter)" | ||||||
| _RuntimeConfig.traceGraphFunctionCounter += 1 | ||||||
| debugLog(""" | ||||||
| Finalizing trace graph func \(tracedFunctionName), with \ | ||||||
| \(symbolicInputs.count) tracee inputs, and \ | ||||||
| \(traceeInputCount) tracee inputs and \ | ||||||
| \(additionalInputTensorCount) additional inputs, and up to \ | ||||||
| \(outputs.count) return values. | ||||||
| """) | ||||||
| traceGraphFn = | ||||||
|
|
@@ -184,14 +224,16 @@ private class TraceContext { | |||||
| free(funcDebugStr) | ||||||
| } | ||||||
|
|
||||||
| // TODO: Consider garbage-collecting these trace graph functions if we end | ||||||
| // up with many of them. | ||||||
| let eagerContext = _TFCGetGlobalEagerContext() | ||||||
| TFE_ContextAddFunction(eagerContext, traceGraphFn, status) | ||||||
| checkOk(status) | ||||||
| } | ||||||
|
|
||||||
| /// Execute the trace graph function, and return the list of output tensors | ||||||
| /// from the trace execution. These output tensors are owned by the caller. | ||||||
| func execute(inputs: [Tensor<Float>], | ||||||
| func execute(traceeInputs: [Tensor<Float>], | ||||||
| outputs: [CTensorHandle]) -> [CTensorHandle] { | ||||||
| // We must be in the `notTracing` enum mode. | ||||||
| internalConsistencyCheck(_RuntimeConfig.traceState.context == nil) | ||||||
|
|
@@ -210,13 +252,28 @@ private class TraceContext { | |||||
| checkOk(status) | ||||||
| } | ||||||
|
|
||||||
| debugLog("Adding \(inputs.count) tracee input tensors.") | ||||||
| internalConsistencyCheck(symbolicInputs.count == inputs.count) | ||||||
| for input in inputs { | ||||||
| debugLog("Adding \(traceeInputs.count) tracee input tensors.") | ||||||
| internalConsistencyCheck(symbolicInputs.count == traceeInputs.count | ||||||
| + Int(additionalInputTensorCount)) | ||||||
| for input in traceeInputs { | ||||||
| _TFCOpAddInputFromTensorHandle(op, input.handle, status) | ||||||
| checkOk(status) | ||||||
| } | ||||||
|
|
||||||
| debugLog("Adding \(additionalInputTensorCount) additional input tensors.") | ||||||
| for i in 0..<additionalInputTensorCount { | ||||||
| let input = TFE_ConsumeInputConcreteTensorFromTraceContext(cTraceContext, | ||||||
| UInt32(i)) | ||||||
| internalConsistencyCheck(input != nil) | ||||||
| debugLog(""" | ||||||
| Adding additional input tensor of idx \ | ||||||
| \(traceeInputs.count+Int(additionalInputTensorCount)):\ | ||||||
| \(input!). | ||||||
| """) | ||||||
| TFE_OpAddInput(op, input, status) | ||||||
| checkOk(status) | ||||||
| } | ||||||
|
|
||||||
| // Tell TensorFlow to execute the graph function we built, containing | ||||||
| // the trace. | ||||||
| let maxReturnValueCount = outputs.count | ||||||
|
|
@@ -240,14 +297,23 @@ private class TraceContext { | |||||
| var traceGraphOutputs: [CTensorHandle] = [] | ||||||
| // Points to an element in `returnValues`. | ||||||
| var returnValueIdx = 0 | ||||||
| // We manually increment `returnValueIdx` below instead of using | ||||||
| // `outputs.enumerated()`, because the logic will be extended in a future PR | ||||||
| // that requires manual counting. | ||||||
| // See the comment block within finalize() below on why we handle concrete | ||||||
| // and symbolic output tensors differently. | ||||||
| for output in outputs { | ||||||
| internalConsistencyCheck(TFE_TensorHandleIsConcrete(output) == 0) | ||||||
| internalConsistencyCheck(returnValues[returnValueIdx] != nil) | ||||||
| traceGraphOutputs.append(returnValues[returnValueIdx]!) | ||||||
| returnValueIdx += 1 | ||||||
| if TFE_TensorHandleIsConcrete(output) != 0 { | ||||||
| // These concrete tensors are owned by some other objects, so we make a | ||||||
| // copy here. | ||||||
| let newOutput = TFE_TensorHandleCopySharingTensor(output, status) | ||||||
| checkOk(status) | ||||||
| internalConsistencyCheck(newOutput != nil) | ||||||
| traceGraphOutputs.append(newOutput!) | ||||||
| } else { | ||||||
| // These symbolic tensors are produced by TFE_Execute() above, and we | ||||||
| // need not make an extra copy. | ||||||
| internalConsistencyCheck(returnValues[returnValueIdx] != nil) | ||||||
| traceGraphOutputs.append(returnValues[returnValueIdx]!) | ||||||
| returnValueIdx += 1 | ||||||
| } | ||||||
| } | ||||||
| internalConsistencyCheck(returnValueIdx == outputReturnValueCount) | ||||||
| return traceGraphOutputs | ||||||
|
|
@@ -694,7 +760,9 @@ public func _graph<State : _TensorArrayProtocolEnhanced, | |||||
| _copying: inputSymbolicTensors.dropFirst( | ||||||
| Int(state._tensorHandleCount))) | ||||||
| // Run tracee to build the trace, adding ops to the trace graph function. | ||||||
| debugLog("Running tracee in tracing mode.") | ||||||
| // The tracee output can contain a mixture of symbolic and concrete tensors | ||||||
| // (see the comment block within TraceContext.finalize()). | ||||||
| debugLog("Running tracee in tracing mode.") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, thanks! |
||||||
| let (outputState, outputResult) = fn(symbolicState, symbolicData) | ||||||
|
|
||||||
| debugLog("Assembling output tensor handles.") | ||||||
|
|
@@ -725,7 +793,7 @@ public func _graph<State : _TensorArrayProtocolEnhanced, | |||||
| }) | ||||||
|
|
||||||
| debugLog("Executing trace graph function.") | ||||||
| let returnValues = traceContext.execute(inputs: inputTensors, | ||||||
| let returnValues = traceContext.execute(traceeInputs: inputTensors, | ||||||
| outputs: outputTensorHandles) | ||||||
|
|
||||||
| debugLog("Creating output model instance.") | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why the initial value is
-1?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This increases the chance of catching bugs (since
-1is never valid), if we forget to set it to some legal value.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see where you are coming from, however, there's a safer way to do that in Swift -- since you expect
additionalInputTensorCountto always be set during class initialization, removing the default value makes you able to catch initialization bugs at compile-time. When you forget to set it ininit, Swift gives you a compile-time error saying "member is not initialized", which is significantly better than catching this at runtime.