Skip to content

Commit

Permalink
Merge pull request #2529 from Adel-Moumen/fix/fix_keys
Browse files Browse the repository at this point in the history
Fix keys name for Transformer
  • Loading branch information
asumagic committed Jun 5, 2024
2 parents cd86624 + 62ad576 commit e627b42
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 6 deletions.
7 changes: 7 additions & 0 deletions speechbrain/lobes/models/transformer/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.attention import RelPosEncXL
from speechbrain.nnet.CNN import Conv1d
from speechbrain.utils.checkpoints import map_old_state_dict_weights

from .Branchformer import BranchformerEncoder
from .Conformer import ConformerEncoder
Expand Down Expand Up @@ -779,6 +780,12 @@ def forward(

return tgt, self_attn, multihead_attention

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Load the model from a state_dict and map the old keys to the new keys."""
mapping = {"mutihead_attention": "multihead_attention"}
state_dict = map_old_state_dict_weights(state_dict, mapping)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class TransformerDecoder(nn.Module):
"""This class implements the Transformer decoder.
Expand Down
113 changes: 107 additions & 6 deletions speechbrain/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
Authors
* Aku Rouhe 2020
* Adel Moumen 2024
"""

import collections
Expand All @@ -57,6 +58,7 @@
import shutil
import time
import warnings
from typing import Dict

import torch
import yaml
Expand All @@ -75,6 +77,77 @@
CKPT_PREFIX = "CKPT"
METAFNAME = f"{CKPT_PREFIX}.yaml" # Important that this is not .ckpt
PARAMFILE_EXT = ".ckpt" # ...because these files will be
# some keys have been renamed in the new version of the code
KEYS_MAPPING: Dict[str, str] = {
".mutihead_attn": ".multihead_attn", # see PR #2489
}


def map_old_state_dict_weights(
state_dict: Dict[str, torch.Tensor], mapping: Dict[str, str]
) -> Dict[str, torch.Tensor]:
"""
Maps the keys in the old state dictionary according to the provided mapping.
NOTE: This function will remap all state_dict keys that contain the old key.
For instance, if the state_dict is {'model.encoder.layer.0.atn.self.query.weight': ...}
and the mapping is {'.atn': '.attn'}, the resulting state_dict will be
{'model.encoder.layer.0.attn.self.query.weight': ...}.
Since this effectively works as a mass substring replacement, partial key
matches (e.g. in the middle of one layer name) will also work, so be
careful to avoid false positives.
Parameters
----------
state_dict : dict
The old state dictionary to be mapped.
mapping : dict
A dictionary specifying the mapping between old and new keys.
Returns
-------
dict
The modified state dictionary with mapped keys.
"""
for replacement_old, replacement_new in mapping.items():
for old_key in list(state_dict.keys()):
if replacement_old in old_key:
new_key = old_key.replace(replacement_old, replacement_new)
state_dict[new_key] = state_dict.pop(old_key)
logger.info(
"Due to replacement compatibility rule '%s'->'%s', renamed "
"`state_dict['%s']`->`state_dict['%s']`",
replacement_old,
replacement_new,
old_key,
new_key,
)
return state_dict


def hook_on_loading_state_dict_checkpoint(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Hook to be called when loading a state_dict checkpoint.
This hook is called when loading a state_dict checkpoint. It can be used
to modify the state_dict before it is loaded into the model.
By default, this hook will map the old state_dict keys to the new ones.
Arguments
---------
state_dict : dict
The state_dict to be loaded.
Returns
-------
dict
The modified state_dict.
"""
altered_state_dict = map_old_state_dict_weights(state_dict, KEYS_MAPPING)
return altered_state_dict


def torch_recovery(obj, path, end_of_epoch):
Expand All @@ -94,10 +167,33 @@ def torch_recovery(obj, path, end_of_epoch):
"""
del end_of_epoch # Unused
device = "cpu"

state_dict = torch_patched_state_dict_load(path, device)
try:
obj.load_state_dict(torch.load(path, map_location=device), strict=True)
obj.load_state_dict(state_dict, strict=True)
except TypeError:
obj.load_state_dict(torch.load(path, map_location=device))
obj.load_state_dict(state_dict)


def torch_patched_state_dict_load(path, device="cpu"):
"""Loads a `state_dict` from the given path using :func:`torch.load` and
calls the SpeechBrain `state_dict` loading hooks, e.g. to apply key name
patching rules for compatibility.
The `state_dict` sees no further preprocessing and is not applied into a
model, see :func:`~torch_recovery` or :func:`~torch_parameter_transfer`.
Arguments
---------
path : str, pathlib.Path
Path where to load from.
device
Device where the loaded `state_dict` tensors should reside. This is
forwarded to :func:`torch.load`; see its documentation for details.
"""
state_dict = torch.load(path, map_location=device)
state_dict = hook_on_loading_state_dict_checkpoint(state_dict)
return state_dict


@main_process_only
Expand Down Expand Up @@ -136,9 +232,8 @@ def torch_parameter_transfer(obj, path):
Path where to load from.
"""
device = "cpu"
incompatible_keys = obj.load_state_dict(
torch.load(path, map_location=device), strict=False
)
state_dict = torch_patched_state_dict_load(path, device)
incompatible_keys = obj.load_state_dict(state_dict, strict=False)
for missing_key in incompatible_keys.missing_keys:
logger.warning(
f"During parameter transfer to {obj} loading from "
Expand Down Expand Up @@ -1247,4 +1342,10 @@ def average_checkpoints(
parameter_loader(ckpt.paramfiles[recoverable_name], map_location=device)
for ckpt in checkpoint_list
)
return averager(parameter_iterator)
parameter_iterator = (
hook_on_loading_state_dict_checkpoint(state_dict)
for state_dict in parameter_iterator
)

avg_ckpt = averager(parameter_iterator)
return avg_ckpt

0 comments on commit e627b42

Please sign in to comment.