# Transcribe WAV file using Wav2Vec2
Rolando Coto-Solano (Rolando.A.Coto.Solano@dartmouth.edu)<br>
Dartmouth College. Last update: 20250601

The program takes four main inputs:

* `audioFileName`: The WAV, MP3 or MP4 file that you wish to transcribe. It needs to be in the `audiofiles-to-transcribe` folder of the sandbox you will use.<br>
* `currentSandbox`: The name of the sandbox you are using. The defaults are {sandbox-user and all-wavs}, but you can use whichever you specified during the installation.<br>
* `installationFolder`: The folder where the ASR sandboxes are contained. The default value is `202506-ood-asr`, but you should use the one you specified during the installation.<br>

The program also takes the following inputs:

* `modelCheckpointToUse`: The name of the folder that has the checkpoint for the transcription model. The default is `checkpoint-1200`, but you should check which one you saved by going to the `wav2vec2-model` folder.<br>
* `minDurationOfFile `: Minimum duration of a file that should be transcribed The default is 100 milliseconds.<br>
* `maxWavDuration`: The maximum duration for a recording to be allowed into the Wav2Vec2 data. Wav2Vec2's CUDA memory crashed when processing long files. The default is 15 seconds; this is the maximum duration where I can guarantee that the Colab memory won't crash.<br>

The program takes the audio file. It then (1) splits the large audio file into smaller chunks using Silero-VAD, and (2) runs each of these through the transcription model. It then (3) saves these results into a TSV file in the folder `tsv-outputs`. This file can then be imported into ELAN.

## (1) File preparation

You need to run this for every new file you process

In [None]:
#=================================================
# If the computer tells you to "restart session",
# please restart it and run this box again.
#=================================================

!pip install numpy==1.25.0

In [None]:
# The file should be in the folder audiofiles-to-transcribe (inside of your sandbox)
audioFileName = "kia-orana-rehearsal.mp4"

# Environmental variables
currentSandbox = "sandbox-user"    # Please type sandbox-user or all-wavs
installationFolder = "202506-ood-asr"

# Model variables
modelCheckpointToUse = "checkpoint-1200"       # You can type "full-cim-model" to use the pretrained model
                                              # Or you can type someting like "checkpoint-1200" to use the model you trained
# Minimum duration of segments that the computer should transcribe
minDurationOfFile = 100 #ms

# Maximum permissible duration of segments. Wav2Vec2's CUDA memory might crash when processing long files
maxWavDuration = 15    # Seconds

# Use GPU for processing? If you select "no", then the system will use the slower CPU processing
useGPU = "yes"

In [None]:
# ================================================================
# Mount the Google Drive onto the virtual computer
# ================================================================

from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

## (2) Model preparation

You only need to run this once per session. It should take about 2 minutes.

In [None]:
!pip install silero-vad
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
vadmodel = load_silero_vad()

In [None]:
#=============================================================
# Determine type of processing
#=============================================================

typeProcessor = "cuda"
if (useGPU == "no"): typeProcessor = "cpu"

#=============================================================
# Downloads ASR model for CIM
#=============================================================

