Skip to content

Commit

Permalink
git push -fCustom checkpoint loading hook
Browse files Browse the repository at this point in the history
  • Loading branch information
helleuch committed Mar 4, 2024
1 parent c0e71f4 commit 1fbd69d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
7 changes: 3 additions & 4 deletions speechbrain/decoders/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
* Mirco Ravanelli 2020
* Sung-Lin Yeh 2020
"""

import torch
from speechbrain.decoders.utils import (
inflate_tensor,
Expand All @@ -18,7 +17,7 @@


class AlivedHypotheses(torch.nn.Module):
"""This class handle the data for the hypotheses during the decoding.
""" This class handle the data for the hypotheses during the decoding.
Arguments
---------
Expand Down Expand Up @@ -721,7 +720,7 @@ def _max_attn_shift_permute_memory_step(self, prev_attn_peak, predecessors):
return prev_attn_peak

def _update_reset_memory(self, enc_states, enc_lens):
"""Call reset memory for each module.
""" Call reset memory for each module.
Arguments
---------
Expand Down Expand Up @@ -1813,7 +1812,7 @@ def set_n_out(self):
def forward_step(self, inp_tokens, memory, enc_states, enc_lens):
"""Performs a step in the implemented beamsearcher."""
memory = _update_mem(inp_tokens, memory)
(dec_out, attn,) = self.model.forward_decoder(enc_states, memory)
dec_out, attn, = self.model.forward_decoder(enc_states, memory)
log_probs = self.softmax(dec_out[:, -1] / self.temperature)
return log_probs, memory, attn

Expand Down
39 changes: 36 additions & 3 deletions speechbrain/lobes/models/huggingface_transformers/speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Haroun Elleuch 2024
"""

import pathlib
import torch
import logging

Expand All @@ -16,12 +17,17 @@
)

from transformers import SpeechT5ForSpeechToText, SpeechT5Config
from speechbrain.utils.checkpoints import (
mark_as_loader,
register_checkpoint_hooks,
)
from speechbrain.utils.fetching import fetch


logger = logging.getLogger(__name__)


@register_checkpoint_hooks
class SpeechT5ForASR(HFTransformersInterface):
"""This lobe enables the integration of HuggingFace and SpeechBrain
pretrained SpeechT5 models for Automatic Speech Recognition.
Expand Down Expand Up @@ -286,16 +292,19 @@ def _from_pretrained(self, source: str, save_path: str, cache_dir: str):
is_sb, ckpt_file, _ = self._check_model_source(source, save_path)

if is_sb or self.for_pretraining:
self.model = SpeechT5ForSpeechToText.from_config(self.config)
self.model = SpeechT5ForSpeechToText._from_config(self.config)

if is_sb:
self.model.gradient_checkpointing_disable() # Required by DDP
# fetch the checkpoint file
ckpt_full_path = fetch(
filename=ckpt_file, source=source, savedir=save_path,
filename=ckpt_file,
source=source,
savedir=save_path,
huggingface_cache_dir=cache_dir,
)
# We transfer the parameters from the checkpoint.
self._load_sb_pretrained_parameters(ckpt_full_path)
self._load_sb_pretrained_parameters(path=ckpt_full_path,)
elif not self.for_pretraining:
self.model = SpeechT5ForSpeechToText.from_pretrained(
source,
Expand All @@ -305,6 +314,30 @@ def _from_pretrained(self, source: str, save_path: str, cache_dir: str):
ignore_mismatched_sizes=True,
)

@mark_as_loader
def _on_load_checkpoint(
self, path: pathlib.Path | str, end_of_epoch: bool
) -> None:
loaded_state_dict = torch.load(path)
model_state_dict = self.state_dict()
is_changed = False
for k in loaded_state_dict:
if k in model_state_dict:
if loaded_state_dict[k].shape != model_state_dict[k].shape:
logger.warning(
f"Skip loading parameter: {k}, "
f"required shape: {model_state_dict[k].shape}, "
f"loaded shape: {loaded_state_dict[k].shape}"
)
loaded_state_dict[k] = model_state_dict[k]
is_changed = True
else:
logger.warning(f"Dropping parameter {k}")
is_changed = True

if is_changed:
loaded_state_dict.pop("optimizer_states", None)


def custom_padding(x, org_pad, custom_pad):
"""
Expand Down

0 comments on commit 1fbd69d

Please sign in to comment.