Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions stdlib/public/TensorFlow/CompilerRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ private class TraceContext {
/// Execute the trace graph function, and return the list of output tensors
/// from the trace execution. These output tensors are owned by the caller.
func execute(
traceeInputs: [_AnyTensorHandle], useXla: Bool = false) -> [CTensorHandle] {
traceeInputs: [_AnyTensorHandle], useXLA: Bool = false) -> [CTensorHandle] {
// We must be in the `notTracing` enum mode.
internalConsistencyCheck(_RuntimeConfig.traceState.context == nil)
internalConsistencyCheck(traceGraphFn != nil)
Expand All @@ -237,7 +237,7 @@ private class TraceContext {
checkOk(status)
}

if useXla {
if useXLA {
debugLog("Enabling XLA compilation")
TFE_OpSetAttrBool(op, "_XlaCompile", 1)
}
Expand Down Expand Up @@ -1002,11 +1002,10 @@ public func _tffunc<State : _TensorArrayProtocolEnhanced,
// Trace the given function to generate a TF graph and return a closure
// that can be used to launch the graph.
public func _graph<In : TensorGroup, Out : TensorGroup>(
_ fn: (In) -> Out, useXla: Bool = false
_ fn: (In) -> Out, useXLA: Bool = false
) -> (In) -> Out {
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
let wrappedFn = {
(inputs: [CTensorHandle]) -> [CTensorHandle] in
let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in
let buffer = UnsafeMutablePointer<CTensorHandle>.allocate(
capacity: Int(inputs.count))
var ptr = buffer
Expand All @@ -1033,7 +1032,7 @@ public func _graph<In : TensorGroup, Out : TensorGroup>(
}
debugLog("Executing trace graph function.")
let returnValues = traceContext.execute(
traceeInputs: inputTensors, useXla: useXla)
traceeInputs: inputTensors, useXLA: useXLA)

debugLog("Creating output model instance.")
return Out(_copying: returnValues)
Expand Down