# Data prep

In [2]:
import re
import requests
import string
import os
import subprocess
import multiprocessing as mp
import math
import json
import pathlib
from pathlib import Path
from tqdm import tqdm
from dataclasses import dataclass
from typing import Any, Dict
from pathlib import Path
from itertools import compress
from datetime import timedelta
from collections import Counter


SAMPLING_RATE = 16000

# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

import datasets
import torch
import jiwer
import numpy as np
import pandas as pd
# from deepcut import tokenize  # Consume too much memory when using with CUDA
from pythainlp.tokenize import word_tokenize as tokenize
from sklearn.model_selection import train_test_split
from transformers import WhisperProcessor, pipeline, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperForConditionalGeneration
from transformers.pipelines.pt_utils import KeyDataset
from datasets import Dataset, load_from_disk
from datasets.features import Audio

  from .autonotebook import tqdm as notebook_tqdm
2023-11-24 12:03:24.012861: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-24 12:03:24.016285: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-24 12:03:24.064177: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-24 12:03:24.064205: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-24 12:03:24.065540: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515

In [3]:
def print_gpu_info():
    gpu_info = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, text=True)
    if gpu_info.stdout.find('failed') >= 0:
        print('Not connected to a GPU')
    else:
        print(gpu_info.stdout)
        
print_gpu_info()

Fri Nov 24 12:03:26 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  On   | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    45W / 400W |      3MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

*The common voice data can be downloaded from this command*  
```
!wget https://storage.googleapis.com/common-voice-prod-prod-datasets/cv-corpus-15.0-2023-09-08/cv-corpus-15.0-2023-09-08-th.tar.gz\?X-Goog-Algorithm\=GOOG4-RSA-SHA256\&X-Goog-Credential\=gke-prod%40moz-fx-common-voice-prod.iam.gserviceaccount.com%2F20231120%2Fauto%2Fstorage%2Fgoog4_request\&X-Goog-Date\=20231120T211937Z\&X-Goog-Expires\=43200\&X-Goog-SignedHeaders\=host\&X-Goog-Signature\=7b63c1ccdb27c7a2f2b1b5e59422ab38668543f242283238b92d39552aa12a2686ba413b29107e71c8fa75d850decf8d5f9e1f5c0f6b72da42154cf478ebe296f8445d1744267a3ad40391433517c9ad8735b26cfe5c53e777feffac2a71d54ee7ce47cb1c580449340a84d066271a57a2beba416de0d7e897ad7bd99f13e68e0d8a1a2cc1c2dbf2341740fd167e1d6572d84b23c9daee4139dd35cc8f827db052a05021ca1c25549baa18c823ed1c25347cd10972451718ac13c73b656bbc69134ebbcce7206ad38c6e3611ac59881e8a630abbdf7390b689bb74d7fe35cb80366742d76cf5a6eb462e6da408dd2bb05a97cd8b89a4110479d62f9dc6f84c4e
```

In [4]:
# Device config
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Data config
AUDIO_SAMPLING_RATE = 16_000
MODEL_PATH_OR_URL = "20231124-model-backup/checkpoint-1000"


# model config
CHUNK_LENGTH = 20
NUM_BEAMS = 2
BATCH_SIZE = 8

# resource config
PRJ_ROOT = Path.cwd().parents[2]
DATA_PATH = PRJ_ROOT / "notebooks" / "whisper-v3" / "data"
VAD_CACHE_PATH = Path.cwd() / "vad-caches"

In [5]:
target_videos = sorted(list((DATA_PATH / "Opp Day SRTs" / "Original Videos").rglob("*.mp4")))

In [6]:
target_videos[0]

PosixPath('/home/jupyter/set-speechtotext-poc/notebooks/whisper-v3/data/Opp Day SRTs/Original Videos/[TH] Oppday Q2_2023 IP บมจ. อินเตอร์ ฟาร์มา.mp4')

In [7]:
def remove_punct(s: str) -> str:
    return re.sub(rf"[{re.escape(string.punctuation)}]", "", s)

remove_punct("ไหน?ลองซิ! ...")

'ไหนลองซิ '

In [8]:
def en_preprocess(text: str) -> str:
    text = re.sub("[QA]:|(uh)","", text)
    return text

def th_preprocess(text: str) -> str:
    text = re.sub("((ค่ะ)|((นะ)?ครับ)|(นะคะ)|เอ่อ)+", "", text)
    # text = thaispellcheck.check(text, autocorrect=True)
    return text

def pred_postprocess(listtext: list[str]) -> list[str]:

    if " " in listtext:
        listtext.remove(" ")
        
    clean_text = remove_longest_repeating_words(listtext)
    return clean_text

