In [None]:
import json
import os
from typing import List, Tuple

import numpy as np
import torch
import whisperx

device = 'cuda'
compute_type = 'float16'

model_dir = 'whisper_large-v3-turbo'
model = whisperx.load_model(
    'large-v3-turbo', device, compute_type=compute_type, download_root=model_dir
)

vad_model, vad_utils = torch.hub.load(
    repo_or_dir='snakers4/silero-vad',
    model='silero_vad',
    trust_repo=True,
    source='github',
)  # type: ignore
(get_speech_timestamps, _, read_audio, _, collect_chunks) = vad_utils


In [None]:
from whisperx.audio import N_SAMPLES, log_mel_spectrogram


def chunk_audio(
    wav: torch.Tensor,
    sr: int = 16000,
    silence_ms: int = 1000,  # ≥ 1 s = split
) -> List[Tuple[int, int]]:
    speech_ts = get_speech_timestamps(
        wav,
        vad_model,
        sampling_rate=sr,
        min_silence_duration_ms=silence_ms,
        speech_pad_ms=0,
    )

    chunks = [(s['start'], s['end']) for s in speech_ts]

    return chunks


def filter_non_ro(path: str):
    sr = 16000
    SHORT_CHUNK_SEC = 2.0
    NON_RO_PROB_THRESHOLD = 0.5

    wav = whisperx.load_audio(path, sr=sr)
    chunks = chunk_audio(wav)

    mask = np.ones(wav.shape[0], dtype=bool)

    for start, end in chunks:
        audio_chunk = wav[start:end]
        duration_sec = (end - start) / sr

        with torch.no_grad():
            model_n_mels = model.model.feat_kwargs.get('feature_size')
            segment = log_mel_spectrogram(
                audio_chunk[:N_SAMPLES],
                n_mels=model_n_mels if model_n_mels is not None else 80,
                padding=0
                if audio_chunk.shape[0] >= N_SAMPLES
                else N_SAMPLES - audio_chunk.shape[0],
            )
            encoder_output = model.model.encode(segment)
            results = model.model.model.detect_language(encoder_output)
            lang_token, lang_prob = results[0][0]
            lang = lang_token[2:-2]

        if duration_sec <= SHORT_CHUNK_SEC:
            # For very short chunks, only drop if confidently non-RO
            if lang != 'ro' and lang_prob >= NON_RO_PROB_THRESHOLD:
                mask[start:end] = False
        else:
            if lang != 'ro':
                mask[start:end] = False

    mask_path = os.path.splitext(path)[0] + '.mask'
    mask = np.packbits(mask)
    mask.tofile(mask_path)


def transcribe(path: str, batch_size: int = 4):
    sr = 16000
    audio = whisperx.load_audio(path, sr=sr)

    mask_path = os.path.splitext(path)[0] + '.mask'
    mask = np.fromfile(mask_path, dtype=np.uint8)
    mask = np.unpackbits(mask)[: audio.shape[0]].astype(bool)
    # audio = audio[mask]
    audio[~mask] = 0.0

    result = model.transcribe(audio, batch_size=batch_size, language='ro')
    out_path = os.path.splitext(path)[0] + '.json'
    with open(out_path, 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=2)

In [None]:
for root, _, files in os.walk('.'):
    for file in files:
        if file.endswith('.seg.mp4') or file.endswith('.seg.mkv'):
            video_path = os.path.join(root, file)
            filter_non_ro(video_path)
            transcribe(video_path)

# Align


In [None]:
import gc

import torch

del model

gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

torch.cuda.synchronize() if torch.cuda.is_available() else None

In [None]:
model_a_dir = 'whisperx-align_ro'
model_a, metadata = whisperx.load_align_model(
    language_code='ro', device=device, model_dir=model_a_dir
)

In [None]:
def align(path):
    sr = 16000
    audio = whisperx.load_audio(path, sr=sr)

    mask_path = os.path.splitext(path)[0] + '.mask'
    mask = np.fromfile(mask_path, dtype=np.uint8)
    mask = np.unpackbits(mask)[: audio.shape[0]].astype(bool)
    audio[~mask] = 0.0

    result_path = os.path.splitext(path)[0] + '.json'
    with open(result_path, encoding='utf-8') as f:
        result = json.load(f)

    result = whisperx.align(
        result['segments'],
        model_a,
        metadata,
        audio,
        device,
        return_char_alignments=False,
    )

    out_path = os.path.splitext(path)[0] + '.align.json'
    with open(out_path, 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=2)

