diff --git a/tfjs-layers/src/engine/container.ts b/tfjs-layers/src/engine/container.ts index 088c774dc83..f6d1d74508a 100644 --- a/tfjs-layers/src/engine/container.ts +++ b/tfjs-layers/src/engine/container.ts @@ -594,12 +594,22 @@ export abstract class Container extends Layer { loadWeights(weights: NamedTensorMap, strict = true) { const nameToWeight: {[name: string]: LayerVariable} = {}; let totalWeightsCount = 0; + // get weights key from tensor map in order to check if it is from keras v3. + // e.g. dense/0 + const key = Object.keys(weights)[0].split('/'); + const isKerasSavedModelFormat = !isNaN(parseInt(key[key.length - 1], 10)); + // Check if weights from keras v3. for (const layer of this.layers) { - for (const weight of layer.weights) { - if (nameToWeight[weight.originalName] != null) { - throw new ValueError(`Duplicate weight name: ${weight.originalName}`); + for (const [index, weight] of layer.weights.entries()) { + // Parse the name to layerName/index. + // e.g. dense/0, dense/1, dense_1/0, dense_1/1 + const parsedName = isKerasSavedModelFormat ? + `${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` : + weight.originalName; + if (nameToWeight[parsedName] != null) { + throw new ValueError(`Duplicate weight name: ${parsedName}`); } - nameToWeight[weight.originalName] = weight; + nameToWeight[parsedName] = weight; totalWeightsCount++; } } diff --git a/tfjs-layers/src/engine/topology_test.ts b/tfjs-layers/src/engine/topology_test.ts index 0b7260ddb43..eaa93449eaf 100644 --- a/tfjs-layers/src/engine/topology_test.ts +++ b/tfjs-layers/src/engine/topology_test.ts @@ -1087,6 +1087,23 @@ describeMathCPUAndGPU('loadWeightsFromNamedTensorMap', () => { expectTensorsClose(denseLayer.weights[1].read(), tensor1d([10, 20])); }); + it('load keras weights', () => { + const denseLayer = + tfl.layers.dense({units: 2, useBias: true, name: 'dense_layer'}); + const output = denseLayer.apply(inputTensor) as SymbolicTensor; + const model = tfl.model({inputs: inputTensor, outputs: output}); + + const namedWeightsMap: NamedTensorMap = {}; + namedWeightsMap[denseLayer.weights[0].originalName.split('/')[0] + '/0'] = + tensor2d([1, 2, 3, 4, 5, 6], [3, 2]); + namedWeightsMap[denseLayer.weights[1].originalName.split('/')[0] + '/1'] = + tensor1d([10, 20]); + model.loadWeights(namedWeightsMap); + expectTensorsClose( + denseLayer.weights[0].read(), tensor2d([1, 2, 3, 4, 5, 6], [3, 2])); + expectTensorsClose(denseLayer.weights[1].read(), tensor1d([10, 20])); + }); + it('Mismatching shape throws an error even in non-strict mode', () => { const denseLayer = tfl.layers.dense({units: 2, useBias: true, name: 'dense_layer'});