diff --git a/tfjs-core/src/io/http.ts b/tfjs-core/src/io/http.ts index d2f339754af..68f60509679 100644 --- a/tfjs-core/src/io/http.ts +++ b/tfjs-core/src/io/http.ts @@ -36,6 +36,8 @@ export class HTTPRequest implements IOHandler { protected readonly requestInit: RequestInit; private readonly fetch: Function; + private readonly weightUrlConverter: + (weightName: string) => Promise; readonly DEFAULT_METHOD = 'POST'; @@ -50,6 +52,7 @@ export class HTTPRequest implements IOHandler { } this.weightPathPrefix = loadOptions.weightPathPrefix; this.onProgress = loadOptions.onProgress; + this.weightUrlConverter = loadOptions.weightUrlConverter; if (loadOptions.fetchFunc != null) { assert( @@ -215,11 +218,21 @@ export class HTTPRequest implements IOHandler { } const fetchURLs: string[] = []; - weightsManifest.forEach(weightsGroup => { - weightsGroup.paths.forEach(path => { - fetchURLs.push(pathPrefix + path + suffix); - }); - }); + const urlPromises: Array> = []; + for (const weightsGroup of weightsManifest) { + for (const path of weightsGroup.paths) { + if (this.weightUrlConverter != null) { + urlPromises.push(this.weightUrlConverter(path)); + } else { + fetchURLs.push(pathPrefix + path + suffix); + } + } + } + + if (this.weightUrlConverter) { + fetchURLs.push(...await Promise.all(urlPromises)); + } + const buffers = await loadWeightsAsArrayBuffer(fetchURLs, { requestInit: this.requestInit, fetchFunc: this.fetch, diff --git a/tfjs-core/src/io/http_test.ts b/tfjs-core/src/io/http_test.ts index 0a4971f3cbc..ddeb5340c82 100644 --- a/tfjs-core/src/io/http_test.ts +++ b/tfjs-core/src/io/http_test.ts @@ -777,6 +777,60 @@ describeWithFlags('http-load', BROWSER_ENVS, () => { done(); } }); + it('Provide WeightFileTranslateFunc', async () => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData = new Float32Array([1, 3, 3, 7, 4]); + setupFakeWeightFiles( + { + './model.json': { + data: JSON.stringify({ + modelTopology: modelTopology1, + weightsManifest: weightManifest1 + }), + contentType: 'application/json' + }, + 'auth_weightfile0': + {data: floatData, contentType: 'application/octet-stream'}, + }, + requestInits); + async function prefixWeightUrlConverter(weightFile: string): + Promise { + // Add 'auth_' prefix to the weight file url. + return new Promise( + resolve => setTimeout(resolve, 1, 'auth_' + weightFile)); + } + + const handler = tf.io.http('./model.json', { + requestInit: {headers: {'header_key_1': 'header_value_1'}}, + weightUrlConverter: prefixWeightUrlConverter + }); + const modelArtifacts = await handler.load(); + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); + expect(Object.keys(requestInits).length).toEqual(2); + expect(Object.keys(requestInits).length).toEqual(2); + expect(requestInits['./model.json'].headers['header_key_1']) + .toEqual('header_value_1'); + expect(requestInits['auth_weightfile0'].headers['header_key_1']) + .toEqual('header_value_1'); + + expect(fetchSpy.calls.mostRecent().object).toEqual(window); + }); }); it('Overriding BrowserHTTPRequest fetchFunc', async () => { diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index 63c5e3deff0..3e7bc818bbe 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -103,9 +103,9 @@ export declare interface WeightsManifestEntry { * Information for dequantization of the weight. */ quantization?: { - scale?: number, // The scaling constant to multiply by. - min?: number, // The (possibly nudged) minimum weight to add. - dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights. + scale?: number, // The scaling constant to multiply by. + min?: number, // The (possibly nudged) minimum weight to add. + dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights. }; } @@ -464,6 +464,18 @@ export interface LoadOptions { * Default: `false`. */ fromTFHub?: boolean; + + /** + * An async function to convert weight file name to URL. The weight file + * names are stored in model.json's weightsManifest.paths field. By default we + * consider weight files are colocated with the model.json file. For example: + * model.json URL: https://www.google.com/models/1/model.json + * group1-shard1of1.bin url: + * https://www.google.com/models/1/group1-shard1of1.bin + * + * With this func you can convert the weight file name to any URL. + */ + weightUrlConverter?: (weightFileName: string) => Promise; } /**