Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Add string dtype to Tensor (#1408)
Browse files Browse the repository at this point in the history
Add `string` dtype to `Tensor`.

This opens the door for adding Python's [`string_ops`](https://www.tensorflow.org/api_docs/python/tf/strings) to TensorFlow.js, which are used for text-based models, and for adding pre-processing layers that operate on strings.

Details:
- dtype was not added as a generic to the Tensor class in order to keep compiler errors simple and code backwards compatible.
- dataSync() can be optionally typed to cast its result. E.g. `t.dataSync<'string'>()` returns `string[]` while `t.dataSync()` returns `TypedArray` for backwards compatibility.
- `layers` and `converter` pass with this build. `node` has 30ish failed tests since `string` is an unknown dtype.
- Only `clone`, `reshape` and `cast` work with strings at this point to keep this PR small. Other ops will get the functionality in a follow-up PR.
- Added unit tests to assert that numeric ops throw on string tensors.
- Backends now should support dtype `string` in their `register/write/read` methods.
- Added a vscode config to do debugging directly from vscode

FEATURE
  • Loading branch information
dsmilkov committed Nov 27, 2018
1 parent a25776b commit f217045
Show file tree
Hide file tree
Showing 44 changed files with 1,764 additions and 468 deletions.
19 changes: 19 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "chrome",
"request": "attach",
"name": "Attach Karma Chrome",
"address": "localhost",
"port": 9333,
"pathMapping": {
"/": "${workspaceRoot}",
"/base/": "${workspaceRoot}/"
}
}
]
}
4 changes: 3 additions & 1 deletion karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ module.exports = function(config) {
chrome_with_swift_shader: {
base: 'Chrome',
flags: ['--blacklist-accelerated-compositing', '--blacklist-webgl']
}
},
chrome_debugging:
{base: 'Chrome', flags: ['--remote-debugging-port=9333']}
},
client: {jasmine: {random: false}, args: args}
});
Expand Down
73 changes: 73 additions & 0 deletions src/buffer_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
import {ALL_ENVS, expectArraysClose, expectArraysEqual} from './test_util';

describeWithFlags('tf.buffer', ALL_ENVS, () => {
it('float32', () => {
const buff = tf.buffer([1, 2, 3], 'float32');
buff.set(1.3, 0, 0, 0);
buff.set(2.9, 0, 1, 0);
expect(buff.get(0, 0, 0)).toBeCloseTo(1.3);
expect(buff.get(0, 0, 1)).toBeCloseTo(0);
expect(buff.get(0, 0, 2)).toBeCloseTo(0);
expect(buff.get(0, 1, 0)).toBeCloseTo(2.9);
expect(buff.get(0, 1, 1)).toBeCloseTo(0);
expect(buff.get(0, 1, 2)).toBeCloseTo(0);
expectArraysClose(buff.toTensor(), [1.3, 0, 0, 2.9, 0, 0]);
expectArraysClose(buff.values, new Float32Array([1.3, 0, 0, 2.9, 0, 0]));
});

it('int32', () => {
const buff = tf.buffer([2, 3], 'int32');
buff.set(1.3, 0, 0);
buff.set(2.1, 1, 1);
expect(buff.get(0, 0)).toEqual(1);
expect(buff.get(0, 1)).toEqual(0);
expect(buff.get(0, 2)).toEqual(0);
expect(buff.get(1, 0)).toEqual(0);
expect(buff.get(1, 1)).toEqual(2);
expect(buff.get(1, 2)).toEqual(0);
expectArraysClose(buff.toTensor(), [1, 0, 0, 0, 2, 0]);
expectArraysClose(buff.values, new Int32Array([1, 0, 0, 0, 2, 0]));
});

it('bool', () => {
const buff = tf.buffer([4], 'bool');
buff.set(true, 1);
buff.set(true, 2);
expect(buff.get(0)).toBeFalsy();
expect(buff.get(1)).toBeTruthy();
expect(buff.get(2)).toBeTruthy();
expect(buff.get(3)).toBeFalsy();
expectArraysClose(buff.toTensor(), [0, 1, 1, 0]);
expectArraysClose(buff.values, new Uint8Array([0, 1, 1, 0]));
});

it('string', () => {
const buff = tf.buffer([2, 2], 'string');
buff.set('first', 0, 0);
buff.set('third', 1, 0);
expect(buff.get(0, 0)).toEqual('first');
expect(buff.get(0, 1)).toBeFalsy();
expect(buff.get(1, 0)).toEqual('third');
expect(buff.get(1, 1)).toBeFalsy();
expectArraysEqual(buff.toTensor(), ['first', null, 'third', null]);
});
});
71 changes: 51 additions & 20 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode
import {DataId, Tensor, Tensor3D, Variable} from './tensor';
import {NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer, isTensorInList} from './tensor_util';
import {DataType, TypedArray} from './types';
import {DataType, DataValues} from './types';
import * as util from './util';
import {makeOnesTypedArray, now, sizeFromShape} from './util';
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';

