From 7f14c4fe49748d70a6b483e87de961d915342b34 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Thu, 13 Aug 2020 14:50:43 -0700 Subject: [PATCH 1/4] add weightUrlTranslationFunc to loadOptions --- tfjs-core/src/io/http.ts | 10 ++++++-- tfjs-core/src/io/http_test.ts | 48 +++++++++++++++++++++++++++++++++++ tfjs-core/src/io/types.ts | 17 ++++++++++--- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index d2f339754af..a6658827a56 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -36,6 +36,7 @@ export class HTTPRequest implements IOHandler { protected readonly requestInit: RequestInit; private readonly fetch: Function; + private readonly weightUrlTranslationFunc: (weightName: string) => Promise; readonly DEFAULT_METHOD = 'POST'; @@ -50,6 +51,7 @@ export class HTTPRequest implements IOHandler { } this.weightPathPrefix = loadOptions.weightPathPrefix; this.onProgress = loadOptions.onProgress; + this.weightUrlTranslationFunc = loadOptions.weightUrlTranslateFunc; if (loadOptions.fetchFunc != null) { assert( @@ -216,8 +218,12 @@ export class HTTPRequest implements IOHandler { const fetchURLs: string[] = []; weightsManifest.forEach(weightsGroup => { - weightsGroup.paths.forEach(path => { - fetchURLs.push(pathPrefix + path + suffix); + weightsGroup.paths.forEach(async (path) => { + if (this.weightUrlTranslationFunc) { + fetchURLs.push(await this.weightUrlTranslationFunc(path)); + } else { + fetchURLs.push(pathPrefix + path + suffix); + } }); }); const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 0a4971f3cbc..53deb4a3c48 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -829,6 +829,54 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(fetchInputs).toEqual(['./model.json', './weightfile0']); + expect(fetchInits.length).toEqual(2); + expect(fetchInits[0].credentials).toEqual('include'); + expect(fetchInits[1].credentials).toEqual('include'); + }); + it('Provide WeightFileTranslateFunc', async () => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData = new Float32Array([1, 3, 3, 7, 4]); + setupFakeWeightFiles( + { + 'path1/model.json': { + data: JSON.stringify({weightManifest1}), + contentType: 'application/json' + }, + 'auth_weightfile0': + {data: floatData, contentType: 'application/octet-stream'} + }, + {}); + const fetchInputs: RequestInfo[] = []; + const fetchInits: RequestInit[] = []; + async function WeightFileTranslateFunc(weightFile: string): + Promise { + return 'auth_' + weightFile; + } + + const handler = tf.io.http('./model.json', { + requestInit: {credentials: 'include'}, + weightUrlTranslateFunc: WeightFileTranslateFunc + }); + const modelArtifacts = await handler.load(); + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(fetchInputs).toEqual(['./model.json', './weightfile0']); expect(fetchInits.length).toEqual(2); expect(fetchInits[0].credentials).toEqual('include'); diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index 63c5e3deff0..9ae95815d93 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -103,9 +103,9 @@ export declare interface WeightsManifestEntry { * Information for dequantization of the weight. */ quantization?: { - scale?: number, // The scaling constant to multiply by. - min?: number, // The (possibly nudged) minimum weight to add. - dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights. + scale?: number, // The scaling constant to multiply by. + min?: number, // The (possibly nudged) minimum weight to add. + dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights. }; } @@ -464,6 +464,17 @@ export interface LoadOptions { * Default: `false`. */ fromTFHub?: boolean; + + /** + * A function to translate weight file name to URL. By default we consider + * weight files are colocated with the model.json file. For example: + * model.json URL: https://www.google.com/models/1/model.json + * group1-shard1of1.bin url: + * https://www.google.com/models/1/group1-shard1of1.bin + * + * With this func you can translate the weight file name to any URL. + */ + weightUrlTranslateFunc?: (weightFileName: string) => Promise; } /** From cbfb51cc7b2b1a309e94a279749ade273f4f9cf0 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Thu, 13 Aug 2020 18:20:06 -0700 Subject: [PATCH 2/4] fix lint --- tfjs-core/src/io/http.ts | 3 +- tfjs-core/src/io/http_test.ts | 97 ++++++++++++++++++----------------- 2 files changed, 51 insertions(+), 49 deletions(-) diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index a6658827a56..1ef4cc95e10 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -36,7 +36,8 @@ export class HTTPRequest implements IOHandler { protected readonly requestInit: RequestInit; private readonly fetch: Function; - private readonly weightUrlTranslationFunc: (weightName: string) => Promise; + private readonly weightUrlTranslationFunc: + (weightName: string) => Promise; readonly DEFAULT_METHOD = 'POST'; diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 53deb4a3c48..6a66edff136 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -777,6 +777,55 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { done(); } }); + it('Provide WeightFileTranslateFunc', async () => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData = new Float32Array([1, 3, 3, 7, 4]); + setupFakeWeightFiles( + { + 'path1/model.json': { + data: JSON.stringify({weightManifest1}), + contentType: 'application/json' + }, + 'auth_weightfile0': + {data: floatData, contentType: 'application/octet-stream'} + }, + requestInits); + const fetchInputs: RequestInfo[] = []; + const fetchInits: RequestInit[] = []; + async function weightUrlTranslateFunc(weightFile: string): + Promise { + console.log(weightFile); + return 'auth_' + weightFile; + } + + const handler = tf.io.http('./model.json', { + requestInit: {credentials: 'include'}, + weightUrlTranslateFunc + }); + const modelArtifacts = await handler.load(); + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + + expect(fetchInputs).toEqual(['./model.json', './weightfile0']); + expect(fetchInits.length).toEqual(2); + expect(fetchInits[0].credentials).toEqual('include'); + expect(fetchInits[1].credentials).toEqual('include'); + }); }); it('Overriding BrowserHTTPRequest fetchFunc', async () => { @@ -829,54 +878,6 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); - expect(fetchInputs).toEqual(['./model.json', './weightfile0']); - expect(fetchInits.length).toEqual(2); - expect(fetchInits[0].credentials).toEqual('include'); - expect(fetchInits[1].credentials).toEqual('include'); - }); - it('Provide WeightFileTranslateFunc', async () => { - const weightManifest1: tf.io.WeightsManifestConfig = [{ - paths: ['weightfile0'], - weights: [ - { - name: 'dense/kernel', - shape: [3, 1], - dtype: 'float32', - }, - { - name: 'dense/bias', - shape: [2], - dtype: 'float32', - } - ] - }]; - const floatData = new Float32Array([1, 3, 3, 7, 4]); - setupFakeWeightFiles( - { - 'path1/model.json': { - data: JSON.stringify({weightManifest1}), - contentType: 'application/json' - }, - 'auth_weightfile0': - {data: floatData, contentType: 'application/octet-stream'} - }, - {}); - const fetchInputs: RequestInfo[] = []; - const fetchInits: RequestInit[] = []; - async function WeightFileTranslateFunc(weightFile: string): - Promise { - return 'auth_' + weightFile; - } - - const handler = tf.io.http('./model.json', { - requestInit: {credentials: 'include'}, - weightUrlTranslateFunc: WeightFileTranslateFunc - }); - const modelArtifacts = await handler.load(); - expect(modelArtifacts.modelTopology).toEqual(modelTopology1); - expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); - expect(fetchInputs).toEqual(['./model.json', './weightfile0']); expect(fetchInits.length).toEqual(2); expect(fetchInits[0].credentials).toEqual('include'); From 63a736353ab701abbb5ac342bcd995a470675aa2 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 14 Aug 2020 08:52:11 -0700 Subject: [PATCH 3/4] fix async call for array iteration --- tfjs-core/src/io/http.ts | 13 +++++++------ tfjs-core/src/io/http_test.ts | 25 ++++++++++++++----------- tfjs-core/src/io/types.ts | 4 ++-- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 1ef4cc95e10..1b9575b4db2 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -218,15 +218,16 @@ export class HTTPRequest implements IOHandler { } const fetchURLs: string[] = []; - weightsManifest.forEach(weightsGroup => { - weightsGroup.paths.forEach(async (path) => { - if (this.weightUrlTranslationFunc) { - fetchURLs.push(await this.weightUrlTranslationFunc(path)); + for await (const weightsGroup of weightsManifest) { + for await (const path of weightsGroup.paths) { + if (this.weightUrlTranslationFunc != null) { + const url = await this.weightUrlTranslationFunc(path); + fetchURLs.push(url); } else { fetchURLs.push(pathPrefix + path + suffix); } - }); - }); + } + } const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { requestInit: this.requestInit, fetchFunc: this.fetch, diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 6a66edff136..660e1d1b044 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -796,35 +796,38 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { const floatData = new Float32Array([1, 3, 3, 7, 4]); setupFakeWeightFiles( { - 'path1/model.json': { - data: JSON.stringify({weightManifest1}), + './model.json': { + data: JSON.stringify({ + modelTopology: modelTopology1, + weightsManifest: weightManifest1 + }), contentType: 'application/json' }, 'auth_weightfile0': - {data: floatData, contentType: 'application/octet-stream'} + {data: floatData, contentType: 'application/octet-stream'}, }, requestInits); - const fetchInputs: RequestInfo[] = []; - const fetchInits: RequestInit[] = []; async function weightUrlTranslateFunc(weightFile: string): Promise { - console.log(weightFile); return 'auth_' + weightFile; } const handler = tf.io.http('./model.json', { - requestInit: {credentials: 'include'}, + requestInit: {headers: {'header_key_1': 'header_value_1'}}, weightUrlTranslateFunc }); const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(Object.keys(requestInits).length).toEqual(2); + expect(Object.keys(requestInits).length).toEqual(2); + expect(requestInits['./model.json'].headers['header_key_1']) + .toEqual('header_value_1'); + expect(requestInits['auth_weightfile0'].headers['header_key_1']) + .toEqual('header_value_1'); - expect(fetchInputs).toEqual(['./model.json', './weightfile0']); - expect(fetchInits.length).toEqual(2); - expect(fetchInits[0].credentials).toEqual('include'); - expect(fetchInits[1].credentials).toEqual('include'); + expect(fetchSpy.calls.mostRecent().object).toEqual(window); }); }); diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index 9ae95815d93..cd623cc3092 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -466,8 +466,8 @@ export interface LoadOptions { fromTFHub?: boolean; /** - * A function to translate weight file name to URL. By default we consider - * weight files are colocated with the model.json file. For example: + * An async function to translate weight file name to URL. By default we + * consider weight files are colocated with the model.json file. For example: * model.json URL: https://www.google.com/models/1/model.json * group1-shard1of1.bin url: * https://www.google.com/models/1/group1-shard1of1.bin From 41b7284f08d59552ef57d92d87b7daf4ec89ad38 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Fri, 14 Aug 2020 18:01:48 -0700 Subject: [PATCH 4/4] address comments --- tfjs-core/src/io/http.ts | 19 ++++++++++++------- tfjs-core/src/io/http_test.ts | 8 +++++--- tfjs-core/src/io/types.ts | 7 ++++--- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index 1b9575b4db2..68f60509679 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -36,7 +36,7 @@ export class HTTPRequest implements IOHandler { protected readonly requestInit: RequestInit; private readonly fetch: Function; - private readonly weightUrlTranslationFunc: + private readonly weightUrlConverter: (weightName: string) => Promise; readonly DEFAULT_METHOD = 'POST'; @@ -52,7 +52,7 @@ export class HTTPRequest implements IOHandler { } this.weightPathPrefix = loadOptions.weightPathPrefix; this.onProgress = loadOptions.onProgress; - this.weightUrlTranslationFunc = loadOptions.weightUrlTranslateFunc; + this.weightUrlConverter = loadOptions.weightUrlConverter; if (loadOptions.fetchFunc != null) { assert( @@ -218,16 +218,21 @@ export class HTTPRequest implements IOHandler { } const fetchURLs: string[] = []; - for await (const weightsGroup of weightsManifest) { - for await (const path of weightsGroup.paths) { - if (this.weightUrlTranslationFunc != null) { - const url = await this.weightUrlTranslationFunc(path); - fetchURLs.push(url); + const urlPromises: Array> = []; + for (const weightsGroup of weightsManifest) { + for (const path of weightsGroup.paths) { + if (this.weightUrlConverter != null) { + urlPromises.push(this.weightUrlConverter(path)); } else { fetchURLs.push(pathPrefix + path + suffix); } } } + + if (this.weightUrlConverter) { + fetchURLs.push(...await Promise.all(urlPromises)); + } + const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { requestInit: this.requestInit, fetchFunc: this.fetch, diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 660e1d1b044..ddeb5340c82 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -807,14 +807,16 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { {data: floatData, contentType: 'application/octet-stream'}, }, requestInits); - async function weightUrlTranslateFunc(weightFile: string): + async function prefixWeightUrlConverter(weightFile: string): Promise { - return 'auth_' + weightFile; + // Add 'auth_' prefix to the weight file url. + return new Promise( + resolve => setTimeout(resolve, 1, 'auth_' + weightFile)); } const handler = tf.io.http('./model.json', { requestInit: {headers: {'header_key_1': 'header_value_1'}}, - weightUrlTranslateFunc + weightUrlConverter: prefixWeightUrlConverter }); const modelArtifacts = await handler.load(); expect(modelArtifacts.modelTopology).toEqual(modelTopology1); diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index cd623cc3092..3e7bc818bbe 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -466,15 +466,16 @@ export interface LoadOptions { fromTFHub?: boolean; /** - * An async function to translate weight file name to URL. By default we + * An async function to convert weight file name to URL. The weight file + * names are stored in model.json's weightsManifest.paths field. By default we * consider weight files are colocated with the model.json file. For example: * model.json URL: https://www.google.com/models/1/model.json * group1-shard1of1.bin url: * https://www.google.com/models/1/group1-shard1of1.bin * - * With this func you can translate the weight file name to any URL. + * With this func you can convert the weight file name to any URL. */ - weightUrlTranslateFunc?: (weightFileName: string) => Promise; + weightUrlConverter?: (weightFileName: string) => Promise; } /**