In [1]:
import sys
import torch
import speechbrain as sb
from torch.utils.data import DataLoader
from hyperpyyaml import load_hyperpyyaml

In [2]:
hparams_file = 'hparams/5k_conformer_medium_infer_no_lm.yaml'

In [3]:
with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin)

In [9]:
class ASR(sb.core.Brain):
    def compute_forward(self, batch):
        """Forward computations from the waveform batches
        to the output probabilities."""
        
        batch = batch.to(self.device)
        wavs = batch
        wav_lens = torch.tensor(1., device='cuda')

        # compute features
        feats = self.hparams.compute_features(wavs)
        print(f'feats size ---- : {feats.size()}')
        print(f'feats ---- : {feats}')
        current_epoch = self.hparams.epoch_counter.current
        print(f'current_epoch ----- : {current_epoch}')
        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)
        print(f'feats ----- : {feats}')

    def on_evaluate_start(self, max_key=None, min_key=None):
        """perform checkpoint averge if needed"""
        super().on_evaluate_start()

        print(f'self.checkpointer checkpoints_dir ----- : {self.checkpointer.checkpoints_dir}')
        # print(f'self.checkpointer ----- : {dir(self.checkpointer.checkpoints_dir)}')
        ckpts = self.checkpointer.find_checkpoints(
            max_key=max_key, min_key=min_key
        )
        print(f'ckpts ----- : {ckpts}')
        # print(f'self.hparams.model.load_state_dict ----- : {self.hparams.model.load_state_dict}')
        ckpt = sb.utils.checkpoints.average_checkpoints(
            ckpts, recoverable_name="model", device=self.device
        )

        self.hparams.model.load_state_dict(ckpt, strict=True)
        self.hparams.model.eval()

    ### for inferrence
    def transcribe_file(
            self,
            data_file,
            max_key, # We load the model with the lowest WER
        ):
        
        sig = sb.dataio.dataio.read_audio(data_file)
        print(f'sig ----- : {sig}')

        self.on_evaluate_start(max_key=max_key) # We call the on_evaluate_start that will load the best model
        # self.modules.eval() # We set the model to eval mode (remove dropout etc)

        # Now we iterate over the dataset and we simply compute_forward and decode
        with torch.no_grad():

            transcripts = []
            # for batch in tqdm(testdata, dynamic_ncols=True):
            batch = sig.unsqueeze(dim=0)
            out = self.compute_forward(batch)
            predicted_tokens = out

                # We go from tokens to words.
            tokenizer = hparams["tokenizer"]
            predicted_words = [
                tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in predicted_tokens
            ]
                
            print(f'label : {batch.wrd}')
            print(f'hyp ----- : {predicted_words}')

In [10]:
asr_brain = ASR(
    modules=hparams["modules"],
    opt_class=hparams["Adam"],
    hparams=hparams,
    checkpointer=hparams["checkpointer"],
)

# adding objects to trainer:
# asr_brain.tokenizer = hparams["tokenizer"]

In [11]:
audio_file = '/data/KsponSpeech/test/eval_clean/KsponSpeech_E02998.wav'

asr_brain.transcribe_file(
    audio_file, # Must be obtained from the dataio_function
    max_key="ACC", # We load the model with the lowest WER
    # loader_kwargs=hparams["test_dataloader_opts"], # opts for the dataloading
)

sig ----- : tensor([-0.0010, -0.0016, -0.0012,  ...,  0.0003,  0.0003,  0.0007])
self.checkpointer checkpoints_dir ----- : /data/models/0513/ascending/save
ckpts ----- : []
self.hparams.model.load_state_dict ----- : <bound method Module.load_state_dict of ModuleList(
  (0): InputNormalization()
  (1): ConvolutionFrontEnd(
    (convblock_0): ConvBlock(
      (convs): Sequential(
        (conv_0): Conv2d(
          (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))
        )
        (norm_0): LayerNorm(
          (norm): LayerNorm((40, 64), eps=1e-05, elementwise_affine=True)
        )
        (act_0): LeakyReLU(negative_slope=0.01)
        (dropout_0): Dropout(p=0.1, inplace=False)
      )
    )
    (convblock_1): ConvBlock(
      (convs): Sequential(
        (conv_0): Conv2d(
          (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(2, 2))
        )
        (norm_0): LayerNorm(
          (norm): LayerNorm((20, 32), eps=1e-05, elementwise_affine=True)
        )
        (act_0):

ValueError: No state dicts to average.

In [None]:
!ls ckpt