Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"seedrandom": "2.4.3"
},
"browser": {
"node-fetch": false
"node-fetch": false,
"util": false
}
}
2 changes: 1 addition & 1 deletion rollup.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function config({plugins = [], output = {}, external = [], visualize = false}) {
node(),
// Polyfill require() from dependencies.
commonjs({
ignore: ['crypto', 'node-fetch'],
ignore: ['crypto', 'node-fetch', 'util'],
include: 'node_modules/**',
namedExports: {
'./node_modules/seedrandom/index.js': ['alea'],
Expand Down
22 changes: 2 additions & 20 deletions src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import * as backend_util from '../backend_util';
import * as complex_util from '../complex_util';
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';

Expand Down Expand Up @@ -2065,26 +2066,7 @@ export class MathBackendCPU implements KernelBackend {

tile<T extends Tensor>(x: T, reps: number[]): T {
this.assertNotComplex(x, 'tile');

const newShape: number[] = new Array(x.rank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[i] * reps[i];
}
const result = ops.buffer(newShape, x.dtype);
const xBuf = this.bufferSync(x);
for (let i = 0; i < result.values.length; ++i) {
const newLoc = result.indexToLoc(i);

const originalLoc: number[] = new Array(x.rank);
for (let i = 0; i < originalLoc.length; i++) {
originalLoc[i] = newLoc[i] % x.shape[i];
}

const originalIndex = xBuf.locToIndex(originalLoc);

result.values[i] = xBuf.values[originalIndex];
}
return result.toTensor() as T;
return tile(this.bufferSync(x), reps) as T;
}

pad<T extends Tensor>(
Expand Down
47 changes: 47 additions & 0 deletions src/backends/tile_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* An implementation of the tile kernel shared between webgl and cpu for string
* tensors only.
*/

import {buffer} from '../ops/array_ops';
import {Tensor, TensorBuffer} from '../tensor';
import {DataType, Rank} from '../types';

export function tile<R extends Rank>(
xBuf: TensorBuffer<R, DataType>, reps: number[]): Tensor<R> {
const newShape: number[] = new Array(xBuf.rank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xBuf.shape[i] * reps[i];
}
const result = buffer(newShape, xBuf.dtype);
for (let i = 0; i < result.values.length; ++i) {
const newLoc = result.indexToLoc(i);

const originalLoc: number[] = new Array(xBuf.rank);
for (let i = 0; i < originalLoc.length; i++) {
originalLoc[i] = newLoc[i] % xBuf.shape[i];
}

const originalIndex = xBuf.locToIndex(originalLoc);

result.values[i] = xBuf.values[originalIndex];
}
return result.toTensor() as Tensor<R>;
}
6 changes: 6 additions & 0 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {ENGINE, MemoryInfo, TimingInfo} from '../../engine';
import {ENV} from '../../environment';
import {tidy} from '../../globals';
import {warn} from '../../log';
import {buffer} from '../../ops/array_ops';
import * as array_ops_util from '../../ops/array_ops_util';
import * as axis_util from '../../ops/axis_util';
import {computeOutShape} from '../../ops/concat_util';
Expand All @@ -44,6 +45,7 @@ import * as backend_util from '../backend_util';
import {mergeRealAndImagArrays} from '../complex_util';
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';

Expand Down Expand Up @@ -974,6 +976,10 @@ export class MathBackendWebGL implements KernelBackend {
}

tile<T extends Tensor>(x: T, reps: number[]): T {
if (x.dtype === 'string') {
const buf = buffer(x.shape, x.dtype, this.readSync(x.dataId) as string[]);
return tile(buf, reps) as T;
}
const program = new TileProgram(x.shape, reps);
return this.compileAndRun(program, [x]);
}
Expand Down
5 changes: 3 additions & 2 deletions src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
import {fromMemory, withSaveHandler} from './passthrough';
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry, WeightGroup} from './types';
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, StringWeightsManifestEntry, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeights, weightsLoaderFactory} from './weights_loader';

export {copyModel, listModels, moveModel, removeModel} from './model_management';
Expand Down Expand Up @@ -54,9 +54,10 @@ export {
SaveConfig,
SaveHandler,
SaveResult,
StringWeightsManifestEntry,
WeightGroup,
weightsLoaderFactory,
WeightsManifestConfig,
WeightsManifestEntry,
WeightGroup,
withSaveHandler
};
54 changes: 33 additions & 21 deletions src/io/io_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
* =============================================================================
*/

import {ENV} from '../environment';
import {tensor} from '../ops/tensor_ops';
import {Tensor} from '../tensor';
import {NamedTensor, NamedTensorMap} from '../tensor_types';
import {TypedArray} from '../types';
import {sizeFromShape} from '../util';

import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightGroup, WeightsManifestEntry} from './types';
import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, StringWeightsManifestEntry, WeightGroup, WeightsManifestEntry} from './types';

/** Used to delimit neighboring strings when encoding string tensors. */
export const STRING_DELIMITER = '\x00';

/**
* Encode a map from names to weight values as an ArrayBuffer, along with an
Expand Down Expand Up @@ -54,15 +57,28 @@ export async function encodeWeights(
for (let i = 0; i < names.length; ++i) {
const name = names[i];
const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool') {
if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
t.dtype !== 'string') {
throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
}
const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype};
if (t.dtype === 'string') {
const utf8bytes = new Promise<TypedArray>(async resolve => {
const stringSpec = spec as StringWeightsManifestEntry;
const data = await t.data();
const bytes = ENV.platform.encodeUTF8(data.join(STRING_DELIMITER));
stringSpec.byteLength = bytes.length;
stringSpec.delimiter = STRING_DELIMITER;
resolve(bytes);
});
dataPromises.push(utf8bytes);
} else {
dataPromises.push(t.data());
}
if (group != null) {
spec.group = group;
}
specs.push(spec);
dataPromises.push(t.data());
}

const tensorValues = await Promise.all(dataPromises);
Expand Down Expand Up @@ -94,7 +110,7 @@ export function decodeWeights(
const dtype = spec.dtype;
const shape = spec.shape;
const size = sizeFromShape(shape);
let typedArray: TypedArray;
let values: TypedArray|string[];

if ('quantization' in spec) {
const quantization = spec.quantization;
Expand All @@ -111,43 +127,39 @@ export function decodeWeights(
new Uint8Array(byteBuffer) :
new Uint16Array(byteBuffer);
if (dtype === 'float32') {
typedArray = Float32Array.from(
values = Float32Array.from(
quantizedArray, v => v * quantization.scale + quantization.min);
} else if (dtype === 'int32') {
typedArray = Int32Array.from(
values = Int32Array.from(
quantizedArray,
v => Math.round(v * quantization.scale + quantization.min));
} else {
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
}
offset += size * quantizationSizeFactor;
} else if (dtype === 'string') {
const stringSpec = spec as StringWeightsManifestEntry;
const bytes =
new Uint8Array(buffer.slice(offset, offset + stringSpec.byteLength));
values = ENV.platform.decodeUTF8(bytes).split(stringSpec.delimiter);
offset += stringSpec.byteLength;
} else {
const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);

if (dtype === 'float32') {
typedArray = new Float32Array(byteBuffer);
values = new Float32Array(byteBuffer);
} else if (dtype === 'int32') {
typedArray = new Int32Array(byteBuffer);
values = new Int32Array(byteBuffer);
} else if (dtype === 'bool') {
typedArray = new Uint8Array(byteBuffer);
values = new Uint8Array(byteBuffer);
} else {
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
}
offset += size * dtypeFactor;
}

let value: Tensor;
if (dtype === 'float32') {
value = tensor(typedArray, shape, 'float32');
} else if (dtype === 'int32') {
value = tensor(typedArray, shape, 'int32');
} else if (dtype === 'bool') {
value = tensor(typedArray, shape, 'bool');
} else {
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
}
out[name] = value;
out[name] = tensor(values, shape, dtype);
}
return out;
}
Expand Down
93 changes: 90 additions & 3 deletions src/io/io_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {NamedTensor, NamedTensorMap} from '../tensor_types';
import {expectArraysEqual} from '../test_util';
import {expectArraysClose} from '../test_util';

import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils';
import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, STRING_DELIMITER, stringByteLength} from './io_utils';
import {WeightsManifestEntry} from './types';

describe('concatenateTypedArrays', () => {
Expand Down Expand Up @@ -329,6 +329,81 @@ describe('encodeWeights', () => {
]);
});

it('String tensors', async () => {
const tensors: NamedTensorMap = {
x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2]),
x2: scalar(''), // Empty string.
x3: tensor1d(['здраво', 'поздрав']), // Cyrillic.
x4: scalar('正常'), // East Asian.
x5: scalar('hello') // Single string.
};
const dataAndSpecs = await tf.io.encodeWeights(tensors);
const data = dataAndSpecs.data;
const specs = dataAndSpecs.specs as tf.io.StringWeightsManifestEntry[];
const x1ByteLength = 7 + 3; // 7 ascii chars + 3 delimiters.
const x2ByteLength = 0; // No chars.
const x3ByteLength = 13 * 2 + 1; // 13 cyrillic letters + 1 delimiter.
const x4ByteLength = 6; // 2 chinese letters.
const x5ByteLength = 5; // 5 ascii chars.
expect(data.byteLength)
.toEqual(
x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength +
x5ByteLength);
let delim = specs[0].delimiter;
expect(new Uint8Array(data, 0, x1ByteLength))
.toEqual(tf.ENV.platform.encodeUTF8(`a${delim}bc${delim}def${delim}g`));
// The middle string takes up 0 bytes.
delim = specs[2].delimiter;
expect(new Uint8Array(data, x1ByteLength + x2ByteLength, x3ByteLength))
.toEqual(tf.ENV.platform.encodeUTF8(`здраво${delim}поздрав`));
delim = specs[3].delimiter;
expect(new Uint8Array(
data, x1ByteLength + x2ByteLength + x3ByteLength, x4ByteLength))
.toEqual(tf.ENV.platform.encodeUTF8('正常'));
delim = specs[4].delimiter;
expect(new Uint8Array(
data, x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength,
x5ByteLength))
.toEqual(tf.ENV.platform.encodeUTF8('hello'));
expect(specs).toEqual([
{
name: 'x1',
dtype: 'string',
shape: [2, 2],
byteLength: x1ByteLength,
delimiter: STRING_DELIMITER,
},
{
name: 'x2',
dtype: 'string',
shape: [],
byteLength: x2ByteLength,
delimiter: STRING_DELIMITER,
},
{
name: 'x3',
dtype: 'string',
shape: [2],
byteLength: x3ByteLength,
delimiter: STRING_DELIMITER,
},
{
name: 'x4',
dtype: 'string',
shape: [],
byteLength: x4ByteLength,
delimiter: STRING_DELIMITER,
},
{
name: 'x5',
dtype: 'string',
shape: [],
byteLength: x5ByteLength,
delimiter: STRING_DELIMITER,
}
]);
});

it('Mixed dtype tensors', async () => {
const tensors: NamedTensorMap = {
x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'),
Expand Down Expand Up @@ -370,17 +445,29 @@ describeWithFlags('decodeWeights', {}, () => {
x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'),
x2: scalar(13.37, 'float32'),
x3: tensor1d([true, false, false], 'bool'),
x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'),
x5: tensor1d([''], 'string'), // Empty string.
x6: scalar('hello'), // Single string.
y1: tensor2d([-10, -20, -30], [3, 1], 'float32'),
};
const dataAndSpecs = await tf.io.encodeWeights(tensors);
const data = dataAndSpecs.data;
const specs = dataAndSpecs.specs;
expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 3 + 4 * 3);
// 12 bytes from cyrillic (6 letters) + 3 bytes from ascii + 3 delimiters.
const x4Bytes = 12 + 3 + 3;
const x5Bytes = 0;
// 5 bytes from ascii.
const x6Bytes = 5;
expect(data.byteLength)
.toEqual(4 * 4 + 4 * 1 + 1 * 3 + x4Bytes + x5Bytes + x6Bytes + 4 * 3);
const decoded = tf.io.decodeWeights(data, specs);
expect(Object.keys(decoded).length).toEqual(4);
expect(Object.keys(decoded).length).toEqual(7);
expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data());
expectArraysEqual(await decoded['x2'].data(), await tensors['x2'].data());
expectArraysEqual(await decoded['x3'].data(), await tensors['x3'].data());
expectArraysEqual(await decoded['x4'].data(), await tensors['x4'].data());
expectArraysEqual(await decoded['x5'].data(), await tensors['x5'].data());
expectArraysEqual(await decoded['x6'].data(), await tensors['x6'].data());
expectArraysEqual(await decoded['y1'].data(), await tensors['y1'].data());
});

Expand Down
Loading