!mkdir /content/wav2vec2-model
!cp /content/drive/MyDrive/{installationFolder}/{currentSandbox}/wav2vec2-model/*.* /content/wav2vec2-model
!mkdir /content/wav2vec2-model/checkpoint
!cp /content/drive/MyDrive/{installationFolder}/{currentSandbox}/wav2vec2-model/{modelCheckpointToUse}/*.* /content/wav2vec2-model/checkpoint

pathCheckpoint = "/content/wav2vec2-model/checkpoint"
modelPath = "/content/wav2vec2-model"

#model = Wav2Vec2ForCTC.from_pretrained(pathCheckpoint).to(typeProcessor)
#processor = Wav2Vec2Processor.from_pretrained("wav2vec2-model")

In [None]:
import os

def findKenLMModel(folder_path):
    for entry in os.listdir(folder_path):
        if entry.endswith('-correct.arpa') and os.path.isfile(os.path.join(folder_path, entry)):
            return entry
    return "-1"  # if no such file is found

filenameCorrectKenlmModel = findKenLMModel(modelPath + "/")
filenameCorrectKenlmModel = modelPath + "/" + filenameCorrectKenlmModel

## (3) Split audio file into bits

You need to run this again for every file you process.

In [None]:
def extractStartEndTimes(filepath):
    startTimes = []
    endTimes = []

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            # Strip newline and split by tabs
            parts = line.strip().split('\t')

            if len(parts) < 3:
                continue

            try:
                start = float(parts[0])
                end = float(parts[1])
                startTimes.append(start)
                endTimes.append(end)
            except ValueError:
                # If conversion to float fails, skip that line
                continue

    return startTimes, endTimes

In [None]:
def countDigits(number):
  count=0
  while(number>0):
    count=count+1
    number=number//10
  return(count)

def addZerosInFrontOfNumber(number, total):

  lenNum = len(str(number))
  lenTotal = len(str(total))

  zerosToAdd = lenTotal-lenNum

  stringZeros = ""
  for i in range(0,zerosToAdd): stringZeros = stringZeros + "0"
  retNum = stringZeros + str(number)
  return retNum

In [None]:
def findSegmentIndex(timepoints, starttime, endtime):
    # Convert timepoints string to a sorted list of floats
    points = sorted([float(tp) for tp in timepoints.split(',')])

    # Find the segment index where starttime and endtime correspond exactly
    # Segment 0: [0, points[0])
    # Segment i: [points[i-1], points[i])

    # Check if starttime is 0 (start of segment 0) or matches points
    # Since segments are defined between these points, the input start and end should match
    # one of the segment boundaries

    # We look for the segment where:
    # segment i means interval [points[i-1], points[i])

    # For segment 0: interval [0, points[0])
    for i, point in enumerate(points):
        if i == 0:
            seg_start = 0.0
        else:
            seg_start = points[i-1]
        seg_end = point

        # Check if the starttime and endtime match this segment
        if starttime == seg_start and endtime == seg_end:
            return i

    # If no match found, optionally return None or error
    return -1

In [None]:
from decimal import Decimal
import os

def fixIntervals(audio_path, file_path, maxDuration):
    # This list will hold all the intervals after processing
    intervals = []

    totalLines = 0

    # Open the file and read lines
    with open(file_path, 'r') as f:
        for line in f:
            totalLines = totalLines + 1
            parts = line.strip().split(",")
            if len(parts) != 2:
                continue
            start, end = float(parts[0]), float(parts[1])

            # Check if the duration exceeds 15 seconds
            while end - start > maxDuration:
                # Add an interval with a maximum of 15 seconds
                intervals.append((start, start + maxDuration))
                start += maxDuration + 0.001  # Increment start for the next interval

            # Add the final interval (which is <= 15 seconds)
            intervals.append((start, end))

    # Sort intervals based on the start time
    intervals.sort()

    output = ""

    lineCounter = 0
    filenames = []
    # Print the processed and sorted intervals
    for interval in intervals:
        lineCounter = lineCounter + 1
        nameAudio = audio_path.replace(".wav", "-" + str(addZerosInFrontOfNumber(lineCounter,totalLines)) + ".wav")
        filenames.append(nameAudio)
        #output = output + str(round(Decimal(interval[0]),3)) + "\t" + interval[1] + "\n"
        output = output + str(round(Decimal(interval[0]),3)) + "\t" + str(round(Decimal(interval[1]),3)) + "\t" + nameAudio + "\n"
        #print(f"{interval[0]:.3f} {interval[1]:.3f}")

    file_path = file_path.replace(".csv", ".tsv")
    f = open(file_path, "w")
    f.write(output)
    f.close()

    sampleCSV = "path,sentence\n"
    for f in filenames:
      sampleCSV = sampleCSV + f + ", \n"
    sampleCSV = sampleCSV[:-1]

    folderPath = os.path.dirname(file_path)
    samplePath = folderPath + "/sample.csv"
    #print(samplePath)
    #print(sampleCSV)
    f = open(samplePath, "w")
    f.write(sampleCSV)
    f.close()


def leaveOnlyLastThreeColsOfTSV(inPath, outPath):
    try:
        with open(inPath, 'r', encoding='utf-8') as infile:  # Open the input file
            lines = infile.readlines()  # Read all lines

        processed_lines = []
        for line in lines:
            columns = line.strip().split('\t')  # Split the line into columns using tab as the delimiter
            if len(columns) >= 3:  # Ensure there are at least 3 columns to keep
                # Keep only the last three columns (columns[2], columns[3], columns[4])
                processed_lines.append(','.join(columns[2:]) + '\n')  # Rejoin and append to the list

        with open(outPath, 'w', encoding='utf-8') as outfile:  # Open the output file for writing
            outfile.writelines(processed_lines)  # Write the processed lines to the output file

        print("File processed successfully from {} to {}.".format(inPath, outPath))

    except IOError as e:
        print("An IOError occurred: {}".format(e))
    except Exception as e:
        print("An unexpected error occurred: {}".format(e))


In [None]:
import glob

#=============================================================
# Erase previous files and get new audio file
#=============================================================

# Erase previous files
%cd /content/
%rm *.wav >/dev/null 2>&1
%rm *.WAV >/dev/null 2>&1
%rm *.mp3 >/dev/null 2>&1
%rm *.MP3 >/dev/null 2>&1
%rm *.mp4 >/dev/null 2>&1
%rm *.MP4 >/dev/null 2>&1
%rm *.csv >/dev/null 2>&1
%rm *.CSV >/dev/null 2>&1
%rm *.tsv >/dev/null 2>&1
%rm *.TSV >/dev/null 2>&1
%rm *.txt >/dev/null 2>&1
%rm *.TXT >/dev/null 2>&1

folderWithAudioFiles = "/content/drive/MyDrive/"+installationFolder+"/"+currentSandbox+"/audiofiles-to-transcribe/"
!cp {folderWithAudioFiles}{audioFileName} .

#=============================================================
# Get proper path to audio file
#=============================================================

# get filenames of wave files in the remote server
path1 = r'/content/*.wav'
path2 = r'/content/*.WAV'
path3 = r'/content/*.mp3'
path4 = r'/content/*.MP3'
path5 = r'/content/*.mp4'
path6 = r'/content/*.MP4'
path7 = r'/content/*.mov'
path8 = r'/content/*.MOV'

files = []
files = glob.glob(path1) + glob.glob(path2) + glob.glob(path3) + glob.glob(path4) + glob.glob(path5) + glob.glob(path6) + glob.glob(path7) + glob.glob(path8)

# get name of the first file
if (len(files) == 0):
  print("=== ERROR: THERE ARE NO AUDIO FILES IN THE REMOTE SERVER ===")
else:
  wavfile = ""
  annotfile = ""
  for i in range(0,1): wavfile = files[i].replace("/content/","")
  print("File to split: " + wavfile)
  print("Annotation file: " + annotfile)

fileExtensionOrig = wavfile[-3:]
fileExtension = fileExtensionOrig.lower()
origFilename = wavfile


#=============================================================
# Convert file from mp3/mp4 to wav
#=============================================================

if (fileExtension == "mp3" or fileExtension == "mp4"):
  wavfile = origFilename.replace(origFilename[-3:],"") + "wav"
  !ffmpeg -i $origFilename -acodec pcm_u8 $wavfile
  print(wavfile)


#=============================================================
# Downgrade WAV file to the right ASR format (e.g. 16K)
#=============================================================

!ffmpeg -y -i $wavfile -ac 1 -ar 16000 temp-$wavfile
!rm $wavfile
!mv temp-$wavfile $wavfile

#=============================================================
# Find voice regions
#=============================================================

wav = read_audio(wavfile)
speech_timestamps = get_speech_timestamps(
  wav,
  vadmodel,
  return_seconds=True,  # Return speech timestamps in seconds (default is samples)
)

#=============================================================
# Write voice regions into a file
#=============================================================

output = ""

linePrefix = "tiername\tspeakername\t"

for t in speech_timestamps:
  #print(t['start'])
  output += linePrefix + str(t['start']) + "\t" + str(t['end']) + "\t \n"
output = output[:-1]

with open("voice-regions.txt", "w") as file: file.write(output)

#=============================================================
# Make sure there aren't any regions that are longer
# than the memory limit
#=============================================================

leaveOnlyLastThreeColsOfTSV("voice-regions.txt","temp-regions.txt")
fixIntervals(wavfile, "temp-regions.txt", maxWavDuration)

In [None]:
timeStart, timeEnd = extractStartEndTimes('temp-regions.txt')

In [None]:
#============================================================
# Separate the big file into smaller wav files
#============================================================

%mkdir wavs

filenames = []
print(len(timeStart))


points = []
for start, end in zip(timeStart, timeEnd):
  i = len(points)-1
  #tempName = "out" + addZerosToNumber(i+1,len(timeStart)) + ".wav"
  #tempName2 = "out" + addZerosToNumber(i+2,len(timeStart)) + ".wav"
  if points and points[-1] == start:
    points.append(end)
    #filenames.append(tempName)
  else:
    points.append(start)
    points.append(end)
    #filenames.append(tempName2)

pointSeq = ','.join(str(t) for t in points)
inFileName = wavfile

zeros = countDigits(len(timeStart))
outFileName = "wavs/out" + f'-%0{zeros}d.wav'
!ffmpeg -y -i "$inFileName" -f segment -ac 1 -ar 16000 -async 1 -segment_times $pointSeq $outFileName

In [None]:
outwavNames = []

for i in range(0,len(timeStart)):
  #print(timeStart[i])
  #print(timeEnd[i])
  segmentIndex = findSegmentIndex(pointSeq, timeStart[i], timeEnd[i])
  outwavNames.append("/content/wavs/out-" + addZerosInFrontOfNumber(str(segmentIndex),len(timeStart)) + ".wav")

print(outwavNames)

In [None]:
#outputName = "tempWav\tstart\tend\n"
outputName = ""
transcribeFile = "path,sentence\n"

for i in range(0,len(outwavNames)):
  outputName += str(timeStart[i]) + "\t" + str(timeEnd[i]) + "\t" + outwavNames[i] + "\n"
  transcribeFile += outwavNames[i] + ", \n"
outputName = outputName[:-1]
transcribeFile = transcribeFile[:-1]

with open("sample.csv", 'w', encoding='utf-8') as f: f.write(transcribeFile)
with open("files-and-times.txt", 'w', encoding='utf-8') as f: f.write(outputName)

## (4) Install decoding packages

This should take about 2 minutes. You only need to run it once per session.

In [None]:
!pip install datasets==2.15.0
!pip install transformers==4.28.0
!pip install pandas==1.5.3
!pip install pyctcdecode==0.3.0
!pip install librosa==0.11.0
!pip install jiwer==3.1.0
!pip install https://github.com/kpu/kenlm/archive/master.zip

In [None]:
from datetime import datetime
currentDateAndTime = datetime.now()
startTime = str(currentDateAndTime)

import os

import transformers
import datasets
import torch

import numpy
import numpy as np

import pandas
import pandas as pd

from datasets import load_dataset, load_metric
from datasets import Dataset
from datasets import ClassLabel
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

from transformers import Wav2Vec2ForCTC
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import TrainingArguments
from transformers import Trainer
from transformers import Wav2Vec2ProcessorWithLM

import pyctcdecode
from pyctcdecode import build_ctcdecoder

import random
import re
import json

import torchaudio
import librosa

from jiwer import wer
import statistics

from multiprocessing import get_context
import kenlm

## (5) Decode the file and transcribe it

In [None]:
#============================================================================
# Load CSV files and prepare dataset
#============================================================================

dataTest = pd.read_csv("sample.csv")

dataTest.head()

common_voice_test = Dataset.from_pandas(dataTest)
common_voice_test_transcription = Dataset.from_pandas(dataTest)

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\â€œ\%\â€˜\â€\ï¿½]'

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
    return batch

common_voice_test = common_voice_test.map(remove_special_characters)

def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

#============================================================================
# Load model
#============================================================================

model = Wav2Vec2ForCTC.from_pretrained(pathCheckpoint).to(typeProcessor)
processor = Wav2Vec2Processor.from_pretrained(modelPath)

vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

tempDecoder = build_ctcdecoder(
	labels=list(sorted_vocab_dict.keys()),
	kenlm_model_path=filenameCorrectKenlmModel,
)

processor_with_lm = Wav2Vec2ProcessorWithLM(
	feature_extractor=processor.feature_extractor,
	tokenizer=processor.tokenizer,
	decoder=tempDecoder
)

#============================================================================
# Preprocess data
#============================================================================

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = speech_array[0].numpy()
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["sentence"]
    return batch

common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)

def resample(batch):
    batch["speech"] = librosa.resample(np.asarray(batch["speech"]), 48_000, 16_000)
    batch["sampling_rate"] = 16_000
    return batch

def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True).input_values

    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)

#===============================================================================
# Transcribe files
#===============================================================================

prediction = []
predictionLM = []
reference = []
paths = []

for i in range(0,len(common_voice_test)):

	input_dict = processor_with_lm(common_voice_test[i]["input_values"], sampling_rate=16_000, return_tensors="pt", padding=True)
	logits = model(input_dict.input_values.to(typeProcessor)).logits
	pred_ids = torch.argmax(logits, dim=-1)[0]

	#print("Prediction:")
	with get_context("fork").Pool(processes=4) as pool:
		predictionlm = processor_with_lm.batch_decode(logits.cpu().detach().numpy(), pool).text[0]
	predictionLM.append(predictionlm)
	print(predictionlm)
	#if (i == 0): print("LM Prediction: " + str(predictionlm))
	#print("LM Prediction: " + str(predictionlm))

	path = common_voice_test_transcription[i]["path"]
	path = path.split("/")
	path = path[-1]
	paths.append(path)

#===============================================================================
# Save transcriptions
#===============================================================================

tsvTimes = "files-and-times.txt"

linesTSV = []

with open(tsvTimes, 'r') as file:
	# Read each line and add it to the list
	linesTSV = file.readlines()
linesTSV = [line.strip() for line in linesTSV]

outputTranscriptions = "start\tend\ttranscription\n"

for l in linesTSV:

	l = l.split("\t")
	#print(l[2])
	tsvWaveChunk = os.path.basename(l[2])

	for i in range(0,len(predictionLM)):
		#print(paths[i])
		if (tsvWaveChunk == paths[i]):
			outputTranscriptions = outputTranscriptions + l[0] + "\t" + l[1] + "\t" + predictionLM[i] + "\n"

outputTranscriptions = outputTranscriptions[:-1]

tsvOutputputFilename = os.path.basename(wavfile.replace(".wav",".tsv"))

file = open(tsvOutputputFilename, 'w')
file.write(outputTranscriptions)
file.close()

print("\nThe results are stored in: \n" + tsvOutputputFilename)