Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

EncodeBase64 and DecodeBase64 ops #1779

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
import {Activation} from '../ops/fused_util';
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {Backend, DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';

export const EPSILON_FLOAT32 = 1e-7;
Expand Down Expand Up @@ -623,4 +623,13 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
dispose(): void {
throw new Error('Not yet implemented');
}

encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
T {
throw new Error('Not yet implemented');
}

decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
throw new Error('Not yet implemented');
}
}
29 changes: 20 additions & 9 deletions src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as selu_util from '../../ops/selu_util';
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
Expand All @@ -42,6 +42,7 @@ import * as backend_util from '../backend_util';
import * as complex_util from '../complex_util';
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {decodeBase64Impl, encodeBase64Impl} from '../string_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';
Expand Down Expand Up @@ -80,8 +81,8 @@ export class MathBackendCPU implements KernelBackend {
public blockSize = 48;

private data: DataStorage<TensorData<DataType>>;
private fromPixels2DContext: CanvasRenderingContext2D
| OffscreenCanvasRenderingContext2D;
private fromPixels2DContext: CanvasRenderingContext2D|
OffscreenCanvasRenderingContext2D;
private firstUse = true;

constructor() {
Expand Down Expand Up @@ -133,12 +134,11 @@ export class MathBackendCPU implements KernelBackend {

const isPixelData = (pixels as PixelData).data instanceof Uint8Array;
const isImageData =
typeof(ImageData) !== 'undefined' && pixels instanceof ImageData;
const isVideo =
typeof(HTMLVideoElement) !== 'undefined'
&& pixels instanceof HTMLVideoElement;
const isImage = typeof(HTMLImageElement) !== 'undefined'
&& pixels instanceof HTMLImageElement;
typeof (ImageData) !== 'undefined' && pixels instanceof ImageData;
const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
pixels instanceof HTMLVideoElement;
const isImage = typeof (HTMLImageElement) !== 'undefined' &&
pixels instanceof HTMLImageElement;

let vals: Uint8ClampedArray|Uint8Array;
// tslint:disable-next-line:no-any
Expand Down Expand Up @@ -3194,6 +3194,17 @@ export class MathBackendCPU implements KernelBackend {

dispose() {}

encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return encodeBase64Impl(sVals, str.shape, pad);
}

decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return decodeBase64Impl(sVals, str.shape);
}

floatPrecision(): 16|32 {
return 32;
}
Expand Down
57 changes: 57 additions & 0 deletions src/backends/string_shared.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {arrayBufferToBase64String, base64StringToArrayBuffer, urlSafeBase64, urlUnsafeBase64} from '../io/io_utils';
import {StringTensor, Tensor} from '../tensor';
import {decodeString} from '../util';

/** Shared implementation of the encodeBase64 kernel across WebGL and CPU. */
export function encodeBase64Impl<T extends StringTensor>(
values: Uint8Array[], shape: number[], pad = false): T {
const resultValues = new Array(values.length);

for (let i = 0; i < values.length; ++i) {
const bStr = arrayBufferToBase64String(values[i].buffer);
const bStrUrl = urlSafeBase64(bStr);

if (pad) {
resultValues[i] = bStrUrl;
} else {
// Remove padding
resultValues[i] = bStrUrl.replace(/=/g, '');
}
}

return Tensor.make(shape, {values: resultValues}, 'string') as T;
}

/** Shared implementation of the decodeBase64 kernel across WebGL and CPU. */
export function decodeBase64Impl<T extends StringTensor>(
values: Uint8Array[], shape: number[]): T {
const resultValues = new Array(values.length);

for (let i = 0; i < values.length; ++i) {
// Undo URL safe and decode from Base64 to ArrayBuffer
const bStrUrl = decodeString(values[i]);
const bStr = urlUnsafeBase64(bStrUrl);
const aBuff = base64StringToArrayBuffer(bStr);

resultValues[i] = decodeString(new Uint8Array(aBuff));
}

return Tensor.make(shape, {values: resultValues}, 'string') as T;
}
63 changes: 37 additions & 26 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import * as segment_util from '../../ops/segment_util';
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
import {softmax} from '../../ops/softmax';
import {range, scalar, tensor} from '../../ops/tensor_ops';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {BackendValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';
Expand All @@ -45,6 +45,7 @@ import * as backend_util from '../backend_util';
import {mergeRealAndImagArrays} from '../complex_util';
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {decodeBase64Impl, encodeBase64Impl} from '../string_shared';
import {tile} from '../tile_impl';
import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';
Expand All @@ -62,7 +63,7 @@ import * as binaryop_gpu from './binaryop_gpu';
import {BinaryOpProgram} from './binaryop_gpu';
import * as binaryop_packed_gpu from './binaryop_packed_gpu';
import {BinaryOpPackedProgram} from './binaryop_packed_gpu';
import {getWebGLContext, createCanvas} from './canvas_util';
import {createCanvas, getWebGLContext} from './canvas_util';
import {ClipProgram} from './clip_gpu';
import {ClipPackedProgram} from './clip_packed_gpu';
import {ComplexAbsProgram} from './complex_abs_gpu';
Expand Down Expand Up @@ -222,8 +223,8 @@ export class MathBackendWebGL implements KernelBackend {
private numBytesInGPU = 0;

private canvas: HTMLCanvasElement;
private fromPixels2DContext: CanvasRenderingContext2D
| OffscreenCanvasRenderingContext2D;
private fromPixels2DContext: CanvasRenderingContext2D|
OffscreenCanvasRenderingContext2D;

private programTimersStack: TimerNode[];
private activeTimers: TimerNode[];
Expand Down Expand Up @@ -272,36 +273,35 @@ export class MathBackendWebGL implements KernelBackend {
}

fromPixels(
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
numChannels: number): Tensor3D {
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
numChannels: number): Tensor3D {
if (pixels == null) {
throw new Error(
'pixels passed to tf.browser.fromPixels() can not be null');
}
const texShape: [number, number] = [pixels.height, pixels.width];
const outShape = [pixels.height, pixels.width, numChannels];

const isCanvas = (typeof(OffscreenCanvas) !== 'undefined'
&& pixels instanceof OffscreenCanvas)
|| (typeof(HTMLCanvasElement) !== 'undefined'
&& pixels instanceof HTMLCanvasElement);
const isCanvas = (typeof (OffscreenCanvas) !== 'undefined' &&
pixels instanceof OffscreenCanvas) ||
(typeof (HTMLCanvasElement) !== 'undefined' &&
pixels instanceof HTMLCanvasElement);
const isPixelData = (pixels as PixelData).data instanceof Uint8Array;
const isImageData =
typeof(ImageData) !== 'undefined' && pixels instanceof ImageData;
const isVideo =
typeof(HTMLVideoElement) !== 'undefined'
&& pixels instanceof HTMLVideoElement;
const isImage = typeof(HTMLImageElement) !== 'undefined'
&& pixels instanceof HTMLImageElement;
typeof (ImageData) !== 'undefined' && pixels instanceof ImageData;
const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
pixels instanceof HTMLVideoElement;
const isImage = typeof (HTMLImageElement) !== 'undefined' &&
pixels instanceof HTMLImageElement;

if (!isCanvas && !isPixelData && !isImageData && !isVideo && !isImage) {
throw new Error(
'pixels passed to tf.browser.fromPixels() must be either an ' +
`HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
`in browser, or OffscreenCanvas, ImageData in webworker` +
` or {data: Uint32Array, width: number, height: number}, ` +
`but was ${(pixels as {}).constructor.name}`);
'pixels passed to tf.browser.fromPixels() must be either an ' +
`HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` +
`in browser, or OffscreenCanvas, ImageData in webworker` +
` or {data: Uint32Array, width: number, height: number}, ` +
`but was ${(pixels as {}).constructor.name}`);
}

if (isVideo) {
Expand All @@ -314,14 +314,14 @@ export class MathBackendWebGL implements KernelBackend {
'on the document object');
}
//@ts-ignore
this.fromPixels2DContext = createCanvas(ENV.getNumber('WEBGL_VERSION'))
.getContext('2d');
this.fromPixels2DContext =
createCanvas(ENV.getNumber('WEBGL_VERSION')).getContext('2d');
}
this.fromPixels2DContext.canvas.width = pixels.width;
this.fromPixels2DContext.canvas.height = pixels.height;
this.fromPixels2DContext.drawImage(
pixels as HTMLVideoElement, 0, 0, pixels.width, pixels.height);
//@ts-ignore
pixels as HTMLVideoElement, 0, 0, pixels.width, pixels.height);
//@ts-ignore
pixels = this.fromPixels2DContext.canvas;
}

Expand Down Expand Up @@ -2176,6 +2176,17 @@ export class MathBackendWebGL implements KernelBackend {
return split(x, sizeSplits, axis);
}

encodeBase64<T extends StringTensor>(str: StringTensor|Tensor, pad = false):
T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return encodeBase64Impl(sVals, str.shape, pad);
}

decodeBase64<T extends StringTensor>(str: StringTensor|Tensor): T {
const sVals = this.readSync(str.dataId) as Uint8Array[];
return decodeBase64Impl(sVals, str.shape);
}

scatterND<R extends Rank>(
indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor<R> {
const {sliceRank, numUpdates, sliceSize, strides, outputSize} =
Expand Down
19 changes: 19 additions & 0 deletions src/io/io_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,22 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts):
modelArtifacts.weightData.byteLength,
};
}

/**
* Make Base64 string URL safe by replacing `+` with `-` and `/` with `_`.
*
* @param str Base64 string to make URL safe.
*/
export function urlSafeBase64(str: string): string {
return str.replace(/\+/g, '-').replace(/\//g, '_');
}

// revert Base64 URL safe replacement of + and /
/**
* Revert Base64 URL safe changes by replacing `-` with `+` and `_` with `/`.
*
* @param str URL safe Base string to revert changes.
*/
export function urlUnsafeBase64(str: string): string {
return str.replace(/-/g, '+').replace(/_/g, '/');
}
1 change: 1 addition & 0 deletions src/io/io_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {NamedTensor, NamedTensorMap} from '../tensor_types';
import {expectArraysEqual} from '../test_util';
import {expectArraysClose} from '../test_util';
import {encodeString} from '../util';

import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils';
import {WeightsManifestEntry} from './types';

Expand Down
1 change: 1 addition & 0 deletions src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export * from './sparse_to_dense';
export * from './gather_nd';
export * from './dropout';
export * from './signal_ops';
export * from './string_ops';

export {op} from './operation';

Expand Down
78 changes: 78 additions & 0 deletions src/ops/string_ops.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE} from '../engine';
import {StringTensor, Tensor} from '../tensor';
import {convertToTensor} from '../tensor_util_env';

import {op} from './operation';

/**
* Encodes the values of a `tf.Tensor` (of dtype `string`) to Base64.
*
* Given a String tensor, returns a new tensor with the values encoded into
* web-safe base64 format.
*
* Web-safe means that the encoder uses `-` and `_` instead of `+` and `/`:
*
* en.wikipedia.org/wiki/Base64
*
* ```js
* const x = tf.tensor1d(['Hello world!'], 'string');
*
* x.encodeBase64().print();
* ```
* @param str The input `tf.Tensor` of dtype `string` to encode.
* @param pad Whether to add padding (`=`) to the end of the encoded string.
*/
/** @doc {heading: 'Operations', subheading: 'String'} */
function encodeBase64_<T extends StringTensor>(
str: StringTensor|Tensor, pad = false): T {
const $str = convertToTensor(str, 'str', 'encodeBase64', 'string');

const backwardsFunc = (dy: T) => ({$str: () => decodeBase64(dy)});

return ENGINE.runKernel(
backend => backend.encodeBase64($str, pad), {$str}, backwardsFunc);
}

/**
* Decodes the values of a `tf.Tensor` (of dtype `string`) from Base64.
*
* Given a String tensor of Base64 encoded values, returns a new tensor with the
* decoded values.
*
* en.wikipedia.org/wiki/Base64
*
* ```js
* const y = tf.scalar('SGVsbG8gd29ybGQh', 'string');
*
* y.decodeBase64().print();
* ```
* @param str The input `tf.Tensor` of dtype `string` to decode.
*/
/** @doc {heading: 'Operations', subheading: 'String'} */
function decodeBase64_<T extends StringTensor>(str: StringTensor|Tensor): T {
const $str = convertToTensor(str, 'str', 'decodeBase64', 'string');

const backwardsFunc = (dy: T) => ({$str: () => encodeBase64(dy)});

return ENGINE.runKernel(
backend => backend.decodeBase64($str), {$str}, backwardsFunc);
}

export const encodeBase64 = op({encodeBase64_});
export const decodeBase64 = op({decodeBase64_});
Loading