Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tfjs-core/src/io/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ export class HTTPRequest implements IOHandler {
protected readonly requestInit: RequestInit;

private readonly fetch: Function;
private readonly weightUrlConverter:
(weightName: string) => Promise<string>;

readonly DEFAULT_METHOD = 'POST';

Expand All @@ -50,6 +52,7 @@ export class HTTPRequest implements IOHandler {
}
this.weightPathPrefix = loadOptions.weightPathPrefix;
this.onProgress = loadOptions.onProgress;
this.weightUrlConverter = loadOptions.weightUrlConverter;

if (loadOptions.fetchFunc != null) {
assert(
Expand Down Expand Up @@ -215,11 +218,21 @@ export class HTTPRequest implements IOHandler {
}

const fetchURLs: string[] = [];
weightsManifest.forEach(weightsGroup => {
weightsGroup.paths.forEach(path => {
fetchURLs.push(pathPrefix + path + suffix);
});
});
const urlPromises: Array<Promise<string>> = [];
for (const weightsGroup of weightsManifest) {
for (const path of weightsGroup.paths) {
if (this.weightUrlConverter != null) {
urlPromises.push(this.weightUrlConverter(path));
} else {
fetchURLs.push(pathPrefix + path + suffix);
}
}
}

if (this.weightUrlConverter) {
fetchURLs.push(...await Promise.all(urlPromises));
}

const buffers = await loadWeightsAsArrayBuffer(fetchURLs, {
requestInit: this.requestInit,
fetchFunc: this.fetch,
Expand Down
54 changes: 54 additions & 0 deletions tfjs-core/src/io/http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,60 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
done();
}
});
it('Provide WeightFileTranslateFunc', async () => {
const weightManifest1: tf.io.WeightsManifestConfig = [{
paths: ['weightfile0'],
weights: [
{
name: 'dense/kernel',
shape: [3, 1],
dtype: 'float32',
},
{
name: 'dense/bias',
shape: [2],
dtype: 'float32',
}
]
}];
const floatData = new Float32Array([1, 3, 3, 7, 4]);
setupFakeWeightFiles(
{
'./model.json': {
data: JSON.stringify({
modelTopology: modelTopology1,
weightsManifest: weightManifest1
}),
contentType: 'application/json'
},
'auth_weightfile0':
{data: floatData, contentType: 'application/octet-stream'},
},
requestInits);
async function prefixWeightUrlConverter(weightFile: string):
Promise<string> {
// Add 'auth_' prefix to the weight file url.
return new Promise(
resolve => setTimeout(resolve, 1, 'auth_' + weightFile));
}

const handler = tf.io.http('./model.json', {
requestInit: {headers: {'header_key_1': 'header_value_1'}},
weightUrlConverter: prefixWeightUrlConverter
});
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
expect(Object.keys(requestInits).length).toEqual(2);
expect(requestInits['./model.json'].headers['header_key_1'])
.toEqual('header_value_1');
expect(requestInits['auth_weightfile0'].headers['header_key_1'])
.toEqual('header_value_1');

expect(fetchSpy.calls.mostRecent().object).toEqual(window);
});
});

it('Overriding BrowserHTTPRequest fetchFunc', async () => {
Expand Down
18 changes: 15 additions & 3 deletions tfjs-core/src/io/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ export declare interface WeightsManifestEntry {
* Information for dequantization of the weight.
*/
quantization?: {
scale?: number, // The scaling constant to multiply by.
min?: number, // The (possibly nudged) minimum weight to add.
dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights.
scale?: number, // The scaling constant to multiply by.
min?: number, // The (possibly nudged) minimum weight to add.
dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights.
};
}

Expand Down Expand Up @@ -464,6 +464,18 @@ export interface LoadOptions {
* Default: `false`.
*/
fromTFHub?: boolean;

/**
* An async function to convert weight file name to URL. The weight file
* names are stored in model.json's weightsManifest.paths field. By default we
* consider weight files are colocated with the model.json file. For example:
* model.json URL: https://www.google.com/models/1/model.json
* group1-shard1of1.bin url:
* https://www.google.com/models/1/group1-shard1of1.bin
*
* With this func you can convert the weight file name to any URL.
*/
weightUrlConverter?: (weightFileName: string) => Promise<string>;
}

/**
Expand Down