In [None]:
import nemo, nemo_asr
from nemo_asr.helpers import post_process_predictions
import numpy as np
import pyaudio as pa
import time

In [None]:
# the checkpoints are available from NGC: https://ngc.nvidia.com/catalog/models/nvidia:quartznet15x5
MODEL_YAML = '../NeMo-CHECKPOINTS/0.8.2/quartznet15x5/quartznet15x5.yaml'
CHECKPOINT_ENCODER = '../NeMo-CHECKPOINTS/0.8.2/quartznet15x5/JasperEncoder-STEP-247400.pt'
CHECKPOINT_DECODER = '../NeMo-CHECKPOINTS/0.8.2/quartznet15x5/JasperDecoderForCTC-STEP-247400.pt'

In [None]:
from ruamel.yaml import YAML
yaml = YAML(typ="safe")
with open(MODEL_YAML) as f:
    jasper_model_definition = yaml.load(f)
labels = jasper_model_definition['labels']

jasper_model_definition['AudioPreprocessing']['dither'] = 0
jasper_model_definition['AudioPreprocessing']['pad_to'] = 0
jasper_model_definition['AudioPreprocessing']['normalize'] = 'fixed'

In [None]:
neural_factory = nemo.core.NeuralModuleFactory(
    placement=nemo.core.DeviceType.GPU,
    backend=nemo.core.Backend.PyTorch)

In [None]:
from nemo.backends.pytorch.nm import DataLayerNM
from nemo.core.neural_types import NeuralType, BatchTag, TimeTag, AxisType
import torch

class AudioDataLayer(DataLayerNM):
    @staticmethod
    def create_ports():
        input_ports = {}
        output_ports = {
            "audio_signal": NeuralType({0: AxisType(BatchTag),
                                        1: AxisType(TimeTag)}),

            "a_sig_length": NeuralType({0: AxisType(BatchTag)}),
        }
        return input_ports, output_ports

    def __init__(self, **kwargs):
        DataLayerNM.__init__(self, **kwargs)
        self.output = True
        
    def __iter__(self):
        return self
    
    def __next__(self):
        if not self.output:
            raise StopIteration
        self.output = False
        return torch.as_tensor(self.signal, dtype=torch.float32), \
               torch.as_tensor(self.signal_shape, dtype=torch.int64)
        
    def set_signal(self, signal):
        self.signal = np.reshape(signal.astype(np.float32)/32768., [1, -1])
        self.signal_shape = np.expand_dims(self.signal.size, 0).astype(np.int64)
        self.output = True

    def __len__(self):
        return 1

    @property
    def dataset(self):
        return None

    @property
    def data_iterator(self):
        return self
    
# Instantiate necessary neural modules
data_layer = AudioDataLayer()

data_preprocessor = nemo_asr.AudioPreprocessing(
    factory=neural_factory,
    **jasper_model_definition["AudioPreprocessing"])

jasper_encoder = nemo_asr.JasperEncoder(
    feat_in=jasper_model_definition["AudioPreprocessing"]["features"],
    **jasper_model_definition["JasperEncoder"])

jasper_decoder = nemo_asr.JasperDecoderForCTC(
    feat_in=jasper_model_definition["JasperEncoder"]["jasper"][-1]["filters"],
    num_classes=len(labels))

greedy_decoder = nemo_asr.GreedyCTCDecoder()

jasper_encoder.restore_from(CHECKPOINT_ENCODER)
jasper_decoder.restore_from(CHECKPOINT_DECODER)

# Define inference DAG
audio_signal, audio_signal_len = data_layer()
processed_signal, processed_signal_len = data_preprocessor(
    input_signal=audio_signal,
    length=audio_signal_len)
encoded, encoded_len = jasper_encoder(audio_signal=processed_signal,
                                      length=processed_signal_len)
log_probs = jasper_decoder(encoder_output=encoded)
predictions = greedy_decoder(log_probs=log_probs)

In [None]:
def infer_signal(self, signal):
    data_layer.set_signal(signal)
    tensors = self.infer([log_probs], verbose=False)
    logits = tensors[0][0]
    return logits

neural_factory.infer_signal = infer_signal.__get__(neural_factory)

In [None]:
def softmax(x):
    '''
    Naive softmax implementation for NumPy
    '''
    m = np.expand_dims(np.max(x, axis=-1), -1)
    e = np.exp(x - m)
    return e / np.expand_dims(e.sum(axis=-1), -1)


