In [None]:
import os
import re
import logging
import torch
import torchaudio
import tempfile
import librosa
from itertools import islice
from pathlib import Path
import numpy as np
from types import SimpleNamespace
from typing import Optional
from openai import OpenAI

from nemo.collections.asr.models.msdd_models import NeuralDiarizer

from ctc_forced_aligner import (
    generate_emissions,
    get_alignments,
    get_spans,
    load_alignment_model,
    postprocess_results,
    preprocess_text,
)

from utils import (
    create_config,
    process_language_arg,
    get_realigned_ws_mapping_with_punctuation,
    get_sentences_speaker_mapping,
    get_speaker_aware_transcript,
    get_words_speaker_mapping,
    langs_to_iso,
    process_language_arg,
    punct_model_langs,
    write_srt,
)
from deepmultilingualpunctuation import PunctuationModel

# Notebook params (replicates argparse inputs)
AUDIO_PATH = "../data/audio/call_1.mp3"
STREAMING = True                  
SUPPRESS_NUMERALS = False      
MODEL_NAME = "medium"        
BATCH_SIZE = 8                   
LANGUAGE = None                 
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEMP_PATH = "temp_outputs"
ALIGN_SR = 16000

# precision map
mtypes = {"cpu": "int8", "cuda": "float16"}

# normalize/validate language same as script
LANGUAGE = process_language_arg(LANGUAGE, MODEL_NAME)

print(f"device={DEVICE}  model={MODEL_NAME}  batch_size={BATCH_SIZE}  language={LANGUAGE}")

In [None]:
language = "sk"

In [None]:
if STREAMING:
    # Isolate vocals from the rest of the audio
    return_code = os.system(
        f'python -m demucs.separate -n htdemucs --two-stems=vocals "{AUDIO_PATH}" -o temp_outputs'
    )

    if return_code != 0:
        logging.warning(
            "Source splitting failed, using original audio file. "
            "Use stemming=False to skip this step."
        )
        vocal_target = AUDIO_PATH
    else:
        vocal_target = os.path.join(
            "temp_outputs",
            "htdemucs",
            os.path.splitext(os.path.basename(AUDIO_PATH))[0],
            "vocals.wav",
        )
else:
    vocal_target = AUDIO_PATH

print("vocal_target:", vocal_target)


In [None]:
client = OpenAI()

def _as_audio_file(vocal_target) -> str:
    """
    Accepts either a file path (str) or a NumPy float waveform (mono).
    Returns a filesystem path to a temporary WAV if needed.
    """
    if isinstance(vocal_target, str) and os.path.exists(vocal_target):
        return vocal_target

    # If it's a numpy array, write a temp WAV (16 kHz mono PCM16)
    if isinstance(vocal_target, np.ndarray):
        import soundfile as sf  # pip install soundfile
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
        # assume float32/-1..1; resample if you need a specific rate
        sf.write(tmp.name, vocal_target, 16000, subtype="PCM_16")
        tmp.close()
        return tmp.name

    raise ValueError("vocal_target must be a path or a NumPy waveform.")

audio_path = _as_audio_file(vocal_target)

with open(audio_path, "rb") as f:
    resp = client.audio.transcriptions.create(
        model="gpt-4o-transcribe",   # or "gpt-4o-mini-transcribe"
        file=f,
        language=language or None,   # omit to auto-detect
    )

full_transcript = getattr(resp, "text", "") or ""
detected_language: Optional[str] = getattr(resp, "language", None)

print("Detected language:", detected_language or "(unknown)")
print("Transcript preview:", full_transcript[:200])

In [None]:
audio_waveform, sr = librosa.load(audio_path, sr=ALIGN_SR, mono=True)
audio_waveform = audio_waveform.astype(np.float32)  

info = SimpleNamespace(language=(detected_language or language or "en"))
info.language = info.language.split("-")[0].lower()


In [None]:
dtype = torch.float16 if DEVICE == "cuda" else torch.float32

alignment_model, alignment_tokenizer = load_alignment_model(
    DEVICE, dtype=dtype
)

# safer device resolution than .device on Module
align_dev = getattr(alignment_model, "device", next(alignment_model.parameters()).device)

