In [1]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import glob
from transformers import WhisperProcessor

from src.models.whisper_wrapper import WhisperASRWithAttack
from src.attacks.pgd import PGDAttack

import src.models as models
import src.attacks as attacks
import src.data as data_loader


device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [2]:
# Load dataset
audio_ds = data_loader.load_dataset()

audio, audio_tensor = data_loader.load_audio_tensor(audio_ds[0])

Found 2620 audio files.
Sample: /Users/victorhugogermano/Development/soundfinal/data/LibriSpeech/test-clean/61/70970/61-70970-0040.flac


In [None]:
# 2. Initialize Model & Processor
wrapper = WhisperASRWithAttack(device=device)
processor = WhisperProcessor.from_pretrained("openai/whisper-base")

def decode_output(logits):
    pred_ids = torch.argmax(logits, dim=-1)
    return processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

# Baseline transcription
with torch.no_grad():
    # Use the original Whisper transcribe method
    result = wrapper.model.transcribe(
        audio_tensor.cpu().numpy().squeeze(),
        language='en',
        fp16=False
    )
    transcription_clean = result['text'].strip()
    
print(f"Original Transcription: '{transcription_clean}'")

Loading weights:   0%|          | 0/245 [00:00<?, ?it/s]

  return _VF.stft(  # type: ignore[attr-defined]


Original Transcription: ''


In [5]:
from transformers import WhisperFeatureExtractor
def test_whisper_dimensions():
    print("--- Testing Whisper Feature Extractor Dimensions ---")
    model_path = "openai/whisper-base"
    try:
        feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
    except Exception as e:
        print(f"Error loading feature extractor: {e}")
        return

    mel_filters = feature_extractor.mel_filters
    print(f"feature_extractor.mel_filters shape: {mel_filters.shape}")
    
    # Expected: (80, 201) usually (n_mels, n_freq)
    
    # Simulate Audio
    sr = 16000
    seconds = 30
    audio = np.random.randn(sr * seconds).astype(np.float32)
    audio_tensor = torch.from_numpy(audio).unsqueeze(0)
    
    print(f"Audio Tensor Shape: {audio_tensor.shape}")
    
    # Manual STFT replication attempt (from wrapper code)
    n_fft = 400
    hop_length = 160
    window = torch.hann_window(n_fft)
    
    # 1. Pad/Crop to 30s
    if audio_tensor.shape[1] < 480000:
         audio_tensor = torch.nn.functional.pad(audio_tensor, (0, 480000 - audio_tensor.shape[1]))
    else:
         audio_tensor = audio_tensor[:, :480000]

    stft = torch.stft(
        audio_tensor,
        n_fft=n_fft,
        hop_length=hop_length,
        window=window,
        center=True,
        return_complex=True
    )
    magnitudes = stft.abs() ** 2
    # magnitudes shape: (Batch, Freq, Time) = (1, 201, 3001) usually for center=True 480000 samples
    
    print(f"STFT Magnitudes Shape (center=True): {magnitudes.shape}")
    
    # Wrapper code does magnitudes[:, :, :-1]
    magnitudes = magnitudes[:, :, :-1]
    print(f"Magnitudes after slicing :-1: {magnitudes.shape}") # Should be (1, 201, 3000)

    # Convert mel_filters to tensor
    mel_filters_tensor = torch.from_numpy(mel_filters).float()
    
    # Wrapper Logic:
    # mels = torch.matmul(magnitudes.transpose(1, 2), self.mel_filters).transpose(1, 2)
    # magnitudes.transpose(1, 2) -> (1, 3000, 201)
    
    print(f"Look at matmul: (1, 3000, 201) @ {mel_filters_tensor.shape}")
    
    try:
        mels = torch.matmul(magnitudes.transpose(1, 2), mel_filters_tensor).transpose(1, 2)
        print("Matmul Successful!")
        print(f"Mels Shape: {mels.shape}")
    except RuntimeError as e:
        print(f"Matmul Failed: {e}")
        print("Trying with Transpose of filters...")
        try:
             mels = torch.matmul(magnitudes.transpose(1, 2), mel_filters_tensor.T).transpose(1, 2)
             print("Matmul with .T Successful!")
             print(f"Mels Shape: {mels.shape}")
        except RuntimeError as e2:
             print(f"Matmul with .T Failed: {e2}")

    # Log Logic Check
    # The feature extractor output
    print("\n--- Comparing with HF execute ---")
    hf_out = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
    hf_mels = hf_out.input_features
    print(f"HF Output Shape: {hf_mels.shape}")
    
    if 'mels' in locals():
        # Complete the manual process to compare values
        log_mels = torch.log10(torch.clamp(mels, min=1e-10))
        log_mels = torch.maximum(log_mels, log_mels.max() - 8.0)
        log_mels = (log_mels + 4.0) / 4.0
        
        print("\n--- Value Comparison ---")
        print(f"Manual Mean: {log_mels.mean().item():.4f}, Max: {log_mels.max().item():.4f}, Min: {log_mels.min().item():.4f}")
        print(f"HF Mean: {hf_mels.mean().item():.4f}, Max: {hf_mels.max().item():.4f}, Min: {hf_mels.min().item():.4f}")
        
        # Check if identical (unlikely to be exactly identical due to float/implementation diffs, but should be close)
        # Note: HF implementation padding logic is slightly different (reflect vs constant, center=False)
        # HF does:
        # waveform = np.pad(waveform, ...)
        # window = np.hanning(n_fft)
        # stft = np.librosa.stft(..., center=True, pad_mode="reflect") <--- WAIT, HF uses center=True?
        
        # Actually HF implementation details:
        # self.feature_extractor(raw_speech) calls `_compute_log_mel_spectrogram`
        # which calls `stft(..., center=True)` ?
        
        # feature_extractor class says:
        # padding_side = "right"
        # padding_value = 0.0
        
        pass


test_whisper_dimensions()

--- Testing Whisper Feature Extractor Dimensions ---
feature_extractor.mel_filters shape: (201, 80)
Audio Tensor Shape: torch.Size([1, 480000])
STFT Magnitudes Shape (center=True): torch.Size([1, 201, 3001])
Magnitudes after slicing :-1: torch.Size([1, 201, 3000])
Look at matmul: (1, 3000, 201) @ torch.Size([201, 80])
Matmul Successful!
Mels Shape: torch.Size([1, 80, 3000])

--- Comparing with HF execute ---
HF Output Shape: torch.Size([1, 80, 3000])

--- Value Comparison ---
Manual Mean: 1.1119, Max: 1.4122, Min: -0.0208
HF Mean: 1.1119, Max: 1.4122, Min: -0.0208


In [4]:
print(f"Mel Filters Shape: {wrapper.mel_filters.shape}")
print(f"Device: {wrapper.mel_filters.device}")

# Also check audio tensor shape
print(f"Audio Tensor Shape: {audio_tensor.shape}")

Mel Filters Shape: torch.Size([201, 80])
Device: mps:0
Audio Tensor Shape: torch.Size([66640])


In [None]:
# 3. Perform PGD Attack
# Epsilon 0.02 is approx -34dB relative to max amplitude 1.0 (roughly)
attacker = PGDAttack(wrapper, epsilon=0.02, alpha=0.002, num_iter=30) 

print("Running PGD...")
adv_audio = attacker.generate(audio)

# 4. Evaluate
from src.attacks.pgd import compute_snr
snr = compute_snr(audio.cpu().numpy(), adv_audio.cpu().numpy())

with torch.no_grad():
    res_adv = wrapper(adv_audio)
    transcription_adv = decode_output(res_adv.logits)

print(f"Adversarial Transcription: '{transcription_adv}'")
print(f"SNR: {snr:.2f} dB")

# 5. Play Audio (Optional)
from IPython.display import Audio, display
print("Adversarial Audio:")
display(Audio(adv_audio.cpu().numpy(), rate=16000))