class FrameASR:
    
    def __init__(self, neural_factory, jasper_model_definition,
                 frame_len=2, frame_overlap=2.5, 
                 timestep_duration=0.02, offset=5):
        '''
        Args:
          model_params: list of OpenSeq2Seq arguments (same as for run.py)
          scope_name: model's scope name
          sr: sample rate, Hz
          frame_len: frame's duration, seconds
          frame_overlap: duration of overlaps before and after current frame, seconds
          timestep_duration: time per step at model's output, seconds
        '''
        self.vocab = jasper_model_definition['labels']
        self.vocab.append('_')
        
        self.sr = jasper_model_definition['sample_rate']
        self.frame_len = frame_len
        self.n_frame_len = int(frame_len * self.sr)
        self.frame_overlap = frame_overlap
        self.n_frame_overlap = int(frame_overlap * self.sr)
        self.n_timesteps_overlap = int(frame_overlap / timestep_duration) - 2
        self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len, dtype=np.float32)
        # self._calibrate_offset()
        self.offset = offset
        self.reset()
        
        
    def _decode(self, frame, offset=0):
        assert len(frame)==self.n_frame_len
        self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:]
        self.buffer[-self.n_frame_len:] = frame
        logits = neural_factory.infer_signal(self.buffer).cpu().numpy()[0]
        # print(logits.shape)
        decoded = self._greedy_decoder(
            logits[self.n_timesteps_overlap:-self.n_timesteps_overlap], 
            self.vocab
        )
        return decoded[:len(decoded)-offset]
    
    def transcribe(self, frame=None, merge=True):
        if frame is None:
            frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)
        if len(frame) < self.n_frame_len:
            frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant')
        unmerged = self._decode(frame, self.offset)
        if not merge:
            return unmerged
        return self.greedy_merge(unmerged)
    
    
    def _calibrate_offset(self, wav_file, max_offset=10, n_calib_inter=10):
        '''
        Calibrate offset for frame-by-frame decoding
        '''
        sr, signal = wave.read(wav_file)
        
        # warmup
        n_warmup = 1 + int(np.ceil(2.0 * self.frame_overlap / self.frame_len))
        for i in range(n_warmup):
            decoded = self._decode(signal[self.n_frame_len*i:self.n_frame_len*(i+1)], offset=0)
        
        i = n_warmup
        
        offsets = defaultdict(lambda: 0)
        while i < n_warmup + n_calib_inter and (i+1)*self.n_frame_len < signal.shape[0]:
            decoded_prev = decoded
            decoded = self._decode(signal[self.n_frame_len*i:self.n_frame_len*(i+1)], offset=0)
            for offset in range(max_offset, 0, -1):
                if decoded[:offset] == decoded_prev[-offset:] and decoded[:offset] != ''.join(['_']*offset):
                    offsets[offset] += 1
                    break
            i += 1
        self.offset = max(offsets, key=offsets.get)
       
        
    def reset(self):
        '''
        Reset frame_history and decoder's state
        '''
        self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32)
        self.prev_char = ''

    @staticmethod
    def _greedy_decoder(logits, vocab):
        s = ''
        for i in range(logits.shape[0]):
            s += vocab[np.argmax(logits[i])]
        return s

    def greedy_merge(self, s):
        s_merged = ''
        
        for i in range(len(s)):
            if s[i] != self.prev_char:
                self.prev_char = s[i]
                if self.prev_char != '_':
                    s_merged += self.prev_char
        return s_merged

In [None]:
FRAME_LEN = 0.5
CHANNELS = 1
RATE = 16000
CHUNK_SIZE = int(FRAME_LEN*RATE)

asr = FrameASR(neural_factory, jasper_model_definition, frame_len=FRAME_LEN, frame_overlap=2.0, offset=10)
asr.reset()

In [None]:
p = pa.PyAudio()
print('Available audio input devices:')
for i in range(p.get_device_count()):
    dev = p.get_device_info_by_index(i)
    if dev.get('maxInputChannels'):
        print(i, dev.get('name'))
print('Please type input device ID:')
dev_idx = int(input())

signal = np.zeros(CHUNK_SIZE)
empty_counter = 0

def callback(in_data, frame_count, time_info, status):
    global empty_counter
    signal = np.frombuffer(in_data, dtype=np.int16)
    text = asr.transcribe(signal)
    if len(text):
        print(text,end='')
        empty_counter = 3
    elif empty_counter > 0:
        empty_counter -= 1
        if empty_counter == 0:
            print(' ',end='')
    
    return (in_data, pa.paContinue)

stream = p.open(format=pa.paInt16,
                channels=CHANNELS,
                rate=RATE,
                input=True,
                input_device_index=dev_idx,
                stream_callback=callback,
                frames_per_buffer=CHUNK_SIZE)

print('Listening...')

stream.start_stream()

while stream.is_active():
    time.sleep(0.1)

stream.stop_stream()
stream.close()
p.terminate()