In [1]:
import torch
import torchaudio
from transformers import Wav2Vec2FeatureExtractor, HubertModel

def extract_hubert_embeddings(audio_path, layer=-1):
    """
    Extract HuBERT embeddings from an audio file.
    
    Parameters:
    audio_path (str): Path to the audio file
    layer (int): Which transformer layer to extract embeddings from (-1 for last layer)
    
    Returns:
    torch.Tensor: HuBERT embeddings
    """
    # Load the audio file
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Load HuBERT model and feature extractor
    model = HubertModel.from_pretrained("facebook/hubert-large-ll60k")
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/hubert-large-ll60k")
    
    # Resample if necessary (HuBERT expects 16kHz)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(sample_rate, 16000)
        waveform = resampler(waveform)
    
    # Prepare inputs
    inputs = feature_extractor(
        waveform.squeeze().numpy(),
        sampling_rate=16000,
        return_tensors="pt",
        padding=True
    )
    
    # Extract features
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        
    # Get hidden states from specified layer
    # -1 means last layer, -2 second to last, etc.
    hidden_states = outputs.hidden_states[layer]
    
    return hidden_states

def get_mean_embeddings(hidden_states):
    """
    Calculate mean embeddings across time dimension.
    
    Parameters:
    hidden_states (torch.Tensor): HuBERT hidden states
    
    Returns:
    torch.Tensor: Mean embeddings
    """
    # Average across time dimension (dim=1)
    mean_embeddings = torch.mean(hidden_states, dim=1)
    return mean_embeddings


  from .autonotebook import tqdm as notebook_tqdm


## SPLIT AUDIO FILE

In [2]:
import os
from pydub import AudioSegment
from pathlib import Path
import logging
from typing import List, Optional
from dataclasses import dataclass
from tqdm import tqdm
import math

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class AudioChunkInfo:
    """Class to hold information about an audio chunk"""
    chunk_number: int
    start_time: float  # in seconds
    end_time: float    # in seconds
    duration: float    # in seconds
    file_path: str

class MP3Splitter:
    """Class to handle MP3 file splitting"""
    
    def __init__(self, 
                 chunk_duration: int = 60, 
                 output_dir: Optional[str] = None,
                 min_chunk_duration: int = 30):
        """
        Initialize MP3Splitter
        
        Args:
            chunk_duration (int): Duration of each chunk in seconds (default: 60)
            output_dir (str): Directory to save chunks (default: None, uses input file directory)
            min_chunk_duration (int): Minimum duration for the last chunk in seconds (default: 30)
        """
        self.chunk_duration = chunk_duration * 1000  # Convert to milliseconds
        self.min_chunk_duration = min_chunk_duration * 1000  # Convert to milliseconds
        self.output_dir = output_dir
    
    def _create_output_dir(self, input_file: str) -> str:
        """Create output directory if it doesn't exist"""
        if self.output_dir is None:
            # Use input file directory and create a subdirectory with the file name
            base_path = Path(input_file).parent
            file_name = Path(input_file).stem
            output_dir = base_path / f"{file_name}_chunks"
        else:
            output_dir = Path(self.output_dir)
        
        output_dir.mkdir(parents=True, exist_ok=True)
        return str(output_dir)
    
    def _get_output_path(self, output_dir: str, chunk_number: int) -> str:
        """Generate output file path for a chunk"""
        return str(Path(output_dir) / f"chunk_{chunk_number:03d}.mp3")
    
    def split_audio(self, input_file: str) -> List[AudioChunkInfo]:
        """
        Split MP3 file into chunks
        
        Args:
            input_file (str): Path to input MP3 file
            
        Returns:
            List[AudioChunkInfo]: Information about each chunk
        """
        # Validate input file
        if not os.path.exists(input_file):
            raise FileNotFoundError(f"Input file not found: {input_file}")
        
        # Create output directory
        output_dir = self._create_output_dir(input_file)
        logger.info(f"Chunks will be saved in: {output_dir}")
        
        try:
            # Load audio file
            logger.info("Loading audio file...")
            audio = AudioSegment.from_mp3(input_file)
            
            # Calculate number of chunks
            total_duration = len(audio)
            num_chunks = math.ceil(total_duration / self.chunk_duration)
            
            chunk_infos = []
            
            # Split audio into chunks with progress bar
            logger.info("Splitting audio into chunks...")
            for i in tqdm(range(num_chunks), desc="Processing chunks"):
                # Calculate start and end times
                start_time = i * self.chunk_duration
                end_time = min((i + 1) * self.chunk_duration, total_duration)
                
                # Skip if last chunk is too short
                if i == num_chunks - 1 and (end_time - start_time) < self.min_chunk_duration:
                    # Extend the previous chunk instead
                    if chunk_infos:
                        prev_chunk = chunk_infos[-1]
                        os.remove(prev_chunk.file_path)  # Remove the previous chunk
                        
                        # Create new extended chunk
                        extended_chunk = audio[prev_chunk.start_time * 1000:end_time]
                        extended_chunk.export(prev_chunk.file_path, format="mp3")
                        
                        # Update previous chunk info
                        chunk_infos[-1] = AudioChunkInfo(
                            chunk_number=prev_chunk.chunk_number,
                            start_time=prev_chunk.start_time,
                            end_time=end_time / 1000,
                            duration=(end_time - prev_chunk.start_time * 1000) / 1000,
                            file_path=prev_chunk.file_path
                        )
                    break
                
                # Extract chunk
                chunk = audio[start_time:end_time]
                
                # Generate output path
                output_path = self._get_output_path(output_dir, i)
                
                # Export chunk
                chunk.export(output_path, format="mp3")
                
                # Store chunk information
                chunk_info = AudioChunkInfo(
                    chunk_number=i,
                    start_time=start_time / 1000,  # Convert to seconds
                    end_time=end_time / 1000,      # Convert to seconds
                    duration=(end_time - start_time) / 1000,  # Convert to seconds
                    file_path=output_path
                )
                chunk_infos.append(chunk_info)
            
            logger.info(f"Successfully created {len(chunk_infos)} chunks")
            return chunk_infos
            
        except Exception as e:
            logger.error(f"Error splitting audio: {str(e)}")
            raise

