### Setup

#### colab setup

In [None]:
%%bash
pip install -q \
    torch>=2.9.1 \
    torchaudio>=2.9.1 \
    transformers>=4.57.0 \

In [1]:
# check cuda version colab and set as variable
import torch
print(torch.cuda.is_available())
print(torch.version.cuda)
cuda_version = torch.version.cuda


True
12.8


In [None]:
# colab only setup uncomment to install dependencies
%%bash
cuda_version=$(python -c "import torch; print(torch.version.cuda)")
pip install -q \
    k2==1.24.4.dev20251118+cuda${cuda_version}.torch2.9.1 -f https://k2-fsa.github.io/k2/cuda.html \
    soundfile==0.31.1 \
    librosa==0.11.0 \
    pyarabic==0.6.10

#### Imports

In [5]:
import torch
import k2
import soundfile as sf
import librosa

In [6]:
import k2.version
k2.version.version.main()

Collecting environment information...

k2 version: 1.24.4
Build type: Release
Git SHA1: 30c3039fbe89f245d5dba3c47e99abc3a638275f
Git date: Tue Nov 18 07:41:31 2025
Cuda used to build k2: 12.8
cuDNN used to build k2: 
Python version used to build k2: 3.12
OS used to build k2: AlmaLinux release 8.10 (Cerulean Leopard)
CMake version: 4.1.2
GCC version: 13.3.1
CMAKE_CUDA_FLAGS: -Wno-deprecated-gpu-targets -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -line

## English Exp

In [5]:
import torch, k2, soundfile as sf, numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCTC


model_name = "facebook/wav2vec2-base-960h"
processor  = Wav2Vec2Processor.from_pretrained(model_name)
model      = AutoModelForCTC.from_pretrained(model_name).cuda().eval()

vocab = list(processor.tokenizer.get_vocab().keys())
id2tok = {v: k for k, v in processor.tokenizer.get_vocab().items()}
blank_id = processor.tokenizer.pad_token_id

if blank_id is None:
        blank_id = processor.tokenizer.word_delimiter_token_id

vocab_size = len(vocab)
print("Model loaded:", model_name, "vocab size:", vocab_size)

def get_log_probs(path):
    wav, sr = sf.read(path)
    if sr != 16000:
        import torchaudio
        wav = torchaudio.functional.resample(torch.tensor(wav).float(), sr, 16000).numpy()
        sr = 16000
    inputs = processor(wav, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.cuda()).logits[0]
    log_probs = torch.log_softmax(logits, dim=-1).cpu()
    return log_probs

