diff --git a/package.json b/package.json index 9977482cd9..83283a65ae 100644 --- a/package.json +++ b/package.json @@ -82,6 +82,7 @@ "seedrandom": "2.4.3" }, "browser": { - "node-fetch": false + "node-fetch": false, + "util": false } } diff --git a/rollup.config.js b/rollup.config.js index 229e398f38..4938bdcda9 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -59,7 +59,7 @@ function config({plugins = [], output = {}, external = [], visualize = false}) { node(), // Polyfill require() from dependencies. commonjs({ - ignore: ['crypto', 'node-fetch'], + ignore: ['crypto', 'node-fetch', 'util'], include: 'node_modules/**', namedExports: { './node_modules/seedrandom/index.js': ['alea'], diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index 56e5424c03..62e80a065f 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -42,6 +42,7 @@ import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -2065,26 +2066,7 @@ export class MathBackendCPU implements KernelBackend { tile(x: T, reps: number[]): T { this.assertNotComplex(x, 'tile'); - - const newShape: number[] = new Array(x.rank); - for (let i = 0; i < newShape.length; i++) { - newShape[i] = x.shape[i] * reps[i]; - } - const result = ops.buffer(newShape, x.dtype); - const xBuf = this.bufferSync(x); - for (let i = 0; i < result.values.length; ++i) { - const newLoc = result.indexToLoc(i); - - const originalLoc: number[] = new Array(x.rank); - for (let i = 0; i < originalLoc.length; i++) { - originalLoc[i] = newLoc[i] % x.shape[i]; - } - - const originalIndex = xBuf.locToIndex(originalLoc); - - result.values[i] = xBuf.values[originalIndex]; - } - return result.toTensor() as T; + return tile(this.bufferSync(x), reps) as T; } pad( diff --git a/src/backends/tile_impl.ts b/src/backends/tile_impl.ts new file mode 100644 index 0000000000..a1cb8c9121 --- /dev/null +++ b/src/backends/tile_impl.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * An implementation of the tile kernel shared between webgl and cpu for string + * tensors only. + */ + +import {buffer} from '../ops/array_ops'; +import {Tensor, TensorBuffer} from '../tensor'; +import {DataType, Rank} from '../types'; + +export function tile( + xBuf: TensorBuffer, reps: number[]): Tensor { + const newShape: number[] = new Array(xBuf.rank); + for (let i = 0; i < newShape.length; i++) { + newShape[i] = xBuf.shape[i] * reps[i]; + } + const result = buffer(newShape, xBuf.dtype); + for (let i = 0; i < result.values.length; ++i) { + const newLoc = result.indexToLoc(i); + + const originalLoc: number[] = new Array(xBuf.rank); + for (let i = 0; i < originalLoc.length; i++) { + originalLoc[i] = newLoc[i] % xBuf.shape[i]; + } + + const originalIndex = xBuf.locToIndex(originalLoc); + + result.values[i] = xBuf.values[originalIndex]; + } + return result.toTensor() as Tensor; +} diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index b3541e08c4..1f4fd823b0 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -23,6 +23,7 @@ import {ENGINE, MemoryInfo, TimingInfo} from '../../engine'; import {ENV} from '../../environment'; import {tidy} from '../../globals'; import {warn} from '../../log'; +import {buffer} from '../../ops/array_ops'; import * as array_ops_util from '../../ops/array_ops_util'; import * as axis_util from '../../ops/axis_util'; import {computeOutShape} from '../../ops/concat_util'; @@ -44,6 +45,7 @@ import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -974,6 +976,10 @@ export class MathBackendWebGL implements KernelBackend { } tile(x: T, reps: number[]): T { + if (x.dtype === 'string') { + const buf = buffer(x.shape, x.dtype, this.readSync(x.dataId) as string[]); + return tile(buf, reps) as T; + } const program = new TileProgram(x.shape, reps); return this.compileAndRun(program, [x]); } diff --git a/src/io/io.ts b/src/io/io.ts index 685a77c9bf..3e6fa1a067 100644 --- a/src/io/io.ts +++ b/src/io/io.ts @@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http'; import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils'; import {fromMemory, withSaveHandler} from './passthrough'; import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry'; -import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry, WeightGroup} from './types'; +import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, StringWeightsManifestEntry, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeights, weightsLoaderFactory} from './weights_loader'; export {copyModel, listModels, moveModel, removeModel} from './model_management'; @@ -54,9 +54,10 @@ export { SaveConfig, SaveHandler, SaveResult, + StringWeightsManifestEntry, + WeightGroup, weightsLoaderFactory, WeightsManifestConfig, WeightsManifestEntry, - WeightGroup, withSaveHandler }; diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 86540c2eb8..cd890cc6c9 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -15,13 +15,16 @@ * ============================================================================= */ +import {ENV} from '../environment'; import {tensor} from '../ops/tensor_ops'; -import {Tensor} from '../tensor'; import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {TypedArray} from '../types'; import {sizeFromShape} from '../util'; -import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightGroup, WeightsManifestEntry} from './types'; +import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, StringWeightsManifestEntry, WeightGroup, WeightsManifestEntry} from './types'; + +/** Used to delimit neighboring strings when encoding string tensors. */ +export const STRING_DELIMITER = '\x00'; /** * Encode a map from names to weight values as an ArrayBuffer, along with an @@ -54,15 +57,28 @@ export async function encodeWeights( for (let i = 0; i < names.length; ++i) { const name = names[i]; const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name]; - if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool') { + if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && + t.dtype !== 'string') { throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`); } const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype}; + if (t.dtype === 'string') { + const utf8bytes = new Promise(async resolve => { + const stringSpec = spec as StringWeightsManifestEntry; + const data = await t.data(); + const bytes = ENV.platform.encodeUTF8(data.join(STRING_DELIMITER)); + stringSpec.byteLength = bytes.length; + stringSpec.delimiter = STRING_DELIMITER; + resolve(bytes); + }); + dataPromises.push(utf8bytes); + } else { + dataPromises.push(t.data()); + } if (group != null) { spec.group = group; } specs.push(spec); - dataPromises.push(t.data()); } const tensorValues = await Promise.all(dataPromises); @@ -94,7 +110,7 @@ export function decodeWeights( const dtype = spec.dtype; const shape = spec.shape; const size = sizeFromShape(shape); - let typedArray: TypedArray; + let values: TypedArray|string[]; if ('quantization' in spec) { const quantization = spec.quantization; @@ -111,43 +127,39 @@ export function decodeWeights( new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); if (dtype === 'float32') { - typedArray = Float32Array.from( + values = Float32Array.from( quantizedArray, v => v * quantization.scale + quantization.min); } else if (dtype === 'int32') { - typedArray = Int32Array.from( + values = Int32Array.from( quantizedArray, v => Math.round(v * quantization.scale + quantization.min)); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * quantizationSizeFactor; + } else if (dtype === 'string') { + const stringSpec = spec as StringWeightsManifestEntry; + const bytes = + new Uint8Array(buffer.slice(offset, offset + stringSpec.byteLength)); + values = ENV.platform.decodeUTF8(bytes).split(stringSpec.delimiter); + offset += stringSpec.byteLength; } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; const byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); if (dtype === 'float32') { - typedArray = new Float32Array(byteBuffer); + values = new Float32Array(byteBuffer); } else if (dtype === 'int32') { - typedArray = new Int32Array(byteBuffer); + values = new Int32Array(byteBuffer); } else if (dtype === 'bool') { - typedArray = new Uint8Array(byteBuffer); + values = new Uint8Array(byteBuffer); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * dtypeFactor; } - let value: Tensor; - if (dtype === 'float32') { - value = tensor(typedArray, shape, 'float32'); - } else if (dtype === 'int32') { - value = tensor(typedArray, shape, 'int32'); - } else if (dtype === 'bool') { - value = tensor(typedArray, shape, 'bool'); - } else { - throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); - } - out[name] = value; + out[name] = tensor(values, shape, dtype); } return out; } diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 5a6feaa07d..0a56f48bd2 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -22,7 +22,7 @@ import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {expectArraysEqual} from '../test_util'; import {expectArraysClose} from '../test_util'; -import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils'; +import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, STRING_DELIMITER, stringByteLength} from './io_utils'; import {WeightsManifestEntry} from './types'; describe('concatenateTypedArrays', () => { @@ -329,6 +329,81 @@ describe('encodeWeights', () => { ]); }); + it('String tensors', async () => { + const tensors: NamedTensorMap = { + x1: tensor2d([['a', 'bc'], ['def', 'g']], [2, 2]), + x2: scalar(''), // Empty string. + x3: tensor1d(['здраво', 'поздрав']), // Cyrillic. + x4: scalar('正常'), // East Asian. + x5: scalar('hello') // Single string. + }; + const dataAndSpecs = await tf.io.encodeWeights(tensors); + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs as tf.io.StringWeightsManifestEntry[]; + const x1ByteLength = 7 + 3; // 7 ascii chars + 3 delimiters. + const x2ByteLength = 0; // No chars. + const x3ByteLength = 13 * 2 + 1; // 13 cyrillic letters + 1 delimiter. + const x4ByteLength = 6; // 2 chinese letters. + const x5ByteLength = 5; // 5 ascii chars. + expect(data.byteLength) + .toEqual( + x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength + + x5ByteLength); + let delim = specs[0].delimiter; + expect(new Uint8Array(data, 0, x1ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8(`a${delim}bc${delim}def${delim}g`)); + // The middle string takes up 0 bytes. + delim = specs[2].delimiter; + expect(new Uint8Array(data, x1ByteLength + x2ByteLength, x3ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8(`здраво${delim}поздрав`)); + delim = specs[3].delimiter; + expect(new Uint8Array( + data, x1ByteLength + x2ByteLength + x3ByteLength, x4ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8('正常')); + delim = specs[4].delimiter; + expect(new Uint8Array( + data, x1ByteLength + x2ByteLength + x3ByteLength + x4ByteLength, + x5ByteLength)) + .toEqual(tf.ENV.platform.encodeUTF8('hello')); + expect(specs).toEqual([ + { + name: 'x1', + dtype: 'string', + shape: [2, 2], + byteLength: x1ByteLength, + delimiter: STRING_DELIMITER, + }, + { + name: 'x2', + dtype: 'string', + shape: [], + byteLength: x2ByteLength, + delimiter: STRING_DELIMITER, + }, + { + name: 'x3', + dtype: 'string', + shape: [2], + byteLength: x3ByteLength, + delimiter: STRING_DELIMITER, + }, + { + name: 'x4', + dtype: 'string', + shape: [], + byteLength: x4ByteLength, + delimiter: STRING_DELIMITER, + }, + { + name: 'x5', + dtype: 'string', + shape: [], + byteLength: x5ByteLength, + delimiter: STRING_DELIMITER, + } + ]); + }); + it('Mixed dtype tensors', async () => { const tensors: NamedTensorMap = { x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), @@ -370,17 +445,29 @@ describeWithFlags('decodeWeights', {}, () => { x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), x2: scalar(13.37, 'float32'), x3: tensor1d([true, false, false], 'bool'), + x4: tensor2d([['здраво', 'a'], ['b', 'c']], [2, 2], 'string'), + x5: tensor1d([''], 'string'), // Empty string. + x6: scalar('hello'), // Single string. y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), }; const dataAndSpecs = await tf.io.encodeWeights(tensors); const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; - expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 3 + 4 * 3); + // 12 bytes from cyrillic (6 letters) + 3 bytes from ascii + 3 delimiters. + const x4Bytes = 12 + 3 + 3; + const x5Bytes = 0; + // 5 bytes from ascii. + const x6Bytes = 5; + expect(data.byteLength) + .toEqual(4 * 4 + 4 * 1 + 1 * 3 + x4Bytes + x5Bytes + x6Bytes + 4 * 3); const decoded = tf.io.decodeWeights(data, specs); - expect(Object.keys(decoded).length).toEqual(4); + expect(Object.keys(decoded).length).toEqual(7); expectArraysEqual(await decoded['x1'].data(), await tensors['x1'].data()); expectArraysEqual(await decoded['x2'].data(), await tensors['x2'].data()); expectArraysEqual(await decoded['x3'].data(), await tensors['x3'].data()); + expectArraysEqual(await decoded['x4'].data(), await tensors['x4'].data()); + expectArraysEqual(await decoded['x5'].data(), await tensors['x5'].data()); + expectArraysEqual(await decoded['x6'].data(), await tensors['x6'].data()); expectArraysEqual(await decoded['y1'].data(), await tensors['y1'].data()); }); diff --git a/src/io/types.ts b/src/io/types.ts index 0e393480bb..0de401a950 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -85,7 +85,7 @@ export declare interface WeightsManifestEntry { /** * Data type of the weight. */ - dtype: 'float32'|'int32'|'bool'; + dtype: 'float32'|'int32'|'bool'|'string'; /** * Type of the weight. @@ -107,6 +107,22 @@ export declare interface WeightsManifestEntry { }; } +export declare interface StringWeightsManifestEntry extends + WeightsManifestEntry { + dtype: 'string'; + /** + * Used for delimiting neighboring strings. If the tensor has no strings or + * only 1 string, there will be no delimiter. If the tensor has N strings + * (N>0), there will be N-1 delimiters. + */ + delimiter: string; + /** + * Number of bytes used by the whole tensor, including the delimiters (N-1 + * delimiters for N strings). + */ + byteLength: number; +} + /** * Options for saving a model. * @innamespace io diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index d8b1c80a0b..01af4ecad7 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -418,7 +418,8 @@ function cast_(x: T|TensorLike, dtype: DataType): T { */ /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function tile_(x: T|TensorLike, reps: number[]): T { - const $x = convertToTensor(x, 'x', 'tile'); + const parseAs: DataType = null; + const $x = convertToTensor(x, 'x', 'tile', parseAs); util.assert( $x.rank === reps.length, @@ -881,7 +882,8 @@ function cumsum_( /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ function expandDims_( x: Tensor|TensorLike, axis = 0): Tensor { - const $x = convertToTensor(x, 'x', 'expandDims'); + const parseAs: DataType = null; + const $x = convertToTensor(x, 'x', 'expandDims', parseAs); util.assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor'); const newShape = $x.shape.slice(); diff --git a/src/ops/array_ops_test.ts b/src/ops/array_ops_test.ts index 88862da30b..5be54dd0dc 100644 --- a/src/ops/array_ops_test.ts +++ b/src/ops/array_ops_test.ts @@ -1946,6 +1946,23 @@ describeWithFlags('tile', ALL_ENVS, () => { await t2.data(), [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]); }); + it('1d string tensor', async () => { + const a = tf.tensor(['a', 'b', 'c']); + const res = tf.tile(a, [2]); + expect(res.shape).toEqual([6]); + expectArraysEqual(await res.data(), ['a', 'b', 'c', 'a', 'b', 'c']); + }); + + it('2d string tensor', async () => { + const a = tf.tensor([['a', 'b'], ['c', 'd']]); + const res = tf.tile(a, [2, 3]); + expect(res.shape).toEqual([4, 6]); + expectArraysEqual(await res.data(), [ + 'a', 'b', 'a', 'b', 'a', 'b', 'c', 'd', 'c', 'd', 'c', 'd', + 'a', 'b', 'a', 'b', 'a', 'b', 'c', 'd', 'c', 'd', 'c', 'd' + ]); + }); + it('propagates NaNs', async () => { const t = tf.tensor1d([1, 2, NaN]); @@ -3512,6 +3529,20 @@ describeWithFlags('expandDims', ALL_ENVS, () => { expectArraysClose(await res.data(), [4]); }); + it('1d string tensor', async () => { + const t = tf.tensor(['hello', 'world']); + const res = t.expandDims(); + expect(res.shape).toEqual([1, 2]); + expectArraysClose(await res.data(), ['hello', 'world']); + }); + + it('2d string tensor, axis=1', async () => { + const t = tf.tensor([['a', 'b'], ['c', 'd']]); + const res = t.expandDims(1); + expect(res.shape).toEqual([2, 1, 2]); + expectArraysClose(await res.data(), ['a', 'b', 'c', 'd']); + }); + it('throws when passed a non-tensor', () => { expect(() => tf.expandDims({} as tf.Tensor)) .toThrowError(/Argument 'x' passed to 'expandDims' must be a Tensor/); diff --git a/src/platforms/platform.ts b/src/platforms/platform.ts index c60621a3e3..31037d1521 100644 --- a/src/platforms/platform.ts +++ b/src/platforms/platform.ts @@ -28,4 +28,9 @@ export interface Platform { * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request */ fetch(path: string, requestInits?: RequestInit): Promise; + + /** UTF-8 encode the provided string into an array of bytes. */ + encodeUTF8(text: string): Uint8Array; + /** UTF-8 decode the provided bytes into a string. */ + decodeUTF8(bytes: Uint8Array): string; } diff --git a/src/platforms/platform_browser.ts b/src/platforms/platform_browser.ts index 5573483506..41968373a0 100644 --- a/src/platforms/platform_browser.ts +++ b/src/platforms/platform_browser.ts @@ -18,6 +18,21 @@ import {ENV} from '../environment'; import {Platform} from './platform'; export class PlatformBrowser implements Platform { + private textEncoder: TextEncoder; + private textDecoder: TextDecoder; + + constructor() { + // The built-in encoder and the decoder use UTF-8 encoding. + this.textEncoder = new TextEncoder(); + this.textDecoder = new TextDecoder(); + } + + encodeUTF8(text: string): Uint8Array { + return this.textEncoder.encode(text); + } + decodeUTF8(bytes: Uint8Array): string { + return this.textDecoder.decode(bytes); + } fetch(path: string, init?: RequestInit): Promise { return fetch(path, init); } diff --git a/src/platforms/platform_browser_test.ts b/src/platforms/platform_browser_test.ts index fed0fd224b..74549fe260 100644 --- a/src/platforms/platform_browser_test.ts +++ b/src/platforms/platform_browser_test.ts @@ -29,4 +29,50 @@ describeWithFlags('PlatformBrowser', BROWSER_ENVS, async () => { expect(self.fetch).toHaveBeenCalledWith('test/url', {method: 'GET'}); }); + + it('encodeUTF8 single string', () => { + const platform = new PlatformBrowser(); + const bytes = platform.encodeUTF8('hello'); + expect(bytes.length).toBe(5); + expect(bytes).toEqual(new Uint8Array([104, 101, 108, 108, 111])); + }); + + it('encodeUTF8 two strings delimited', () => { + const platform = new PlatformBrowser(); + const bytes = platform.encodeUTF8('hello\x00world'); + expect(bytes.length).toBe(11); + expect(bytes).toEqual( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + }); + + it('encodeUTF8 cyrillic', () => { + const platform = new PlatformBrowser(); + const bytes = platform.encodeUTF8('Здраво'); + expect(bytes.length).toBe(12); + expect(bytes).toEqual(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + }); + + it('decodeUTF8 single string', () => { + const platform = new PlatformBrowser(); + const s = platform.decodeUTF8(new Uint8Array([104, 101, 108, 108, 111])); + expect(s.length).toBe(5); + expect(s).toEqual('hello'); + }); + + it('decodeUTF8 two strings delimited', () => { + const platform = new PlatformBrowser(); + const s = platform.decodeUTF8( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + expect(s.length).toBe(11); + expect(s).toEqual('hello\x00world'); + }); + + it('decodeUTF8 cyrillic', () => { + const platform = new PlatformBrowser(); + const s = platform.decodeUTF8(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + expect(s.length).toBe(6); + expect(s).toEqual('Здраво'); + }); }); diff --git a/src/platforms/platform_node.ts b/src/platforms/platform_node.ts index 56dacd805c..e027ec8e77 100644 --- a/src/platforms/platform_node.ts +++ b/src/platforms/platform_node.ts @@ -25,7 +25,28 @@ export const getNodeFetch = { }; export let systemFetch: (url: string, init?: RequestInit) => Promise; + export class PlatformNode implements Platform { + private textEncoder: TextEncoder; + private textDecoder: TextDecoder; + + constructor() { + // tslint:disable-next-line: no-require-imports + const util = require('util'); + // The built-in encoder and the decoder use UTF-8 encoding. + this.textEncoder = new util.TextEncoder(); + this.textDecoder = new util.TextDecoder(); + } + + encodeUTF8(text: string): Uint8Array { + return this.textEncoder.encode(text); + } + decodeUTF8(bytes: Uint8Array): string { + if (bytes.length === 0) { + return ''; + } + return this.textDecoder.decode(bytes); + } fetch(path: string, requestInits?: RequestInit): Promise { if (ENV.global.fetch != null) { return ENV.global.fetch(path, requestInits); diff --git a/src/platforms/platform_node_test.ts b/src/platforms/platform_node_test.ts index 60d8cb1d91..bb9f1a438e 100644 --- a/src/platforms/platform_node_test.ts +++ b/src/platforms/platform_node_test.ts @@ -67,4 +67,50 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => { platform_node.systemFetch = savedFetch; ENV.global.fetch = globalFetch; }); + + it('encodeUTF8 single string', () => { + const platform = new PlatformNode(); + const bytes = platform.encodeUTF8('hello'); + expect(bytes.length).toBe(5); + expect(bytes).toEqual(new Uint8Array([104, 101, 108, 108, 111])); + }); + + it('encodeUTF8 two strings delimited', () => { + const platform = new PlatformNode(); + const bytes = platform.encodeUTF8('hello\x00world'); + expect(bytes.length).toBe(11); + expect(bytes).toEqual( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + }); + + it('encodeUTF8 cyrillic', () => { + const platform = new PlatformNode(); + const bytes = platform.encodeUTF8('Здраво'); + expect(bytes.length).toBe(12); + expect(bytes).toEqual(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + }); + + it('decodeUTF8 single string', () => { + const platform = new PlatformNode(); + const s = platform.decodeUTF8(new Uint8Array([104, 101, 108, 108, 111])); + expect(s.length).toBe(5); + expect(s).toEqual('hello'); + }); + + it('decodeUTF8 two strings delimited', () => { + const platform = new PlatformNode(); + const s = platform.decodeUTF8( + new Uint8Array([104, 101, 108, 108, 111, 0, 119, 111, 114, 108, 100])); + expect(s.length).toBe(11); + expect(s).toEqual('hello\x00world'); + }); + + it('decodeUTF8 cyrillic', () => { + const platform = new PlatformNode(); + const s = platform.decodeUTF8(new Uint8Array( + [208, 151, 208, 180, 209, 128, 208, 176, 208, 178, 208, 190])); + expect(s.length).toBe(6); + expect(s).toEqual('Здраво'); + }); });