diff --git a/stdlib/public/TensorFlow/CompilerRuntime.swift b/stdlib/public/TensorFlow/CompilerRuntime.swift index a02c6ccf32eb1..adcd0be88a039 100644 --- a/stdlib/public/TensorFlow/CompilerRuntime.swift +++ b/stdlib/public/TensorFlow/CompilerRuntime.swift @@ -219,7 +219,7 @@ private class TraceContext { /// Execute the trace graph function, and return the list of output tensors /// from the trace execution. These output tensors are owned by the caller. func execute( - traceeInputs: [_AnyTensorHandle], useXla: Bool = false) -> [CTensorHandle] { + traceeInputs: [_AnyTensorHandle], useXLA: Bool = false) -> [CTensorHandle] { // We must be in the `notTracing` enum mode. internalConsistencyCheck(_RuntimeConfig.traceState.context == nil) internalConsistencyCheck(traceGraphFn != nil) @@ -237,7 +237,7 @@ private class TraceContext { checkOk(status) } - if useXla { + if useXLA { debugLog("Enabling XLA compilation") TFE_OpSetAttrBool(op, "_XlaCompile", 1) } @@ -1002,11 +1002,10 @@ public func _tffunc( - _ fn: (In) -> Out, useXla: Bool = false + _ fn: (In) -> Out, useXLA: Bool = false ) -> (In) -> Out { let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in - let wrappedFn = { - (inputs: [CTensorHandle]) -> [CTensorHandle] in + let wrappedFn = { (inputs: [CTensorHandle]) -> [CTensorHandle] in let buffer = UnsafeMutablePointer.allocate( capacity: Int(inputs.count)) var ptr = buffer @@ -1033,7 +1032,7 @@ public func _graph( } debugLog("Executing trace graph function.") let returnValues = traceContext.execute( - traceeInputs: inputTensors, useXla: useXla) + traceeInputs: inputTensors, useXLA: useXLA) debugLog("Creating output model instance.") return Out(_copying: returnValues)