<a href="https://colab.research.google.com/github/prabhakaran-s-code/genai-python/blob/main/speech_to_text_whisper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Abstract
This is a python code to convert speech to text for a wide variety of video and audio files including speaker diarization. The transcript files will be generated in txt format and captions in vtt format. Open AI Whisper Model is used for generating caption/transcript and Nvidia NeMo model is used for speaker diarization

In [None]:
# Install Whisper
!pip install git+https://github.com/openai/whisper.git

# Install NeMo
BRANCH = 'r1.23.0'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]

## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
#!pip install text-unidecode
!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html # Install TorchAudio

import json, datetime, json, subprocess, whisper, os, re
from nemo.collections.asr.models import ClusteringDiarizer
from omegaconf import OmegaConf

# Load Whisper model
model = whisper.load_model("base")

In [None]:
# initialize the intermediate/final output folder paths
vttpath = "path/to/subtitles/"
txtpath = "path/to/transcripts/"
json_path = "path/to/json/"
diarization_path = "path/to/diarization/"
video_path = "path/to/videos/"

if not os.path.exists(vttpath):
  os.makedirs(vttpath)

if not os.path.exists(txtpath):
  os.makedirs(txtpath)

if not os.path.exists(json_path):
  os.makedirs(json_path)

if not os.path.exists(video_path):
  os.makedirs(video_path)

if not os.path.exists(diarization_path):
  os.makedirs(diarization_path)

# meta data for NeMo model
meta = {
    'audio_filepath': '',
    'offset': 0,
    'duration':None,
    'label': 'infer',
    'text': '-',
    'num_speakers': 2,
    'rttm_filepath': None,
    'uem_filepath' : None
}

# Load Model Config for NeMo model
MODEL_CONFIG = os.path.join(diarization_path,'diar_infer_telephonic.yaml')
if not os.path.exists(MODEL_CONFIG):
    !wget -P $diarization_path "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml"

config = OmegaConf.load(MODEL_CONFIG)
#print(OmegaConf.to_yaml(config))
config.diarizer.manifest_filepath = diarization_path + 'input_manifest.json'
config.diarizer.out_dir = diarization_path #Directory to store intermediate files and prediction outputs

In [None]:
# initialize URLs containing media files for which transcript/caption has to be generated
urls = [#add the urls of video or audio files separated by comma
        ]

