From 1adea1d1f437506b2c1ba950a9fca7843ae08c49 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Mon, 27 May 2019 18:23:46 -0400 Subject: [PATCH 1/8] save --- src/io/io_utils.ts | 29 +++++++++++------------------ src/io/types.ts | 6 +++++- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 86540c2eb8..aebf0babbc 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -16,7 +16,6 @@ */ import {tensor} from '../ops/tensor_ops'; -import {Tensor} from '../tensor'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; import {sizeFromShape} from '../util'; @@ -94,7 +93,7 @@ export function decodeWeights( const dtype = spec.dtype; const shape = spec.shape; const size = sizeFromShape(shape); - let typedArray: TypedArray; + let values: TypedArray|string[]; if ('quantization' in spec) { const quantization = spec.quantization; @@ -111,43 +110,37 @@ export function decodeWeights( new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); if (dtype === 'float32') { - typedArray = Float32Array.from( + values = Float32Array.from( quantizedArray, v => v * quantization.scale + quantization.min); } else if (dtype === 'int32') { - typedArray = Int32Array.from( + values = Int32Array.from( quantizedArray, v => Math.round(v * quantization.scale + quantization.min)); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * quantizationSizeFactor; + } else if (dtype === 'string') { + const decoder = new TextDecoder('utf-8'); + const bytes = buffer.slice(offset, offset + spec.byte_length); + values = decoder.decode(bytes).split('\x00'); } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); if (dtype === 'float32') { - typedArray = new Float32Array(byteBuffer); + values = new Float32Array(byteBuffer); } else if (dtype === 'int32') { - typedArray = new Int32Array(byteBuffer); + values = new Int32Array(byteBuffer); } else if (dtype === 'bool') { - typedArray = new Uint8Array(byteBuffer); + values = new Uint8Array(byteBuffer); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * dtypeFactor; } - let value: Tensor; - if (dtype === 'float32') { - value = tensor(typedArray, shape, 'float32'); - } else if (dtype === 'int32') { - value = tensor(typedArray, shape, 'int32'); - } else if (dtype === 'bool') { - value = tensor(typedArray, shape, 'bool'); - } else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); - } - out[name] = value; + out[name] = tensor(values, shape, dtype); } return out; } diff --git a/src/io/types.ts b/src/io/types.ts index 0e393480bb..97a8833a68 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -85,7 +85,7 @@ export declare interface WeightsManifestEntry { /** * Data type of the weight. */ - dtype: 'float32'|'int32'|'bool'; + dtype: 'float32'|'int32'|'bool'|'string'; /** * Type of the weight. @@ -105,6 +105,10 @@ export declare interface WeightsManifestEntry { min: number, // The (possibly nudged) minimum weight to add. dtype: 'uint16'|'uint8' // The dtype of the quantized weights. }; + + // Available for string tensors. + delimiter?: string; + byte_length?: number; } /** From 1fc472bd8242e45a2f289e24dd17ad2bfd336109 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Fri, 31 May 2019 07:11:07 -0400 Subject: [PATCH 2/8] save --- package.json | 3 +- rollup.config.js | 44 +++++++++++------------ src/io/io.ts | 5 +-- src/io/io_utils.ts | 28 +++++++++++---- src/io/io_utils_test.ts | 59 +++++++++++++++++++++++++++++-- src/io/types.ts | 9 +++-- src/platforms/platform.ts | 5 +++ src/platforms/platform_browser.ts | 15 ++++++++ src/platforms/platform_node.ts | 20 +++++++++++ 9 files changed, 151 insertions(+), 37 deletions(-) diff --git a/package.json b/package.json index 675c8f2bf0..ac403d026c 100644 --- a/package.json +++ b/package.json @@ -81,6 +81,7 @@ "seedrandom": "2.4.3" }, "browser": { - "node-fetch": false + "node-fetch": false, + "util": false } } diff --git a/rollup.config.js b/rollup.config.js index 229e398f38..fa85badd00 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -59,7 +59,7 @@ function config({plugins = [], output = {}, external = [], visualize = false}) { node(), // Polyfill require() from dependencies. commonjs({ - ignore: ['crypto', 'node-fetch'], + ignore: ['crypto', 'node-fetch', 'util'], include: 'node_modules/**', namedExports: { './node_modules/seedrandom/index.js': ['alea'], @@ -99,27 +99,27 @@ module.exports = cmdOptions => { })); } - // tf-core.min.js - bundles.push(config({ - plugins: [terser({output: {preamble: PREAMBLE}})], - output: { - format: 'umd', - name: 'tf', - extend: true, - file: 'dist/tf-core.min.js', - }, - visualize: cmdOptions.visualize - })); + // // tf-core.min.js + // bundles.push(config({ + // plugins: [terser({output: {preamble: PREAMBLE}})], + // output: { + // format: 'umd', + // name: 'tf', + // extend: true, + // file: 'dist/tf-core.min.js', + // }, + // visualize: cmdOptions.visualize + // })); - if (!cmdOptions.ci) { - // tf-core.esm.js - bundles.push(config({ - plugins: [terser({output: {preamble: PREAMBLE}})], - output: { - format: 'es', - file: 'dist/tf-core.esm.js', - } - })); - } + // if (!cmdOptions.ci) { + // // tf-core.esm.js + // bundles.push(config({ + // plugins: [terser({output: {preamble: PREAMBLE}})], + // output: { + // format: 'es', + // file: 'dist/tf-core.esm.js', + // } + // })); + // } return bundles; }; diff --git a/src/io/io.ts b/src/io/io.ts index 685a77c9bf..3e6fa1a067 100644 --- a/src/io/io.ts +++ b/src/io/io.ts @@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http'; import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils'; import {fromMemory, withSaveHandler} from './passthrough'; import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; -import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry, WeightGroup} from './types'; +import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, StringWeightsManifestEntry, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeights, weightsLoaderFactory} from './weights_loader'; export {copyModel, listModels, moveModel, removeModel} from './model_management'; @@ -54,9 +54,10 @@ export { SaveConfig, SaveHandler, SaveResult, + StringWeightsManifestEntry, + WeightGroup, weightsLoaderFactory, WeightsManifestConfig, WeightsManifestEntry, - WeightGroup, withSaveHandler }; diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index aebf0babbc..fb19545328 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -15,12 +15,16 @@ * ============================================================================= */ +import {ENV} from '../environment'; import {tensor} from '../ops/tensor_ops'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; import {sizeFromShape} from '../util'; -import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightGroup, WeightsManifestEntry} from './types'; +import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, StringWeightsManifestEntry, WeightGroup, WeightsManifestEntry} from './types'; + +/** Used to delimit neighboring strings when encoding string tensors. */ +export const STRING_DELIMITER = '\x00'; /** * Encode a map from names to weight values as an ArrayBuffer, along with an @@ -53,15 +57,25 @@ export async function encodeWeights( for (let i = 0; i < names.length; ++i) { const name = names[i]; const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name]; - if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool') { + if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && + t.dtype !== 'string') { throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`); } const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype}; + if (t.dtype === 'string') { + const utf8bytes = + ENV.platform.encodeUTF8((await t.data()).join(STRING_DELIMITER)); + dataPromises.push(Promise.resolve(utf8bytes)); + const stringSpec = spec as StringWeightsManifestEntry; + stringSpec.byteLength = utf8bytes.length; + stringSpec.delimiter = STRING_DELIMITER; + } else { + dataPromises.push(t.data()); + } if (group != null) { spec.group = group; } specs.push(spec); - dataPromises.push(t.data()); } const tensorValues = await Promise.all(dataPromises); @@ -121,9 +135,11 @@ export function decodeWeights( } offset += size * quantizationSizeFactor; } else if (dtype === 'string') { - const decoder = new TextDecoder('utf-8'); - const bytes = buffer.slice(offset, offset + spec.byte_length); - values = decoder.decode(bytes).split('\x00'); + const stringSpec = spec as StringWeightsManifestEntry; + const bytes = + new Uint8Array(buffer.slice(offset, offset + stringSpec.byteLength)); + values = ENV.platform.decodeUTF8(bytes).split(stringSpec.delimiter); + offset += stringSpec.byteLength; } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 5a6feaa07d..ce8fd9bf67 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -22,7 +22,7 @@ import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {expectArraysEqual} from '../test_util'; import {expectArraysClose} from '../test_util'; -import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils'; +import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, STRING_DELIMITER, stringByteLength} from './io_utils'; import {WeightsManifestEntry} from './types'; describe('concatenateTypedArrays', () => { @@ -329,6 +329,51 @@ describe('encodeWeights', () => { ]); }); + it('String tensors', async () => { + const tensors: NamedTensorMap = { + x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2], 'string'), + x2: scalar('', 'string'), // Empty string. + x3: tensor1d(['здраво', 'поздрав'], 'string'), // Cyrillic. + }; + const dataAndSpecs = await tf.io.encodeWeights(tensors); + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs as tf.io.StringWeightsManifestEntry[]; + const x1ByteLength = 7 + 3; // 7 chars + 3 delimiters. + const x2ByteLength = 0; // No chars. + const x3ByteLength = 13 * 2 + 1; // 13 cyrillic letters + 1 delimiter. + expect(data.byteLength).toEqual(x1ByteLength + x2ByteLength + x3ByteLength); + let delim = specs[0].delimiter; + expect(new Uint8Array(data, 0, x1ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8(`a${delim}bc${delim}def${delim}g`)); + // The middle string takes up 0 bytes. + delim = specs[2].delimiter; + expect(new Uint8Array(data, x1ByteLength, x3ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8(`здраво${delim}поздрав`)); + expect(specs).toEqual([ + { + name: 'x1', + dtype: 'string', + shape: [2, 2], + byteLength: x1ByteLength, + delimiter: STRING_DELIMITER, + }, + { + name: 'x2', + dtype: 'string', + shape: [], + byteLength: x2ByteLength, + delimiter: STRING_DELIMITER, + }, + { + name: 'x3', + dtype: 'string', + shape: [2], + byteLength: x3ByteLength, + delimiter: STRING_DELIMITER, + } + ]); + }); + it('Mixed dtype tensors', async () => { const tensors: NamedTensorMap = { x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), @@ -370,17 +415,25 @@ describeWithFlags('decodeWeights', {}, () => { x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), x2: scalar(13.37, 'float32'), x3: tensor1d([true, false, false], 'bool'), + x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'), + x5: tensor1d([''], 'string'), y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; - expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 3 + 4 * 3); + // 12 bytes from cyrillic (6 letters) + 3 bytes from ascii + 3 delimiters. + const x4Bytes = 12 + 3 + 3; + const x5Bytes = 0; + expect(data.byteLength) + .toEqual(4 * 4 + 4 * 1 + 1 * 3 + x4Bytes + x5Bytes + 4 * 3); const decoded = tf.io.decodeWeights(data, specs); - expect(Object.keys(decoded).length).toEqual(4); + expect(Object.keys(decoded).length).toEqual(6); expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data()); expectArraysEqual(await decoded['x2'].data(), await tensors['x2'].data()); expectArraysEqual(await decoded['x3'].data(), await tensors['x3'].data()); + expectArraysEqual(await decoded['x4'].data(), await tensors['x4'].data()); + expectArraysEqual(await decoded['x5'].data(), await tensors['x5'].data()); expectArraysEqual(await decoded['y1'].data(), await tensors['y1'].data()); }); diff --git a/src/io/types.ts b/src/io/types.ts index 97a8833a68..d05892fa52 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -105,10 +105,13 @@ export declare interface WeightsManifestEntry { min: number, // The (possibly nudged) minimum weight to add. dtype: 'uint16'|'uint8' // The dtype of the quantized weights. }; +} - // Available for string tensors. - delimiter?: string; - byte_length?: number; +export declare interface StringWeightsManifestEntry extends + WeightsManifestEntry { + dtype: 'string'; + delimiter: string; + byteLength: number; } /** diff --git a/src/platforms/platform.ts b/src/platforms/platform.ts index c60621a3e3..31037d1521 100644 --- a/src/platforms/platform.ts +++ b/src/platforms/platform.ts @@ -28,4 +28,9 @@ export interface Platform { * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request */ fetch(path: string, requestInits?: RequestInit): Promise; + + /** UTF-8 encode the provided string into an array of bytes. */ + encodeUTF8(text: string): Uint8Array; + /** UTF-8 decode the provided bytes into a string. */ + decodeUTF8(bytes: Uint8Array): string; } diff --git a/src/platforms/platform_browser.ts b/src/platforms/platform_browser.ts index 5573483506..7ad880aeda 100644 --- a/src/platforms/platform_browser.ts +++ b/src/platforms/platform_browser.ts @@ -18,6 +18,21 @@ import {ENV} from '../environment'; import {Platform} from './platform'; export class PlatformBrowser implements Platform { + private textEncoder: TextEncoder; + private textDecoder: TextDecoder; + + constructor() { + // Both the encoder and the decoder use UTF-8 encoding by default. + this.textEncoder = new TextEncoder(); + this.textDecoder = new TextDecoder(); + } + + encodeUTF8(text: string): Uint8Array { + return this.textEncoder.encode(text); + } + decodeUTF8(bytes: Uint8Array): string { + return this.textDecoder.decode(bytes); + } fetch(path: string, init?: RequestInit): Promise { return fetch(path, init); } diff --git a/src/platforms/platform_node.ts b/src/platforms/platform_node.ts index 56dacd805c..baab5c8c52 100644 --- a/src/platforms/platform_node.ts +++ b/src/platforms/platform_node.ts @@ -26,6 +26,26 @@ export const getNodeFetch = { export let systemFetch: (url: string, init?: RequestInit) => Promise; export class PlatformNode implements Platform { + private textEncoder: TextEncoder; + private textDecoder: TextDecoder; + + constructor() { + // tslint:disable-next-line: no-require-imports + const util = require('util'); + // Both the encoder and the decoder use UTF-8 encoding by default. + this.textEncoder = new util.TextEncoder(); + this.textDecoder = new util.TextDecoder(); + } + + encodeUTF8(text: string): Uint8Array { + return this.textEncoder.encode(text); + } + decodeUTF8(bytes: Uint8Array): string { + if (bytes.length === 0) { + return ''; + } + return this.textDecoder.decode(bytes); + } fetch(path: string, requestInits?: RequestInit): Promise { if (ENV.global.fetch != null) { return ENV.global.fetch(path, requestInits); From ed348034efa731a4d46552a64643901acd0b5427 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 18 Jun 2019 07:26:57 -0400 Subject: [PATCH 3/8] save --- src/ops/array_ops.ts | 2 +- src/ops/array_ops_test.ts | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index d8b1c80a0b..d0191747ce 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -881,7 +881,7 @@ function cumsum_( /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ function expandDims_( x: Tensor|TensorLike, axis = 0): Tensor { - const $x = convertToTensor(x, 'x', 'expandDims'); + const $x = convertToTensor(x, 'x', 'expandDims', null); util.assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor'); const newShape = $x.shape.slice(); diff --git a/src/ops/array_ops_test.ts b/src/ops/array_ops_test.ts index 88862da30b..ba0ec07e75 100644 --- a/src/ops/array_ops_test.ts +++ b/src/ops/array_ops_test.ts @@ -3512,6 +3512,20 @@ describeWithFlags('expandDims', ALL_ENVS, () => { expectArraysClose(await res.data(), [4]); }); + it('1d string tensor', async () => { + const t = tf.tensor(['hello', 'world']); + const res = t.expandDims(); + expect(res.shape).toEqual([1, 2]); + expectArraysClose(await res.data(), ['hello', 'world']); + }); + + it('2d string tensor, axis=1', async () => { + const t = tf.tensor([['a', 'b'], ['c', 'd']]); + const res = t.expandDims(1); + expect(res.shape).toEqual([2, 1, 2]); + expectArraysClose(await res.data(), ['a', 'b', 'c', 'd']); + }); + it('throws when passed a non-tensor', () => { expect(() => tf.expandDims({} as tf.Tensor)) .toThrowError(/Argument 'x' passed to 'expandDims' must be a Tensor/); From 5525199a463053906615c04ff70f938ea16e2709 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 18 Jun 2019 08:00:45 -0400 Subject: [PATCH 4/8] save --- src/backends/cpu/backend_cpu.ts | 22 ++------------ src/backends/tile_impl.ts | 47 +++++++++++++++++++++++++++++ src/backends/webgl/backend_webgl.ts | 10 ++++-- src/ops/array_ops.ts | 2 +- src/ops/array_ops_test.ts | 17 +++++++++++ 5 files changed, 75 insertions(+), 23 deletions(-) create mode 100644 src/backends/tile_impl.ts diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index 56e5424c03..62e80a065f 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -42,6 +42,7 @@ import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -2065,26 +2066,7 @@ export class MathBackendCPU implements KernelBackend { tile(x: T, reps: number[]): T { this.assertNotComplex(x, 'tile'); - - const newShape: number[] = new Array(x.rank); - for (let i = 0; i < newShape.length; i++) { - newShape[i] = x.shape[i] * reps[i]; - } - const result = ops.buffer(newShape, x.dtype); - const xBuf = this.bufferSync(x); - for (let i = 0; i < result.values.length; ++i) { - const newLoc = result.indexToLoc(i); - - const originalLoc: number[] = new Array(x.rank); - for (let i = 0; i < originalLoc.length; i++) { - originalLoc[i] = newLoc[i] % x.shape[i]; - } - - const originalIndex = xBuf.locToIndex(originalLoc); - - result.values[i] = xBuf.values[originalIndex]; - } - return result.toTensor() as T; + return tile(this.bufferSync(x), reps) as T; } pad( diff --git a/src/backends/tile_impl.ts b/src/backends/tile_impl.ts new file mode 100644 index 0000000000..a1cb8c9121 --- /dev/null +++ b/src/backends/tile_impl.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * An implementation of the tile kernel shared between webgl and cpu for string + * tensors only. + */ + +import {buffer} from '../ops/array_ops'; +import {Tensor, TensorBuffer} from '../tensor'; +import {DataType, Rank} from '../types'; + +export function tile( + xBuf: TensorBuffer, reps: number[]): Tensor { + const newShape: number[] = new Array(xBuf.rank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = xBuf.shape[i] * reps[i]; + } + const result = buffer(newShape, xBuf.dtype); + for (let i = 0; i < result.values.length; ++i) { + const newLoc = result.indexToLoc(i); + + const originalLoc: number[] = new Array(xBuf.rank); + for (let i = 0; i < originalLoc.length; i++) { + originalLoc[i] = newLoc[i] % xBuf.shape[i]; + } + + const originalIndex = xBuf.locToIndex(originalLoc); + + result.values[i] = xBuf.values[originalIndex]; + } + return result.toTensor() as Tensor; +} diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index e675390293..33ebfda3d7 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -23,6 +23,7 @@ import {ENGINE, MemoryInfo, TimingInfo} from '../../engine'; import {ENV} from '../../environment'; import {tidy} from '../../globals'; import {warn} from '../../log'; +import {buffer} from '../../ops/array_ops'; import * as array_ops_util from '../../ops/array_ops_util'; import * as axis_util from '../../ops/axis_util'; import {computeOutShape} from '../../ops/concat_util'; @@ -44,6 +45,7 @@ import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -974,6 +976,10 @@ export class MathBackendWebGL implements KernelBackend { } tile(x: T, reps: number[]): T { + if (x.dtype === 'string') { + const buf = buffer(x.shape, x.dtype, this.readSync(x.dataId) as string[]); + return tile(buf, reps) as T; + } const program = new TileProgram(x.shape, reps); return this.compileAndRun(program, [x]); } @@ -2471,8 +2477,8 @@ export class MathBackendWebGL implements KernelBackend { } else { this.canvas = null; } - if (this.fromPixels2DContext != null - && this.fromPixels2DContext.canvas.remove != null) { + if (this.fromPixels2DContext != null && + this.fromPixels2DContext.canvas.remove != null) { this.fromPixels2DContext.canvas.remove(); } if (this.gpgpuCreatedLocally) { diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index d0191747ce..793e16f30c 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -418,7 +418,7 @@ function cast_(x: T|TensorLike, dtype: DataType): T { */ /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function tile_(x: T|TensorLike, reps: number[]): T { - const $x = convertToTensor(x, 'x', 'tile'); + const $x = convertToTensor(x, 'x', 'tile', null); util.assert( $x.rank === reps.length, diff --git a/src/ops/array_ops_test.ts b/src/ops/array_ops_test.ts index ba0ec07e75..5be54dd0dc 100644 --- a/src/ops/array_ops_test.ts +++ b/src/ops/array_ops_test.ts @@ -1946,6 +1946,23 @@ describeWithFlags('tile', ALL_ENVS, () => { await t2.data(), [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]); }); + it('1d string tensor', async () => { + const a = tf.tensor(['a', 'b', 'c']); + const res = tf.tile(a, [2]); + expect(res.shape).toEqual([6]); + expectArraysEqual(await res.data(), ['a', 'b', 'c', 'a', 'b', 'c']); + }); + + it('2d string tensor', async () => { + const a = tf.tensor([['a', 'b'], ['c', 'd']]); + const res = tf.tile(a, [2, 3]); + expect(res.shape).toEqual([4, 6]); + expectArraysEqual(await res.data(), [ + 'a', 'b', 'a', 'b', 'a', 'b', 'c', 'd', 'c', 'd', 'c', 'd', + 'a', 'b', 'a', 'b', 'a', 'b', 'c', 'd', 'c', 'd', 'c', 'd' + ]); + }); + it('propagates NaNs', async () => { const t = tf.tensor1d([1, 2, NaN]); From 054f1936d2200411de85c5d03377389aa6b83c95 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 18 Jun 2019 09:30:24 -0400 Subject: [PATCH 5/8] save --- rollup.config.js | 42 +++++++++++++++---------------- src/io/io_utils_test.ts | 30 +++++++++++++++++----- src/io/types.ts | 4 +++ src/platforms/platform_browser.ts | 2 +- src/platforms/platform_node.ts | 3 ++- 5 files changed, 52 insertions(+), 29 deletions(-) diff --git a/rollup.config.js b/rollup.config.js index fa85badd00..4938bdcda9 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -99,27 +99,27 @@ module.exports = cmdOptions => { })); } - // // tf-core.min.js - // bundles.push(config({ - // plugins: [terser({output: {preamble: PREAMBLE}})], - // output: { - // format: 'umd', - // name: 'tf', - // extend: true, - // file: 'dist/tf-core.min.js', - // }, - // visualize: cmdOptions.visualize - // })); + // tf-core.min.js + bundles.push(config({ + plugins: [terser({output: {preamble: PREAMBLE}})], + output: { + format: 'umd', + name: 'tf', + extend: true, + file: 'dist/tf-core.min.js', + }, + visualize: cmdOptions.visualize + })); - // if (!cmdOptions.ci) { - // // tf-core.esm.js - // bundles.push(config({ - // plugins: [terser({output: {preamble: PREAMBLE}})], - // output: { - // format: 'es', - // file: 'dist/tf-core.esm.js', - // } - // })); - // } + if (!cmdOptions.ci) { + // tf-core.esm.js + bundles.push(config({ + plugins: [terser({output: {preamble: PREAMBLE}})], + output: { + format: 'es', + file: 'dist/tf-core.esm.js', + } + })); + } return bundles; }; diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index ce8fd9bf67..3de492be57 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -334,21 +334,28 @@ describe('encodeWeights', () => { x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2], 'string'), x2: scalar('', 'string'), // Empty string. x3: tensor1d(['здраво', 'поздрав'], 'string'), // Cyrillic. + x4: scalar('hello', 'string') // Single string. }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; const specs = dataAndSpecs.specs as tf.io.StringWeightsManifestEntry[]; - const x1ByteLength = 7 + 3; // 7 chars + 3 delimiters. + const x1ByteLength = 7 + 3; // 7 ascii chars + 3 delimiters. const x2ByteLength = 0; // No chars. const x3ByteLength = 13 * 2 + 1; // 13 cyrillic letters + 1 delimiter. - expect(data.byteLength).toEqual(x1ByteLength + x2ByteLength + x3ByteLength); + const x4ByteLength = 5; // 5 ascii chars. + expect(data.byteLength) + .toEqual(x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength); let delim = specs[0].delimiter; expect(new Uint8Array(data, 0, x1ByteLength)) .toEqual(tf.ENV.platform.encodeUTF8(`a${delim}bc${delim}def${delim}g`)); // The middle string takes up 0 bytes. delim = specs[2].delimiter; - expect(new Uint8Array(data, x1ByteLength, x3ByteLength)) + expect(new Uint8Array(data, x1ByteLength + x2ByteLength, x3ByteLength)) .toEqual(tf.ENV.platform.encodeUTF8(`здраво${delim}поздрав`)); + delim = specs[3].delimiter; + expect(new Uint8Array( + data, x1ByteLength + x2ByteLength + x3ByteLength, x4ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8('hello')); expect(specs).toEqual([ { name: 'x1', @@ -370,6 +377,13 @@ describe('encodeWeights', () => { shape: [2], byteLength: x3ByteLength, delimiter: STRING_DELIMITER, + }, + { + name: 'x4', + dtype: 'string', + shape: [], + byteLength: x4ByteLength, + delimiter: STRING_DELIMITER, } ]); }); @@ -416,7 +430,8 @@ describeWithFlags('decodeWeights', {}, () => { x2: scalar(13.37, 'float32'), x3: tensor1d([true, false, false], 'bool'), x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'), - x5: tensor1d([''], 'string'), + x5: tensor1d([''], 'string'), // Empty string. + x6: scalar('hello'), // Single string. y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), }; const dataAndSpecs = await tf.io.encodeWeights(tensors); @@ -425,15 +440,18 @@ describeWithFlags('decodeWeights', {}, () => { // 12 bytes from cyrillic (6 letters) + 3 bytes from ascii + 3 delimiters. const x4Bytes = 12 + 3 + 3; const x5Bytes = 0; + // 5 bytes from ascii. + const x6Bytes = 5; expect(data.byteLength) - .toEqual(4 * 4 + 4 * 1 + 1 * 3 + x4Bytes + x5Bytes + 4 * 3); + .toEqual(4 * 4 + 4 * 1 + 1 * 3 + x4Bytes + x5Bytes + x6Bytes + 4 * 3); const decoded = tf.io.decodeWeights(data, specs); - expect(Object.keys(decoded).length).toEqual(6); + expect(Object.keys(decoded).length).toEqual(7); expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data()); expectArraysEqual(await decoded['x2'].data(), await tensors['x2'].data()); expectArraysEqual(await decoded['x3'].data(), await tensors['x3'].data()); expectArraysEqual(await decoded['x4'].data(), await tensors['x4'].data()); expectArraysEqual(await decoded['x5'].data(), await tensors['x5'].data()); + expectArraysEqual(await decoded['x6'].data(), await tensors['x6'].data()); expectArraysEqual(await decoded['y1'].data(), await tensors['y1'].data()); }); diff --git a/src/io/types.ts b/src/io/types.ts index d05892fa52..6c6d3bc192 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -110,7 +110,11 @@ export declare interface WeightsManifestEntry { export declare interface StringWeightsManifestEntry extends WeightsManifestEntry { dtype: 'string'; + // Used for delimiting neighboring strings. If the tensor has no strings or + // only 1 string, there will be no delimiter. If the tensor has N strings + // (N > 0), there will be N-1 delimiters. delimiter: string; + // Number of bytes used by the whole tensor, including the delimiters. byteLength: number; } diff --git a/src/platforms/platform_browser.ts b/src/platforms/platform_browser.ts index 7ad880aeda..41968373a0 100644 --- a/src/platforms/platform_browser.ts +++ b/src/platforms/platform_browser.ts @@ -22,7 +22,7 @@ export class PlatformBrowser implements Platform { private textDecoder: TextDecoder; constructor() { - // Both the encoder and the decoder use UTF-8 encoding by default. + // The built-in encoder and the decoder use UTF-8 encoding. this.textEncoder = new TextEncoder(); this.textDecoder = new TextDecoder(); } diff --git a/src/platforms/platform_node.ts b/src/platforms/platform_node.ts index baab5c8c52..e027ec8e77 100644 --- a/src/platforms/platform_node.ts +++ b/src/platforms/platform_node.ts @@ -25,6 +25,7 @@ export const getNodeFetch = { }; export let systemFetch: (url: string, init?: RequestInit) => Promise; + export class PlatformNode implements Platform { private textEncoder: TextEncoder; private textDecoder: TextDecoder; @@ -32,7 +33,7 @@ export class PlatformNode implements Platform { constructor() { // tslint:disable-next-line: no-require-imports const util = require('util'); - // Both the encoder and the decoder use UTF-8 encoding by default. + // The built-in encoder and the decoder use UTF-8 encoding. this.textEncoder = new util.TextEncoder(); this.textDecoder = new util.TextDecoder(); } From 179d7485b2668f006bac3c152d1c73fe1ef92917 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Tue, 18 Jun 2019 11:42:19 -0400 Subject: [PATCH 6/8] save --- src/io/io_utils.ts | 15 +++++---- src/ops/array_ops.ts | 6 ++-- src/platforms/platform_browser_test.ts | 46 ++++++++++++++++++++++++++ src/platforms/platform_node_test.ts | 46 ++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 8 deletions(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index fb19545328..cd890cc6c9 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -63,12 +63,15 @@ export async function encodeWeights( } const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype}; if (t.dtype === 'string') { - const utf8bytes = - ENV.platform.encodeUTF8((await t.data()).join(STRING_DELIMITER)); - dataPromises.push(Promise.resolve(utf8bytes)); - const stringSpec = spec as StringWeightsManifestEntry; - stringSpec.byteLength = utf8bytes.length; - stringSpec.delimiter = STRING_DELIMITER; + const utf8bytes = new Promise(async resolve => { + const stringSpec = spec as StringWeightsManifestEntry; + const data = await t.data(); + const bytes = ENV.platform.encodeUTF8(data.join(STRING_DELIMITER)); + stringSpec.byteLength = bytes.length; + stringSpec.delimiter = STRING_DELIMITER; + resolve(bytes); + }); + dataPromises.push(utf8bytes); } else { dataPromises.push(t.data()); } diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index 793e16f30c..01af4ecad7 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -418,7 +418,8 @@ function cast_(x: T|TensorLike, dtype: DataType): T { */ /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function tile_(x: T|TensorLike, reps: number[]): T { - const $x = convertToTensor(x, 'x', 'tile', null); + const parseAs: DataType = null; + const $x = convertToTensor(x, 'x', 'tile', parseAs); util.assert( $x.rank === reps.length, @@ -881,7 +882,8 @@ function cumsum_( /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ function expandDims_( x: Tensor|TensorLike, axis = 0): Tensor { - const $x = convertToTensor(x, 'x', 'expandDims', null); + const parseAs: DataType = null; + const $x = convertToTensor(x, 'x', 'expandDims', parseAs); util.assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor'); const newShape = $x.shape.slice(); diff --git a/src/platforms/platform_browser_test.ts b/src/platforms/platform_browser_test.ts index fed0fd224b..74549fe260 100644 --- a/src/platforms/platform_browser_test.ts +++ b/src/platforms/platform_browser_test.ts @@ -29,4 +29,50 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => { expect(self.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'}); }); + + it('encodeUTF8 single string', () => { + const platform = new PlatformBrowser(); + const bytes = platform.encodeUTF8('hello'); + expect(bytes.length).toBe(5); + expect(bytes).toEqual(new Uint8Array([104, 101, 108, 108, 111])); + }); + + it('encodeUTF8 two strings delimited', () => { + const platform = new PlatformBrowser(); + const bytes = platform.encodeUTF8('hello\x00world'); + expect(bytes.length).toBe(11); + expect(bytes).toEqual( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + }); + + it('encodeUTF8 cyrillic', () => { + const platform = new PlatformBrowser(); + const bytes = platform.encodeUTF8('Здраво'); + expect(bytes.length).toBe(12); + expect(bytes).toEqual(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + }); + + it('decodeUTF8 single string', () => { + const platform = new PlatformBrowser(); + const s = platform.decodeUTF8(new Uint8Array([104, 101, 108, 108, 111])); + expect(s.length).toBe(5); + expect(s).toEqual('hello'); + }); + + it('decodeUTF8 two strings delimited', () => { + const platform = new PlatformBrowser(); + const s = platform.decodeUTF8( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + expect(s.length).toBe(11); + expect(s).toEqual('hello\x00world'); + }); + + it('decodeUTF8 cyrillic', () => { + const platform = new PlatformBrowser(); + const s = platform.decodeUTF8(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + expect(s.length).toBe(6); + expect(s).toEqual('Здраво'); + }); }); diff --git a/src/platforms/platform_node_test.ts b/src/platforms/platform_node_test.ts index 60d8cb1d91..bb9f1a438e 100644 --- a/src/platforms/platform_node_test.ts +++ b/src/platforms/platform_node_test.ts @@ -67,4 +67,50 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => { platform_node.systemFetch = savedFetch; ENV.global.fetch = globalFetch; }); + + it('encodeUTF8 single string', () => { + const platform = new PlatformNode(); + const bytes = platform.encodeUTF8('hello'); + expect(bytes.length).toBe(5); + expect(bytes).toEqual(new Uint8Array([104, 101, 108, 108, 111])); + }); + + it('encodeUTF8 two strings delimited', () => { + const platform = new PlatformNode(); + const bytes = platform.encodeUTF8('hello\x00world'); + expect(bytes.length).toBe(11); + expect(bytes).toEqual( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + }); + + it('encodeUTF8 cyrillic', () => { + const platform = new PlatformNode(); + const bytes = platform.encodeUTF8('Здраво'); + expect(bytes.length).toBe(12); + expect(bytes).toEqual(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + }); + + it('decodeUTF8 single string', () => { + const platform = new PlatformNode(); + const s = platform.decodeUTF8(new Uint8Array([104, 101, 108, 108, 111])); + expect(s.length).toBe(5); + expect(s).toEqual('hello'); + }); + + it('decodeUTF8 two strings delimited', () => { + const platform = new PlatformNode(); + const s = platform.decodeUTF8( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + expect(s.length).toBe(11); + expect(s).toEqual('hello\x00world'); + }); + + it('decodeUTF8 cyrillic', () => { + const platform = new PlatformNode(); + const s = platform.decodeUTF8(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + expect(s.length).toBe(6); + expect(s).toEqual('Здраво'); + }); }); From 154fd382737d888628027630920f30367456ea0b Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 19 Jun 2019 10:39:25 -0400 Subject: [PATCH 7/8] save --- src/io/io_utils_test.ts | 25 +++++++++++++++++++++---- src/io/types.ts | 13 +++++++++---- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 3de492be57..df1b00fcd8 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -329,12 +329,14 @@ describe('encodeWeights', () => { ]); }); - it('String tensors', async () => { + // tslint:disable-next-line: ban + fit('String tensors', async () => { const tensors: NamedTensorMap = { x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2], 'string'), x2: scalar('', 'string'), // Empty string. x3: tensor1d(['здраво', 'поздрав'], 'string'), // Cyrillic. - x4: scalar('hello', 'string') // Single string. + x4: scalar('正常'), // Chinese. + x5: scalar('hello') // Single string. }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; @@ -342,9 +344,12 @@ describe('encodeWeights', () => { const x1ByteLength = 7 + 3; // 7 ascii chars + 3 delimiters. const x2ByteLength = 0; // No chars. const x3ByteLength = 13 * 2 + 1; // 13 cyrillic letters + 1 delimiter. - const x4ByteLength = 5; // 5 ascii chars. + const x4ByteLength = 6; // 2 chinese letters. + const x5ByteLength = 5; // 5 ascii chars. expect(data.byteLength) - .toEqual(x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength); + .toEqual( + x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength + + x5ByteLength); let delim = specs[0].delimiter; expect(new Uint8Array(data, 0, x1ByteLength)) .toEqual(tf.ENV.platform.encodeUTF8(`a${delim}bc${delim}def${delim}g`)); @@ -355,6 +360,11 @@ describe('encodeWeights', () => { delim = specs[3].delimiter; expect(new Uint8Array( data, x1ByteLength + x2ByteLength + x3ByteLength, x4ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8('正常')); + delim = specs[4].delimiter; + expect(new Uint8Array( + data, x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength, + x5ByteLength)) .toEqual(tf.ENV.platform.encodeUTF8('hello')); expect(specs).toEqual([ { @@ -384,6 +394,13 @@ describe('encodeWeights', () => { shape: [], byteLength: x4ByteLength, delimiter: STRING_DELIMITER, + }, + { + name: 'x5', + dtype: 'string', + shape: [], + byteLength: x5ByteLength, + delimiter: STRING_DELIMITER, } ]); }); diff --git a/src/io/types.ts b/src/io/types.ts index 6c6d3bc192..0de401a950 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -110,11 +110,16 @@ export declare interface WeightsManifestEntry { export declare interface StringWeightsManifestEntry extends WeightsManifestEntry { dtype: 'string'; - // Used for delimiting neighboring strings. If the tensor has no strings or - // only 1 string, there will be no delimiter. If the tensor has N strings - // (N > 0), there will be N-1 delimiters. + /** + * Used for delimiting neighboring strings. If the tensor has no strings or + * only 1 string, there will be no delimiter. If the tensor has N strings + * (N>0), there will be N-1 delimiters. + */ delimiter: string; - // Number of bytes used by the whole tensor, including the delimiters. + /** + * Number of bytes used by the whole tensor, including the delimiters (N-1 + * delimiters for N strings). + */ byteLength: number; } From 46ad7b25664d184a7c816ea8d412588b662c224d Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 19 Jun 2019 17:54:04 -0400 Subject: [PATCH 8/8] save --- src/io/io_utils_test.ts | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index df1b00fcd8..0a56f48bd2 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -329,14 +329,13 @@ describe('encodeWeights', () => { ]); }); - // tslint:disable-next-line: ban - fit('String tensors', async () => { + it('String tensors', async () => { const tensors: NamedTensorMap = { - x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2], 'string'), - x2: scalar('', 'string'), // Empty string. - x3: tensor1d(['здраво', 'поздрав'], 'string'), // Cyrillic. - x4: scalar('正常'), // Chinese. - x5: scalar('hello') // Single string. + x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2]), + x2: scalar(''), // Empty string. + x3: tensor1d(['здраво', 'поздрав']), // Cyrillic. + x4: scalar('正常'), // East Asian. + x5: scalar('hello') // Single string. }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data;