In [1]:
import torch
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from scipy.spatial.distance import cosine
import librosa
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class PhonemeExtractor:
    """Class để load model một lần và tái sử dụng"""
    
    def __init__(self, model_name: str = "bookbot/wav2vec2-ljspeech-gruut", device: str = None):
        """
        Initialize model and processor once
        
        Args:
            model_name: HuggingFace model name
            device: 'cuda' or 'cpu', auto-detect if None
        """
        print(f"Loading model {model_name}...")
        
        # Auto-detect device
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        # Load processor và model một lần
        self.processor = Wav2Vec2Processor.from_pretrained(model_name)
        self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()
        
        print(f"Model loaded on {self.device}")
    
    def load_audio(self, audio_path: str, sr: int = 16000) -> np.ndarray:
        """Load audio file and resample to 16kHz."""
        audio, _ = librosa.load(audio_path, sr=sr)
        return audio
    
    def get_phoneme_embeddings(self, audio: np.ndarray) -> Tuple[torch.Tensor, str]:
        """Extract phoneme-level embeddings and predicted phonemes from audio."""
        # Prepare input
        inputs = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
        
        # Move inputs to device
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            # Get hidden states (embeddings)
            outputs = self.model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]  # Last layer
            
            # Get phoneme predictions
            logits = outputs.logits
            predicted_ids = torch.argmax(logits, dim=-1)
            phonemes = self.processor.batch_decode(predicted_ids)[0]
        
        return hidden_states.squeeze(0).cpu(), phonemes
    
    def process_file(self, audio_path: str) -> Tuple[torch.Tensor, str]:
        """Load audio và extract phonemes trong một lần gọi"""
        audio = self.load_audio(audio_path)
        return self.get_phoneme_embeddings(audio)

In [9]:
# Cell 1: Load model một lần duy nhất (chỉ chạy 1 lần)
extractor = PhonemeExtractor(device='cuda')  # hoặc 'cuda' nếu có GPU

Loading model bookbot/wav2vec2-ljspeech-gruut...
Model loaded on cuda


In [None]:
# Cell 2: Sử dụng nhiều lần mà không cần load lại
ref_embeddings, ref_phonemes = extractor.process_file("./audio_files/word_january.mp3")
print(ref_phonemes) 

d͡ʒæ n j u ɛ ɹ i