class AudioMetadata:
    """Class to handle audio metadata operations"""
    
    @staticmethod
    def save_chunk_metadata(chunk_infos: List[AudioChunkInfo], output_dir: str):
        """Save metadata for all chunks"""
        metadata_path = os.path.join(output_dir, "chunks_metadata.txt")
        
        with open(metadata_path, 'w') as f:
            f.write("Chunk Information:\n")
            f.write("-" * 50 + "\n")
            
            for chunk in chunk_infos:
                f.write(f"Chunk {chunk.chunk_number:03d}:\n")
                f.write(f"  Start Time: {chunk.start_time:.2f} seconds\n")
                f.write(f"  End Time: {chunk.end_time:.2f} seconds\n")
                f.write(f"  Duration: {chunk.duration:.2f} seconds\n")
                f.write(f"  File: {os.path.basename(chunk.file_path)}\n")
                f.write("-" * 50 + "\n")

In [3]:
input_file = "/home/snp2453/slt/speech_retrieval/data/raw_audio/long_news.mp3"
output_dir = "/home/snp2453/slt/speech_retrieval/data/processed_audio"
chunk_duration = 15
min_chunk_duration = 5

try:
    # Initialize splitter
    splitter = MP3Splitter(
        chunk_duration=chunk_duration,
        output_dir=output_dir,
        min_chunk_duration=min_chunk_duration
    )
    
    # Split audio
    chunk_infos = splitter.split_audio(input_file)
    
    # Save metadata
    if chunk_infos:
        output_dir = os.path.dirname(chunk_infos[0].file_path)
        AudioMetadata.save_chunk_metadata(chunk_infos, output_dir)
        
    logger.info("Processing completed successfully!")
    
except Exception as e:
    logger.error(f"Error: {str(e)}")
    raise

INFO:__main__:Chunks will be saved in: /home/snp2453/slt/speech_retrieval/data/processed_audio
INFO:__main__:Loading audio file...
INFO:__main__:Splitting audio into chunks...
Processing chunks: 100%|██████████| 225/225 [00:53<00:00,  4.22it/s]
INFO:__main__:Successfully created 225 chunks
INFO:__main__:Processing completed successfully!


### Chunking embeddings

In [2]:
import os
path = "/home/snp2453/slt/speech_retrieval/data/processed_audio"
files = os.listdir(path)
files.sort()

