In [1]:
import sys
import torch
import speechbrain as sb
from torch.utils.data import DataLoader
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.parameter_transfer import Pretrainer

In [2]:
hparams_file = 'hparams/conformer_medium.yaml'

In [3]:
with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin)

In [4]:
# model = hparams['model']

In [5]:
# Define training procedure
class ASR(sb.core.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches
        to the output probabilities."""
        # print(f'compute_forward ----- 1')
        # print(f'type of batch : {batch}')
        batch = batch.to(self.device)
        # print(f'compute_forward ----- 2')
        wavs, wav_lens = batch.sig
        # print(f'wavs, wav_lens : {wavs}, {wav_lens}')
        # print(f'compute_forward ----- 3')
        # print(f'wavs : {wavs}')
        # print(f'wav_lens : {wav_lens}')
        tokens_bos, _ = batch.tokens_bos

        # Add augmentation if specified
        ### kdialectspeech, ksponspeech, librispeech 에서는 사용안함, template 예제에서 사용
        if stage == sb.Stage.TRAIN:
            if hasattr(self.modules, "env_corrupt"):
                wavs_noise = self.modules.env_corrupt(wavs, wav_lens)
                wavs = torch.cat([wavs, wavs_noise], dim=0)
                wav_lens = torch.cat([wav_lens, wav_lens])
                tokens_bos = torch.cat([tokens_bos, tokens_bos], dim=0)

        # compute features
        feats = self.hparams.compute_features(wavs)
        current_epoch = self.hparams.epoch_counter.current
        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)

        if stage == sb.Stage.TRAIN:
            if hasattr(self.hparams, "augmentation"):
                feats = self.hparams.augmentation(feats)

        # forward modules
        src = self.modules.CNN(feats)
        # print(f'tokens_bos : {tokens_bos}')
        # print(f'pad_idx : {self.hparams.pad_index}')
        enc_out, pred = self.modules.Transformer( # pred : decoder out
            src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index
        )

        # output layer for ctc log-probabilities
        logits = self.modules.ctc_lin(enc_out)
        p_ctc = self.hparams.log_softmax(logits)

        # output layer for seq2seq log-probabilities
        pred = self.modules.seq_lin(pred)
        p_seq = self.hparams.log_softmax(pred)

        # print(f'enc_out size : {enc_out.size()}')
        # Compute outputs
        hyps = None
        if stage == sb.Stage.TRAIN:
            hyps = None
        elif stage == sb.Stage.VALID:
            hyps = None
            current_epoch = self.hparams.epoch_counter.current
            if current_epoch % self.hparams.valid_search_interval == 0:
                # for the sake of efficiency, we only perform beamsearch with
                # limited capacity and no LM to give user some idea of
                # how the AM is doing
                ####
                #### 시간이 많이 걸리는 부분 : 아래 valid_search
                ####
                # print(f' valid enc_out size : {enc_out.size()}')
                # print(f' valid wav_lens : {wav_lens}')
                hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
                # print(f' valid hyps : {hyps}')
        elif stage == sb.Stage.TEST:
            # print(f'compute_forward ----- 4')
            # print(f' test enc_out size : {enc_out.size()}')
            hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens) # test_search와 valid_search의 차이는 LM 사용 여부
            # print(f' test hyps : {hyps}')
            # print(f'compute_forward ----- 5')
        # print(f'compute_forward ------------------------------------')
        # print(f'compute_forward p_ctc ----- : {p_ctc}')
        # print(f'compute_forward p_seq ----- : {p_seq}')
        # print(f'compute_forward wav_lens ----- : {wav_lens}')
        # print(f'compute_forward hyps ----- : {hyps}')
        return p_ctc, p_seq, wav_lens, hyps

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC+NLL) given predictions and targets."""
        
        # print(f'compute_objectives ----- 1')
        (p_ctc, p_seq, wav_lens, hyps,) = predictions

        ids = batch.id
        # print(f'compute_objectives ids : {ids}')
        tokens_eos, tokens_eos_lens = batch.tokens_eos
        tokens, tokens_lens = batch.tokens
        
        # logger.info(f'compute_objectives tokens.size ----- : {tokens.size()}') # npark

        if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
            tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
            tokens_eos_lens = torch.cat(
                [tokens_eos_lens, tokens_eos_lens], dim=0
            )
            tokens = torch.cat([tokens, tokens], dim=0)
            tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)


        # print(f' compute_objectives tokens_eos : {tokens_eos}')
        # print(f' compute_objectives p_seq : {p_seq}')
        loss_seq = self.hparams.seq_cost(
            p_seq, tokens_eos, length=tokens_eos_lens
        )
        loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
        loss = (
            self.hparams.ctc_weight * loss_ctc
            + (1 - self.hparams.ctc_weight) * loss_seq
        )
        if stage != sb.Stage.TRAIN:
            # print(f'compute_objectives stage is not train -------------')
            current_epoch = self.hparams.epoch_counter.current
            valid_search_interval = self.hparams.valid_search_interval
            if current_epoch % valid_search_interval == 0 or (
                stage == sb.Stage.TEST
            ):
                # Decode token terms to words
                predicted_words = [
                    tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps
                ]
                target_words = [wrd.split(" ") for wrd in batch.wrd]

                ### predicted_swords = get_swords(hyps, wrd) -> space normalized words

                predicted_chars = [
                    list("".join(utt_seq)) for utt_seq in predicted_words
                ]
                target_chars = [list("".join(wrd.split())) for wrd in batch.wrd]
                self.wer_metric.append(ids, predicted_words, target_words)
                # self.swer_metric.append(ids, predicted_swords, target_words)
                self.cer_metric.append(ids, predicted_chars, target_chars)

            # compute the accuracy of the one-step-forward prediction
            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
            
        # logger.info(f'compute_objectives loss ----- : {loss}') # npark
        return loss

    def fit_batch(self, batch):
        """Train the parameters given a single batch in input"""
        # check if we need to switch optimizer
        # if so change the optimizer from Adam to SGD
        
        # print(f'train length of batch : {len(batch)}')
        self.check_and_reset_optimizer()

        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)

        # normalize the loss by gradient_accumulation step
        (loss / self.hparams.gradient_accumulation).backward()

        if self.step % self.hparams.gradient_accumulation == 0:
            # gradient clipping & early stop if loss is not fini
            self.check_gradients(loss)

            self.optimizer.step()
            self.optimizer.zero_grad()

            # anneal lr every update
            self.hparams.noam_annealing(self.optimizer)

            if isinstance(
                self.hparams.train_logger,
                sb.utils.train_logger.TensorboardLogger,
            ):
                self.hparams.train_logger.log_stats(
                    stats_meta={"step": self.step}, train_stats={"loss": loss},
                )

        return loss.detach()

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        # print(f'stage : {stage}')
        # print(f'length of batch : {len(batch)}')
        # print(f'batch.id type -------- : {type(batch.id)}')
        # print(f'batch.id -------- : {batch.id}')
        # print(f'batch sig type : {type(batch.sig)}')
        # print(f'batch sig : {batch.sig[0]}')
        # print(f'batch sig size : {batch.sig[0].size()}')
        
        
        # for k, v in batch.sig:
        #     print(k)
        #     print(v)
        
        
        with torch.no_grad():
            # print('########## compute_forward #########')
            predictions = self.compute_forward(batch, stage=stage)
            # print(f'########## compute_objectives ######### stage : {stage}')
            loss = self.compute_objectives(predictions, batch, stage=stage)
            # print('########## eval end #########')
        return loss.detach()

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            self.acc_metric = self.hparams.acc_computer()
            self.wer_metric = self.hparams.error_rate_computer()
            # self.swer_metric = self.hparams.error_rate_computer()
            self.cer_metric = self.hparams.error_rate_computer()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {"loss": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stage_stats["ACC"] = self.acc_metric.summarize()
            current_epoch = self.hparams.epoch_counter.current
            valid_search_interval = self.hparams.valid_search_interval
            if (
                current_epoch % valid_search_interval == 0
                or stage == sb.Stage.TEST
            ):
                stage_stats["WER"] = self.wer_metric.summarize("error_rate")
                # stage_stats["sWER"] = self.swer_metric.summarize("error_rate")
                stage_stats["CER"] = self.cer_metric.summarize("error_rate")

        # log stats and save checkpoint at end-of-epoch
        if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():

            # report different epoch stages according current stage
            current_epoch = self.hparams.epoch_counter.current
            if current_epoch <= self.hparams.stage_one_epochs:
                lr = self.hparams.noam_annealing.current_lr
                steps = self.hparams.noam_annealing.n_steps
            else:
                lr = self.hparams.lr_sgd
                steps = -1

            epoch_stats = {"epoch": epoch, "lr": lr, "steps": steps}
            self.hparams.train_logger.log_stats(
                stats_meta=epoch_stats,
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                max_keys=["ACC"],
                num_to_keep=5,
            )

        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
            with open(self.hparams.wer_file, "w") as w:
                # self.swer_metric.write_stats(w)
                self.wer_metric.write_stats(w)
                self.cer_metric.write_stats(w)

            # save the averaged checkpoint at the end of the evaluation stage
            # delete the rest of the intermediate checkpoints
            # ACC is set to 1.1 so checkpointer
            # only keeps the averaged checkpoint
            self.checkpointer.save_and_keep_only(
                meta={"ACC": 1.1, "epoch": epoch},
                max_keys=["ACC"],
                num_to_keep=1,
            )

    def check_and_reset_optimizer(self):
        """reset the optimizer if training enters stage 2"""
        current_epoch = self.hparams.epoch_counter.current
        if not hasattr(self, "switched"):
            self.switched = False
            if isinstance(self.optimizer, torch.optim.SGD):
                self.switched = True

        if self.switched is True:
            return

        if current_epoch > self.hparams.stage_one_epochs:
            self.optimizer = self.hparams.SGD(self.modules.parameters())

            if self.checkpointer is not None:
                self.checkpointer.add_recoverable("optimizer", self.optimizer)

            self.switched = True

    def on_fit_start(self):
        """Initialize the right optimizer on the training start"""
        super().on_fit_start()

        # if the model is resumed from stage two, reinitialize the optimizer
        current_epoch = self.hparams.epoch_counter.current
        current_optimizer = self.optimizer
        if current_epoch > self.hparams.stage_one_epochs:
            del self.optimizer
            self.optimizer = self.hparams.SGD(self.modules.parameters())

            # Load latest checkpoint to resume training if interrupted
            if self.checkpointer is not None:

                # do not reload the weights if training is interrupted
                # right before stage 2
                group = current_optimizer.param_groups[0]
                if "momentum" not in group:
                    return

                self.checkpointer.recover_if_possible(
                    device=torch.device(self.device)
                )

    def on_evaluate_start(self, max_key=None, min_key=None):
        """perform checkpoint averge if needed"""
        super().on_evaluate_start()

        ckpts = self.checkpointer.find_checkpoints(
            max_key=max_key, min_key=min_key
        )
        ckpt = sb.utils.checkpoints.average_checkpoints(
            ckpts, recoverable_name="model", device=self.device
        )

        self.hparams.model.load_state_dict(ckpt, strict=True)
        self.hparams.model.eval()

In [6]:
# Trainer initialization
asr_brain = ASR(
    modules=hparams["modules"],
    opt_class=hparams["Adam"],
    hparams=hparams,
    # run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

In [7]:
asr_brain.tokenizer = hparams["tokenizer"]

In [8]:
pretrained_model = '../Inference/pretrained-model-src/kspon/asr.ckpt'
pretrain = Pretrainer(collect_in='model_local', loadables={'model': asr_brain}, paths={'model': pretrained_model})
# pretrain = Pretrainer(collect_in='model_local', loadables={'model': model}, paths={'model': pretrained_model})

In [9]:
pretrain.collect_files()

{'model': PosixPath('model_local/model.ckpt')}

In [10]:
pretrain.load_collected()

load_colleted name : model
load_colleted PARAMFILE_EXT : .ckpt
load_colleted paramfiles[name] : model_local/model.ckpt
Pretrainer _call_load_hooks name : model
DEFAULT_TRANSFER_HOOKS ----- : {<class 'torch.nn.modules.module.Module'>: <function torch_parameter_transfer at 0x7f26850f9940>, <class 'sentencepiece.SentencePieceProcessor'>: <function _load_spm at 0x7f26850f99d0>, <class 'speechbrain.processing.features.InputNormalization'>: <function InputNormalization._load at 0x7f264996f670>}
DEFAULT_LOAD_HOOKS ----- : {<class 'torch.nn.modules.module.Module'>: <function torch_recovery at 0x7f26850f9820>, <class 'torch.optim.optimizer.Optimizer'>: <function torch_recovery at 0x7f26850f9820>, <class 'torch.optim.lr_scheduler._LRScheduler'>: <function torch_recovery at 0x7f26850f9820>, <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>: <function torch_recovery at 0x7f26850f9820>, <class 'torch.cuda.amp.grad_scaler.GradScaler'>: <function torch_recovery at 0x7f26850f9820>, <class 'torch.op

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte