# Audio Processor

In [1]:
import torch
import torchaudio

from abc import ABC, abstractmethod
from typing import Dict
from transformers import BertModel, BertTokenizer


class AudioProcessor:
    SAMPLE_RATE = 16_000
    N_MELS = 64
    WINDOW_SIZE = 0.02  # 20 ms
    HOP_LENGTH = 0.01  # 10 ms
    N_FFT = 512
    EPSILON = 1e-6
    MIN_FREQUENCY = 85
    MAX_FREQUENCY = 3000

    # wiener filter parameters
    WIENER_N_FFT = 512
    WIENER_HOP_LENGTH = 128
    WIENER_WIN_LENGTH = 512
    NOISE_FRAME_COUNT = 5  # first 5 frames for noise estimation

    mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=N_FFT,
        win_length=int(WINDOW_SIZE * SAMPLE_RATE),
        hop_length=int(HOP_LENGTH * SAMPLE_RATE),
        n_mels=N_MELS,
        center=True,
        power=2.0,
        f_min=MIN_FREQUENCY,
        f_max=MAX_FREQUENCY,
    )
    log_transform = torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80)

    @staticmethod
    def wiener_filter(waveform: torch.Tensor) -> torch.Tensor:
        """Applies Wiener filtering for noise reduction.

        Parameters
        ----------
            waveform: Input audio tensor (1, T)
        Returns
        -------
            Denoised waveform (1, T)
        """
        stft = torch.stft(
            waveform.squeeze(0),
            n_fft=AudioProcessor.WIENER_N_FFT,
            hop_length=AudioProcessor.WIENER_HOP_LENGTH,
            win_length=AudioProcessor.WIENER_WIN_LENGTH,
            window=torch.hann_window(AudioProcessor.WIENER_WIN_LENGTH).to(waveform.device),
            return_complex=True,
        )

        # estimate noise from first few frames
        magnitude = torch.abs(stft)
        noise_estimate = magnitude[:, : AudioProcessor.NOISE_FRAME_COUNT].mean(dim=1, keepdim=True)

        # wiener gain
        gain = (magnitude - noise_estimate).clamp(min=0) / (magnitude + AudioProcessor.EPSILON)

        # reconstruct waveform
        enhanced_stft = stft * gain
        enhanced_waveform = torch.istft(
            enhanced_stft,
            n_fft=AudioProcessor.WIENER_N_FFT,
            hop_length=AudioProcessor.WIENER_HOP_LENGTH,
            win_length=AudioProcessor.WIENER_WIN_LENGTH,
            window=torch.hann_window(AudioProcessor.WIENER_WIN_LENGTH).to(waveform.device),
        )

        return enhanced_waveform.unsqueeze(0)

    @staticmethod
    def preprocess(waveform: torch.Tensor, original_sample_rate: int, apply_wiener: bool = True) -> torch.Tensor:
        """Preprocess the input audio waveform for feature extraction.

        This method performs the following steps:
        1. Converts stereo waveform to mono by averaging across channels.
        2. Normalizes the waveform to the range [-1, 1].
        3. Resamples the audio waveform to a target sample rate (16 kHz) if needed.
        4. Applies a Wiener filter to the waveform, if specified.
        5. Computes a log-Mel spectrogram of the waveform.

        Parameters
        ----------
        waveform : torch.Tensor
            A 1D or 2D tensor representing the audio waveform. Shape: (channels, samples).

        original_sample_rate : int
            The original sample rate of the waveform.

        apply_wiener : bool, optional, default=True
            Whether to apply the Wiener filter to the waveform. If set to False, the Wiener filter is skipped.

        Returns
        -------
        torch.Tensor
            A tensor representing the log-Mel spectrogram of the input waveform. Shape: (n_mel_bins, time_steps).
        """

        # convert to mono
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # normalize waveform to [-1, 1]
        waveform = waveform / waveform.abs().max()

        # resample to 16kHz if needed
        if original_sample_rate != AudioProcessor.SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(
                orig_freq=original_sample_rate, new_freq=AudioProcessor.SAMPLE_RATE
            )
            waveform = resampler(waveform)

        if apply_wiener:
            waveform = AudioProcessor.wiener_filter(waveform)

        # compute log-mel spectrogram
        mel_spec = AudioProcessor.mel_spectrogram_transform(waveform)
        log_mel_spec = AudioProcessor.log_transform(mel_spec + AudioProcessor.EPSILON)  # add epsilon to avoid log(0)

        return log_mel_spec


