Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tfjs-node/src/io/file_system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,19 @@ export class NodeFileSystem implements tfc.io.IOHandler {
paths: [this.WEIGHTS_BINARY_FILENAME],
weights: modelArtifacts.weightSpecs
}];
const modelJSON = {
const modelJSON: tfc.io.ModelJSON = {
modelTopology: modelArtifacts.modelTopology,
weightsManifest,
format: modelArtifacts.format,
generatedBy: modelArtifacts.generatedBy,
convertedBy: modelArtifacts.convertedBy
};
if (modelArtifacts.trainingConfig != null) {
modelJSON.trainingConfig = modelArtifacts.trainingConfig;
}
if (modelArtifacts.userDefinedMetadata != null) {
modelJSON.userDefinedMetadata = modelArtifacts.userDefinedMetadata;
}
const modelJSONPath = join(this.path, this.MODEL_JSON_FILENAME);
await writeFile(modelJSONPath, JSON.stringify(modelJSON), 'utf8');
await writeFile(
Expand Down Expand Up @@ -165,13 +174,22 @@ export class NodeFileSystem implements tfc.io.IOHandler {

const modelArtifacts: tfc.io.ModelArtifacts = {
modelTopology: modelJSON.modelTopology,
format: modelJSON.format,
generatedBy: modelJSON.generatedBy,
convertedBy: modelJSON.convertedBy
};
if (modelJSON.weightsManifest != null) {
const [weightSpecs, weightData] =
await this.loadWeights(modelJSON.weightsManifest, path);
modelArtifacts.weightSpecs = weightSpecs;
modelArtifacts.weightData = weightData;
}
if (modelJSON.trainingConfig != null) {
modelArtifacts.trainingConfig = modelJSON.trainingConfig;
}
if (modelJSON.userDefinedMetadata != null) {
modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
}
return modelArtifacts;
} else {
throw new Error(
Expand Down
60 changes: 60 additions & 0 deletions tfjs-node/src/io/file_system_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import * as tfc from '@tensorflow/tfjs-core';
import * as tfl from '@tensorflow/tfjs-layers';
import {expectArraysClose} from '@tensorflow/tfjs-core/dist/test_util';
import * as fs from 'fs';
import * as path from 'path';
Expand Down Expand Up @@ -451,6 +452,65 @@ describe('File system IOHandler', () => {
expect(typeof handler.load).toEqual('function');
});

it('Save and load model with loss and optimizer', async () => {
const model = tfl.sequential();
model.add(tfl.layers.dense({
units: 1,
kernelInitializer: 'zeros',
inputShape: [1]
}));
model.compile({
loss: 'meanSquaredError',
optimizer: tfc.train.adam(2.5e-2)
});

const xs = tfc.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tfc.tensor2d([-1, -3, -5, -7], [4, 1]);
await model.fit(xs, ys, {
epochs: 2,
shuffle: false,
verbose: 0
});

const saveURL = `file://${testDir}`;
const loadURL = `file://${testDir}/model.json`;

await model.save(saveURL, {includeOptimizer: true});
const model2 = await tfl.loadLayersModel(loadURL);
const optimizerConfig = model2.optimizer.getConfig();
expect(model2.optimizer.getClassName()).toEqual('Adam');
expect(optimizerConfig['learningRate']).toEqual(2.5e-2);

// Test that model2 can be trained immediately, without a compile() call
// due to the loaded optimizer and loss information.
const history2 = await model2.fit(xs, ys, {
epochs: 2,
shuffle: false,
verbose: 0
});
// The final loss value from training the model twice, 2 epochs
// at a time, should be equal to the final loss of trainig the
// model only once with 4 epochs.
expect(history2.history.loss[1]).toBeCloseTo(18.603);
});

it('Save and load model with user-defined metadata', async () => {
const model = tfl.sequential();
model.add(tfl.layers.dense({units: 3, inputShape: [4]}));
model.setUserDefinedMetadata({
'outputLabels': ['Label1', 'Label2', 'Label3']
});

const saveURL = `file://${testDir}`;
const loadURL = `file://${testDir}/model.json`;

await model.save(saveURL);
const model2 = await tfl.loadLayersModel(loadURL);
expect(model2.getUserDefinedMetadata()).toEqual({
'outputLabels': ['Label1', 'Label2', 'Label3']
});
});

describe('nodeFileSystemRouter', () => {
it('should handle single path', () => {
expect(nodeFileSystemRouter('file://model.json')).toBeDefined();
Expand Down