def get_logits(audio_path):
    """Load audio and get model logits."""
    wav, sr = sf.read(audio_path)
    
    # Resample if needed
    if sr != 16000:
        import librosa
        wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
        sr = 16000
    
    # Handle stereo
    if len(wav.shape) > 1:
        wav = wav[:, 0]
    
    # Get logits
    inputs = processor(wav, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(inputs.input_values.cuda()).logits[0].cpu().numpy()
    
    return logits

  from .autonotebook import tqdm as notebook_tqdm
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded: facebook/wav2vec2-base-960h vocab size: 32


In [6]:
import torch, k2, soundfile as sf, numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCTC


model_name = "facebook/wav2vec2-base-960h"
processor  = Wav2Vec2Processor.from_pretrained(model_name)
model      = AutoModelForCTC.from_pretrained(model_name).cuda().eval()

vocab = list(processor.tokenizer.get_vocab().keys())
id2tok = {v: k for k, v in processor.tokenizer.get_vocab().items()}
blank_id = processor.tokenizer.pad_token_id

def build_pattern_fsa(pattern, token2id, wildcard_ids):
    """
    pattern: list of characters, '.' for wildcard
    token2id: dict mapping from char -> token id
    wildcard_ids: allowed token ids for wildcard positions
    """
    arcs = []
    state = 0
    for i, ch in enumerate(pattern):
        if ch == '.':
            for wid in wildcard_ids:
                arcs.append(f"{state} {state+1} {wid} {wid} 0.0")
        else:
            if ch not in token2id:
                continue
            tid = token2id[ch]
            arcs.append(f"{state} {state+1} {tid} {tid} 0.0")
        state += 1
    arcs.append(f"{state} 0.0")
    txt = "\n".join(arcs)
    fsa = k2.Fsa.from_str(txt, acceptor=False, openfst=True)
    return k2.arc_sort(fsa)

def wildcard_decode_k2_old(log_probs, pattern, wildcard_set):
    """
    log_probs: (T, V) log probabilities
    pattern: list of characters (with '.')
    wildcard_set: list of allowed tokens for '.'
    """
    T, V = log_probs.shape
    dense = k2.DenseFsaVec(log_probs.unsqueeze(0), torch.tensor([[0, 0, T]], dtype=torch.int32))
    ctc_topo = k2.arc_sort(k2.ctc_topo(V-1))
    pattern_fsa = build_pattern_fsa(pattern, processor.tokenizer.get_vocab(), wildcard_set)
    decoding_graph = k2.arc_sort(k2.compose(ctc_topo, pattern_fsa))
    lattice = k2.intersect_dense_pruned(
       decoding_graph, dense,
       search_beam=20.0, output_beam=8.0,
       min_active_states=30, max_active_states=10000)
    # lattice = k2.intersect_dense(
    #     decoding_graph, dense, output_beam=10.0)
    best_path = k2.shortest_path(lattice, use_double_scores=False)
    aux = k2.get_aux_labels(best_path)[0]
    hyp_ids = [x for x in aux if x >= 0]
    return "".join(id2tok[i] for i in hyp_ids)

def wildcard_decode_k2(logits, pattern, wildcard_set):
    """
    logits: (T, V) raw logits from model
    pattern: list of characters (with '.')
    wildcard_set: list of allowed tokens for '.'
    """
    # Convert to torch tensor if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    
    # Convert logits to log probabilities
    log_probs = torch.log_softmax(logits, dim=-1)
    
    T, V = log_probs.shape
    dense = k2.DenseFsaVec(log_probs.unsqueeze(0), torch.tensor([[0, 0, T]], dtype=torch.int32))
    ctc_topo = k2.arc_sort(k2.ctc_topo(V-1))
    pattern_fsa = build_pattern_fsa(pattern, processor.tokenizer.get_vocab(), wildcard_set)
    decoding_graph = k2.arc_sort(k2.compose(ctc_topo, pattern_fsa))
    lattice = k2.intersect_dense_pruned(
       decoding_graph, dense,
       search_beam=20.0, output_beam=8.0,
       min_active_states=30, max_active_states=10000)
    
    best_path = k2.shortest_path(lattice, use_double_scores=False)
    aux = k2.get_aux_labels(best_path)[0]
    hyp_ids = [x for x in aux if x >= 0]
    
    result = "".join(id2tok[i] for i in hyp_ids)
    
    # Replace word delimiter with space
    word_delim = processor.tokenizer.word_delimiter_token
    if word_delim:
        result = result.replace(word_delim, " ")
    
    return result.strip()

#Regular CTC Decode:
def ctc_decode_k2_old(log_probs, search_beam=20.0, output_beam=8.0):
    """
    CTC decoding using k2 with same beam settings (no pattern constraints).
    This is equivalent to your constrained decoding but without the pattern FSA.
    """
    T, V = log_probs.shape
    dense = k2.DenseFsaVec(log_probs.unsqueeze(0), torch.tensor([[0, 0, T]], dtype=torch.int32))
    ctc_topo = k2.arc_sort(k2.ctc_topo(V-1))
    
    # No pattern FSA - just CTC topology
    lattice = k2.intersect_dense_pruned(
        ctc_topo, dense,
        search_beam=search_beam, 
        output_beam=output_beam,
        min_active_states=30, 
        max_active_states=10000
    )
    
    best_path = k2.shortest_path(lattice, use_double_scores=False)
    aux = k2.get_aux_labels(best_path)[0]
    hyp_ids = [x for x in aux if x >= 0]
    return "".join(id2tok[i] for i in hyp_ids)

def ctc_decode_k2(logits, search_beam=20.0, output_beam=8.0):
    """
    CTC decoding using k2 with beam settings (no pattern constraints).
    Takes raw logits as input.
    """
    # Convert to torch tensor if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    
    # Convert logits to log probabilities
    log_probs = torch.log_softmax(logits, dim=-1)
    
    T, V = log_probs.shape
    dense = k2.DenseFsaVec(log_probs.unsqueeze(0), torch.tensor([[0, 0, T]], dtype=torch.int32))
    ctc_topo = k2.arc_sort(k2.ctc_topo(V-1))
    
    # No pattern FSA - just CTC topology
    lattice = k2.intersect_dense_pruned(
        ctc_topo, dense,
        search_beam=search_beam, 
        output_beam=output_beam,
        min_active_states=30, 
        max_active_states=10000
    )
    
    best_path = k2.shortest_path(lattice, use_double_scores=False)
    aux = k2.get_aux_labels(best_path)[0]
    hyp_ids = [x for x in aux if x >= 0]
    
    result = "".join(id2tok[i] for i in hyp_ids)
    
    # Replace word delimiter with space
    word_delim = processor.tokenizer.word_delimiter_token
    if word_delim:
        result = result.replace(word_delim, " ")
    
    return result.strip()

def ctc_decode_greedy(logits):
    """
    Greedy CTC decoding: argmax at each frame, then collapse repeats and remove blanks.
    Takes raw logits as input (not log_probs).
    """
    # Convert to torch tensor if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    
    # Get the most probable token at each frame (argmax on logits)
    greedy_ids = logits.argmax(dim=-1)  # Shape: (T,)
    
    # Collapse repeats and remove blanks
    output = []
    prev_id = None
    
    for token_id in greedy_ids.tolist():
        if token_id == blank_id:
            prev_id = None  # Reset on blank
            continue
        if token_id != prev_id:  # Only add if different from previous
            output.append(token_id)
            prev_id = token_id
    
    result = "".join(id2tok[i] for i in output)
    
    # Replace word delimiter with space
    word_delim = processor.tokenizer.word_delimiter_token
    if word_delim:
        result = result.replace(word_delim, " ")
    
    return result.strip()


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:

#test
audio_path = "samples/test_english.wav"
reference  = "THE BIRCH CANOE SLID ON THE SMOOTH PLANK GLUE THE SHEET TO THE DARK BLUE BACKGROUND IT IS EASY TO TELL THE DEPTH OF THE WELL THESE DAYS A CHICKEN LEG IS A RARE DISH RICE IS OFTEN SERVED IN ROUND BOWLS THE JUSE OF LEMON MAKES FINE PUNCH THE BOX WAS TONL  BESIDE THE PARK TRUNK THE HOX ARE SED CHOPPED CORN AND GARBAGE FOUR HOURS A STEADY WORK FACED US A LARGE SIDE IN STOCKINGS IS HARD TO SELL" 

# Mask all vowels
reference = reference.replace(' ', '|')
pattern_vowels = ["." if ch in "AEIOUaeiou" else ch for ch in reference]
# Every other char
pattern_everyother = [ch if i % 2 == 0 else '.' for i, ch in enumerate(reference)]

log_probs = get_log_probs(audio_path)
logits = get_logits(audio_path)
# wildcard set = all lowercase letters
wildcard_ids = [processor.tokenizer.get_vocab()[ch] for ch in processor.tokenizer.get_vocab().keys()
                if ch.isalpha() and len(ch) == 1]


# Add the space token to wildcard_ids
space_token = '|'
if space_token in processor.tokenizer.get_vocab():
    wildcard_ids.append(processor.tokenizer.get_vocab()[space_token])
    print(f"Added space token '{space_token}' to wildcard set")


    
regular_output = ctc_decode_k2(logits, search_beam=20.0, output_beam=8.0)
greedy_output = ctc_decode_greedy(logits)
hyp_vowels = wildcard_decode_k2(logits, pattern_vowels, wildcard_ids)
hyp_everyother = wildcard_decode_k2(logits, pattern_everyother, wildcard_ids)

print("Reference :", reference)
print("CTC prediction ", regular_output)
print("CTC greedy prediction ", greedy_output)
print("Pattern (vowels masked)     :", "".join(pattern_vowels))
print("Prediction (vowels filled)  :", hyp_vowels)
print("Pattern (every other masked):", "".join(pattern_everyother))
print("Prediction (filled)         :", hyp_everyother)


Added space token '|' to wildcard set
Reference : THE|BIRCH|CANOE|SLID|ON|THE|SMOOTH|PLANK|GLUE|THE|SHEET|TO|THE|DARK|BLUE|BACKGROUND|IT|IS|EASY|TO|TELL|THE|DEPTH|OF|THE|WELL|THESE|DAYS|A|CHICKEN|LEG|IS|A|RARE|DISH|RICE|IS|OFTEN|SERVED|IN|ROUND|BOWLS|THE|JUSE|OF|LEMON|MAKES|FINE|PUNCH|THE|BOX|WAS|TONL||BESIDE|THE|PARK|TRUNK|THE|HOX|ARE|SED|CHOPPED|CORN|AND|GARBAGE|FOUR|HOURS|A|STEADY|WORK|FACED|US|A|LARGE|SIDE|IN|STOCKINGS|IS|HARD|TO|SELL
CTC prediction  THE BIRCH CANOE SLIT ON THE SMOOTH PLANK GLE THE HEE TO THE DARK BLUE BACKGROUND IT IS EASY TO TELL THE DEPTH OF THE WELL THESE DAYS A CICK A MEG IS A RARE DISH RICE IS OXEN SERVED IN ROUND BULL THE JUSE OF LONDONS MAKES FINE PUNCH THE BOX WAS TONL  BESIDE THE PARK TRUK THE HOX ARE SED CHOPPED CORN AND GARBAGE FOUR HOURS A STEADY WORK FACED US E LARGE SIDE AN STOCKINGS IS HARD TO SELL
CTC greedy prediction  THE BIRCH CANOE SLIT ON THE SMOOTH PLANK GLE THE HEE TO THE DARK BLUE BACKGROUND IT IS EASY TO TELL THE DEPTH OF THE WELL THESE DA

## Diacritization exps

### Utils

In [3]:
import torch, k2, soundfile as sf, numpy as np
from jiwer import wer
from pyarabic import araby

diacritics = araby.DIACRITICS
arabic_letters = [ch for ch in araby.LETTERS if ch not in araby.DIACRITICS]
precomposed = ['\u0627', '\u064A', '\u0648']
wildcard_token = '.'


def build_pattern_fsa(pattern, wildcard_ids, token2id):
    """
    pattern: list of characters, '.' for wildcard
    token2id: dict mapping from char -> token id
    wildcard_ids: allowed token ids for wildcard positions
    """
    arcs = []
    state = 0
    for i, ch in enumerate(pattern):
        if ch == '.':
            for wid in wildcard_ids:
                arcs.append(f"{state} {state+1} {wid} {wid} 0.0")
        else:
            if ch not in token2id:
                continue
            tid = token2id[ch]
            arcs.append(f"{state} {state+1} {tid} {tid} 0.0")
        state += 1
    arcs.append(f"{state} 0.0")
    txt = "\n".join(arcs)
    fsa = k2.Fsa.from_str(txt, acceptor=False, openfst=True)
    return k2.arc_sort(fsa)

# WFS Decoding
def wildcard_decode_k2(logits, pattern, wildcard_set, token2id, word_delimiter_token="|"):
    """
    logits: (T, V) raw logits from model
    pattern: list of characters (with '.')
    wildcard_set: list of allowed tokens for '.'
    """
    # Convert to torch tensor if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    
    # Convert logits to log probabilities
    log_probs = torch.log_softmax(logits, dim=-1)
    
    T, V = log_probs.shape
    dense = k2.DenseFsaVec(log_probs.unsqueeze(0), torch.tensor([[0, 0, T]], dtype=torch.int32))
    ctc_topo = k2.arc_sort(k2.ctc_topo(V-1))
    pattern_fsa = build_pattern_fsa(pattern, wildcard_set, token2id)
    decoding_graph = k2.arc_sort(k2.compose(ctc_topo, pattern_fsa))
    lattice = k2.intersect_dense_pruned(
       decoding_graph, dense,
       search_beam=20.0, output_beam=8.0,
       min_active_states=30, max_active_states=10000)
    
    best_path = k2.shortest_path(lattice, use_double_scores=False)
    aux = k2.get_aux_labels(best_path)[0]
    hyp_ids = [x for x in aux if x >= 0]
    
    result = "".join(id2tok[i] for i in hyp_ids)
    
    # Replace word delimiter with space
    if word_delimiter_token:
        result = result.replace(word_delimiter_token, " ")
    
    
    return result.strip()

# CTC Decoding
def ctc_decode_k2(logits, search_beam=20.0, output_beam=8.0, word_delimiter_token="|"):
    """
    CTC decoding using k2 with beam settings (no pattern constraints).
    Takes raw logits as input.
    """
    # Convert to torch tensor if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    
    # Convert logits to log probabilities
    log_probs = torch.log_softmax(logits, dim=-1)
    
    T, V = log_probs.shape
    dense = k2.DenseFsaVec(log_probs.unsqueeze(0), torch.tensor([[0, 0, T]], dtype=torch.int32))
    ctc_topo = k2.arc_sort(k2.ctc_topo(V-1))
    
    # No pattern FSA - just CTC topology
    lattice = k2.intersect_dense_pruned(
        ctc_topo, dense,
        search_beam=search_beam, 
        output_beam=output_beam,
        min_active_states=30, 
        max_active_states=10000
    )
    
    best_path = k2.shortest_path(lattice, use_double_scores=False)
    aux = k2.get_aux_labels(best_path)[0]
    hyp_ids = [x for x in aux if x >= 0]
    
    result = "".join(id2tok[i] for i in hyp_ids)
    
    # Replace word delimiter with space
    if word_delimiter_token:
        result = result.replace(word_delimiter_token, " ")
    
    return result.strip()

def ctc_decode_greedy(logits, word_delimiter_token="|"):
    """
    Greedy CTC decoding: argmax at each frame, then collapse repeats and remove blanks.
    Takes raw logits as input (not log_probs).
    """
    # Convert to torch tensor if needed
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
    
    # Get the most probable token at each frame (argmax on logits)
    greedy_ids = logits.argmax(dim=-1)  # Shape: (T,)
    
    # Collapse repeats and remove blanks
    output = []
    prev_id = None
    
    for token_id in greedy_ids.tolist():
        if token_id == blank_id:
            prev_id = None  # Reset on blank
            continue
        if token_id != prev_id:  # Only add if different from previous
            output.append(token_id)
            prev_id = token_id
    
    result = "".join(id2tok[i] for i in output)
    
    # Replace word delimiter with space
    if word_delimiter_token:
        result = result.replace(word_delimiter_token, " ")
    
    return result.strip()

def print_util(data):
    col_width = max(len(row[0]) for row in data) + 2
    for key, value in data:
        print(f"{key:<{col_width}}: {value}")
    print()  # extra newline for separation

def clean_text(text, remove_diacritics=False):
    if remove_diacritics:
        text = araby.strip_diacritics(text)
    # replace | and - with space
    text = text.replace('|', ' ').replace('-', ' ')
    # collapse extra spaces
    text = text.replace('  ', ' ')
    return text

def calculate_wer(hyp, ref):
    if isinstance(hyp, list):
        hyp = [clean_text(h) for h in hyp]
        ref = [clean_text(r) for r in ref]
    else:
        hyp = clean_text(hyp)
        ref = clean_text(ref)
    return wer(ref, hyp)

def construct_pattern(text):
    text = text.replace(' ', space_token)
    pattern = []
    for c in text:
        pattern.append(c)
        if c in arabic_letters:
            pattern.append(wildcard_token)
    return pattern

In [4]:
print(precomposed)

['ا', 'ي', 'و']


In [4]:
text = araby.strip_diacritics("صَائِرٌ خَبَرًا فَكُنْ خَبَرًا يَرُوقُ جَمِيلَا")
print(list(text))
pattern = construct_pattern(text)
for i,ch in enumerate(pattern):
    print(f"{ch} {i}")


['ص', 'ا', 'ئ', 'ر', ' ', 'خ', 'ب', 'ر', 'ا', ' ', 'ف', 'ك', 'ن', ' ', 'خ', 'ب', 'ر', 'ا', ' ', 'ي', 'ر', 'و', 'ق', ' ', 'ج', 'م', 'ي', 'ل', 'ا']


NameError: name 'space_token' is not defined

### Wav2Vec

In [7]:
# Load model directly
from transformers import AutoProcessor, AutoModelForCTC
wav2vec_processor = AutoProcessor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
wav2vec_model = AutoModelForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic")

vocab = list(wav2vec_processor.tokenizer.get_vocab().keys())
id2tok = {v: k for k, v in wav2vec_processor.tokenizer.get_vocab().items()}
token2id = {v: k for k, v in enumerate(vocab)}
blank_id = wav2vec_processor.tokenizer.pad_token_id
space_token = wav2vec_processor.tokenizer.word_delimiter_token

diac_in_vocab = [diac for diac in diacritics if diac in vocab]

unconstrained_wildcard_ids = [i for i,ch in enumerate(vocab)]
constrained_wildcard_ids = [i for i,ch in enumerate(vocab) if ch in diac_in_vocab]
# constrained_wildcard_ids.append(wav2vec_processor.tokenizer.word_delimiter_token_id)
constrained_wildcard_ids.append(blank_id)

print(f"diacritics in vocab: {diac_in_vocab}")
print(f"vocab size: {len(vocab)}")
print(f"blank_id: {blank_id}")
print(f"token2id: {token2id}")
print(f"constrained_wildcard_ids: {constrained_wildcard_ids}")
print(f"unconstrained_wildcard_ids: {unconstrained_wildcard_ids}")

  from .autonotebook import tqdm as notebook_tqdm


diacritics in vocab: ['ً', 'ٌ', 'ٍ', 'َ', 'ُ', 'ِ', 'ّ', 'ْ']
vocab size: 51
blank_id: 0
token2id: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, '|': 4, '-': 5, 'ء': 6, 'آ': 7, 'أ': 8, 'ؤ': 9, 'إ': 10, 'ئ': 11, 'ا': 12, 'ب': 13, 'ة': 14, 'ت': 15, 'ث': 16, 'ج': 17, 'ح': 18, 'خ': 19, 'د': 20, 'ذ': 21, 'ر': 22, 'ز': 23, 'س': 24, 'ش': 25, 'ص': 26, 'ض': 27, 'ط': 28, 'ظ': 29, 'ع': 30, 'غ': 31, 'ـ': 32, 'ف': 33, 'ق': 34, 'ك': 35, 'ل': 36, 'م': 37, 'ن': 38, 'ه': 39, 'و': 40, 'ى': 41, 'ي': 42, 'ً': 43, 'ٌ': 44, 'ٍ': 45, 'َ': 46, 'ُ': 47, 'ِ': 48, 'ّ': 49, 'ْ': 50}
constrained_wildcard_ids: [43, 44, 45, 46, 47, 48, 49, 50, 0]
unconstrained_wildcard_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]