In [None]:
# Function to convert seconds to HH:mm:ss.SSS format
def convert_seconds(total_seconds):
  """
  Converts seconds to a string in the format HH:mm:ss.SSS.

  Args:
    total_seconds: The number of seconds to convert.

  Returns:
    A string in the format HH:mm:ss.SSS.
  """
  hours = int(total_seconds // 3600)
  minutes = int((total_seconds % 3600) // 60)
  seconds = int(total_seconds % 60)
  milliseconds = int((total_seconds * 1000) % 1000)

  return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"


In [None]:
# Function to write the complete transcribed text to a json, text and vtt file
def write_to_file_wh(transcript, filename):
  with open(json_path + os.path.splitext(os.path.basename(filename))[0]+".json", 'w') as f:
    f.write(json.dumps(transcript))

  with open(txtpath + os.path.splitext(os.path.basename(filename))[0]+".txt", 'w') as f:
    text = transcript['text']
    f.write(text + '\n')
  captions = transcript['segments']
  #print(caption)

  with open(vttpath + os.path.splitext(os.path.basename(filename))[0]+".vtt", 'w') as f:
    f.write('WEBVTT' + '\n\n')
    captions = transcript['segments']
    for output_dict in captions:
        text = output_dict['text']
        if(text != ""):
            output_start_time = convert_seconds(output_dict['start'])
            output_end_time = convert_seconds(output_dict['end'])
            f.write(output_start_time + " --> " + output_end_time + '\n' + text.strip() + '\n\n')

In [None]:
def load_rttm_file(base_video_file_name):
    turns = []  # List to store turn information
    rttm_path = diarization_path + 'pred_rttms/' + base_video_file_name + '.rttm'
    print(rttm_path)
    with open(rttm_path, 'r') as rttm_file:
        for line in rttm_file:
            fields = line.strip().split()  # Split line by spaces
            if len(fields) == 10:
                # Extract relevant fields
                turn_info = {
                    'Type': fields[0],
                    'File ID': fields[1],
                    'Turn Onset': float(fields[3]),
                    'Turn Duration': float(fields[4]),
                    'Speaker Name': fields[7]
                }
                turns.append(turn_info)
    return turns

In [None]:
def get_speaker_label(turns, start_time, end_time):
    # Find the relevant line(s) in the RTTM data
    if start_time == 0.00:
        return turns[0]['Speaker Name']

    for turn_info in turns:
#        print(str(turn_info['Turn Onset']) + " --- " + str((turn_info['Turn Onset'] + turn_info['Turn Duration'])))
        if int(turn_info['Turn Onset']) <= start_time <= (int((turn_info['Turn Onset']) + int(turn_info['Turn Duration']))):
            return turn_info['Speaker Name']

#        print(str(turn_info['Turn Onset']) + " --- " + str((turn_info['Turn Onset'] + turn_info['Turn Duration'])))
        if int(turn_info['Turn Onset']) <= end_time <= (int((turn_info['Turn Onset']) + int(turn_info['Turn Duration']))):
            return turn_info['Speaker Name']

    # If no match found, return a default label (e.g., 'Unknown')
    return 'Unknown'

In [None]:
def load_transcript_json(json_transcript_path):
    with open(json_transcript_path, 'r') as json_file:
        return json.load(json_file)

In [None]:
def add_speaker_labels(transcript_json, base_video_file_name):
  previous_speaker = ''
  turns = load_rttm_file(base_video_file_name)
  with open(txtpath + base_video_file_name+"_speaker.txt", 'w') as f:
    for index, segment in enumerate(transcript_json['segments'], start =0):
        # Match segment timestamps with RTTM data
        # and add the corresponding speaker label
        start_time = int(segment['start'])
        end_time = int(segment['end'])
        speaker = get_speaker_label(turns, start_time, end_time)
        if speaker != 'Unknown':
            number = int(speaker.split("_")[-1])
            number += 1
            speaker = "speaker_" + str(number)
        segment['speaker'] = speaker
        # convert time to HH:mm:ss.SSS format and strip off milliseconds
        time = re.sub(r'\.\d+$', '',convert_seconds(start_time))
        if previous_speaker == speaker :
            f.write(segment['text'])
        else :
            # f.write(" " + time + " " + speaker + '\n')
            if (index == 0):
                f.write(speaker + '\n')
            else:
                f.write('\n\n' + speaker + '\n')
            f.write(segment['text'].strip())
        #f.write(time + " " + speaker + '\n')
        previous_speaker = speaker
    return transcript_json

# Example usage
#transcript_json_path = '/content/path/to/json/abc.json'

#transcript_data = load_transcript_json(transcript_json_path)
#updated_transcript = add_speaker_labels(transcript_data)

#print(updated_transcript)

In [None]:
def extract_file_name(url):
  """Extracts the file name from the end of a URL."""
  match = re.search(r'/([^/]+)$', url)
  if match:
    return match.group(1)
  else:
    return None

In [None]:
#main block
import gc
for url in urls:
  #print (url)
  !wget -P path/to/videos {url}
  filename = video_path+extract_file_name(url)
  result = model.transcribe(filename) #Transcribe the video file using Whisper Model
  #print(result)
  write_to_file_wh(result, filename) #Write the output transcript json, text and vtt files
  base_video_file_name = os.path.splitext(os.path.basename(filename))[0]
  wav_file_path = video_path + base_video_file_name+".wav"
  subprocess.call(['ffmpeg', '-i', filename,'-ac', '1' , wav_file_path, '-y'])
  meta['audio_filepath'] = wav_file_path
  with open(diarization_path + 'input_manifest.json','w') as fp:
      json.dump(meta,fp)
      fp.write('\n')
  sd_model = ClusteringDiarizer(cfg=config)
  sd_model.diarize() #Extract speaker information from the audio file

  transcript_data = load_transcript_json(json_path + base_video_file_name+".json")
  add_speaker_labels(transcript_data, base_video_file_name)
  gc.collect()