# Speech Recognition with CitriNet

This notebook uses [CitriNet](https://arxiv.org/abs/2104.01721) from the open source project [NVIDIA/NeMo](https://github.com/NVIDIA/NeMo) to transcribe a given youtube video.

For other deep-learning Colab notebooks, visit [tugstugi/dl-colab-notebooks](https://github.com/tugstugi/dl-colab-notebooks).


## Install NVIDIA/Nemo

In [None]:
#@title
import os
from os.path import exists, join, basename, splitext

!pip -q install wget youtube-dl wget #
!pip install -q nemo_toolkit[all]==1.2.0

# we need also Apex
if not exists('apex'):
  !git clone -q --depth 1 https://github.com/NVIDIA/apex
  !cd apex && pip install -q --no-cache-dir ./
  !pip install -q https://github.com/tugstugi/dl-colab-notebooks/archive/colab_utils.zip

from IPython.display import Audio, display, clear_output
import ipywidgets as widgets
import numpy as np
from scipy.io import wavfile
from dl_colab_notebooks.audio import record_audio, upload_audio

## Initialize CitriNet

In [None]:
#@title
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="stt_en_citrinet_1024")
asr_model = asr_model.eval()

import torch
torch.set_grad_enabled(False)

## Record or Upload Speech

In [None]:
#@markdown * Either record audio from microphone or upload audio from file (.mp3 or .wav) 

SAMPLE_RATE = 16000
record_or_upload = "Record" #@param ["Record", "Upload (.mp3 or .wav)"]
record_seconds =   10#@param {type:"number", min:1, max:10, step:1}

def _recognize(audio):
  display(Audio(audio, rate=SAMPLE_RATE, autoplay=True))
  wavfile.write('test.wav', SAMPLE_RATE, (32767*audio).astype(np.int16))

  print('\n')
  transcription = predictions = asr_model.transcribe(['test.wav'], batch_size=1)
  print('\n\n')
  print(transcription)


def _record_audio(b):
  clear_output()
  audio = record_audio(record_seconds, sample_rate=SAMPLE_RATE)
  _recognize(audio)
def _upload_audio(b):
  clear_output()
  audio = upload_audio(sample_rate=SAMPLE_RATE)
  _recognize(audio)

if record_or_upload == "Record":
  button = widgets.Button(description="Record Speech")
  button.on_click(_record_audio)
  display(button)
else:
  _upload_audio("")