diff --git a/tfjs-node/src/io/file_system.ts b/tfjs-node/src/io/file_system.ts index 891a7a1128c..8880042974c 100644 --- a/tfjs-node/src/io/file_system.ts +++ b/tfjs-node/src/io/file_system.ts @@ -99,10 +99,19 @@ export class NodeFileSystem implements tfc.io.IOHandler { paths: [this.WEIGHTS_BINARY_FILENAME], weights: modelArtifacts.weightSpecs }]; - const modelJSON = { + const modelJSON: tfc.io.ModelJSON = { modelTopology: modelArtifacts.modelTopology, weightsManifest, + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy }; + if (modelArtifacts.trainingConfig != null) { + modelJSON.trainingConfig = modelArtifacts.trainingConfig; + } + if (modelArtifacts.userDefinedMetadata != null) { + modelJSON.userDefinedMetadata = modelArtifacts.userDefinedMetadata; + } const modelJSONPath = join(this.path, this.MODEL_JSON_FILENAME); await writeFile(modelJSONPath, JSON.stringify(modelJSON), 'utf8'); await writeFile( @@ -165,6 +174,9 @@ export class NodeFileSystem implements tfc.io.IOHandler { const modelArtifacts: tfc.io.ModelArtifacts = { modelTopology: modelJSON.modelTopology, + format: modelJSON.format, + generatedBy: modelJSON.generatedBy, + convertedBy: modelJSON.convertedBy }; if (modelJSON.weightsManifest != null) { const [weightSpecs, weightData] = @@ -172,6 +184,12 @@ export class NodeFileSystem implements tfc.io.IOHandler { modelArtifacts.weightSpecs = weightSpecs; modelArtifacts.weightData = weightData; } + if (modelJSON.trainingConfig != null) { + modelArtifacts.trainingConfig = modelJSON.trainingConfig; + } + if (modelJSON.userDefinedMetadata != null) { + modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata; + } return modelArtifacts; } else { throw new Error( diff --git a/tfjs-node/src/io/file_system_test.ts b/tfjs-node/src/io/file_system_test.ts index 67fda2d72b8..e132a85280b 100644 --- a/tfjs-node/src/io/file_system_test.ts +++ b/tfjs-node/src/io/file_system_test.ts @@ -16,6 +16,7 @@ */ import * as tfc from '@tensorflow/tfjs-core'; +import * as tfl from '@tensorflow/tfjs-layers'; import {expectArraysClose} from '@tensorflow/tfjs-core/dist/test_util'; import * as fs from 'fs'; import * as path from 'path'; @@ -451,6 +452,65 @@ describe('File system IOHandler', () => { expect(typeof handler.load).toEqual('function'); }); + it('Save and load model with loss and optimizer', async () => { + const model = tfl.sequential(); + model.add(tfl.layers.dense({ + units: 1, + kernelInitializer: 'zeros', + inputShape: [1] + })); + model.compile({ + loss: 'meanSquaredError', + optimizer: tfc.train.adam(2.5e-2) + }); + + const xs = tfc.tensor2d([1, 2, 3, 4], [4, 1]); + const ys = tfc.tensor2d([-1, -3, -5, -7], [4, 1]); + await model.fit(xs, ys, { + epochs: 2, + shuffle: false, + verbose: 0 + }); + + const saveURL = `file://${testDir}`; + const loadURL = `file://${testDir}/model.json`; + + await model.save(saveURL, {includeOptimizer: true}); + const model2 = await tfl.loadLayersModel(loadURL); + const optimizerConfig = model2.optimizer.getConfig(); + expect(model2.optimizer.getClassName()).toEqual('Adam'); + expect(optimizerConfig['learningRate']).toEqual(2.5e-2); + + // Test that model2 can be trained immediately, without a compile() call + // due to the loaded optimizer and loss information. + const history2 = await model2.fit(xs, ys, { + epochs: 2, + shuffle: false, + verbose: 0 + }); + // The final loss value from training the model twice, 2 epochs + // at a time, should be equal to the final loss of trainig the + // model only once with 4 epochs. + expect(history2.history.loss[1]).toBeCloseTo(18.603); + }); + + it('Save and load model with user-defined metadata', async () => { + const model = tfl.sequential(); + model.add(tfl.layers.dense({units: 3, inputShape: [4]})); + model.setUserDefinedMetadata({ + 'outputLabels': ['Label1', 'Label2', 'Label3'] + }); + + const saveURL = `file://${testDir}`; + const loadURL = `file://${testDir}/model.json`; + + await model.save(saveURL); + const model2 = await tfl.loadLayersModel(loadURL); + expect(model2.getUserDefinedMetadata()).toEqual({ + 'outputLabels': ['Label1', 'Label2', 'Label3'] + }); + }); + describe('nodeFileSystemRouter', () => { it('should handle single path', () => { expect(nodeFileSystemRouter('file://model.json')).toBeDefined();