From 1fbd69d0f0e56fc8efe1143e9d84168aeb24a540 Mon Sep 17 00:00:00 2001 From: Haroun Elleuch Date: Mon, 4 Mar 2024 14:45:46 +0100 Subject: [PATCH] git push -fCustom checkpoint loading hook --- speechbrain/decoders/seq2seq.py | 7 ++-- .../huggingface_transformers/speecht5.py | 39 +++++++++++++++++-- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/speechbrain/decoders/seq2seq.py b/speechbrain/decoders/seq2seq.py index f60f43a9b1..54ea94d644 100644 --- a/speechbrain/decoders/seq2seq.py +++ b/speechbrain/decoders/seq2seq.py @@ -7,7 +7,6 @@ * Mirco Ravanelli 2020 * Sung-Lin Yeh 2020 """ - import torch from speechbrain.decoders.utils import ( inflate_tensor, @@ -18,7 +17,7 @@ class AlivedHypotheses(torch.nn.Module): - """This class handle the data for the hypotheses during the decoding. + """ This class handle the data for the hypotheses during the decoding. Arguments --------- @@ -721,7 +720,7 @@ def _max_attn_shift_permute_memory_step(self, prev_attn_peak, predecessors): return prev_attn_peak def _update_reset_memory(self, enc_states, enc_lens): - """Call reset memory for each module. + """ Call reset memory for each module. Arguments --------- @@ -1813,7 +1812,7 @@ def set_n_out(self): def forward_step(self, inp_tokens, memory, enc_states, enc_lens): """Performs a step in the implemented beamsearcher.""" memory = _update_mem(inp_tokens, memory) - (dec_out, attn,) = self.model.forward_decoder(enc_states, memory) + dec_out, attn, = self.model.forward_decoder(enc_states, memory) log_probs = self.softmax(dec_out[:, -1] / self.temperature) return log_probs, memory, attn diff --git a/speechbrain/lobes/models/huggingface_transformers/speecht5.py b/speechbrain/lobes/models/huggingface_transformers/speecht5.py index 5739b570d0..0bec7ad09b 100644 --- a/speechbrain/lobes/models/huggingface_transformers/speecht5.py +++ b/speechbrain/lobes/models/huggingface_transformers/speecht5.py @@ -8,6 +8,7 @@ * Haroun Elleuch 2024 """ +import pathlib import torch import logging @@ -16,12 +17,17 @@ ) from transformers import SpeechT5ForSpeechToText, SpeechT5Config +from speechbrain.utils.checkpoints import ( + mark_as_loader, + register_checkpoint_hooks, +) from speechbrain.utils.fetching import fetch logger = logging.getLogger(__name__) +@register_checkpoint_hooks class SpeechT5ForASR(HFTransformersInterface): """This lobe enables the integration of HuggingFace and SpeechBrain pretrained SpeechT5 models for Automatic Speech Recognition. @@ -286,16 +292,19 @@ def _from_pretrained(self, source: str, save_path: str, cache_dir: str): is_sb, ckpt_file, _ = self._check_model_source(source, save_path) if is_sb or self.for_pretraining: - self.model = SpeechT5ForSpeechToText.from_config(self.config) + self.model = SpeechT5ForSpeechToText._from_config(self.config) if is_sb: self.model.gradient_checkpointing_disable() # Required by DDP # fetch the checkpoint file ckpt_full_path = fetch( - filename=ckpt_file, source=source, savedir=save_path, + filename=ckpt_file, + source=source, + savedir=save_path, + huggingface_cache_dir=cache_dir, ) # We transfer the parameters from the checkpoint. - self._load_sb_pretrained_parameters(ckpt_full_path) + self._load_sb_pretrained_parameters(path=ckpt_full_path,) elif not self.for_pretraining: self.model = SpeechT5ForSpeechToText.from_pretrained( source, @@ -305,6 +314,30 @@ def _from_pretrained(self, source: str, save_path: str, cache_dir: str): ignore_mismatched_sizes=True, ) + @mark_as_loader + def _on_load_checkpoint( + self, path: pathlib.Path | str, end_of_epoch: bool + ) -> None: + loaded_state_dict = torch.load(path) + model_state_dict = self.state_dict() + is_changed = False + for k in loaded_state_dict: + if k in model_state_dict: + if loaded_state_dict[k].shape != model_state_dict[k].shape: + logger.warning( + f"Skip loading parameter: {k}, " + f"required shape: {model_state_dict[k].shape}, " + f"loaded shape: {loaded_state_dict[k].shape}" + ) + loaded_state_dict[k] = model_state_dict[k] + is_changed = True + else: + logger.warning(f"Dropping parameter {k}") + is_changed = True + + if is_changed: + loaded_state_dict.pop("optimizer_states", None) + def custom_padding(x, org_pad, custom_pad): """