In [1]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

  from .autonotebook import tqdm as notebook_tqdm


1.13.1+cu116
0.13.1+cu116


In [3]:
import IPython

from torchaudio.io import StreamReader


In [4]:
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH

feature_extractor = bundle.get_streaming_feature_extractor()
decoder = bundle.get_decoder()
token_processor = bundle.get_token_processor()

100%|██████████| 3.81k/3.81k [00:00<00:00, 1.95MB/s]
100%|██████████| 293M/293M [00:14<00:00, 20.6MB/s] 
100%|██████████| 295k/295k [00:00<00:00, 1.02MB/s]


In [5]:
sample_rate = bundle.sample_rate
segment_length = bundle.segment_length * bundle.hop_length
context_length = bundle.right_context_length * bundle.hop_length

print(f"Sample rate: {sample_rate}")
print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")

Sample rate: 16000
Main segment: 2560 frames (0.16 seconds)
Right context: 640 frames (0.04 seconds)


In [6]:
src = "https://download.pytorch.org/torchaudio/tutorial-assets/greatpiratestories_00_various.mp3"

streamer = StreamReader(src)
streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)

print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0))

StreamReaderSourceAudioStream(media_type='audio', codec='mp3', codec_long_name='MP3 (MPEG audio layer 3)', format='fltp', bit_rate=128000, num_frames=0, bits_per_sample=0, metadata={}, sample_rate=44100.0, num_channels=2)
StreamReaderOutputStream(source_index=0, filter_description='aresample=16000,aformat=sample_fmts=fltp')


In [7]:
class ContextCacher:
    """Cache the end of input data and prepend the next input data with it.

    Args:
        segment_length (int): The size of main segment.
            If the incoming segment is shorter, then the segment is padded.
        context_length (int): The size of the context, cached and appended.
    """

    def __init__(self, segment_length: int, context_length: int):
        self.segment_length = segment_length
        self.context_length = context_length
        self.context = torch.zeros([context_length])

    def __call__(self, chunk: torch.Tensor):
        if chunk.size(0) < self.segment_length:
            chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
        chunk_with_context = torch.cat((self.context, chunk))
        self.context = chunk[-self.context_length :]
        return chunk_with_context

In [8]:
cacher = ContextCacher(segment_length, context_length)

state, hypothesis = None, None

In [9]:
stream_iterator = streamer.stream()


@torch.inference_mode()
def run_inference(num_iter=200):
    global state, hypothesis
    chunks = []
    for i, (chunk,) in enumerate(stream_iterator, start=1):
        segment = cacher(chunk[:, 0])
        features, length = feature_extractor(segment)
        hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
        hypothesis = hypos[0]
        transcript = token_processor(hypothesis[0], lstrip=False)
        print(transcript, end="", flush=True)

        chunks.append(chunk)
        if i == num_iter:
            break

    return IPython.display.Audio(torch.cat(chunks).T.numpy(), rate=bundle.sample_rate)

In [10]:
run_inference()

 forward great pirate's this is aver's recording all thects recordings are in the public dum for more information or please visit liberg recording by james christopher great pirite stories by various edited by josey embodies the romance of theed expression it is a sad but inevable comment on our civilization that so far as the sea is concerned