#### Test single sample

In [14]:

arabic_audio_path = "samples/female_ab_00000.wav"
arabic_reference = open('samples/female_ab_00000.txt', 'r', encoding='utf-8').read()
arabic_reference_no_diac = araby.strip_diacritics(arabic_reference)


In [15]:

def get_logits(audio_path):
    wav, sr = sf.read(audio_path)

    if sr != 16000:
        import librosa
        wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)

    wav = wav[None, :]

    processed_input = wav2vec_processor(
        audio=wav,
        sampling_rate=16000,
        return_tensors="pt"
        )

    output = wav2vec_model(processed_input.input_values)
    logits = output.logits
    predicted_ids = torch.argmax(logits, dim=-1)

    transcription = wav2vec_processor.decode(predicted_ids[0])

    return logits, transcription

logits, transcription = get_logits(arabic_audio_path)
wer_sample = calculate_wer(arabic_reference, transcription)

print(f"hyp: {transcription}")
print(f"ref: {arabic_reference}")
print(f"wer: {wer_sample}")


hyp: فَمِنَ الرَّامِ السُّوَيْديِّ - أُسْكُرْ سَواهِنْ - الَّذِي شَرَكَ فِي أُولَمْبِيَادِ أَلْفُتِسُعْمِئَةٍ وَعِشْرونَ - حِينَ كَانَ يَبْلُغُ اثْنَيْنِ وَسَبرينَ عَاماً - وَمِئتَيْنِ وَوَاحِدُ وثَمانينَ يَوْمًا - إِلَى الفارسَتَيْنِ أَنْْيُوزْلَندِيَّةِ - جيولِي بروغْهَامْ - والأَسْتْرَاليَّةِ مارِي هَنَا - الَّتَيْنِ تُشَارِكَانِ فِي أُولَمْبِيَادِرِيُو - وَهُمَا فِي الْوَاحِدِ وَالسِّتّينَ مِنَ الْعُمُرِ
ref: فَمِنَ الرَّامِي السُّوِيدِيِّ أُوسْكَار سَوَاهِنْ الَّذِي شَارَكَ فِي أُولُمْبِيَادِ  أَلْفْ وَتِسْعُمِئَةٍ وَعِشْرُونْ حِينَ كَانَ يَبْلُغُ إَثْنَينْ وَ سَبْعينَ  عَامًا وَمِئَتَينْ وَوَاحِدْ وَثَمَانِينَ يَوْمًا إِلَى الْفَارِسَتَيْنِ النِّيُوزِيلَنْدِيَّةِ جُولِي برُوغْهَام وَالْأُسْتُرَالِيَّةِ مَارِي هَانَا اللَّتَيْنِ تُشَارِكَانِ فِي أُولُمْبِيَّادِ رِيُو وَهُمَا فِي الْوَاحِدِ وَالسِّتِينَ مِنَ الْعُمُر
wer: 0.717948717948718