In [3]:
%%time
for i in files:
    full_path = os.path.join(path, i)
    name = i.split(".")[0]
    type = i.split(".")[1]
    if type != "mp3":
        continue
    hidden_states = extract_hubert_embeddings(full_path)
    mean_embeddings = get_mean_embeddings(hidden_states)
    torch.save(mean_embeddings, f"/home/snp2453/slt/speech_retrieval/data/embeddings/chunks_embedding/{name}.pt")

CPU times: user 43min 13s, sys: 5min 33s, total: 48min 47s
Wall time: 4min 44s


2 min 30 seconds vs 4 mins 44 seconds

In [9]:
hidden_states = extract_hubert_embeddings("/home/snp2453/slt/speech_retrieval/data/raw_audio/Russia_Google.mp3")
mean_embeddings = get_mean_embeddings(hidden_states)
torch.save(mean_embeddings, f"/home/snp2453/slt/speech_retrieval/data/embeddings/Russia_Google.pt")

### Similarity

In [5]:
import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from typing import List, Dict, Tuple, Union, Optional
from dataclasses import dataclass
import logging
from tqdm import tqdm
import json

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SimilarityResult:
    """Class to hold similarity computation results"""
    chunk_id: str
    similarity_score: float
    chunk_start_time: float
    chunk_end_time: float
    embedding_path: str

class EmbeddingSimilarityCalculator:
    """Handles similarity computations between embeddings"""
    
    def __init__(self, embeddings_dir: str):
        """
        Initialize calculator
        
        Args:
            embeddings_dir (str): Directory containing stored embeddings
        """
        self.embeddings_dir = Path(embeddings_dir)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def load_embedding(self, embedding_path: Union[str, Path]) -> torch.Tensor:
        """
        Load embedding from file and process it
        
        Args:
            embedding_path (Union[str, Path]): Path to embedding file
            
        Returns:
            torch.Tensor: Processed embedding tensor
        """
        try:
            embedding = torch.load(embedding_path, map_location=self.device)
            
            # If embedding is 3D (batch_size, sequence_length, hidden_size),
            # take mean over sequence dimension
            if len(embedding.shape) == 3:
                embedding = torch.mean(embedding, dim=1)
            
            # Ensure embedding is 2D (batch_size, hidden_size)
            if len(embedding.shape) == 1:
                embedding = embedding.unsqueeze(0)
                
            # Normalize embedding
            embedding = F.normalize(embedding, p=2, dim=-1)
            
            return embedding
            
        except Exception as e:
            logger.error(f"Error loading embedding from {embedding_path}: {str(e)}")
            raise
    
    def compute_cosine_similarity(self, 
                                query_embedding: torch.Tensor,
                                chunk_embedding: torch.Tensor) -> float:
        """
        Compute cosine similarity between query and chunk embeddings
        
        Args:
            query_embedding (torch.Tensor): Query embedding
            chunk_embedding (torch.Tensor): Chunk embedding
            
        Returns:
            float: Cosine similarity score
        """
        with torch.no_grad():
            similarity = F.cosine_similarity(
                query_embedding,
                chunk_embedding,
                dim=-1
            )
        return similarity.item()
    
    def process_single_chunk(self,
                           query_embedding: torch.Tensor,
                           chunk_embedding_path: Union[str, Path],
                           chunk_metadata: Optional[Dict] = None) -> SimilarityResult:
        """
        Process a single chunk and compute similarity
        
        Args:
            query_embedding (torch.Tensor): Query embedding
            chunk_embedding_path (Union[str, Path]): Path to chunk embedding
            chunk_metadata (Optional[Dict]): Chunk metadata if available
            
        Returns:
            SimilarityResult: Similarity computation result
        """
        # Load and process chunk embedding
        chunk_embedding = self.load_embedding(chunk_embedding_path)
        
        # Compute similarity
        similarity_score = self.compute_cosine_similarity(
            query_embedding,
            chunk_embedding
        )
        
        # Get chunk information
        chunk_id = Path(chunk_embedding_path).stem
        start_time = chunk_metadata.get('start_time', 0.0) if chunk_metadata else 0.0
        end_time = chunk_metadata.get('end_time', 0.0) if chunk_metadata else 0.0
        
        return SimilarityResult(
            chunk_id=chunk_id,
            similarity_score=similarity_score,
            chunk_start_time=start_time,
            chunk_end_time=end_time,
            embedding_path=str(chunk_embedding_path)
        )
    
    def find_most_similar_chunks(self,
                               query_path: str,
                               top_k: int = 1,
                               metadata_path: Optional[str] = None) -> List[SimilarityResult]:
        """
        Find the most similar chunks to a query
        
        Args:
            query_path (str): Path to query embedding
            top_k (int): Number of top results to return
            metadata_path (Optional[str]): Path to chunks metadata file
            
        Returns:
            List[SimilarityResult]: Top-k most similar chunks
        """
        # Load query embedding
        query_embedding = self.load_embedding(query_path)
        
        # Load metadata if available
        chunk_metadata = {}
        if metadata_path and Path(metadata_path).exists():
            with open(metadata_path, 'r') as f:
                chunk_metadata = json.load(f)
        
        # Process all chunk embeddings
        results = []
        chunk_paths = sorted(self.embeddings_dir.glob("chunk_*.pt"))
        
        for chunk_path in tqdm(chunk_paths, desc="Processing chunks"):
            try:
                metadata = chunk_metadata.get(chunk_path.stem, {})
                result = self.process_single_chunk(
                    query_embedding,
                    chunk_path,
                    metadata
                )
                results.append(result)
            except Exception as e:
                logger.warning(f"Error processing chunk {chunk_path}: {str(e)}")
                continue
        
        # Sort by similarity score and get top-k
        results.sort(key=lambda x: x.similarity_score, reverse=True)
        return results[:top_k]

