diff --git a/stdlib/public/TensorFlow/CompilerRuntime.swift b/stdlib/public/TensorFlow/CompilerRuntime.swift index a35da35b9c9b8..e2870f42d0b63 100644 --- a/stdlib/public/TensorFlow/CompilerRuntime.swift +++ b/stdlib/public/TensorFlow/CompilerRuntime.swift @@ -228,6 +228,7 @@ private class TraceContext { internalConsistencyCheck(tracedFunctionName != nil) let eagerContext = _TFCGetGlobalEagerContext() let op: CTFEOp! = TFE_NewOp(eagerContext, tracedFunctionName, status) + defer { TFE_DeleteOp(op) } checkOk(status) let deviceName = _ExecutionContext.global.currentDeviceName @@ -829,6 +830,20 @@ private extension TensorGroup { TF_DeleteStatus(status) self.init(_owning: buffer) } + + init(_owning input: C) where C.Element == CTensorHandle { + assert(Self._tensorHandleCount == input.count) + let buffer = UnsafeMutablePointer.allocate( + capacity: input.count) + let status = TF_NewStatus() + // copy input to buffer + for (i, inputTensorHandle) in input.enumerated() { + let address = buffer.advanced(by: i) + address.initialize(to: inputTensorHandle) + } + TF_DeleteStatus(status) + self.init(_owning: buffer) + } } // TODO: Fold this protocol into TensorArrayProtocol. @@ -936,7 +951,7 @@ private func _graphInternal( traceeInputs: inputTensors, useXLA: useXLA) debugLog("Creating output model instance.") - return Out(_copying: returnValues) + return Out(_owning: returnValues) } }