In [1]:
import pathlib
from argparse import ArgumentParser

import sentencepiece as spm

import torch
import torchaudio
from lightning import ConformerRNNTModule
from transforms import get_data_module
import json
from torchaudio.models import Hypothesis, RNNTBeamSearch
from typing import List, Tuple
import math
from IPython.display import Audio


In [2]:
sp_model = spm.SentencePieceProcessor(model_file='/home/wonkyum/fc-asr/spm_unigram_1023.model')

In [3]:
checkpoint_path = '/home/wonkyum/fc-asr/exp/checkpoints/epoch=12-step=185523.ckpt'

In [4]:
rnnt_module = ConformerRNNTModule.load_from_checkpoint(checkpoint_path, sp_model=sp_model).eval()



In [5]:
rnnt_module.model.to("cuda")
decoder = RNNTBeamSearch(rnnt_module.model, 1023)

In [6]:
def post_process_hypos(
    hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
    tokens_idx = 0
    score_idx = 3
    post_process_remove_list = [
        sp_model.unk_id(),
        sp_model.eos_id(),
        sp_model.pad_id(),
    ]
    filtered_hypo_tokens = [
        [token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
    ]
    hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
    hypos_ids = [h[tokens_idx][1:] for h in hypos]
    hypos_score = [[math.exp(h[score_idx])] for h in hypos]

    nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))

    return nbest_batch

In [7]:
def _piecewise_linear_log(x):
    x = x * _gain
    x[x > math.e] = torch.log(x[x > math.e])
    x[x <= math.e] = x[x <= math.e] / math.e
    return x


class FunctionalModule(torch.nn.Module):
    def __init__(self, functional):
        super().__init__()
        self.functional = functional

    def forward(self, input):
        return self.functional(input)

class GlobalStatsNormalization(torch.nn.Module):
    def __init__(self, global_stats_path):
        super().__init__()

        with open(global_stats_path) as f:
            blob = json.loads(f.read())

        self.mean = torch.tensor(blob["mean"])
        self.invstddev = torch.tensor(blob["invstddev"])

    def forward(self, input):
        return (input - self.mean) * self.invstddev

In [8]:
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)



def run_decoder(waveform):
    extra_pipeline= torch.nn.Sequential(
            FunctionalModule(_piecewise_linear_log),
            GlobalStatsNormalization('./global_stats.json'),
    )
    mel_f = _spectrogram_transform(waveform[0].squeeze()).transpose(1, 0)
    feats=extra_pipeline(mel_f)
    lengths=torch.tensor(feats.shape[0])
    hypotheses = decoder(feats.to("cuda"), lengths.to("cuda"), 20)
    result=post_process_hypos(hypotheses, sp_model)
    return result[0][0]






In [9]:
my_wave_form, samplerate=torchaudio.load('/home/wonkyum/speech.wav')
Audio(my_wave_form.numpy(), rate=samplerate)

In [10]:
run_decoder(my_wave_form)

'spitch system is fun'