-
Notifications
You must be signed in to change notification settings - Fork 950
Serialize string tensors as encoded (raw bytes) #1816
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 24 of 24 files at r1.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @dsmilkov, @nsthorat, and @pyu10055)
src/tensor_util_env.ts, line 114 at r1 (raw file):
const values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype as DataType, ENV.getBool('DEBUG')) : flatten(x as string[], [], true) as string[];
true => named
src/io/io_utils.ts, line 65 at r1 (raw file):
const vals = await t.bytes() as Uint8Array[]; const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) + 4 * vals.length;
can you pull the 4 and zero to named constants here and elsewhere
src/ops/tensor_ops.ts, line 54 at r1 (raw file):
*/ /** @doc {heading: 'Tensors', subheading: 'Creation'} */ function tensor<R extends Rank>(
I think it would be good to document somewhere that if you pass a string literal then we encode in utf8
src/ops/tensor_ops.ts, line 140 at r1 (raw file):
throw new Error( 'When making a scalar from encoded string, ' + 'the value must be Uint8Array');
period
src/platforms/platform_browser.ts, line 37 at r1 (raw file):
} decode(bytes: Uint8Array, encoding: string): string { return new TextDecoder(encoding).decode(bytes);
do you want to create and destroy this each time? how come this isnt the same as the encoder? same question for node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @nsthorat and @pyu10055)
src/tensor_util_env.ts, line 114 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
true => named
Done.
src/io/io_utils.ts, line 65 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
can you pull the 4 and zero to named constants here and elsewhere
Done.
src/ops/tensor_ops.ts, line 54 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
I think it would be good to document somewhere that if you pass a string literal then we encode in utf8
Done.
src/ops/tensor_ops.ts, line 140 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
period
Done.
src/platforms/platform_browser.ts, line 37 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
do you want to create and destroy this each time? how come this isnt the same as the encoder? same question for node
That unfortunately the API according to the spec. Text decoders take encoding scheme as part of their c-tor, while text encoder doesn't (can only encode utf-8)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 3 of 24 files at r1.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @nsthorat and @pyu10055)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this.
Reviewed 1 of 3 files at r2, 1 of 2 files at r3.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @nsthorat)
Allow serialization and deserialization of string weights in the binary weights format. This is needed to enable execution of AutoML models, which store the vocab as a weight (`Const` of dtype `string`). Corresponding change in tfjs-core: tensorflow/tfjs-core#1816 Fixes tensorflow/tfjs#1598
To align with TensorFlow Python/C++, this PR changes the way we serialize strings in the weights format, and in our engine.
Uint8Array
. Thus a string tensors (which has multiple strings) is backed byUint8Array[]
.tensor.data()
returnsstring[]
, which means that we try to utf-8 decode a string.[4 bytes][string1...][4 bytes][string2...][4 bytes][string3....]
util.encodeString(text: string, encoding?: string)
andutil.decodeString(bytes: Uint8Array, encoding?: string)
, along with the respectivePlatform
methodstensor.bytes()
which gives the underlying bytes of the data -Uint8Array
for any numeric tensor, andUint8Array[]
for string tensors.Corresponding change in tfjs-converter: tensorflow/tfjs-converter#386
This change is