In [None]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import glob
from transformers import WhisperProcessor

# Add src to path
sys.path.append(os.path.join(os.getcwd(), '..'))

from src.models.whisper_wrapper import WhisperASRWithAttack
from src.attacks.pgd import PGDAttack
from src.data.audio_loader import load_audio
from src.data.download_data import download_librispeech_sample

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# 1. Load Data
data_root = os.path.join(os.getcwd(), '..', 'data')
dataset_path = download_librispeech_sample(data_root)

# Find a file
files = glob.glob(os.path.join(dataset_path, "**", "*.flac"), recursive=True)
if not files:
    raise RuntimeError("No audio files found! Run download first.")

target_file = files[0]
print(f"Attacking file: {target_file}")

# Load audio
audio = load_audio(target_file).to(device)
print(f"Audio shape: {audio.shape}")

In [None]:
# 2. Initialize Model & Processor
wrapper = WhisperASRWithAttack(device=device)
processor = WhisperProcessor.from_pretrained("openai/whisper-base")

def decode_output(logits):
    pred_ids = torch.argmax(logits, dim=-1)
    return processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

# Baseline transcription
with torch.no_grad():
    res_clean = wrapper(audio)
    transcription_clean = decode_output(res_clean.logits)
    
print(f"Original Transcription: '{transcription_clean}'")

In [None]:
# 3. Perform PGD Attack
# Epsilon 0.02 is approx -34dB relative to max amplitude 1.0 (roughly)
attacker = PGDAttack(wrapper, epsilon=0.02, alpha=0.002, num_iter=30) 

print("Running PGD...")
adv_audio = attacker.generate(audio)

# 4. Evaluate
from src.attacks.pgd import compute_snr
snr = compute_snr(audio.cpu().numpy(), adv_audio.cpu().numpy())

with torch.no_grad():
    res_adv = wrapper(adv_audio)
    transcription_adv = decode_output(res_adv.logits)

print(f"Adversarial Transcription: '{transcription_adv}'")
print(f"SNR: {snr:.2f} dB")

# 5. Play Audio (Optional)
from IPython.display import Audio, display
print("Adversarial Audio:")
display(Audio(adv_audio.cpu().numpy(), rate=16000))