def tokenize_text(text: str, pred = False) -> list:
    if not isinstance(text, str):
        return []
    try:
        lang = detect(text)
    except:
        lang = 'en'
    if lang != 'en':
        text = th_preprocess(text)
        splited_text = deepcut.tokenize(text)
    else:
        eng_normalizer = EnglishTextNormalizer()
        text = eng_normalizer(text)
        text = en_preprocess(text)
        splited_text = text.split(" ")
        
    return pred_postprocess(splited_text) if pred else splited_text

def perform_vad(src: pathlib.Path, dest_dir: pathlib.Path, use_cache=True) -> tuple:
    """
    Perform Voice Activity Detection (VAD) on the given audio source file
    and save resulting audio chunks to the destination directory.

    Parameters:
    - src (str): Path to the source audio file.
    - dest_dir (str): Path to the destination directory where chunks will be saved.

    Returns:
    - tuple: Tuple containing wav data and chunklist information.
    """

    # Conditionally create the folder if it doesn't exist
    dest_dir.mkdir(exist_ok=True)
    
    # use cache
    vad_pth = (dest_dir / src.stem)
    tempchunk = vad_pth / (src.stem + "_chunk.json")
    if vad_pth.exists() and use_cache:
        with open(tempchunk, "r", encoding="utf-8") as fp:
            chunklist = json.load(fp)
        return chunklist
    else:
        vad_pth.mkdir(exist_ok=True)
    
    # Try to load the VAD model and utilities
    try:
        smodel, utils = torch.hub.load(
            repo_or_dir="snakers4/silero-vad",
            model="silero_vad",
            onnx=False,
        )
        (
            get_speech_timestamps,
            save_audio,
            read_audio,
            VADIterator,
            collect_chunks,
        ) = utils
    except Exception as e:
        raise RuntimeError(f"Failed to load silero-vad model and utilities. Error: {e}")

    # Try to read the audio
    try:
        wav = read_audio(src, sampling_rate=SAMPLING_RATE)
    except Exception as e:
        raise RuntimeError(f"Failed to read audio from {src}. Error: {e}")

    # Get speech timestamps
    st = get_speech_timestamps(
        wav,
        smodel,
        threshold=0.65,
        sampling_rate=SAMPLING_RATE,
        min_speech_duration_ms=500,
        min_silence_duration_ms=100,
        window_size_samples=1536,
        return_seconds=False,
    )
    

    total_samples = list(wav.size())[0]
    chunklist = []

    # Process the speech timestamps
    for i, s in enumerate(st):
        fname = vad_pth / (src.stem + f"_{i:05d}.wav")

        start = s["start"] - int(120 * SAMPLING_RATE / 1000)
        start = max(start, 0)

        end = s["end"] + int(60 * SAMPLING_RATE / 1000)
        end = min(end, total_samples - 1)

        chunklist.append(
            {"start": start, "end": end, "idx": i, "text": "", "fname": str(fname)}
        )

    # Save the chunk information to a json file
    with open(tempchunk, "w", encoding="utf-8") as fp:
        json.dump(chunklist, fp)

    # Save audio chunks to the destination directory
    for c in chunklist:
        save_audio(
            c["fname"],
            collect_chunks([c], wav),
            sampling_rate=SAMPLING_RATE,
        )

    return chunklist

In [9]:
def video_transcribe(
    fp: Path, pipe, vad_cache_dir=VAD_CACHE_PATH, use_cache=True,
    language="th",
) -> list[str]:
    chunks = perform_vad(fp, vad_cache_dir, use_cache)
    fps = [c["fname"] for c in sorted(chunks, key=lambda c: c["idx"])]
    preds = pipe(fps, generate_kwargs={"language": language, "task":"transcribe"}, batch_size=BATCH_SIZE)
    return [p["text"] for p in preds]

In [10]:
chunks = perform_vad(target_videos[0], VAD_CACHE_PATH)

In [11]:
chunks[0]

{'start': 201888,
 'end': 219552,
 'idx': 0,
 'text': '',
 'fname': '/home/jupyter/set-speechtotext-poc/notebooks/whisper-v3/local-version/vad-caches/[TH] Oppday Q2_2023 IP บมจ. อินเตอร์ ฟาร์มา/[TH] Oppday Q2_2023 IP บมจ. อินเตอร์ ฟาร์มา_00000.wav'}

## Preprocessor prep

In [12]:
processor = WhisperProcessor.from_pretrained(
    "openai/whisper-large-v3",
    language="Thai",
    task="transcribe",
  )

