diff --git a/stdlib/public/TensorFlow/CompilerRuntime.swift b/stdlib/public/TensorFlow/CompilerRuntime.swift index 9278f6534ec8f..a02c6ccf32eb1 100644 --- a/stdlib/public/TensorFlow/CompilerRuntime.swift +++ b/stdlib/public/TensorFlow/CompilerRuntime.swift @@ -218,7 +218,8 @@ private class TraceContext { /// Execute the trace graph function, and return the list of output tensors /// from the trace execution. These output tensors are owned by the caller. - func execute(traceeInputs: [_AnyTensorHandle]) -> [CTensorHandle] { + func execute( + traceeInputs: [_AnyTensorHandle], useXla: Bool = false) -> [CTensorHandle] { // We must be in the `notTracing` enum mode. internalConsistencyCheck(_RuntimeConfig.traceState.context == nil) internalConsistencyCheck(traceGraphFn != nil) @@ -236,6 +237,11 @@ private class TraceContext { checkOk(status) } + if useXla { + debugLog("Enabling XLA compilation") + TFE_OpSetAttrBool(op, "_XlaCompile", 1) + } + debugLog("Adding \(traceeInputs.count) tracee input tensors.") internalConsistencyCheck(symbolicInputs.count == traceeInputs.count + Int(additionalInputTensorCount)) @@ -993,6 +999,47 @@ public func _tffunc( + _ fn: (In) -> Out, useXla: Bool = false +) -> (In) -> Out { + let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in + let wrappedFn = { + (inputs: [CTensorHandle]) -> [CTensorHandle] in + let buffer = UnsafeMutablePointer.allocate( + capacity: Int(inputs.count)) + var ptr = buffer + for input in inputs { + ptr.initialize(to: input) + ptr = ptr.advanced(by: 1) + } + let symbolicIn = In(_owning: buffer) + let symbolicOut = escapableFn(symbolicIn) + return symbolicOut.cTensorHandles + } + let dtypes = In._typeList.map { $0._cDataType } + return _trace(with: dtypes, in: wrappedFn) + } + // The result is a closure that captures and executes the trace graph + // function in the trace context. + return { (input: In) -> (Out) in + debugLog("Running trace function over input \(input).") + + debugLog("Getting input state tensor handles.") + let inputStateTensorHandles = input.cTensorHandles + let inputTensors = inputStateTensorHandles.map { + _TFCCreateTensorHandleFromC($0) + } + debugLog("Executing trace graph function.") + let returnValues = traceContext.execute( + traceeInputs: inputTensors, useXla: useXla) + + debugLog("Creating output model instance.") + return Out(_copying: returnValues) + } +} + /// Trace the given function and return the name of the corresponding /// `TF_Function: In -> Out` that was created. public func _tffunc( diff --git a/test/TensorFlowRuntime/tracer.swift b/test/TensorFlowRuntime/tracer.swift index a19f8eeb464c4..f3cfffe2b44b1 100644 --- a/test/TensorFlowRuntime/tracer.swift +++ b/test/TensorFlowRuntime/tracer.swift @@ -129,6 +129,16 @@ TracerTests.testAllBackends("TraceWithNoResult") { expectNearlyEqualWithScalarTensor(8.0, tracedAdd(Tensor(5.0), three)) } +TracerTests.testAllBackends("TracerWithInOut") { + func addOne(state: Tensor) -> (Tensor) { + return state + 1 + } + let addOneGraph = _graph(addOne) + expectEqual(addOneGraph(Tensor(5)), Tensor(6)) + expectEqual(addOneGraph(Tensor(0)), Tensor(1)) + expectEqual(addOneGraph(Tensor(-1)), Tensor(0)) +} + TracerTests.testAllBackends("Basic_IntermediateTensors") { func tracee(state: Tensor, data: Data) -> (Tensor, Result) { // Create an intermediate tensor value, which the tracing infra needs to