diff --git a/tfjs-backend-webgl/src/webgl_topixels_test.ts b/tfjs-backend-webgl/src/webgl_topixels_test.ts new file mode 100644 index 00000000000..e0dc2bb9120 --- /dev/null +++ b/tfjs-backend-webgl/src/webgl_topixels_test.ts @@ -0,0 +1,41 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '@tensorflow/tfjs-core'; +// tslint:disable-next-line: no-imports-from-dist +import {describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util'; + +import {WebGLMemoryInfo} from './backend_webgl'; +import {WEBGL_ENVS} from './backend_webgl_test_registry'; + +describeWithFlags('toPixels', WEBGL_ENVS, () => { + it('does not leak memory', async () => { + const x = tf.tensor2d([[.1], [.2]], [2, 1]); + const startNumBytesInGPU = (tf.memory() as WebGLMemoryInfo).numBytesInGPU; + await tf.browser.toPixels(x); + expect((tf.memory() as WebGLMemoryInfo).numBytesInGPU) + .toEqual(startNumBytesInGPU); + }); + + it('does not leak memory given a tensor-like object', async () => { + const x = [[10], [20]]; // 2x1; + const startNumBytesInGPU = (tf.memory() as WebGLMemoryInfo).numBytesInGPU; + await tf.browser.toPixels(x); + expect((tf.memory() as WebGLMemoryInfo).numBytesInGPU) + .toEqual(startNumBytesInGPU); + }); +}); diff --git a/tfjs-core/src/ops/browser.ts b/tfjs-core/src/ops/browser.ts index e060db10fa1..5589c3d3c92 100644 --- a/tfjs-core/src/ops/browser.ts +++ b/tfjs-core/src/ops/browser.ts @@ -24,8 +24,6 @@ import {convertToTensor} from '../tensor_util_env'; import {PixelData, TensorLike} from '../types'; import {cast} from './cast'; -import {max} from './max'; -import {min} from './min'; import {op} from './operation'; import {tensor3d} from './tensor3d'; @@ -196,60 +194,50 @@ export async function toPixels( `1, 3 or 4 but got ${depth}`); } - const data = await $img.data(); - const minTensor = min($img); - const maxTensor = max($img); - const vals = await Promise.all([minTensor.data(), maxTensor.data()]); - const minVals = vals[0]; - const maxVals = vals[1]; - const minVal = minVals[0]; - const maxVal = maxVals[0]; - minTensor.dispose(); - maxTensor.dispose(); - if ($img.dtype === 'float32') { - if (minVal < 0 || maxVal > 1) { - throw new Error( - `Tensor values for a float32 Tensor must be in the ` + - `range [0 - 1] but got range [${minVal} - ${maxVal}].`); - } - } else if ($img.dtype === 'int32') { - if (minVal < 0 || maxVal > 255) { - throw new Error( - `Tensor values for a int32 Tensor must be in the ` + - `range [0 - 255] but got range [${minVal} - ${maxVal}].`); - } - } else { + if ($img.dtype !== 'float32' && $img.dtype !== 'int32') { throw new Error( `Unsupported type for toPixels: ${$img.dtype}.` + ` Please use float32 or int32 tensors.`); } + + const data = await $img.data(); const multiplier = $img.dtype === 'float32' ? 255 : 1; const bytes = new Uint8ClampedArray(width * height * 4); for (let i = 0; i < height * width; ++i) { - let r, g, b, a; - if (depth === 1) { - r = data[i] * multiplier; - g = data[i] * multiplier; - b = data[i] * multiplier; - a = 255; - } else if (depth === 3) { - r = data[i * 3] * multiplier; - g = data[i * 3 + 1] * multiplier; - b = data[i * 3 + 2] * multiplier; - a = 255; - } else if (depth === 4) { - r = data[i * 4] * multiplier; - g = data[i * 4 + 1] * multiplier; - b = data[i * 4 + 2] * multiplier; - a = data[i * 4 + 3] * multiplier; + const rgba = [0, 0, 0, 255]; + + for (let d = 0; d < depth; d++) { + const value = data[i * depth + d]; + + if ($img.dtype === 'float32') { + if (value < 0 || value > 1) { + throw new Error( + `Tensor values for a float32 Tensor must be in the ` + + `range [0 - 1] but encountered ${value}.`); + } + } else if ($img.dtype === 'int32') { + if (value < 0 || value > 255) { + throw new Error( + `Tensor values for a int32 Tensor must be in the ` + + `range [0 - 255] but encountered ${value}.`); + } + } + + if (depth === 1) { + rgba[0] = value * multiplier; + rgba[1] = value * multiplier; + rgba[2] = value * multiplier; + } else { + rgba[d] = value * multiplier; + } } const j = i * 4; - bytes[j + 0] = Math.round(r); - bytes[j + 1] = Math.round(g); - bytes[j + 2] = Math.round(b); - bytes[j + 3] = Math.round(a); + bytes[j + 0] = Math.round(rgba[0]); + bytes[j + 1] = Math.round(rgba[1]); + bytes[j + 2] = Math.round(rgba[2]); + bytes[j + 3] = Math.round(rgba[3]); } if (canvas != null) {