Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix keys name for Transformer #2529

Merged
merged 8 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions speechbrain/lobes/models/transformer/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from speechbrain.nnet.activations import Swish
from speechbrain.nnet.attention import RelPosEncXL
from speechbrain.nnet.CNN import Conv1d
from speechbrain.utils.checkpoints import map_old_state_dict_weights

from .Branchformer import BranchformerEncoder
from .Conformer import ConformerEncoder
Expand Down Expand Up @@ -779,6 +780,12 @@ def forward(

return tgt, self_attn, multihead_attention

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
"""Load the model from a state_dict and map the old keys to the new keys."""
mapping = {"mutihead_attention": "multihead_attention"}
state_dict = map_old_state_dict_weights(state_dict, mapping)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


class TransformerDecoder(nn.Module):
"""This class implements the Transformer decoder.
Expand Down
52 changes: 49 additions & 3 deletions speechbrain/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

Authors
* Aku Rouhe 2020
* Adel Moumen 2024
"""

import collections
Expand All @@ -57,6 +58,7 @@
import shutil
import time
import warnings
from typing import Dict

import torch
import yaml
Expand All @@ -75,6 +77,41 @@
CKPT_PREFIX = "CKPT"
METAFNAME = f"{CKPT_PREFIX}.yaml" # Important that this is not .ckpt
PARAMFILE_EXT = ".ckpt" # ...because these files will be
# some keys have been renamed in the new version of the code
KEYS_MAPPING: Dict[str, str] = {
"mutihead_attn": "multihead_attn", # see PR #2489
}


def map_old_state_dict_weights(
state_dict: Dict[str, torch.Tensor], mapping: Dict[str, str]
) -> Dict[str, torch.Tensor]:
"""
Maps the keys in the old state dictionary according to the provided mapping,
under the given prefix.

Parameters
----------
state_dict : dict
The old state dictionary to be mapped.
mapping : dict
A dictionary specifying the mapping between old and new keys.
asumagic marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
dict
The modified state dictionary with mapped keys.
"""
for checkpoint_name, attribute_name in mapping.items():
for full_checkpoint_name in list(state_dict.keys()):
if checkpoint_name in full_checkpoint_name:
full_attribute_name = full_checkpoint_name.replace(
asumagic marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_name, attribute_name
)
state_dict[full_attribute_name] = state_dict.pop(
full_checkpoint_name
)
return state_dict


def torch_recovery(obj, path, end_of_epoch):
Expand All @@ -94,10 +131,13 @@ def torch_recovery(obj, path, end_of_epoch):
"""
del end_of_epoch # Unused
device = "cpu"

state_dict = torch.load(path, map_location=device)
state_dict = map_old_state_dict_weights(state_dict, KEYS_MAPPING)
asumagic marked this conversation as resolved.
Show resolved Hide resolved
try:
obj.load_state_dict(torch.load(path, map_location=device), strict=True)
obj.load_state_dict(state_dict, strict=True)
except TypeError:
obj.load_state_dict(torch.load(path, map_location=device))
obj.load_state_dict(state_dict)


@main_process_only
Expand Down Expand Up @@ -1247,4 +1287,10 @@ def average_checkpoints(
parameter_loader(ckpt.paramfiles[recoverable_name], map_location=device)
for ckpt in checkpoint_list
)
return averager(parameter_iterator)
parameter_iterator = (
map_old_state_dict_weights(state_dict, KEYS_MAPPING)
for state_dict in parameter_iterator
)

avg_ckpt = averager(parameter_iterator)
return avg_ckpt
Loading