Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix keys name for Transformer #2529

Merged
merged 8 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions speechbrain/lobes/models/transformer/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.attention import RelPosEncXL
from speechbrain.nnet.CNN import Conv1d
from speechbrain.utils.checkpoints import map_old_state_dict_weights

from .Branchformer import BranchformerEncoder
from .Conformer import ConformerEncoder
Expand Down Expand Up @@ -779,6 +780,12 @@ def forward(

return tgt, self_attn, multihead_attention

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Load the model from a state_dict and map the old keys to the new keys."""
mapping = {"mutihead_attention": "multihead_attention"}
state_dict = map_old_state_dict_weights(state_dict, mapping)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class TransformerDecoder(nn.Module):
"""This class implements the Transformer decoder.
Expand Down
113 changes: 107 additions & 6 deletions speechbrain/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

Authors
* Aku Rouhe 2020
* Adel Moumen 2024
"""

import collections
Expand All @@ -57,6 +58,7 @@
import shutil
import time
import warnings
from typing import Dict

import torch
import yaml
Expand All @@ -75,6 +77,77 @@
CKPT_PREFIX = "CKPT"
METAFNAME = f"{CKPT_PREFIX}.yaml" # Important that this is not .ckpt
PARAMFILE_EXT = ".ckpt" # ...because these files will be
# some keys have been renamed in the new version of the code
KEYS_MAPPING: Dict[str, str] = {
".mutihead_attn": ".multihead_attn", # see PR #2489
}


def map_old_state_dict_weights(
state_dict: Dict[str, torch.Tensor], mapping: Dict[str, str]
) -> Dict[str, torch.Tensor]:
"""
Maps the keys in the old state dictionary according to the provided mapping.

NOTE: This function will remap all state_dict keys that contain the old key.
For instance, if the state_dict is {'model.encoder.layer.0.atn.self.query.weight': ...}
and the mapping is {'.atn': '.attn'}, the resulting state_dict will be
{'model.encoder.layer.0.attn.self.query.weight': ...}.

Since this effectively works as a mass substring replacement, partial key
matches (e.g. in the middle of one layer name) will also work, so be
careful to avoid false positives.

Parameters
----------
state_dict : dict
The old state dictionary to be mapped.
mapping : dict
A dictionary specifying the mapping between old and new keys.
asumagic marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
dict
The modified state dictionary with mapped keys.
"""
for replacement_old, replacement_new in mapping.items():
for old_key in list(state_dict.keys()):
if replacement_old in old_key:
new_key = old_key.replace(replacement_old, replacement_new)
state_dict[new_key] = state_dict.pop(old_key)
logger.info(
"Due to replacement compatibility rule '%s'->'%s', renamed "
"`state_dict['%s']`->`state_dict['%s']`",
replacement_old,
replacement_new,
old_key,
new_key,
)
return state_dict


def hook_on_loading_state_dict_checkpoint(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Hook to be called when loading a state_dict checkpoint.

This hook is called when loading a state_dict checkpoint. It can be used
to modify the state_dict before it is loaded into the model.

By default, this hook will map the old state_dict keys to the new ones.

Arguments
---------
state_dict : dict
The state_dict to be loaded.

Returns
-------
dict
The modified state_dict.
"""
altered_state_dict = map_old_state_dict_weights(state_dict, KEYS_MAPPING)
return altered_state_dict


def torch_recovery(obj, path, end_of_epoch):
Expand All @@ -94,10 +167,33 @@ def torch_recovery(obj, path, end_of_epoch):
"""
del end_of_epoch # Unused
device = "cpu"

state_dict = torch_patched_state_dict_load(path, device)
try:
obj.load_state_dict(torch.load(path, map_location=device), strict=True)
obj.load_state_dict(state_dict, strict=True)
except TypeError:
obj.load_state_dict(torch.load(path, map_location=device))
obj.load_state_dict(state_dict)


def torch_patched_state_dict_load(path, device="cpu"):
"""Loads a `state_dict` from the given path using :func:`torch.load` and
calls the SpeechBrain `state_dict` loading hooks, e.g. to apply key name
patching rules for compatibility.

The `state_dict` sees no further preprocessing and is not applied into a
model, see :func:`~torch_recovery` or :func:`~torch_parameter_transfer`.

Arguments
---------
path : str, pathlib.Path
Path where to load from.
device
Device where the loaded `state_dict` tensors should reside. This is
forwarded to :func:`torch.load`; see its documentation for details.
"""
state_dict = torch.load(path, map_location=device)
state_dict = hook_on_loading_state_dict_checkpoint(state_dict)
return state_dict


@main_process_only
Expand Down Expand Up @@ -136,9 +232,8 @@ def torch_parameter_transfer(obj, path):
Path where to load from.
"""
device = "cpu"
incompatible_keys = obj.load_state_dict(
torch.load(path, map_location=device), strict=False
)
state_dict = torch_patched_state_dict_load(path, device)
incompatible_keys = obj.load_state_dict(state_dict, strict=False)
for missing_key in incompatible_keys.missing_keys:
logger.warning(
f"During parameter transfer to {obj} loading from "
Expand Down Expand Up @@ -1247,4 +1342,10 @@ def average_checkpoints(
parameter_loader(ckpt.paramfiles[recoverable_name], map_location=device)
for ckpt in checkpoint_list
)
return averager(parameter_iterator)
parameter_iterator = (
hook_on_loading_state_dict_checkpoint(state_dict)
for state_dict in parameter_iterator
)

avg_ckpt = averager(parameter_iterator)
return avg_ckpt