Skip to content

Commit

Permalink
Update weights loading (#7872)
Browse files Browse the repository at this point in the history
* Update weights loading

* fix tests

* remove

* fix

* fix comments

* fix lint
  • Loading branch information
fengwuyao committed Aug 4, 2023
1 parent ba8ee4d commit aaa637e
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tfjs-layers/src/engine/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ export abstract class Container extends Layer {
// e.g. dense/0
const key = Object.keys(weights)[0].split('/');
const isKerasSavedModelFormat = !isNaN(parseInt(key[key.length - 1], 10));
if (isKerasSavedModelFormat) {
this.parseWeights(weights);
}
// Check if weights from keras v3.
for (const layer of this.layers) {
for (const [index, weight] of layer.weights.entries()) {
Expand Down Expand Up @@ -652,6 +655,33 @@ export abstract class Container extends Layer {
batchSetValue(weightValueTuples);
}

protected parseWeights(weights: NamedTensorMap) {
for (const key in Object.keys(weights)) {
const listParts = key.split('/');
const list = ['vars', 'layer_checkpoint_dependencies'];
// For keras v3, the weights name are saved based on the folder structure.
// e.g. _backbone/_layer_checkpoint_dependencies/transformer/_self../
// _output_dense/vars/0
// Therefore we discard the `vars` and `layer_checkpoint_depencies` within
// the saved name and only keeps the layer name and weights.
// This can help to mapping the actual name of the layers and load each
// weight accordingly.
const newKey = listParts
.map(str => {
if (str.startsWith('_')) {
return str.slice(1);
}
return str;
})
.filter(str => !list.includes(str))
.join('/');
if (newKey !== key) {
weights[newKey] = weights[key];
delete weights[key];
}
}
}

/**
* Util shared between different serialization methods.
* @returns LayersModel config with Keras version information added.
Expand Down

0 comments on commit aaa637e

Please sign in to comment.