In [118]:
import torch
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from scipy.spatial.distance import cosine
import librosa
from dataclasses import dataclass
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')


In [119]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
print(torch.cuda.get_device_capability())


True
NVIDIA GeForce RTX 5080
(12, 0)


In [120]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [121]:
def load_audio(audio_path: str, sr: int = 16000) -> np.ndarray:
    """Load audio file and resample to 16kHz."""
    audio, _ = librosa.load(audio_path, sr=sr)
    return audio

def get_phoneme_embeddings(audio: np.ndarray) -> Tuple[torch.Tensor, str]:
    """Extract phoneme-level embeddings and predicted phonemes from audio."""
    processor = Wav2Vec2Processor.from_pretrained("bookbot/wav2vec2-ljspeech-gruut")
    model = Wav2Vec2ForCTC.from_pretrained("bookbot/wav2vec2-ljspeech-gruut")
    model.eval()
    
    # Prepare input
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
    
    with torch.no_grad():
        # Get hidden states (embeddings)
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # Last layer
        
        # Get phoneme predictions
        logits = outputs.logits
        predicted_ids = torch.argmax(logits, dim=-1)
        phonemes = processor.batch_decode(predicted_ids)[0]

    return hidden_states.squeeze(0), phonemes

In [122]:
# def get_phoneme_embeddings(audio: np.ndarray) -> Tuple[torch.Tensor, str]:
#     """Extract phoneme-level embeddings and predicted phonemes from audio."""
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
#     processor = Wav2Vec2Processor.from_pretrained("bookbot/wav2vec2-ljspeech-gruut")
#     model = Wav2Vec2ForCTC.from_pretrained("bookbot/wav2vec2-ljspeech-gruut")
#     model.to(device)  # Di chuyển model lên GPU
#     model.eval()
    
#     # Prepare input và di chuyển lên GPU
#     inputs = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
#     inputs = {k: v.to(device) for k, v in inputs.items()}  # Chuyển tất cả inputs lên GPU
    
#     with torch.inference_mode():  # Tối ưu hơn no_grad cho inference
#         # Get hidden states (embeddings)
#         outputs = model(**inputs, output_hidden_states=True)
#         hidden_states = outputs.hidden_states[-1]  # Last layer, vẫn trên GPU
        
#         # Get phoneme predictions
#         logits = outputs.logits
#         predicted_ids = torch.argmax(logits, dim=-1)
#         # predicted_ids_cpu = predicted_ids.cpu()  # Chuyển về CPU cho decode
#         # phonemes = processor.batch_decode(predicted_ids_cpu)[0]

#     return hidden_states.squeeze(0)

In [123]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# processor = Wav2Vec2Processor.from_pretrained("bookbot/wav2vec2-ljspeech-gruut")
# model = Wav2Vec2ForCTC.from_pretrained("bookbot/wav2vec2-ljspeech-gruut").to(device).eval()

# def get_phoneme_embeddings_fast(audio: np.ndarray):
#     inputs = processor(audio, sampling_rate=16000, return_tensors="pt", padding=False)
#     input_values = inputs.input_values.to(device, non_blocking=True)
#     attention_mask = inputs.get("attention_mask")
#     if attention_mask is not None:
#         attention_mask = attention_mask.to(device, non_blocking=True)

#     with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=torch.cuda.is_available()), \
#          torch.inference_mode():
#         out = model(input_values, attention_mask=attention_mask)  # không xin hidden_states
#         last = out.last_hidden_state.squeeze(0)                   # [T, H]
#         ids = torch.argmax(out.logits, dim=-1).cpu().numpy()      # decode trên CPU
#     phonemes = processor.batch_decode(ids)[0]
#     return last, phonemes


In [127]:
ref_audio_path = "./audio_files/word_february.mp3"

ref_audio = load_audio(ref_audio_path)
ref_embeddings, ref_phonemes = get_phoneme_embeddings(ref_audio)

ref_phonemes

'f ɛ b j u ɛ ɹ i'