emissions, stride = generate_emissions(
    alignment_model,
    torch.from_numpy(audio_waveform).to(dtype).to(align_dev),
    batch_size=BATCH_SIZE,
)

del alignment_model
torch.cuda.empty_cache()

tokens_starred, text_starred = preprocess_text(
    full_transcript,
    romanize=True,
    language=langs_to_iso[info.language],
)

segments, scores, blank_token = get_alignments(
    emissions,
    tokens_starred,
    alignment_tokenizer,
)

spans = get_spans(tokens_starred, segments, blank_token)

word_timestamps = postprocess_results(text_starred, spans, stride, scores)
print("Aligned words:", len(word_timestamps))

In [None]:
os.makedirs(TEMP_PATH, exist_ok=True)

mono_wav = os.path.join(TEMP_PATH, "mono_file.wav")
if not os.path.exists(mono_wav):
    torchaudio.save(mono_wav, torch.from_numpy(audio_waveform).unsqueeze(0).float(), 16000)

cfg = create_config(TEMP_PATH)
msdd = NeuralDiarizer(cfg=cfg).to(DEVICE)
msdd.diarize()
del msdd; torch.cuda.empty_cache()


In [None]:
assert "word_timestamps" in globals(), "word_timestamps not defined (run alignment cell)."
assert "info" in globals(), "info (from transcription) not defined."
assert "audio_waveform" in globals(), "audio_waveform not defined."
assert "audio_path" in globals(), "audio_path not set."

# Ensure temp dir + mono wav exist
ROOT = os.getcwd()
temp_path = os.path.join(ROOT, "temp_outputs")
mono_wav = os.path.join(temp_path, "mono_file.wav")
assert os.path.exists(mono_wav), f"Expected mono wav at {mono_wav}"

# Parse RTTM to speaker_ts
rttm_path = os.path.join(temp_path, "pred_rttms", "mono_file.rttm")
if not os.path.exists(rttm_path):
    raise FileNotFoundError(f"RTTM not found at {rttm_path}. Re-run diarization and check create_config(out_dir).")

speaker_ts = []
with open(rttm_path, "r", encoding="utf-8") as f:
    for line in f:
        if not line.strip() or line.startswith("#"):
            continue
        # Keep indices exactly like your script:
        line_list = line.split(" ")
        s = int(float(line_list[5]) * 1000)
        e = s + int(float(line_list[8]) * 1000)
        speaker_id = int(line_list[11].split("_")[-1])
        if e > s:
            speaker_ts.append([s, e, speaker_id])

wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")

def chunked(seq, n):
    it = iter(seq)
    while True:
        chunk = list(islice(it, n))
        if not chunk:
            break
        yield chunk

if info.language in punct_model_langs:
    punct_model = PunctuationModel(model="kredor/punctuate-all")
    words_list = [x["word"] for x in wsm]

    labeled_words = []
    for chunk in chunked(words_list, 220):
        labeled_words.extend(punct_model.predict(chunk))  # returns (token, punct, score)

    ending_puncts = ".?!"
    model_puncts = ".,;:!?"
    is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)

    for word_dict, labeled in zip(wsm, labeled_words):
        word = word_dict["word"]
        punct = labeled[1] if isinstance(labeled, (list, tuple)) and len(labeled) > 1 else ""
        if (
            word
            and punct in ending_puncts
            and (word[-1] not in model_puncts or is_acronym(word))
        ):
            word += punct
            if word.endswith(".."):
                word = word.rstrip(".")
            word_dict["word"] = word
else:
    logging.warning(
        f"Punctuation restoration is not available for {info.language} language. Using the original punctuation."
    )

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)

base = os.path.splitext(audio_path)[0]
txt_out = Path(f"{base}.txt")
srt_out = Path(f"{base}.srt")

with open(txt_out, "w", encoding="utf-8-sig") as f:
    get_speaker_aware_transcript(ssm, f)

with open(srt_out, "w", encoding="utf-8-sig") as srt:
    write_srt(ssm, srt)

print(f"Done\n Transcript: {txt_out}\n - Subtitles:  {srt_out}")