/**
* A function that computes an output. The save function is for saving tensors
Expand All @@ -42,7 +42,7 @@ export type CustomGradientFunc<T extends Tensor> = (...args: Tensor[]) => {

export type MemoryInfo = {
numTensors: number; numDataBuffers: number; numBytes: number;
unreliable?: boolean;
unreliable?: boolean; reasons: string[];
};

type KernelProfile = {
Expand Down Expand Up @@ -85,6 +85,7 @@ export class Engine implements TensorManager, DataMover {
private nextTapeNodeId = 0;
private numBytes = 0;
private numTensors = 0;
private numStringTensors = 0;
private numDataBuffers = 0;

private profiling = false;
Expand All @@ -102,6 +103,7 @@ export class Engine implements TensorManager, DataMover {

private tensorInfo = new WeakMap<DataId, {
backend: KernelBackend,
bytes: number,
dtype: DataType,
shape: number[],
refCount: number
Expand Down Expand Up @@ -250,18 +252,26 @@ export class Engine implements TensorManager, DataMover {
this.tensorInfo.get(a.dataId).refCount :
0;
this.numTensors++;
if (a.dtype === 'string') {
this.numStringTensors++;
}
if (refCount === 0) {
this.numDataBuffers++;

// Don't count bytes for complex numbers as they are counted by their
// components.
if (a.dtype !== 'complex64') {
this.numBytes +=
util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype);
// Bytes for complex numbers are counted by their components. Bytes for
// string tensors are counted when writing values.
let bytes = 0;
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
bytes = util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype);
}
this.tensorInfo.set(
a.dataId,
{backend: this.backend, dtype: a.dtype, shape: a.shape, refCount: 0});
this.tensorInfo.set(a.dataId, {
backend: this.backend,
dtype: a.dtype,
shape: a.shape,
bytes,
refCount: 0
});
this.numBytes += bytes;
this.backend.register(a.dataId, a.shape, a.dtype);
}
this.tensorInfo.get(a.dataId).refCount++;
Expand All @@ -285,17 +295,19 @@ export class Engine implements TensorManager, DataMover {
this.keepTensors.delete(a.id);
}
this.numTensors--;
const refCount = this.tensorInfo.get(a.dataId).refCount;
if (a.dtype === 'string') {
this.numStringTensors--;
}
const info = this.tensorInfo.get(a.dataId);
const refCount = info.refCount;
if (refCount <= 1) {
const info = this.tensorInfo.get(a.dataId);
info.backend.disposeData(a.dataId);
this.numDataBuffers--;
// Don't count bytes for complex numbers as they are counted by their
// components.
if (a.dtype !== 'complex64') {
this.numBytes -=
util.sizeFromShape(a.shape) * util.bytesPerElement(a.dtype);
this.numBytes -= info.bytes;
}
this.numDataBuffers--;
info.backend.disposeData(a.dataId);
this.tensorInfo.delete(a.dataId);
} else {
this.tensorInfo.get(a.dataId).refCount--;
Expand All @@ -318,6 +330,15 @@ export class Engine implements TensorManager, DataMover {
info.numTensors = this.numTensors;
info.numDataBuffers = this.numDataBuffers;
info.numBytes = this.numBytes;
if (this.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push(
'Memory usage by string tensors is approximate ' +
'(2 bytes per character)');
}
return info;
}

Expand Down Expand Up @@ -457,6 +478,9 @@ export class Engine implements TensorManager, DataMover {
f: () => T, xs: Tensor[], dy?: T,
allowNoGradients = false): {value: T, grads: Tensor[]} {
util.assert(xs.length > 0, 'gradients() received an empty list of xs.');
if (dy != null && dy.dtype !== 'float32') {
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
}

return this.tidy('gradients', () => {
const y = f();
Expand Down Expand Up @@ -537,8 +561,15 @@ export class Engine implements TensorManager, DataMover {
}

// Forwarding to backend.
write(dataId: DataId, values: TypedArray): void {
write(dataId: DataId, values: DataValues): void {
const info = this.tensorInfo.get(dataId);
// Bytes for string tensors are counted when writing.
if (info.dtype === 'string') {
const newBytes = bytesFromStringArray(values as string[]);
this.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}

if (this.backend !== info.backend) {
// Delete the tensor from the old backend and move it to the new backend.
info.backend.disposeData(dataId);
Expand All @@ -547,12 +578,12 @@ export class Engine implements TensorManager, DataMover {
}
this.backend.write(dataId, values);
}
readSync(dataId: DataId): TypedArray {
readSync(dataId: DataId): DataValues {
// Route the read to the correct backend.
const info = this.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
}
read(dataId: DataId): Promise<TypedArray> {
read(dataId: DataId): Promise<DataValues> {
// Route the read to the correct backend.
const info = this.tensorInfo.get(dataId);
return info.backend.read(dataId);
Expand Down
70 changes: 69 additions & 1 deletion src/engine_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {describeWithFlags} from './jasmine_util';
import {MathBackendCPU} from './kernels/backend_cpu';
import {MathBackendWebGL} from './kernels/backend_webgl';
import {Tensor} from './tensor';
import {ALL_ENVS, expectArraysClose, expectArraysEqual, expectNumbersClose, WEBGL_ENVS} from './test_util';
import {ALL_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, expectNumbersClose, WEBGL_ENVS} from './test_util';

describeWithFlags('fromPixels + regular math op', WEBGL_ENVS, () => {
it('debug mode does not error when no nans', () => {
Expand Down Expand Up @@ -361,6 +361,74 @@ describeWithFlags('memory', ALL_ENVS, () => {
expect(sum.dtype).toBe('int32');
expectArraysClose(sum, [1 + 1 + 0 + 1]);
});

it('string tensor', () => {
const a = tf.tensor([['a', 'bb'], ['c', 'd']]);

expect(tf.memory().numTensors).toBe(1);
expect(tf.memory().numBytes).toBe(10); // 5 letters, each 2 bytes.

a.dispose();

expect(tf.memory().numTensors).toBe(0);
expect(tf.memory().numBytes).toBe(0);
});

it('unreliable is true for string tensors', () => {
tf.tensor('a');
const mem = tf.memory();
expect(mem.unreliable).toBe(true);
const expectedReason = 'Memory usage by string tensors is approximate ' +
'(2 bytes per character)';
expect(mem.reasons.indexOf(expectedReason) >= 0).toBe(true);
});
});

describeWithFlags('memory webgl', WEBGL_ENVS, () => {
it('unreliable is falsy/not present when all tensors are numeric', () => {
tf.tensor(1);
const mem = tf.memory();
expect(mem.numTensors).toBe(1);
expect(mem.numDataBuffers).toBe(1);
expect(mem.numBytes).toBe(4);
expect(mem.unreliable).toBeFalsy();
});
});

describeWithFlags('memory cpu', CPU_ENVS, () => {
it('unreliable is true due to auto gc', () => {
tf.tensor(1);
const mem = tf.memory();
expect(mem.numTensors).toBe(1);
expect(mem.numDataBuffers).toBe(1);
expect(mem.numBytes).toBe(4);
expect(mem.unreliable).toBe(true);

const expectedReason =
'The reported memory is an upper bound. Due to automatic garbage ' +
'collection, the true allocated memory may be less.';
expect(mem.reasons.indexOf(expectedReason) >= 0).toBe(true);
});

it('unreliable is true due to both auto gc and string tensors', () => {
tf.tensor(1);
tf.tensor('a');

const mem = tf.memory();
expect(mem.numTensors).toBe(2);
expect(mem.numDataBuffers).toBe(2);
expect(mem.numBytes).toBe(6);
expect(mem.unreliable).toBe(true);

const expectedReasonGC =
'The reported memory is an upper bound. Due to automatic garbage ' +
'collection, the true allocated memory may be less.';
expect(mem.reasons.indexOf(expectedReasonGC) >= 0).toBe(true);
const expectedReasonString =
'Memory usage by string tensors is approximate ' +
'(2 bytes per character)';
expect(mem.reasons.indexOf(expectedReasonString) >= 0).toBe(true);
});
});

describeWithFlags('profile', ALL_ENVS, () => {
Expand Down
9 changes: 4 additions & 5 deletions src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ export class Environment {
* (undisposed) at this time, which is ≤ the number of tensors
* (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
* data buffer with `a`).
* - `unreliable`: `Optional` `boolean`:
* - On WebGL, not present (always reliable).
* - On CPU, true. Due to automatic garbage collection, these numbers
* represent undisposed tensors, i.e. not wrapped in `tidy()`, or
* lacking a call to `tensor.dispose()`.
* - `unreliable`: True if the memory usage is unreliable. See `reasons` when
* `unrealible` is true.
* - `reasons`: `string[]`, reasons why the memory is unreliable, present if
* `unreliable` is true.
*/
/** @doc {heading: 'Performance', subheading: 'Memory'} */
static memory(): MemoryInfo {
Expand Down
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
export {SGDOptimizer} from './optimizers/sgd_optimizer';
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor';
export {NamedTensorMap} from './tensor_types';
export {DataType, Rank, ShapeMap} from './types';
export {DataType, DataTypeMap, DataValues, Rank, ShapeMap} from './types';

export * from './ops/ops';
export {LSTMCellFunc} from './ops/lstm';
Expand Down
6 changes: 3 additions & 3 deletions src/jasmine_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export interface TestEnv {

export let TEST_ENVS: TestEnv[] = [
{
name: 'test-webgl1',
name: 'webgl1',
factory: () => new MathBackendWebGL(),
features: {
'WEBGL_VERSION': 1,
Expand All @@ -102,7 +102,7 @@ export let TEST_ENVS: TestEnv[] = [
}
},
{
name: 'test-webgl2',
name: 'webgl2',
factory: () => new MathBackendWebGL(),
features: {
'WEBGL_VERSION': 2,
Expand All @@ -111,7 +111,7 @@ export let TEST_ENVS: TestEnv[] = [
}
},
{
name: 'test-cpu',
name: 'cpu',
factory: () => new MathBackendCPU(),
features: {'HAS_WEBGL': false}
}
Expand Down

0 comments on commit f217045

Please sign in to comment.