In [37]:

#################### Greedy CTC ####################
greedy_output = ctc_decode_greedy(logits.squeeze(0))
data = [
    ("greedy_output", clean_text(greedy_output)),
    ("transcription",  clean_text(transcription)),
    ("reference",      clean_text(arabic_reference)),
    ("wer", calculate_wer(greedy_output, arabic_reference))
]
print_util(data)
############################################################

#################### CTC K2 ########################
regular_output = ctc_decode_k2(logits.squeeze(0), search_beam=20.0, output_beam=8.0)
data = [
    ("regular_output", clean_text(regular_output)),
    ("transcription",  clean_text(transcription)),
    ("reference",      clean_text(arabic_reference)),
    ("wer", calculate_wer(regular_output, arabic_reference))
]
print_util(data)
############################################################

#################### WFST constrained #############################
pattern_diacritics = construct_pattern(arabic_reference_no_diac)[:5]
hyp_diacritics_constrained, lattice, dense, ctc_topo, pattern_fsa,best_path = wildcard_decode_k2(logits.squeeze(0), pattern_diacritics, constrained_wildcard_ids, token2id)
data = [
    ("pattern", "".join(pattern_diacritics)),
    ("hyp_diacritics_constrained", clean_text(hyp_diacritics_constrained)),
    # ("transcription",  clean_text(transcription)),
    ("reference",      clean_text(arabic_reference)),
    ("wer", calculate_wer(hyp_diacritics_constrained, arabic_reference))
]
print_util(data)
############################################################