class TextProcessor(ABC):
    """Abstract base class for text processing to generate embeddings."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @abstractmethod
    def process(self, text: str) -> torch.Tensor:
        """Process the input text and return its embedding.

        Parameters
        ----------
        text : str
            The input text to process.

        Returns
        -------
        torch.Tensor
            A tensor representing the processed text embedding.
        """
        pass


class BertProcessor(TextProcessor):
    """BERT-based implementation of the TextProcessor.

    This class uses a pre-trained BERT model to tokenize input text and extract contextual embeddings.

    Parameters
    ----------
    tokenizer : str, optional
        The name of the BERT tokenizer to use (default is "bert-base-uncased").
    model : str, optional
        The name of the BERT model to use (default is "bert-base-uncased").
    """

    def __init__(self, tokenizer: str = "bert-base-uncased", model: str = "bert-base-uncased"):
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
        self.model = BertModel.from_pretrained(model).to(BertProcessor.device)

    def tokenize_text(self, text: str) -> Dict[str, torch.Tensor]:
        """Tokenize input text using the BERT tokenizer.

        Parameters
        ----------
        text : str
            The input string to tokenize.

        Returns
        -------
        Dict[str, torch.Tensor]
            A dictionary containing input IDs, attention masks, and token type IDs.
        """
        return self.tokenizer(text, return_tensors="pt").to(BertProcessor.device)

    def embed_tokens(self, tokens: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Generate token embeddings from tokenized input using BERT.

        Parameters
        ----------
        tokens : Dict[str, torch.Tensor]
            A dictionary containing tokenized inputs as returned by the tokenizer.

        Returns
        -------
        torch.Tensor
            A tensor of shape (batch_size, sequence_length, hidden_size)
            representing contextual embeddings for each token.
        """
        with torch.no_grad():
            outputs = self.model(**tokens)
            return outputs.last_hidden_state

    def process(self, text: str) -> torch.Tensor:
        """Process the input text by tokenizing and generating BERT embeddings.

        Parameters
        ----------
        text : str
            The input text to process.

        Returns
        -------
        torch.Tensor
            A tensor of contextual embeddings for the input text.
        """
        tokens = self.tokenize_text(text)
        embeddings = self.embed_tokens(tokens)

        return embeddings

2025-05-07 13:23:04.309105: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746624184.537167      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746624184.600102      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# DataLoader

In [2]:
import os
import random
import torchaudio
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from typing import List, Tuple


class MovieSubDataset(Dataset):
    def __init__(self, samples: List[Tuple[str, str]], text_processor: TextProcessor = BertProcessor(), show_progress: bool =True):
        self.data = []
        self.text_processor = text_processor

        iterator = tqdm(samples, desc="Loading samples") if show_progress else samples

        for wav_path, txt_path in iterator:
            waveform, sample_rate = torchaudio.load(wav_path)
            with open(txt_path, "r", encoding="utf-8") as f:
                subtitle_text = f.read()

            features = AudioProcessor.preprocess(waveform, sample_rate)
            text_embeddings = self.text_processor.process(subtitle_text)
            
            self.data.append((features, subtitle_text, text_embeddings))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch):
    features, subtitles, embeddings = zip(*batch)

    # pad embeddings to max sequence length
    max_len = max(e.shape[1] for e in embeddings)
    padded_embeddings = torch.stack([
        F.pad(e, (0, 0, 0, max_len - e.shape[1]))
        for e in embeddings
    ])

    features = torch.stack(features)
    return features, list(subtitles), padded_embeddings


def load_movie_subs(root_dir: str, batch_size: int = 64) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Load movie subs dataset, split into train/val/test sets, and return corresponding DataLoaders.

    Parameters
    ----------
    root_dir : str
        Path to the dataset directory containing subdirectories for each movie.
    batch_size : int, optional
        Batch size for the DataLoaders. Default is 64.

    Returns
    -------
    Tuple[DataLoader, DataLoader, DataLoader]
        DataLoaders for training, validation, and test datasets.
    """
    all_samples = []

    for movie_name in os.listdir(root_dir):
        movie_path = os.path.join(root_dir, movie_name)
        if not os.path.isdir(movie_path):
            continue

        for fname in os.listdir(movie_path):
            if fname.endswith(".wav"):
                base = os.path.splitext(fname)[0]
                
                wav_path = os.path.join(movie_path, f"{base}.wav")
                txt_path = os.path.join(movie_path, f"{base}.txt")
                
                if os.path.exists(wav_path) and os.path.exists(txt_path):
                    all_samples.append((wav_path, txt_path))

    # shuffle and split
    random.seed(42)
    random.shuffle(all_samples)

    train_data, temp_data = train_test_split(all_samples, test_size=0.3, random_state=42)
    val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

    # create DataLoaders
    train_loader = DataLoader(MovieSubDataset(train_data), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(MovieSubDataset(val_data), batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(MovieSubDataset(test_data), batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return train_loader, val_loader, test_loader


def main():
    train_loader, validation_loader, test_loader = load_movie_subs("/kaggle/input/movie2sub-dataset/dataset")

    batch = next(iter(train_loader))
    feature, subtitle, embedding = batch[0][0], batch[1][0], batch[2][0]

    print(f"Feature shape: {feature.numpy().shape}\n")
    print("Subtitle text:\n", subtitle)
    print(f"Embedding shape: {embedding.cpu().numpy().shape}")
    

if __name__ == "__main__":
    main()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Loading samples: 100%|██████████| 702/702 [01:42<00:00,  6.86it/s]
Loading samples: 100%|██████████| 150/150 [00:22<00:00,  6.56it/s]
Loading samples: 100%|██████████| 151/151 [00:22<00:00,  6.70it/s]


Feature shape: (1, 64, 3001)

Subtitle text:
 i want to help you but i won't be locked in no goddamn trunk of no car
you think i wanted to spend 10000 on your ass huh
i know you helped me out
do you think i wanted to spend 10000 on your ass
yes or no
of course you didn't but look in the trunk
that's the only way i could help you right
so that's what i did
now look man all i'm asking you to do is get in the trunk hold this fuckin' shotgun point it at these buddhaheads when i open it all right
you're catching a nigger offguard with this shit
look here
i tell you what
when we get through fucking with these koreans
Embedding shape: (1, 147, 768)
