Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 84 additions & 16 deletions stdlib/public/TensorFlow/CompilerRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,19 @@ private class TraceContext {
/// (TF_Function) upon finalizing.
let graph = TF_NewGraph()

/// The list of inputs to the trace graph function.
/// The list of inputs to the trace graph function. It starts with the inputs
/// to the function that we trace (referred to as the "tracee function" or
/// "tracee"), followed by possible additional inputs that correspond to
/// concrete tensors produced within the trace function.
///
/// For example, if the tracee is:
/// func foo(x: TensorPair) -> Tensor {
/// let y = Tensor<Float>(1.0)
/// return x.first + x.second + y
/// }
///
/// Then the generated trace graph function has 3 input tensors: x.first,
/// x.second, and y.
///
/// These symbolic tensors corresond to PlaceHolder nodes in the trace graph,
/// and will be filled in when we execute the trace graph function.
Expand All @@ -116,6 +128,11 @@ private class TraceContext {
/// The trace graph function created by `finalize()`.
var traceGraphFn: CTFFunction?

/// The number of additional input tensors to the trace graph function,
/// created from concrete intermediate tensors in the tracee, such as `y` in
/// the code snippet above.
var additionalInputTensorCount: Int32 = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why the initial value is -1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This increases the chance of catching bugs (since -1 is never valid), if we forget to set it to some legal value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see where you are coming from, however, there's a safer way to do that in Swift -- since you expect additionalInputTensorCount to always be set during class initialization, removing the default value makes you able to catch initialization bugs at compile-time. When you forget to set it in init, Swift gives you a compile-time error saying "member is not initialized", which is significantly better than catching this at runtime.


/// `inputValueCount` is the length of the (flattened) list of input tensors
/// to the trace function.
init(inputValueCount: Int) {
Expand Down Expand Up @@ -150,18 +167,41 @@ private class TraceContext {
outputs: [CTensorHandle]) {
internalConsistencyCheck(traceGraphFn == nil)
var symbolicOutputs: [TF_Output] = []
for (i, output) in outputs.enumerated() {
// Only add symbolic output tensors as the outputs of the trace graph function.
// For example, let the tracee be:
// func foo(x: Tensor) -> (Tensor, Tensor) {
// let y = Tensor<Float>(1.0)
// return (x + x, y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test added here does not exercise this scenario?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in person, and TracerTests.testAllBackends("Basic_IntermediateTensors") should cover this case.

// }
//
// Here foo() returns 2 tensors, but only the first one (as computed by x +
// x) is symbolic. The second one for y is concrete, and is computed at
// trace creation time, not trace execution time.
// Also see the comment block above finalizeAndExecuteTraceFn().
for (i, output) in outputs.enumerated()
where TFE_TensorHandleIsConcrete(output) == 0 {
debugLog("Adding symbolic output \(i) as a trace graph func output.")
symbolicOutputs.append(TFE_GetTFOutputFromTensorHandle(output ,status))
checkOk(status)
}

let traceeInputCount = symbolicInputs.count
// Append concrete tensors created within the tracee as symbolic inputs to
// the generated trace graph function.
additionalInputTensorCount = TFE_FinalizeInputTensorsFromTraceContext(
cTraceContext)
for i in 0..<additionalInputTensorCount {
symbolicInputs.append(TFE_GetInputGraphNodeFromTraceContext(
cTraceContext, UInt32(i)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cTraceContext, UInt32(i)))
cTraceContext, UInt32(i)))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cTraceContext is indented two spaces after the pervious line on TFE_GetInputGraphNodeFromTraceContext( -- I believe that's the correct indentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I misread.

}

let tracedFunctionName =
"\(traceeBasicName)_\(_RuntimeConfig.traceGraphFunctionCounter)"
_RuntimeConfig.traceGraphFunctionCounter += 1
debugLog("""
Finalizing trace graph func \(tracedFunctionName), with \
\(symbolicInputs.count) tracee inputs, and \
\(traceeInputCount) tracee inputs and \
\(additionalInputTensorCount) additional inputs, and up to \
\(outputs.count) return values.
""")
traceGraphFn =
Expand All @@ -184,14 +224,16 @@ private class TraceContext {
free(funcDebugStr)
}

// TODO: Consider garbage-collecting these trace graph functions if we end
// up with many of them.
let eagerContext = _TFCGetGlobalEagerContext()
TFE_ContextAddFunction(eagerContext, traceGraphFn, status)
checkOk(status)
}

/// Execute the trace graph function, and return the list of output tensors
/// from the trace execution. These output tensors are owned by the caller.
func execute(inputs: [Tensor<Float>],
func execute(traceeInputs: [Tensor<Float>],
outputs: [CTensorHandle]) -> [CTensorHandle] {
// We must be in the `notTracing` enum mode.
internalConsistencyCheck(_RuntimeConfig.traceState.context == nil)
Expand All @@ -210,13 +252,28 @@ private class TraceContext {
checkOk(status)
}

debugLog("Adding \(inputs.count) tracee input tensors.")
internalConsistencyCheck(symbolicInputs.count == inputs.count)
for input in inputs {
debugLog("Adding \(traceeInputs.count) tracee input tensors.")
internalConsistencyCheck(symbolicInputs.count == traceeInputs.count
+ Int(additionalInputTensorCount))
for input in traceeInputs {
_TFCOpAddInputFromTensorHandle(op, input.handle, status)
checkOk(status)
}

debugLog("Adding \(additionalInputTensorCount) additional input tensors.")
for i in 0..<additionalInputTensorCount {
let input = TFE_ConsumeInputConcreteTensorFromTraceContext(cTraceContext,
UInt32(i))
internalConsistencyCheck(input != nil)
debugLog("""
Adding additional input tensor of idx \
\(traceeInputs.count+Int(additionalInputTensorCount)):\
\(input!).
""")
TFE_OpAddInput(op, input, status)
checkOk(status)
}

// Tell TensorFlow to execute the graph function we built, containing
// the trace.
let maxReturnValueCount = outputs.count
Expand All @@ -240,14 +297,23 @@ private class TraceContext {
var traceGraphOutputs: [CTensorHandle] = []
// Points to an element in `returnValues`.
var returnValueIdx = 0
// We manually increment `returnValueIdx` below instead of using
// `outputs.enumerated()`, because the logic will be extended in a future PR
// that requires manual counting.
// See the comment block within finalize() below on why we handle concrete
// and symbolic output tensors differently.
for output in outputs {
internalConsistencyCheck(TFE_TensorHandleIsConcrete(output) == 0)
internalConsistencyCheck(returnValues[returnValueIdx] != nil)
traceGraphOutputs.append(returnValues[returnValueIdx]!)
returnValueIdx += 1
if TFE_TensorHandleIsConcrete(output) != 0 {
// These concrete tensors are owned by some other objects, so we make a
// copy here.
let newOutput = TFE_TensorHandleCopySharingTensor(output, status)
checkOk(status)
internalConsistencyCheck(newOutput != nil)
traceGraphOutputs.append(newOutput!)
} else {
// These symbolic tensors are produced by TFE_Execute() above, and we
// need not make an extra copy.
internalConsistencyCheck(returnValues[returnValueIdx] != nil)
traceGraphOutputs.append(returnValues[returnValueIdx]!)
returnValueIdx += 1
}
}
internalConsistencyCheck(returnValueIdx == outputReturnValueCount)
return traceGraphOutputs
Expand Down Expand Up @@ -694,7 +760,9 @@ public func _graph<State : _TensorArrayProtocolEnhanced,
_copying: inputSymbolicTensors.dropFirst(
Int(state._tensorHandleCount)))
// Run tracee to build the trace, adding ops to the trace graph function.
debugLog("Running tracee in tracing mode.")
// The tracee output can contain a mixture of symbolic and concrete tensors
// (see the comment block within TraceContext.finalize()).
debugLog("Running tracee in tracing mode.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
debugLog("Running tracee in tracing mode.")
debugLog("Running tracee in tracing mode.")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks!

let (outputState, outputResult) = fn(symbolicState, symbolicData)

debugLog("Assembling output tensor handles.")
Expand Down Expand Up @@ -725,7 +793,7 @@ public func _graph<State : _TensorArrayProtocolEnhanced,
})

debugLog("Executing trace graph function.")
let returnValues = traceContext.execute(inputs: inputTensors,
let returnValues = traceContext.execute(traceeInputs: inputTensors,
outputs: outputTensorHandles)

debugLog("Creating output model instance.")
Expand Down
85 changes: 85 additions & 0 deletions test/TensorFlowRuntime/tracer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,89 @@ TracerTests.testAllBackends("Basic") {
expectNearlyEqualWithScalarTensor(1.0, result2)
}

TracerTests.testAllBackends("Basic_IntermediateTensors") {
func tracee(state: Tensor<Float>, data: Data) -> (Tensor<Float>, Result) {
// Create an intermediate tensor value, which the tracing infra needs to
// convert into a placeholder input into the generated trace graph function.
let tmp = Tensor<Float>(1.0)
return (tmp, tmp + data)
}

let state = Tensor<Float>(2.0)
let data = Tensor<Float>(3.0)
let tracedFn = _graph(with: state, in: tracee)
let (newState, result) = tracedFn(state, data)

_hostOp(newState)
expectNearlyEqualWithScalarTensor(1.0, newState)

_hostOp(result)
expectNearlyEqualWithScalarTensor(4.0, result)
}

TracerTests.testAllBackends("Advanced") {
typealias Model = [Tensor<Float>]

typealias Optimizer = [Tensor<Float>]

struct State : _TensorArrayProtocolEnhanced {
var model: Model = [Tensor<Float>(1.0), Tensor<Float>(2.0)]
var optimizer: Optimizer = [Tensor<Float>(1.0), Tensor<Float>(2.0)]

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
print("Calling State._unpackTensorHandles().")
var ptr = address
model._unpackTensorHandles(into: ptr)
ptr = ptr!.advanced(by: Int(model._tensorHandleCount))
optimizer._unpackTensorHandles(into: ptr)
}
public var _tensorHandleCount: Int32 {
return model._tensorHandleCount + optimizer._tensorHandleCount
}

func _makeInstance<C: Collection>(owning inputs: C) -> State
where C.Element == CTensorHandle {
assert(inputs.count == 4)
var abstractState = State()
let index0 = inputs.startIndex
let index1 = inputs.index(after: index0)
abstractState.model = [Tensor(handle: TensorHandle<Float>(_owning: inputs[index0])),
Tensor(handle: TensorHandle<Float>(_owning: inputs[index1]))]
let index2 = inputs.index(after: index1)
let index3 = inputs.index(after: index2)
abstractState.optimizer = [Tensor(handle: TensorHandle<Float>(_owning: inputs[index2])),
Tensor(handle: TensorHandle<Float>(_owning: inputs[index3]))]
return abstractState
}
}

func tracee(state: State, data: Data) -> (State, Result) {
print("Running tracee()")
var tmp = Tensor<Float>(0.0)
for i in 0..<state.model.count {
tmp += state.model[i] * state.optimizer[i]
}

print("Creating return value()")
var newState = state
newState.model[0] = state.model[0] + state.model[1]
let ret = (newState, tmp + data)
return ret
}

let state = State()
let data = Tensor<Float>(3.0)
let tracedFn = _graph(with: state, in: tracee)
let (newState, result) = tracedFn(state, data)

_hostOp(newState) // should be State(model: [3.0, 2.0], optimizer: [1.0, 2.0])
expectNearlyEqualWithScalarTensor(3.0, newState.model[0])
expectNearlyEqualWithScalarTensor(2.0, newState.model[1])
expectNearlyEqualWithScalarTensor(1.0, newState.optimizer[0])
expectNearlyEqualWithScalarTensor(2.0, newState.optimizer[1])

_hostOp(result) // should be 8.0
expectNearlyEqualWithScalarTensor(8.0, result)
}

runAllTests()