#################### WFST unconstrained #############################
# hyp_diacritics_unconstrained, lattice, dense, ctc_topo = wildcard_decode_k2(logits.squeeze(0), pattern_diacritics, unconstrained_wildcard_ids, token2id)
# data = [
#     ("pattern", "".join(pattern_diacritics)),
#     ("hyp_diacritics_unconstrained", clean_text(hyp_diacritics_unconstrained)),
#     # ("transcription",  clean_text(transcription)),
#     ("reference",      clean_text(arabic_reference)),
#     ("wer", calculate_wer(hyp_diacritics_unconstrained, arabic_reference))
# ]
# print_util(data)
############################################################


greedy_output  : فَمِنَ الرَّامِ السُّوَيْديِّ  أُسْكُرْ سَواهِنْ  الَّذِي شَرَكَ فِي أُولَمْبِيَادِ أَلْفُتِسُعْمِئَةٍ وَعِشْرونَ  حِينَ كَانَ يَبْلُغُ اثْنَيْنِ وَسَبرينَ عَاماً  وَمِئتَيْنِ وَوَاحِدُ وثَمانينَ يَوْمًا  إِلَى الفارسَتَيْنِ أَنْْيُوزْلَندِيَّةِ  جيولِي بروغْهَامْ  والأَسْتْرَاليَّةِ مارِي هَنَا  الَّتَيْنِ تُشَارِكَانِ فِي أُولَمْبِيَادِرِيُو  وَهُمَا فِي الْوَاحِدِ وَالسِّتّينَ مِنَ الْعُمُرِ
transcription  : فَمِنَ الرَّامِ السُّوَيْديِّ  أُسْكُرْ سَواهِنْ  الَّذِي شَرَكَ فِي أُولَمْبِيَادِ أَلْفُتِسُعْمِئَةٍ وَعِشْرونَ  حِينَ كَانَ يَبْلُغُ اثْنَيْنِ وَسَبرينَ عَاماً  وَمِئتَيْنِ وَوَاحِدُ وثَمانينَ يَوْمًا  إِلَى الفارسَتَيْنِ أَنْْيُوزْلَندِيَّةِ  جيولِي بروغْهَامْ  والأَسْتْرَاليَّةِ مارِي هَنَا  الَّتَيْنِ تُشَارِكَانِ فِي أُولَمْبِيَادِرِيُو  وَهُمَا فِي الْوَاحِدِ وَالسِّتّينَ مِنَ الْعُمُرِ
reference      : فَمِنَ الرَّامِي السُّوِيدِيِّ أُوسْكَار سَوَاهِنْ الَّذِي شَارَكَ فِي أُولُمْبِيَادِ أَلْفْ وَتِسْعُمِئَةٍ وَعِشْرُونْ حِينَ كَانَ يَبْلُغُ إَثْنَينْ وَ

In [46]:
print(dense)

