In [None]:
# Copyright 2019 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# Kaldi TRTIS Inference Demo

## Overview


This repository provides a wrapper around the online GPU-accelerated ASR pipeline from the paper [GPU-Accelerated Viterbi Exact Lattice Decoder for Batched Online and Offline Speech Recognition](https://arxiv.org/abs/1910.10032). That work includes a high-performance implementation of a GPU HMM Decoder, a low-latency Neural Net driver, fast Feature Extraction for preprocessing, and new ASR pipelines tailored for GPUs. These different modules have been integrated into the Kaldi ASR framework.

This repository contains a TensorRT Inference Server custom backend for the Kaldi ASR framework. This custom backend calls the high-performance online GPU pipeline from the Kaldi ASR framework. This TensorRT Inference Server integration provides ease-of-use to Kaldi ASR inference: gRPC streaming server, dynamic sequence batching, and multi-instances support. A client connects to the gRPC server, streams audio by sending chunks to the server, and gets back the inferred text as an answer. More information about the TensorRT Inference Server can be found [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/).  



### Learning objectives

This notebook demonstrates the steps for carrying out inferencing with the Kaldi TRTIS backend server using a Python gRPC client in an offline context, that is, we will stream pre-recorded .wav files to the inference server and receive the results back.

## Content
1. [Pre-requisite](#1)
1. [Setup](#2)
1. [Audio helper classes](#3)
1. [Inference](#4)


<a id="1"></a>
## 1. Pre-requisite

### 1.1 Docker containers
Follow the steps in [README](README.md) to build Kaldi server and client containers.

### 1.2 Hadrware
This notebook can be executed on any CUDA-enabled NVIDIA GPU, although for efficient mixed precision inference, a [Tensor Core NVIDIA GPU](https://www.nvidia.com/en-us/data-center/tensorcore/) is desired (Volta, Turing or newer architectures). 

In [None]:
!nvidia-smi \
    --query-gpu=index,name,driver_version,memory.total,uuid,mig.mode.current \
    --format=csv

### 1.3 Data download and preprocessing

The  script `scripts/docker/launch_download.sh` will download the LibriSpeech test dataset along with Kaldi ASR models.

In [None]:
!ls /Kaldi/data
!ls /Kaldi/data/data/LibriSpeech

Within the docker container, the final data  and model directory should look like:

```
data  datasets	models
BOOKS.TXT     LICENSE.TXT  SPEAKERS.TXT  test-other
CHAPTERS.TXT  README.TXT   test-clean
```

<a id="2"></a>
## 2 Setup 
### 2.1 Import libraries and parameters

In [None]:
import os
import queue
import subprocess
from functools import partial
from io import StringIO

import grpc
import IPython.display as ipd
# import librosa
import numpy as np
import pandas as pd
import soundfile
import tritonclient.grpc as grpcclient
from tritonclient.utils import InferenceServerException

### 2.2 Set Client Variables

In [None]:
# Path for input file. First line should contain number of lines to search in
FILE = None

# Inference server URL from LoadBalancer.
URL = "172.42.42.1:8001"

# Name of model
MODEL_NAME = "kaldi_online"

# Version of model
MODEL_VERSION = "1"

# Inference Batch Size
BATCH_SIZE = 1

# Inference / Server verbosity
VERBOSE = False

# Unique ID for a request
CORRELATION_ID = "1101"

# Convert sample type to float32. Integers will be scaled to [-1, 1] in float32
WAV_SCALE_FACTOR = 2**15 - 1

### 2.3 Checking server status

Triton has multiple methods of collecting health status and metadata from both the server and model.

In [None]:
# Create gRPC stub for communicating with the server
channel = grpc.insecure_channel(URL)
grpc_stub = grpcclient.service_pb2_grpc.GRPCInferenceServiceStub(channel)

# Check Server Status
request = grpcclient.service_pb2.ServerLiveRequest()
response = grpc_stub.ServerLive(request)
print("server {}".format(response))

# Check Server Metadata
response = grpc_stub.ServerLive(request)
request = grpcclient.service_pb2.ServerMetadataRequest()
response = grpc_stub.ServerMetadata(request)
print("server metadata:\n{}".format(response))

# Check Model Status
request = grpcclient.service_pb2.ModelReadyRequest(
    name=MODEL_NAME, version=MODEL_VERSION
)
response = grpc_stub.ModelReady(request)
print("model {}".format(response))


# Check Model Metadata
request = grpcclient.service_pb2.ModelMetadataRequest(
    name=MODEL_NAME, version=MODEL_VERSION
)
response = grpc_stub.ModelMetadata(request)
print("model metadata:\n{}".format(response))


# Check Model Configuration - output may be verbose
request = grpcclient.service_pb2.ModelConfigRequest(
    name=MODEL_NAME, version=MODEL_VERSION
)
response = grpc_stub.ModelConfig(request)
print("model config:\n{}".format(response))

<a id="3"></a>
## 3. Audio helper classes

Next, we define some helper classes for pre-processing audio from files. The below AudioSegment class reads audio data from .wav files and converts the sampling rate to that required by the Kaldi ASR model, which is 16000Hz by default.

Note:  For historical reasons, Kaldi expects waveforms in the range (2^15-1)x[-1, 1], not the usual default DSP range [-1, 1]. Therefore, we scale the audio signal by a factor of (2^15-1).

### Loading data

We load and play a wave file from the LibriSpeech data set. The LibriSpeech data set is organized into directories and subdirectories containing speech segments and transcripts for different speakers. 

Next, we define a helper function which generate pairs of filepath and transcript from a LibriSpeech data directory.

In [None]:
DIR = "1089"
SUBDIR = "134686"
FILE_ID = "0000"
FILE_NAME = "/Kaldi/data/data/LibriSpeech/test-clean/%s/%s/%s-%s-%s.wav" % (
    DIR,
    SUBDIR,
    DIR,
    SUBDIR,
    FILE_ID,
)
TRANSRIPTION_FILE = "/Kaldi/data/data/LibriSpeech/test-clean/%s/%s/%s-%s.trans.txt" % (
    DIR,
    SUBDIR,
    DIR,
    SUBDIR,
)

In [None]:
class AudioSegment(object):
    """Monaural audio segment abstraction.
    :param samples: Audio samples [num_samples x num_channels].
    :type samples: ndarray.float32
    :param sample_rate: Audio sample rate.
    :type sample_rate: int
    :raises TypeError: If the sample data type is not float or int.
    """

    def __init__(self, samples, sample_rate, target_sr=16000, trim=False, trim_db=60):
        """Create audio segment from samples.
        Samples are convert float32 internally, with int scaled to [-1, 1].
        """
        samples = self._convert_samples_to_float32(samples)
        if target_sr is not None and target_sr != sample_rate:
            samples = librosa.core.resample(samples, sample_rate, target_sr)
            sample_rate = target_sr
        if trim:
            samples, _ = librosa.effects.trim(samples, trim_db)
        self._samples = samples
        self._sample_rate = sample_rate
        if self._samples.ndim >= 2:
            self._samples = np.mean(self._samples, 1)

    @staticmethod
    def _convert_samples_to_float32(samples):
        """Convert sample type to float32.
        Audio sample type is usually integer or float-point.
        Integers will be scaled to [-1, 1] in float32.
        """
        float32_samples = samples.astype("float32")
        if samples.dtype in np.sctypes["int"]:
            bits = np.iinfo(samples.dtype).bits
            float32_samples *= 1.0 / ((2 ** (bits - 1)) - 1)
        elif samples.dtype in np.sctypes["float"]:
            pass
        else:
            raise TypeError("Unsupported sample type: %s." % samples.dtype)
        return WAV_SCALE_FACTOR * float32_samples

    @classmethod
    def from_file(
        cls,
        filename,
        target_sr=16000,
        offset=0,
        duration=0,
        min_duration=0,
        trim=False,
    ):
        """
        Load a file supported by librosa and return as an AudioSegment.
        :param filename: path of file to load
        :param target_sr: the desired sample rate
        :param int_values: if true, load samples as 32-bit integers
        :param offset: offset in seconds when loading audio
        :param duration: duration in seconds when loading audio
        :return: numpy array of samples
        """
        with sf.SoundFile(filename, "r") as f:
            dtype_options = {
                "PCM_16": "int16",
                "PCM_32": "int32",
                "FLOAT": "float32",
            }
            dtype_file = f.subtype
            if dtype_file in dtype_options:
                dtype = dtype_options[dtype_file]
            else:
                dtype = "float32"
            sample_rate = f.samplerate
            if offset > 0:
                f.seek(int(offset * sample_rate))
            if duration > 0:
                samples = f.read(int(duration * sample_rate), dtype=dtype)
            else:
                samples = f.read(dtype=dtype)

        num_zero_pad = int(target_sr * min_duration - samples.shape[0])
        if num_zero_pad > 0:
            samples = np.pad(samples, [0, num_zero_pad], mode="constant")

        samples = samples.transpose()
        return cls(samples, sample_rate, target_sr=target_sr, trim=trim)

    @property
    def samples(self):
        return self._samples.copy()

    @property
    def sample_rate(self):
        return self._sample_rate


# read audio chunk from a file
def get_audio_chunk_from_soundfile(sf, chunk_size):

    dtype_options = {"PCM_16": "int16", "PCM_32": "int32", "FLOAT": "float32"}
    dtype_file = sf.subtype
    if dtype_file in dtype_options:
        dtype = dtype_options[dtype_file]
    else:
        dtype = "float32"
    audio_signal = sf.read(chunk_size, dtype=dtype)
    end = False
    # pad to chunk size
    if len(audio_signal) < chunk_size:
        end = True
        audio_signal = np.pad(
            audio_signal, (0, chunk_size - len(audio_signal)), mode="constant"
        )
    return audio_signal, end


# generator that returns chunks of audio data from file
def audio_generator_from_file(input_filename, target_sr, chunk_duration):

    sf = soundfile.SoundFile(input_filename, "rb")
    chunk_size = int(chunk_duration * sf.samplerate)
    start = True
    end = False

    while not end:

        audio_signal, end = get_audio_chunk_from_soundfile(sf, chunk_size)

        audio_segment = AudioSegment(audio_signal, sf.samplerate, target_sr)

        yield audio_segment.samples, target_sr, start, end
        start = False

    sf.close()


def libri_generator(DATASET_ROOT):
    for subdir in os.listdir(DATASET_ROOT):
        SUBDIR = os.path.join(DATASET_ROOT, subdir)
        if os.path.isdir(os.path.join(DATASET_ROOT, subdir)):
            for subsubdir in os.listdir(SUBDIR):
                SUBSUBDIR = os.path.join(SUBDIR, subsubdir)
                # print(os.listdir(SUBSUBDIR))
                transcription_file = os.path.join(
                    DATASET_ROOT,
                    SUBDIR,
                    SUBSUBDIR,
                    "%s-%s.trans.txt" % (subdir, subsubdir),
                )
                transcriptions = {}
                # pdb.set_trace()
                with open(transcription_file, "r") as f:
                    for line in f:
                        fields = line.split(" ")
                        transcriptions[fields[0]] = " ".join(fields[1:])
                for file_key, transcript in transcriptions.items():
                    file_path = os.path.join(
                        DATASET_ROOT, SUBDIR, SUBSUBDIR, file_key + ".wav"
                    )
                    yield file_path, transcript.strip().lower()


class UserData:
    def __init__(self):
        self._completed_requests = queue.Queue()


# Define the callback function. Note the last two parameters should be
# result and error. InferenceServerClient would povide the results of an
# inference as grpcclient.InferResult in result. For successful
# inference, error will be None, otherwise it will be an object of
# tritonclientutils.InferenceServerException holding the error details
def callback(user_data, result, error):
    if error:
        user_data._completed_requests.put(error)
    else:
        # print(result.as_numpy("TEXT"))
        user_data._completed_requests.put(result.as_numpy("TEXT"))


class Streamer():
    def __init__(self, q):
        self.q = q
    def __iter__(self):
        while not self.q._completed_requests.empty():
            yield self.q._completed_requests.get_nowait().flatten()[0]

<a id="4"></a>
## Offline Inference

We first create an inference client that connects to the Kaldi TRTIS servier via a gPRC connection.

The server expects chunks of audio each containing up to input.WAV_DATA.dims samples (default: 8160). Per default, this corresponds to 510ms of audio per chunk (i.e. 16000Hz sampling rate). The last chunk can send a partial chunk smaller than this maximum value.

Next, we take chunks from a selected audio file (each 510ms in duration, containing 8160 samples) and stream them sequentially to the Kaldi server. The server processes each chunk as soon as it is received. The transcription result is stored in the callback function for each stream.

In [None]:
user_data = UserData()
sequence_id = 1
count = 0
result_list = []


with grpcclient.InferenceServerClient(url=URL, verbose=VERBOSE) as triton_client:
    try:
        # Establish stream
        triton_client.start_stream(callback=partial(callback, user_data))

        for value_data in audio_generator_from_file(FILE_NAME, 16000, 0.51):
            # Create the tensor for INPUT
            wav_data = value_data[0]
            dim = np.full(shape=[1, 1], fill_value=len(wav_data), dtype=np.int32)

            wav_data = np.full(
                shape=[BATCH_SIZE, 8160], fill_value=wav_data, dtype=np.float32
            )

            inputs = []
            inputs.append(grpcclient.InferInput("WAV_DATA", wav_data.shape, "FP32"))
            inputs[-1].set_data_from_numpy(wav_data)

            inputs.append(grpcclient.InferInput("WAV_DATA_DIM", dim.shape, "INT32"))
            inputs[-1].set_data_from_numpy(dim)

            outputs = []
            outputs.append(grpcclient.InferRequestedOutput("TEXT"))

            # Issue the asynchronous sequence inference.
            result = triton_client.async_stream_infer(
                model_name=MODEL_NAME,
                model_version=MODEL_VERSION,
                inputs=inputs,
                outputs=outputs,
                request_id="{}_{}".format(sequence_id, count),
                sequence_id=sequence_id,
                sequence_start=value_data[2],
                sequence_end=value_data[3],
            )
            count = count + 1

    except InferenceServerException as error:
        print("InferenceServerException: {}".format(error))
        
        
for item in Streamer(user_data):
    print(item.decode('utf-8').lower())

## Ground Truth

In [None]:
batcmd = "cat %s|grep %s" % (TRANSRIPTION_FILE, FILE_ID)
res = subprocess.check_output(batcmd, shell=True)
transcript = " ".join(res.decode("utf-8").split(" ")[1:]).lower()

print(transcript)
ipd.Audio(FILE_NAME)

## Run Inference in Parallel
___

```bash
mkdir /Kaldi/data/results
kaldi-asr-parallel-client \
    -i 5 \
    -c 2000 \
    -o \
    -u 172.42.42.1:8001

==================================================
============= Triton Kaldi ASR Client ============
==================================================

Configuration:

Number of iterations            : 5
Number of parallel channels     : 2000
Server URL                      : 172.42.42.1:8001
Print text outputs              : No
Print partial text outputs      : No
Online - Realtime I/O           : Yes

Loading eval dataset...done
Loaded dataset with 2620 utterances, frequency 16000hz, total audio 19452.5 seconds
Opening GRPC contexts...done
Streaming utterances...
..................................................................................................done
Waiting for all results...done
Latencies:      90%             95%             99%             Avg
                0.109           0.115           0.13            0.0893
```