class SimilarityVisualizer:
    """Handles visualization of similarity results"""
    
    @staticmethod
    def save_results(results: List[SimilarityResult], 
                    output_path: str):
        """Save similarity results to file"""
        with open(output_path, 'w') as f:
            f.write("Similarity Results:\n")
            f.write("-" * 50 + "\n")
            
            for i, result in enumerate(results, 1):
                f.write(f"Rank {i}:\n")
                f.write(f"  Chunk ID: {result.chunk_id}\n")
                f.write(f"  Similarity Score: {result.similarity_score:.4f}\n")
                f.write(f"  Time Range: {result.chunk_start_time:.2f}s - "
                       f"{result.chunk_end_time:.2f}s\n")
                f.write(f"  Embedding: {result.embedding_path}\n")
                f.write("-" * 50 + "\n")
    
    @staticmethod
    def plot_similarities(results: List[SimilarityResult],
                        output_path: str):
        """Create visualization of similarity scores"""
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            # Prepare data
            chunk_ids = [r.chunk_id for r in results]
            scores = [r.similarity_score for r in results]
            
            # Create plot
            plt.figure(figsize=(12, 6))
            sns.barplot(x=range(len(chunk_ids)), y=scores)
            
            # Customize plot
            plt.title("Chunk Similarity Scores")
            plt.xlabel("Chunk ID")
            plt.ylabel("Cosine Similarity")
            plt.xticks(range(len(chunk_ids)), chunk_ids, rotation=45)
            
            # Save plot
            plt.tight_layout()
            plt.savefig(output_path)
            plt.close()
            
        except ImportError:
            logger.warning("matplotlib and seaborn required for visualization")

In [10]:
%%time
output_dir = "/home/snp2453/slt/speech_retrieval/results"
embeddings_dir = "/home/snp2453/slt/speech_retrieval/data/embeddings/chunks_embedding"
query_path = "/home/snp2453/slt/speech_retrieval/data/embeddings/Russia_Google.pt"
top_k = 20
metadata_path = "/home/snp2453/slt/speech_retrieval/chunk_metadata.json"
try:
    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    # Initialize calculator
    calculator = EmbeddingSimilarityCalculator(embeddings_dir)
    # Find most similar chunks
    results = calculator.find_most_similar_chunks(
    query_path,
    top_k=top_k,
    metadata_path=metadata_path
    )
    # Save results
    SimilarityVisualizer.save_results(
    results,
    output_dir / "similarity_results.txt"
    )
    # Create visualization
    SimilarityVisualizer.plot_similarities(
    results,
    output_dir / "similarity_plot.png"
    )
    logger.info("Processing completed successfully!")
except Exception as e:
    logger.error(f"Error: {str(e)}")
    raise

  embedding = torch.load(embedding_path, map_location=self.device)
Processing chunks:   0%|          | 0/225 [00:00<?, ?it/s]

Processing chunks: 100%|██████████| 225/225 [00:00<00:00, 1766.41it/s]
INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
INFO:__main__:Processing completed successfully!


CPU times: user 440 ms, sys: 133 ms, total: 573 ms
Wall time: 342 ms
