Skip to content

Commit

Permalink
unit test dataSync() error
Browse files Browse the repository at this point in the history
  • Loading branch information
vabarbosa committed Jun 11, 2019
1 parent 4eefddd commit c209006
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/backends/string_shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,50 @@
*/

import {arrayBufferToBase64String, arrayBufferToString, base64StringToArrayBuffer, stringToArrayBuffer, urlSafeBase64, urlUnsafeBase64} from '../io/io_utils';
import * as ops from '../ops/ops';
import {StringTensor, Tensor} from '../tensor';

/** Shared implementation of the encodeBase64 kernel across WebGL and CPU. */
export function encodeBase64<T extends StringTensor>(
str: StringTensor|Tensor, pad = false): T {
const resultValues = new Array(str.size);
const values = str.dataSync();
const buffer = ops.buffer(str.shape, str.dtype);
const strBuffer = str.bufferSync();

for (let i = 0; i < buffer.size; ++i) {
const loc = buffer.indexToLoc(i);
const value = strBuffer.get(...loc).toString();

for (let i = 0; i < values.length; ++i) {
// Convert from string to ArrayBuffer of UTF-8 multibyte sequence
// tslint:disable-next-line: max-line-length
// https://developer.mozilla.org/en-US/docs/Web/API/WindowBase64/Base64_encoding_and_decoding#The_Unicode_Problem
const aBuff = stringToArrayBuffer(values[i].toString());
const aBuff = stringToArrayBuffer(value);

// Encode to Base64 and make URL safe
const bVal = urlSafeBase64(arrayBufferToBase64String(aBuff));

// Remove padding
resultValues[i] = pad ? bVal : bVal.replace(/=/g, '');
buffer.values[i] = pad ? bVal : bVal.replace(/=/g, '');
}

return Tensor.make(str.shape, {values: resultValues}, str.dtype) as T;
return buffer.toTensor() as T;
}

/** Shared implementation of the decodeBase64 kernel across WebGL and CPU. */
export function decodeBase64<T extends StringTensor>(str: StringTensor|
Tensor): T {
const resultValues = new Array(str.size);
const values = str.dataSync();
const buffer = ops.buffer(str.shape, str.dtype);
const strBuffer = str.bufferSync();

for (let i = 0; i < buffer.size; ++i) {
const loc = buffer.indexToLoc(i);
const value = strBuffer.get(...loc).toString();

for (let i = 0; i < values.length; ++i) {
// Undo URL safe and decode from Base64 to ArrayBuffer
const aBuff =
base64StringToArrayBuffer(urlUnsafeBase64(values[i].toString()));
const aBuff = base64StringToArrayBuffer(urlUnsafeBase64(value));

// Convert from ArrayBuffer of UTF-8 multibyte sequence to string
resultValues[i] = arrayBufferToString(aBuff);
buffer.values[i] = arrayBufferToString(aBuff);
}

return Tensor.make(str.shape, {values: resultValues}, str.dtype) as T;
return buffer.toTensor() as T;
}

0 comments on commit c209006

Please sign in to comment.