# 06: Adversarial Jamming: Smart Noise vs. Loud Noise

**Hypothesis**: A universally trained perturbation (UAP) acts as a more efficient "jammer" than random white noise. It should degrade transcription accuracy at lower volumes (higher SNR) compared to random noise.

**Scenario**: We simulate an "Over-the-Air" attack by digitally mixing the noise into the audio track, as if it were playing in the background of a room.

## Goals
1. Load the trained UAP (from Notebook 04).
2. Compare it against Gaussian (White) Noise.
3. Measure Transcription WER across different "loudness" levels (SNR).

In [None]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import glob
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from jiwer import wer

# Add src to path
sys.path.append(os.path.join(os.getcwd(), '..'))
from src.data.audio_loader import load_audio, get_audio_duration
from src.data.download_data import download_librispeech_sample

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Setup: Load Model and Data

In [None]:
# Load Victim Model (Standard Whisper)
model_name = "openai/whisper-base"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
model.eval()
print("Victim model loaded.")

In [None]:
# Load a clean audio sample
data_root = os.path.join(os.getcwd(), '..', 'data')
dataset_path = download_librispeech_sample(data_root)
files = glob.glob(os.path.join(dataset_path, "**", "*.flac"), recursive=True)

if not files:
    raise RuntimeError("No audio files found! Run Notebook 01 first.")

sample_path = files[0]
clean_audio = load_audio(sample_path).to(device)

# Ensure shape is (1, samples)
if clean_audio.ndim == 1:
    clean_audio = clean_audio.unsqueeze(0)

print(f"Loaded audio: {sample_path}")

## 2. Load the Universal Adversarial Perturbation (UAP)
We attempt to load `uap.pt`. If it doesn't exist (because you haven't run Notebook 04 yet), we generate a random one for demonstration purposes.

In [None]:
uap_path = os.path.join(os.getcwd(), 'uap.pt')

if os.path.exists(uap_path):
    print("Loading trained UAP...")
    uap_noise = torch.load(uap_path, map_location=device)
else:
    print("WARNING: UAP file not found. Generating RANDOM UAP for demo purposes.")
    # Generate a random perturbation of 10 seconds (avg length)
    # In reality, this should be your trained vector
    uap_noise = torch.randn(1, 16000 * 10).to(device) * 0.05

# Prepare UAP: Tile it to match audio length if necessary
def match_length(noise, target_len):
    if noise.shape[1] >= target_len:
        return noise[:, :target_len]
    else:
        repeat_times = (target_len // noise.shape[1]) + 1
        return noise.repeat(1, repeat_times)[:, :target_len]

uap_aligned = match_length(uap_noise, clean_audio.shape[1])
print(f"UAP Shape aligned: {uap_aligned.shape}")

## 3. Define Mixing Function (SNR)
We mix noise at a specific Signal-to-Noise Ratio (dB). 
- **High SNR (e.g., 40dB)** = Quiet Noise (Hard to hear)
- **Low SNR (e.g., 10dB)** = Loud Noise (Very obvious)

We want to see the UAP break the model at **High SNR**.

In [None]:
def set_snr(clean, noise, target_snr_db):
    """
    Scale noise to achieve target SNR.
    SNR = 10 * log10(P_signal / P_noise)
    P_noise_target = P_signal / 10^(SNR/10)
    """
    # Calculate power
    p_signal = torch.mean(clean ** 2)
    p_noise = torch.mean(noise ** 2)
    
    if p_noise == 0:
        return noise
        
    # Calculate scaling factor
    p_noise_target = p_signal / (10 ** (target_snr_db / 10))
    scale = torch.sqrt(p_noise_target / p_noise)
    
    return noise * scale

def transcribe(audio_tensor):
    """Run Whisper inference."""
    input_features = processor(audio_tensor.squeeze().cpu().numpy(), sampling_rate=16000, return_tensors="pt").input_features.to(device)
    predicted_ids = model.generate(input_features)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return transcription

## 4. Run Experiment: Smart vs. Loud Jammer
We iterate through SNR levels from 50dB (quiet) to 0dB (loud).

In [None]:
snr_levels = [50, 40, 30, 20, 10, 5, 0] # dB
uap_wers = []
white_wers = []

# 1. Get Baseline
baseline_text = transcribe(clean_audio)
print(f"Baseline Text: {baseline_text}\n")

for snr in snr_levels:
    print(f"--- Testing SNR: {snr} dB ---")
    
    # A. White Noise (Random)
    white_noise_raw = torch.randn_like(clean_audio)
    white_noise_scaled = set_snr(clean_audio, white_noise_raw, snr)
    audio_white = torch.clamp(clean_audio + white_noise_scaled, -1.0, 1.0)
    
    trans_white = transcribe(audio_white)
    wer_white = wer(baseline_text, trans_white)
    white_wers.append(wer_white)
    
    # B. UAP (Smart Noise)
    # Note: We align UAP earlier
    uap_scaled = set_snr(clean_audio, uap_aligned, snr)
    audio_uap = torch.clamp(clean_audio + uap_scaled, -1.0, 1.0)
    
    trans_uap = transcribe(audio_uap)
    wer_uap = wer(baseline_text, trans_uap)
    uap_wers.append(wer_uap)
    
    print(f"White Noise WER: {wer_white:.2f} | UAP WER: {wer_uap:.2f}")
    if wer_uap > 0.8: 
        print(f"  -> UAP JAMMED at {snr}dB: '{trans_uap}'")

## 5. Visualization

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(snr_levels, white_wers, marker='o', label='White Noise (Random)', linestyle='--')
plt.plot(snr_levels, uap_wers, marker='x', label='UAP (Smart Jammer)', linewidth=2, color='red')

plt.title('Jamming Efficiency: UAP vs. Random Noise')
plt.xlabel('SNR (dB) - Higher is Quieter Noise')
plt.ylabel('Word Error Rate (WER)')
plt.gca().invert_xaxis() # We want x-axis to go from Quiet (50) to Loud (0)
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()