In [None]:
import sys
import torch
import speechbrain as sb
from torch.utils.data import DataLoader
from hyperpyyaml import load_hyperpyyaml

In [None]:
hparams_file = 'hparams/5k_conformer_medium_infer_no_lm.yaml'

In [None]:
with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin)

In [None]:
class ASR(sb.core.Brain):
    def compute_forward(self, batch):
        """Forward computations from the waveform batches
        to the output probabilities."""
        
        batch = batch.to(self.device)
        wavs = batch
        wav_lens = torch.tensor(1., device='cuda')
        # wavs, wav_lens = batch.sig
        # tokens_bos, _ = batch.tokens_bos
        print(f'wav_lens ----- : {wav_lens}')

        # compute features
        print(f'wavs ---- : {wavs}')
        feats = self.hparams.compute_features(wavs)
        print(f'feats size ---- : {feats.size()}')
        print(f'feats ---- : {feats}')
        current_epoch = self.hparams.epoch_counter.current
        print(f'current_epoch ----- : {current_epoch}')
        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)

        print(f'feats ----- : {feats}')

        # forward modules
        # src = self.modules.CNN(feats)
        # enc_out, pred = self.modules.Transformer( # pred : decoder out
        #     src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index
        # )

        # hyps = None
        # hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens) # Valid
        # hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens) # Test
        # return hyps

    def on_evaluate_start(self, max_key=None, min_key=None):
        """perform checkpoint averge if needed"""
        super().on_evaluate_start()

        print(f'self.checkpointer checkpoints_dir ----- : {self.checkpointer.checkpoints_dir}')
        # print(f'self.checkpointer ----- : {dir(self.checkpointer.checkpoints_dir)}')
        ckpts = self.checkpointer.find_checkpoints(
            max_key=max_key, min_key=min_key
        )
        print(f'ckpts ----- : {ckpts}')
        ckpt = sb.utils.checkpoints.average_checkpoints(
            ckpts, recoverable_name="model", device=self.device
        )

        self.hparams.model.load_state_dict(ckpt, strict=True)
        self.hparams.model.eval()


    ### for inferrence
    def transcribe_file(
            self,
            data_file,
            max_key, # We load the model with the lowest WER
        ):
        
        sig = sb.dataio.dataio.read_audio(data_file)
        print(f'sig ----- : {sig}')

        self.on_evaluate_start(max_key=max_key) # We call the on_evaluate_start that will load the best model
        # self.modules.eval() # We set the model to eval mode (remove dropout etc)

        # Now we iterate over the dataset and we simply compute_forward and decode
        with torch.no_grad():

            transcripts = []
            # for batch in tqdm(testdata, dynamic_ncols=True):
            batch = sig.unsqueeze(dim=0)
            out = self.compute_forward(batch)
            predicted_tokens = out

                # We go from tokens to words.
            tokenizer = hparams["tokenizer"]
            predicted_words = [
                tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in predicted_tokens
            ]
                
            print(f'label : {batch.wrd}')
            print(f'hyp ----- : {predicted_words}')

In [None]:
asr_brain = ASR(
    modules=hparams["modules"],
    opt_class=hparams["Adam"],
    hparams=hparams,
    checkpointer=hparams["checkpointer"],
)

# adding objects to trainer:
# asr_brain.tokenizer = hparams["tokenizer"]

In [None]:
audio_file = '/data/KsponSpeech/eval_clean_wav/KsponSpeech_E02998.wav'

asr_brain.transcribe_file(
    audio_file, # Must be obtained from the dataio_function
    max_key="ACC", # We load the model with the lowest WER
    # loader_kwargs=hparams["test_dataloader_opts"], # opts for the dataloading
)

In [None]:
!ls ckpt