In [None]:
for root, _, files in os.walk('.'):
    for file in files:
        if file.endswith('.mask'):
            mask_path = os.path.join(root, file)

            if os.path.exists(mask_path.replace('.mask', '.align.json')):
                print(f'skipping {mask_path}')
                continue

            video_path = mask_path.replace('.mask', '.mp4')
            if not os.path.exists(video_path):
                video_path = mask_path.replace('.mask', '.mkv')

            align(video_path)
            print(mask_path)

## Filter hallucinations


In [None]:
import json  # noqa: F811
import os
import re

from transformers.models.t5 import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained(
    'dumitrescustefan/mt5-base-romanian', legacy=False
)


def is_hallucination(text, min_repeats=3, dominance_threshold=0.5, max_phrase_len=64):
    tokens = re.findall(r'\b\w+\b', text.lower())
    total_tokens = len(tokens)

    for phrase_len in range(1, max_phrase_len + 1):
        i = 0
        while i <= total_tokens - phrase_len * min_repeats:
            # Get the current phrase
            phrase = tokens[i : i + phrase_len]
            count = 1

            # Check how many times it repeats consecutively
            while True:
                start = i + count * phrase_len
                end = start + phrase_len
                if end > total_tokens:
                    break
                if tokens[start:end] == phrase:
                    count += 1
                else:
                    break

            # If repeated enough times, check dominance
            if count >= min_repeats:
                total_phrase_tokens = count * phrase_len
                proportion = total_phrase_tokens / total_tokens
                if proportion >= dominance_threshold:
                    return True  # hallucination

                i += total_phrase_tokens  # skip ahead
            else:
                i += 1

    return False


def filter_garbage(segments):
    for segment in segments:
        start = segment['start']
        end = segment['end']

        if len(segment['words']) == 0:
            continue

        duration = end - start
        text = (
            segment['text']
            .strip()
            .replace('ţ', 'ț')
            .replace('ş', 'ș')
            .replace('Ţ', 'Ț')
            .replace('Ş', 'Ș')
        )
        word_duration = duration / len(segment['words'])

        if duration < 2:
            continue

        if word_duration < 0.1:
            continue

        if word_duration > 2:
            continue

        if is_hallucination(text):
            continue

        segment['tokens'] = tokenizer(
            text,
            truncation=True,
            max_length=256,
        ).data


for root, _, files in os.walk('.'):
    for file in files:
        if file.endswith('.align.json'):
            path = os.path.join(root, file)

            with open(path, encoding='utf-8') as f:
                js = json.load(f)
                segments = js['segments']

            filter_garbage(segments)

            with open(
                path.replace('.align.json', '.align.tok.json'), 'w', encoding='utf-8'
            ) as f:
                json.dump(js, f, ensure_ascii=False, indent=2)

            print(path.replace('.align.json', '.align.tok.json'))

./digi/scraped/2019/12/05/segment_0.align.tok.json
./digi/scraped/2019/12/05/segment_1.align.tok.json
./digi/scraped/2019/12/10/segment_0.align.tok.json
./digi/scraped/2019/12/03/segment_0.align.tok.json
./digi/scraped/2019/12/23/segment_0.align.tok.json
./digi/scraped/2019/12/13/segment_0.align.tok.json
./digi/scraped/2019/12/13/segment_2.align.tok.json
./digi/scraped/2019/12/13/segment_1.align.tok.json
./digi/scraped/2019/12/18/segment_0.align.tok.json
./digi/scraped/2019/12/20/segment_0.align.tok.json
./digi/scraped/2019/12/20/segment_2.align.tok.json
./digi/scraped/2019/12/20/segment_1.align.tok.json
./digi/scraped/2019/12/16/segment_0.align.tok.json
./digi/scraped/2019/12/16/segment_1.align.tok.json
./digi/scraped/2019/12/11/segment_0.align.tok.json
./digi/scraped/2019/12/27/segment_0.align.tok.json
./digi/scraped/2019/12/27/segment_1.align.tok.json
./digi/scraped/2019/12/02/segment_0.align.tok.json
./digi/scraped/2019/12/02/segment_2.align.tok.json
./digi/scraped/2019/12/02/segme

In [7]:
import json
import os

sum = 0

for root, _, files in os.walk('.'):
    for file in files:
        if not file.endswith('.align.tok.json'):
            continue

        path = os.path.join(root, file)

        with open(path, encoding='utf-8') as f:
            js = json.load(f)

        for segment in js.get('segments', []):
            if 'tokens' not in segment or not segment['tokens']:
                continue

            start = segment['start']
            end = segment['end']

            duration = end - start
            sum += duration

print(sum / 60 / 60)

1222.6621613889363