model = WhisperForConditionalGeneration.from_pretrained(
    MODEL_PATH_OR_URL,
    # torch_dtype=torch_dtype,
    num_beams=NUM_BEAMS,
)

# model.config.forced_decoder_ids = None
# model.config.suppress_tokens = []
# model.enable_input_require_grads()
# model.config.dropout=0.0

pipe = pipeline(
    "automatic-speech-recognition",
    model=model, tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device=device,
    chunk_length_s=CHUNK_LENGTH,
  )

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.16s/it]


In [13]:
txts = video_transcribe(target_videos[0], pipe)

# Metrics

In [14]:
CLEAN_PATTERNS = "((นะ)?(คะ|ครับ)|เอ่อ|อ่า)"
REMOVE_TOKENS = {"", " "}

def hack_wer(
    hypothesis: str,
    reference: str,
    debug=False,
  ) -> float:
    """
    we will tokenize TH long txt into list of words,
    then concat it back separated by whitespace.
    Then, we will just use normal WER jiwer, to utilize
    C++ implementation.
    """
    refs = tokenize(re.sub(CLEAN_PATTERNS, "", reference))
    hyps = tokenize(re.sub(CLEAN_PATTERNS, "", hypothesis))

    refs = [r for r in refs if r not in REMOVE_TOKENS]
    hyps = [h for h in hyps if h not in REMOVE_TOKENS]

    if debug: print(refs); print(hyps)

    return jiwer.wer(" ".join(refs), " ".join(hyps))


def isd_np(preds: list[str], actuals: list[str], debug=True) -> int:
    dp = np.array([np.arange(len(preds) + 1) for _ in range(len(actuals) + 1)], dtype="int16")

    for row in range(len(dp)):
        for col in range(len(dp[0])):
            if row == 0 or col == 0:
                dp[row][col] = max(row, col)
                continue

            if preds[col - 1] != actuals[row - 1]:
                dp[row][col] = min(dp[row - 1][col], dp[row][col - 1], dp[row - 1][col - 1]) + 1
            else:
                dp[row][col] = min(dp[row - 1][col], dp[row][col - 1], dp[row - 1][col - 1])

    if debug: print(*dp, sep="\n")

    return dp[-1][-1]


def wer(pred: str, actual: str, **kwargs) -> float:
    refs = tokenize(re.sub(CLEAN_PATTERNS, "", actual))
    hyps = tokenize(re.sub(CLEAN_PATTERNS, "", pred))

    actuals = [r for r in refs if r not in REMOVE_TOKENS]
    preds = [h for h in hyps if h not in REMOVE_TOKENS]
    if kwargs["debug"]: print(f"{preds}\n{actuals}")
    err = isd_np(preds, actuals, **kwargs)
    return err / len(actuals)

In [15]:
print(hack_wer("สวัสดีครับอิอิ ผมไม่เด็กแล้วนะครับ จริงๆนะ", "สวัสดีครับอุอุ ผมโตแล้วครับ จริงๆนะ", debug=True))
print(wer("สวัสดีครับอิอิ ผมไม่เด็กแล้วนะครับ จริงๆนะ", "สวัสดีครับอุอุ ผมโตแล้วครับ จริงๆนะ", debug=True))

['สวัสดี', 'อุ', 'อุ', 'ผม', 'โต', 'แล้ว', 'จริงๆ', 'นะ']
['สวัสดี', 'อิอิ', 'ผม', 'ไม่', 'เด็ก', 'แล้ว', 'จริงๆ', 'นะ']
0.5
['สวัสดี', 'อิอิ', 'ผม', 'ไม่', 'เด็ก', 'แล้ว', 'จริงๆ', 'นะ']
['สวัสดี', 'อุ', 'อุ', 'ผม', 'โต', 'แล้ว', 'จริงๆ', 'นะ']
[0 1 2 3 4 5 6 7 8]
[1 0 1 2 3 4 5 6 7]
[2 1 1 2 3 4 5 6 7]
[3 2 2 2 3 4 5 6 7]
[4 3 3 2 3 4 5 6 7]
[5 4 4 3 3 4 5 6 7]
[6 5 5 4 4 4 4 5 6]
[7 6 6 5 5 5 5 4 5]
[8 7 7 6 6 6 6 5 4]
0.5


In [16]:
with open("oppday-finetune-preds.txt", "w") as f:
    f.write("\n".join(txts))

# Inference

In [24]:
# preds = pipe([str(p) for p in common_voice_test.full_path.tolist()], generate_kwargs={"language":"<|th|>", "task":"transcribe"}, batch_size=64)

# Eval