From 9cc69e3e9b3c8cbb2e49bcca7ba08b2a145a7d2d Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 24 Aug 2020 10:19:54 -0700 Subject: [PATCH] fix tensorarray disposing weight tensors --- tfjs-converter/src/executor/execution_context.ts | 6 +++--- tfjs-converter/src/executor/graph_executor.ts | 15 +++++++-------- .../src/executor/graph_executor_test.ts | 2 ++ tfjs-converter/src/executor/tensor_array.ts | 8 ++++++-- tfjs-converter/src/executor/tensor_array_test.ts | 16 ++++++++++++++++ tfjs-converter/src/executor/tensor_list.ts | 8 ++++++-- tfjs-converter/src/executor/tensor_list_test.ts | 14 ++++++++++++++ 7 files changed, 54 insertions(+), 15 deletions(-) diff --git a/tfjs-converter/src/executor/execution_context.ts b/tfjs-converter/src/executor/execution_context.ts index ead625fc22d..0771a6adadb 100644 --- a/tfjs-converter/src/executor/execution_context.ts +++ b/tfjs-converter/src/executor/execution_context.ts @@ -175,13 +175,13 @@ export class ExecutionContext { return this.tensorListMap[id]; } - dispose() { + dispose(keepIds: Set) { for (const key in this.tensorArrayMap) { - this.tensorArrayMap[key].clearAndClose(); + this.tensorArrayMap[key].clearAndClose(keepIds); } for (const key in this.tensorListMap) { - this.tensorListMap[key].clearAndClose(); + this.tensorListMap[key].clearAndClose(keepIds); } } } diff --git a/tfjs-converter/src/executor/graph_executor.ts b/tfjs-converter/src/executor/graph_executor.ts index dceb868a943..6eafd455446 100644 --- a/tfjs-converter/src/executor/graph_executor.ts +++ b/tfjs-converter/src/executor/graph_executor.ts @@ -224,7 +224,7 @@ export class GraphExecutor implements FunctionExecutor { } // dispose the context for the root executor if (this.parent == null) { - context.dispose(); + context.dispose(tensorsToKeep); } return outputs.map(name => getTensor(name, tensorsMap, context)); }); @@ -333,22 +333,21 @@ export class GraphExecutor implements FunctionExecutor { const results = outputs.map(name => getTensor(name, tensorMap, context)); // dispose all the intermediate tensors - const outputIds = new Set(results.map(t => t.id)); - const inputIds = - new Set(Object.keys(inputs).map(name => inputs[name].id)); + const outputIds = results.map(t => t.id); + const inputIds = Object.keys(inputs).map(name => inputs[name].id); + const keepIds = + new Set([...outputIds, ...inputIds, ...this.weightIds]); Object.keys(tensorMap).forEach(key => { const tensorArray = tensorMap[key]; tensorArray.forEach(tensor => { - if (tensor && !tensor.isDisposed && !outputIds.has(tensor.id) && - !inputIds.has(tensor.id) && - this.weightIds.indexOf(tensor.id) === -1) { + if (tensor && !tensor.isDisposed && !keepIds.has(tensor.id)) { tensor.dispose(); } }); }); // dispose the context for the root executor if (this.parent == null) { - context.dispose(); + context.dispose(keepIds); } return results; diff --git a/tfjs-converter/src/executor/graph_executor_test.ts b/tfjs-converter/src/executor/graph_executor_test.ts index 40793b9100f..d5654a32aa8 100644 --- a/tfjs-converter/src/executor/graph_executor_test.ts +++ b/tfjs-converter/src/executor/graph_executor_test.ts @@ -469,6 +469,7 @@ describe('GraphExecutor', () => { }; executor = new GraphExecutor(graphWithControlFlow); + executor.weightMap = {}; }); it('should execute control flow v2 graph', async () => { @@ -595,6 +596,7 @@ describe('GraphExecutor', () => { }; executor = new GraphExecutor(graphWithControlFlow); + executor.weightMap = {}; }); it('should execute control flow v2 graph', async () => { diff --git a/tfjs-converter/src/executor/tensor_array.ts b/tfjs-converter/src/executor/tensor_array.ts index 4ac46db6955..b69553ffb48 100644 --- a/tfjs-converter/src/executor/tensor_array.ts +++ b/tfjs-converter/src/executor/tensor_array.ts @@ -52,8 +52,12 @@ export class TensorArray { /** * Dispose the tensors and idTensor and mark the TensoryArray as closed. */ - clearAndClose() { - this.tensors.forEach(tensor => tensor.tensor.dispose()); + clearAndClose(keepIds?: Set) { + this.tensors.forEach(tensor => { + if (keepIds == null || !keepIds.has(tensor.tensor.id)) { + tensor.tensor.dispose(); + } + }); this.tensors = []; this.closed_ = true; this.idTensor.dispose(); diff --git a/tfjs-converter/src/executor/tensor_array_test.ts b/tfjs-converter/src/executor/tensor_array_test.ts index 425c2305976..0e3d77387b9 100644 --- a/tfjs-converter/src/executor/tensor_array_test.ts +++ b/tfjs-converter/src/executor/tensor_array_test.ts @@ -54,10 +54,26 @@ describe('TensorArray', () => { tensorArray.clearAndClose(); expect(tensorArray.size()).toBe(0); expect(tensorArray.closed).toBeTruthy(); + // disposed the tensor in the array and idTensor of the array expect(memory().numTensors).toEqual(numOfTensors - size - 1); }); + it('should not dispose keep tensors when close', () => { + const numOfTensors = memory().numTensors; + tensorArray.write(0, tensor); + tensorArray.write(1, tensor2); + const size = tensorArray.size(); + const keepIds = new Set([tensor.id]); + tensorArray.clearAndClose(keepIds); + expect(tensorArray.size()).toBe(0); + expect(tensorArray.closed).toBeTruthy(); + expect(tensor.isDisposed).toBeFalsy(); + expect(tensor2.isDisposed).toBeTruthy(); + // disposed the tensor in the array and idTensor of the array + expect(memory().numTensors).toEqual(numOfTensors - size); + }); + describe('write', () => { it('should add new tensor', () => { tensorArray.write(0, tensor); diff --git a/tfjs-converter/src/executor/tensor_list.ts b/tfjs-converter/src/executor/tensor_list.ts index 9eb3c45b27b..0ceb5615650 100644 --- a/tfjs-converter/src/executor/tensor_list.ts +++ b/tfjs-converter/src/executor/tensor_list.ts @@ -80,8 +80,12 @@ export class TensorList { /** * Dispose the tensors and idTensor and clear the tensor list. */ - clearAndClose() { - this.tensors.forEach(tensor => tensor.dispose()); + clearAndClose(keepIds?: Set) { + this.tensors.forEach(tensor => { + if (keepIds == null || !keepIds.has(tensor.id)) { + tensor.dispose(); + } + }); this.tensors.length = 0; this.idTensor.dispose(); } diff --git a/tfjs-converter/src/executor/tensor_list_test.ts b/tfjs-converter/src/executor/tensor_list_test.ts index 60455ddd8bd..ddc77cd59e6 100644 --- a/tfjs-converter/src/executor/tensor_list_test.ts +++ b/tfjs-converter/src/executor/tensor_list_test.ts @@ -40,6 +40,20 @@ describe('TensorList', () => { expect(tensorList.elementShape).toEqual(SHAPE); }); + it('should not dispose keep tensors when close', () => { + const numOfTensors = memory().numTensors; + tensorList.pushBack(tensor); + tensorList.pushBack(tensor2); + const size = tensorList.size(); + const keepIds = new Set([tensor.id]); + tensorList.clearAndClose(keepIds); + expect(tensorList.size()).toBe(0); + expect(tensor.isDisposed).toBeFalsy(); + expect(tensor2.isDisposed).toBeTruthy(); + // disposed the tensor in the array and idTensor of the array + expect(memory().numTensors).toEqual(numOfTensors - size); + }); + describe('pushBack', () => { it('should add new tensor', () => { tensorList.pushBack(tensor);