num_axes: 2
device_type: kCpu
device_id: -1
row_splits1: [ 0 1230 ]
row_ids1: [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 

#### Evaluate on ArVoice, ClArTTS, NADI-2025

In [7]:
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader

def collate_fn(batch):
    # For batch_size=1, return the single item directly (no extra dimension)
    if len(batch) == 1:
        wav2vec_logits = torch.from_numpy(np.array(batch[0]['wav2vec_logits'][0])).float()
        transcription = batch[0]['transcription'][0]
        text = batch[0]['text']
    else:
        wav2vec_logits = torch.from_numpy(np.array([item['wav2vec_logits'][0] for item in batch])).float()
        transcription = [item['transcription'][0] for item in batch]
        text = [item['text'][0] for item in batch]
    return wav2vec_logits, transcription, text

##### ClArTTS

In [8]:
# load MBZUAI/ClArTTS
clartts_dataset = load_from_disk("/home/rufael/Projects/diac-btc/data/clartts/wav2vec/test")
clartts_dataset

Dataset({
    features: ['text', 'file', 'audio', 'sampling_rate', 'duration', 'transcription', 'wav2vec_logits'],
    num_rows: 205
})

In [12]:
import soundfile as sf

audio = clartts_dataset[8]['audio']
sr = clartts_dataset[8]['sampling_rate']
sf.write('clartts_8.wav', np.array(audio), sr)
clartts_dataset[8]['text']

'عَلَى عُمَرَ بْنِ الْخَطَّابِ  رَضِيَ اللَّهُ عَنْهُ  فَقَالَ مَا هَذَا قَالُوا صَدَاقُ أُمِّ كُلْثُومٍ'

In [66]:
constrained_wildcard_ids

[43, 44, 45, 46, 47, 48, 49, 50, 0]

In [11]:
from collections import defaultdict
import json
from tqdm import tqdm

test_dataloader = DataLoader(clartts_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

results = {}
for i,batch in tqdm(enumerate(test_dataloader)):
    logits, transcription, text = batch
    
    ref_no_diac = araby.strip_diacritics(text)
    pattern = construct_pattern(ref_no_diac)

    hyp_diacritics_unconstrained = wildcard_decode_k2(logits.squeeze(0), pattern, unconstrained_wildcard_ids, token2id, word_delimiter_token=space_token)
    hyp_diacritics_constrained = wildcard_decode_k2(logits.squeeze(0), pattern, constrained_wildcard_ids, token2id, word_delimiter_token=space_token)
    if hyp_diacritics_unconstrained=="":
        print("here")
    results[f'{i}']={
        'hyp_constrained': hyp_diacritics_constrained,
        'hyp_unconstrained': hyp_diacritics_unconstrained,
        'text': text,
        'wav2vec_regular_output': transcription,
        'pattern': "".join(pattern),
        'unconstrained_wer': calculate_wer(hyp_diacritics_unconstrained, text),
        'constrained_wer': calculate_wer(hyp_diacritics_constrained, text),
        'baseline_wer': calculate_wer(transcription, text)
    }

# collect and dump the text, hyp_constrained, hyp_unconstrained to separate txt files
with open('.results/wav2vec/clartts/text.txt', 'w', encoding='utf-8') as f:
    for i in results:
        f.write(results[i]['text'] + '\n')

with open('.results/wav2vec/clartts/hyp_constrained.txt', 'w', encoding='utf-8') as f:
    for i in results:
        f.write(results[i]['hyp_constrained'] + '\n')

with open('.results/wav2vec/clartts/hyp_unconstrained.txt', 'w', encoding='utf-8') as f:
    for i in results:
        f.write(results[i]['hyp_unconstrained'] + '\n')

with open('.results/wav2vec/clartts/transcription.txt', 'w', encoding='utf-8') as f:
    for i in results:
        f.write(results[i]['wav2vec_regular_output'] + '\n')

# collect wer results
# results["wer_constrained"] = sum([results[i]['constrained_wer'] for i in results])/len(results)
# results["wer_unconstrained"] = sum([results[i]['unconstrained_wer'] for i in results])/len(results)
# results["wer_baseline"] = sum([results[i]['baseline_wer'] for i in results])/len(results)

# dump results to json
with open('.results/wav2vec/clartts/results.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)

print(f"WER constrained: {results['wer_constrained']}")
print(f"WER unconstrained: {results['wer_unconstrained']}")
print(f"WER baseline: {results['wer_baseline']}")


13it [05:49, 33.64s/it]

here


34it [05:50,  2.71s/it]

here


94it [05:52, 24.89it/s]

here


102it [05:52, 29.25it/s]

here
here


127it [05:53, 29.32it/s]

here


139it [05:54, 28.51it/s]

here


149it [05:54, 27.92it/s]

here


160it [05:54, 29.03it/s]

here


187it [05:55, 26.52it/s]

here


205it [05:56,  1.74s/it]


KeyError: 'wer_constrained'

##### ArVoice

In [28]:
arvoice_dataset = load_from_disk("/home/rufael/Projects/diac-btc/data/arvoice/test")
arvoice_dataset

Dataset({
    features: ['file_name', 'transcription', 'speaker_id', 'source', 'wav2vec_logits'],
    num_rows: 248
})

In [None]:
test_dataloader = DataLoader(arvoice_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

for batch in test_dataloader:
    logits, transcription = batch


torch.Size([255, 51])
تحتل ضواحي مدينة أدرارا صدارة الأحداث في الرواية


##### NADI

### Nemo

In [9]:
arabic_audio_path = "samples/female_ab_00000.wav"
arabic_reference = open('samples/female_ab_00000.txt', 'r', encoding='utf-8').read()
arabic_reference_no_diac = araby.strip_diacritics(arabic_reference)

In [10]:
import nemo.collections.asr as nemo_asr
from nemo.collections.common.data.utils import move_data_to_device
from omegaconf import open_dict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(model_name="nvidia/stt_ar_fastconformer_hybrid_large_pcd_v1.0")
# Configure for greedy CTC
with open_dict(asr_model.cfg.decoding):
    asr_model.cfg.decoding.strategy = "greedy"
    asr_model.cfg.decoding.compute_timestamps = False # Optional
asr_model.change_decoding_strategy(decoder_type="ctc")

# test model
output = asr_model.transcribe([arabic_audio_path])

wer_val = calculate_wer(output[0].text, arabic_reference)
print(output[0].text)
print(arabic_reference)
print(f"WER: {wer_val}")


[NeMo W 2026-01-26 17:14:16 nemo_logging:405] Skipped conversion for config/subconfig:
    {'manifest_filepath': '???', 'sample_rate': 16000, 'batch_size': 16, 'shuffle': True, 'num_workers': 8, 'pin_memory': True, 'max_duration': 20, 'min_duration': 0.5, 'is_tarred': True, 'tarred_audio_filepaths': '???', 'shuffle_n': 2048, 'bucketing_strategy': 'fully_randomized', 'bucketing_batch_size': None}
     Reason: Missing mandatory value: train_ds.manifest_filepath
        full_key: train_ds.manifest_filepath
        object_type=dict.
[NeMo W 2026-01-26 17:14:16 nemo_logging:405] Skipped conversion for config/subconfig:
    {'manifest_filepath': '???', 'sample_rate': 16000, 'batch_size': 16, 'shuffle': False, 'use_start_end_token': False, 'num_workers': 8, 'pin_memory': True}
     Reason: Missing mandatory value: validation_ds.manifest_filepath
        full_key: validation_ds.manifest_filepath
        object_type=dict.
[NeMo W 2026-01-26 17:14:16 nemo_logging:405] Skipped conversion for conf

[NeMo I 2026-01-26 17:14:16 nemo_logging:393] Tokenizer SentencePieceTokenizer initialized with 1024 tokens


[NeMo W 2026-01-26 17:14:16 nemo_logging:405] Skipped conversion for config/subconfig:
    {'manifest_filepath': '???', 'sample_rate': 16000, 'batch_size': 16, 'shuffle': True, 'num_workers': 8, 'pin_memory': True, 'max_duration': 20, 'min_duration': 0.5, 'is_tarred': True, 'tarred_audio_filepaths': '???', 'shuffle_n': 2048, 'bucketing_strategy': 'fully_randomized', 'bucketing_batch_size': None}
     Reason: Missing mandatory value: train_ds.manifest_filepath
        full_key: train_ds.manifest_filepath
        object_type=dict.
[NeMo W 2026-01-26 17:14:16 nemo_logging:405] Skipped conversion for config/subconfig:
    {'manifest_filepath': '???', 'sample_rate': 16000, 'batch_size': 16, 'shuffle': False, 'use_start_end_token': False, 'num_workers': 8, 'pin_memory': True}
     Reason: Missing mandatory value: validation_ds.manifest_filepath
        full_key: validation_ds.manifest_filepath
        object_type=dict.
[NeMo W 2026-01-26 17:14:16 nemo_logging:405] Skipped conversion for conf

[NeMo I 2026-01-26 17:14:16 nemo_logging:393] PADDING: 0
[NeMo I 2026-01-26 17:14:17 nemo_logging:393] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2026-01-26 17:14:17 nemo_logging:393] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2026-01-26 17:14:17 nemo_logging:393] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2026-01-26 17:14:17 nemo_logging:393] Model EncDecHybridRNNTCTCBPEModel was successfully restored from /home/rufael/.cache/huggingface/hub/models--nvidia--stt_ar_fastconformer_hybrid_large_pcd_v1.0/snapshots/7f32349d952f42a28dce979ba73270aa2bbdfa89/stt_ar_fastconformer_hybrid_large_pcd_v1.0.nemo.
[NeMo I 2026-01-26 17:14:17 nemo_logging:393] No `decoding_cfg` passed when changing decoding strategy, using internal config
[NeMo I 2026-01-26 17:14:17 nemo_logging:393] Changed deco

[NeMo W 2026-01-26 17:14:17 nemo_logging:405] The following configuration keys are ignored by Lhotse dataloader: use_start_end_token
[NeMo W 2026-01-26 17:14:17 nemo_logging:405] You are using a non-tarred dataset and requested tokenization during data sampling (pretokenize=True). This will cause the tokenization to happen in the main (GPU) process,possibly impacting the training speed if your tokenizer is very large.If the impact is noticable, set pretokenize=False in dataloader config.(note: that will disable token-per-second filtering and 2D bucketing features)
Transcribing: 0it [00:00, ?it/s][NeMo W 2026-01-26 17:14:17 nemo_logging:405] CTC decoding strategy 'greedy' is slower than 'greedy_batch', which implements the same exact interface. Consider changing your strategy to 'greedy_batch' for a free performance improvement.
Transcribing: 1it [00:00,  4.21it/s]

فمن الرامي السويدي أوسكار سواهن الذي شارك في أولمبياد ألف وتسع مئة وعشرون حين كان يبلغ اثنان وسبعون عامًا و مئتان وواحد وثمانون يومًا. إلى الفارستين النيوزيلندية جولي بروغهام والأستترالية ماري هانا اللتين تشاركان في أولمبياد ريو وهما في واحد وستون من العمر
فَمِنَ الرَّامِي السُّوِيدِيِّ أُوسْكَار سَوَاهِنْ الَّذِي شَارَكَ فِي أُولُمْبِيَادِ  أَلْفْ وَتِسْعُمِئَةٍ وَعِشْرُونْ حِينَ كَانَ يَبْلُغُ إَثْنَينْ وَ سَبْعينَ  عَامًا وَمِئَتَينْ وَوَاحِدْ وَثَمَانِينَ يَوْمًا إِلَى الْفَارِسَتَيْنِ النِّيُوزِيلَنْدِيَّةِ جُولِي برُوغْهَام وَالْأُسْتُرَالِيَّةِ مَارِي هَانَا اللَّتَيْنِ تُشَارِكَانِ فِي أُولُمْبِيَّادِ رِيُو وَهُمَا فِي الْوَاحِدِ وَالسِّتِينَ مِنَ الْعُمُر
WER: 1.0238095238095237





In [15]:
asr_model.tokenizer.__dir__()

['chat_template',
 'tokenizer',
 'original_vocab_size',
 'vocab_size',
 'legacy',
 'ignore_extra_whitespaces',
 'extra_space_token',
 'special_token_to_id',
 'id_to_special_token',
 'trim_spm_separator_after_special_token',
 'spm_separator',
 'spm_separator_id',
 'removed_extra_spaces',
 'space_sensitive',
 'supports_capitalization',
 'supported_punctuation',
 '__module__',
 '__doc__',
 '__init__',
 'text_to_tokens',
 'text_to_ids',
 '_text_to_ids',
 '_text_to_ids_extra_space',
 'tokens_to_text',
 'ids_to_text',
 'token_to_id',
 'ids_to_tokens',
 'tokens_to_ids',
 'add_special_tokens',
 'pad_id',
 'bos_id',
 'eos_id',
 'sep_id',
 'cls_id',
 'mask_id',
 'unk_id',
 'additional_special_tokens_ids',
 'vocab',
 '__abstractmethods__',
 '_abc_impl',
 'apply_chat_template',
 'name',
 'unique_identifiers',
 'cls',
 'sep',
 'pad',
 'eod',
 'bos',
 'eos',
 'mask',
 '__dict__',
 '__weakref__',
 '__slots__',
 '__new__',
 '__repr__',
 '__hash__',
 '__str__',
 '__getattribute__',
 '__setattr__',
 '__

In [None]:

# Get logits for nemo
def get_logits(audio_path):
    wav, sr = sf.read(audio_path)
    
    if sr != 16000:
        import librosa
        wav = librosa.resample(wav, orig_sr=sr, target_sr=16000)
        sr = 16000

        # get length of wav
        length = len(wav)
    else:
        length = len(wav)

    # convert to tensor
    wav = torch.from_numpy(wav).unsqueeze(0)
    length = torch.tensor([length])

    wav = move_data_to_device(wav, device)
    length = move_data_to_device(length, device)
    # get logits from asr_model
    encoded, encoded_length = asr_model.forward(input_signal=wav, input_signal_length=length)
    logits = asr_model.ctc_decoder(encoder_output=encoded)

    hypotheses = asr_model.ctc_decoding.ctc_decoder_predictions_tensor(
            logits,
            encoded_length,
            return_hypotheses=True,
        )

    return hypotheses, logits.squeeze(0)

asr_model.change_decoding_strategy(asr_model.cfg.decoding, decoder_type="ctc")

hypotheses, logits = get_logits(arabic_audio_path)

In [1]:
import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(model_name="nvidia/stt_ar_fastconformer_hybrid_large_pc_v1.0")

  from .autonotebook import tqdm as notebook_tqdm
[NeMo W 2026-01-26 06:15:54 nemo_logging:405] Megatron num_microbatches_calculator not found, using Apex version.
OneLogger: Setting error_handling_strategy to DISABLE_QUIETLY_AND_REPORT_METRIC_ERROR for rank (rank=0) with OneLogger disabled. To override: explicitly set error_handling_strategy parameter.
No exporters were provided. This means that no telemetry data will be collected.
[NeMo W 2026-01-26 06:16:07 nemo_logging:405] Skipped conversion for config/subconfig:
    {'manifest_filepath': '???', 'sample_rate': 16000, 'batch_size': 32, 'shuffle': True, 'num_workers': 8, 'pin_memory': True, 'max_duration': 20, 'min_duration': 0.5, 'is_tarred': True, 'tarred_audio_filepaths': '???', 'shuffle_n': 2048, 'bucketing_strategy': 'fully_randomized', 'bucketing_batch_size': None}
     Reason: Missing mandatory value: train_ds.manifest_filepath
        full_key: train_ds.manifest_filepath
        object_type=dict.
[NeMo W 2026-01-26 06:16:07 

[NeMo I 2026-01-26 06:16:07 nemo_logging:393] Tokenizer SentencePieceTokenizer initialized with 1024 tokens


[NeMo W 2026-01-26 06:16:07 nemo_logging:405] Skipped conversion for config/subconfig:
    {'manifest_filepath': '???', 'sample_rate': 16000, 'batch_size': 32, 'shuffle': True, 'num_workers': 8, 'pin_memory': True, 'max_duration': 20, 'min_duration': 0.5, 'is_tarred': True, 'tarred_audio_filepaths': '???', 'shuffle_n': 2048, 'bucketing_strategy': 'fully_randomized', 'bucketing_batch_size': None}
     Reason: Missing mandatory value: train_ds.manifest_filepath
        full_key: train_ds.manifest_filepath
        object_type=dict.
[NeMo W 2026-01-26 06:16:07 nemo_logging:405] Skipped conversion for config/subconfig:
    {'manifest_filepath': '???', 'sample_rate': 16000, 'batch_size': 16, 'shuffle': False, 'use_start_end_token': False, 'num_workers': 8, 'pin_memory': True}
     Reason: Missing mandatory value: validation_ds.manifest_filepath
        full_key: validation_ds.manifest_filepath
        object_type=dict.
[NeMo W 2026-01-26 06:16:07 nemo_logging:405] Skipped conversion for conf

[NeMo I 2026-01-26 06:16:08 nemo_logging:393] PADDING: 0
[NeMo I 2026-01-26 06:16:08 nemo_logging:393] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2026-01-26 06:16:08 nemo_logging:393] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2026-01-26 06:16:08 nemo_logging:393] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2026-01-26 06:16:09 nemo_logging:393] Model EncDecHybridRNNTCTCBPEModel was successfully restored from /home/rufael/.cache/huggingface/hub/models--nvidia--stt_ar_fastconformer_hybrid_large_pc_v1.0/snapshots/3e748b4d91935672f19fc4d6cba38fff9ef013c5/stt_ar_fastconformer_hybrid_large_pc_v1.0.nemo.


In [2]:
asr_model

EncDecHybridRNNTCTCBPEModel(
  (preprocessor): AudioToMelSpectrogramPreprocessor(
    (featurizer): FilterbankFeatures()
  )
  (encoder): ConformerEncoder(
    (pre_encode): ConvSubsampling(
      (out): Linear(in_features=2560, out_features=512, bias=True)
      (conv): MaskedConvSequential(
        (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
        (3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (4): ReLU(inplace=True)
        (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
        (6): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (7): ReLU(inplace=True)
      )
    )
    (pos_enc): RelPositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-16): 17 x ConformerLayer(
        (norm_feed_forward1): LayerNorm((512,), e