In [None]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCTC, AutoProcessor
import torchaudio.functional as F
import soundfile
import librosa
import time

from IPython.display import Audio

In [None]:
model_id = "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm"


model = AutoModelForCTC.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

model = model.cuda()

In [None]:
audio_path = "/data/podcasts/El hilo/La-ultima-clinica-de-aborto-en-la-frontera-sur-de-Texas.mp3"
wav, sr = librosa.load(audio_path, duration=100)

wav_clip = wav[:sr*60]
Audio(wav_clip, rate=sr)

In [None]:
resampled_wav = F.resample(torch.tensor(wav_clip), sr, 16_000).numpy()

input_values = processor(resampled_wav, return_tensors="pt").input_values
input_values = input_values.cuda()
start_time = time.time()
with torch.no_grad():
    logits = model(input_values).logits
transcription = processor.batch_decode(logits.cpu().numpy()).text
print(time.time() - start_time)
print(transcription)

### Notes

- on RTX 3090 it takes less then 1 sec ~ 830 msec to transcribe 1 min of audio. Pretty good!
- on Titan X 1.92 sec

Testing model finetuned on 9 hours of mls dataset, without LM. finetuning took around 3 hours on single RTX 3090
greedy decoding without LM, takes less time as expected. ~ 1.38 sec for 1 min of audio, or TitanX

In [None]:
from pathlib import Path

ckpt_path = Path('/home/taras/git-repos/one-lang/audio-representation-learning/wav2vec2-large-xlsr-53-spanish-mls/checkpoint-1200/')
mls_spanish_model = AutoModelForCTC.from_pretrained(ckpt_path)
mls_spanish_model = mls_spanish_model.cuda()

In [None]:
from transformers import AutoTokenizer, AutoFeatureExtractor, Wav2Vec2Processor

feature_extractor = AutoFeatureExtractor.from_pretrained(ckpt_path)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path.parent)
mls_spanish_processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
start_time = time.time()
with torch.no_grad():
    logits = mls_spanish_model(input_values).logits
transcription = mls_spanish_processor.batch_decode(logits.argmax(-1).cpu().numpy())
print(time.time() - start_time)
print(transcription)