<a href="https://colab.research.google.com/" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Detector2 — Dynamite Fishing Detection System

This notebook implements a rule-based prototype for detecting dynamite fishing events
in `.wav` files based on the hierarchical framework detailed in *"Acoustic Signatures
of Underwater Explosions: A Technical Report."*

The detection system uses a three-tiered approach:
- **Tier 1:** Initial Event Detection (Candidate Identification)
- **Tier 2:** Primary Classification (Explosion Verification)
- **Tier 3:** Contextual Filtering (Confuser Rejection)

### Data input modes
| Mode | Cell to run | Description |
|------|------------|-------------|
| **Google Drive (mounted)** | *Configuration (COLAB ONLY)* | Mount Drive and point to a folder |
| **Google Drive link** | *Configuration (GOOGLE DRIVE LINK)* | Paste a shared Drive link; files are downloaded via `gdown` |
| **Local** | *Configuration (LOCAL ONLY)* | Point to a local folder |

After running **one** of the configuration cells above, always run the **Configuration (COMMON)** cell.

In [None]:
# @title Imports
import os
import sys
import glob
import numpy as np
import librosa
import scipy.signal
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [None]:
# @title Configuration (COLAB ONLY) — Mount Google Drive

# Run this cell ONLY in Google Colab when your audio files live on a mounted Drive.

from google.colab import drive
drive.mount('/content/drive')

# -----------------------------
# DRIVE PATHS (source of truth)
# -----------------------------
input_audio_dir = "/content/drive/Shareddrives/MAR FUTURA/Perú/Spanish bombs/Bombs RS"
output_csv_filepath = f"{input_audio_dir}/explosion_results.csv"

# -----------------------------
# LOCAL STAGING (optional speed-up)
# -----------------------------
run_root = "/content/run_detector2"
staged_audio_dir = f"{run_root}/audio"
use_local_staging = False  # Set True to rsync WAVs to local SSD first

Path(run_root).mkdir(parents=True, exist_ok=True)

if use_local_staging:
    import subprocess
    Path(staged_audio_dir).mkdir(parents=True, exist_ok=True)
    print('Staging WAV files to local disk...')
    subprocess.run([
        'rsync', '-a', '--info=progress2', '--prune-empty-dirs',
        '--include=*/', '--include=*.wav', '--include=*.WAV', '--exclude=*',
        input_audio_dir.rstrip('/') + '/',
        staged_audio_dir.rstrip('/') + '/',
    ], check=True)
    input_audio_dir = staged_audio_dir
    print('Staging complete.')

print('input_audio_dir:', input_audio_dir)
print('output_csv_filepath:', output_csv_filepath)

In [None]:
# @title Configuration (GOOGLE DRIVE LINK) — Download from a shared link

# Run this cell when you have a Google Drive shared link to a folder or ZIP of WAV files.
# Supports:
#   - A shared FOLDER link  (gdown downloads all files in the folder)
#   - A shared FILE link    (single WAV or a ZIP archive that will be extracted)

!pip install -q gdown
import gdown
import zipfile

# -----------------------------
# PASTE YOUR GOOGLE DRIVE LINK HERE
# -----------------------------
gdrive_link = "https://drive.google.com/drive/folders/1haxHTVn_l9sfK-p2Pb0P1Npa2_K1teZ6"

# -----------------------------
# LOCAL PATHS
# -----------------------------
run_root = "/content/run_detector2"
download_dir = f"{run_root}/download"
input_audio_dir = f"{run_root}/audio"
output_csv_filepath = f"{run_root}/explosion_results.csv"

Path(download_dir).mkdir(parents=True, exist_ok=True)
Path(input_audio_dir).mkdir(parents=True, exist_ok=True)

# Detect if the link is a folder or a single file
is_folder = '/folders/' in gdrive_link

if is_folder:
    print('Detected FOLDER link. Downloading all files...')
    gdown.download_folder(gdrive_link, output=input_audio_dir, quiet=False)
else:
    print('Detected FILE link. Downloading...')
    downloaded_file = gdown.download(gdrive_link, output=download_dir + '/', fuzzy=True, quiet=False)
    # If it is a ZIP, extract it
    if downloaded_file and downloaded_file.endswith('.zip'):
        print('Extracting ZIP archive...')
        with zipfile.ZipFile(downloaded_file, 'r') as zf:
            zf.extractall(input_audio_dir)
        print('Extraction complete.')
    elif downloaded_file and downloaded_file.lower().endswith('.wav'):
        import shutil
        shutil.move(downloaded_file, os.path.join(input_audio_dir, os.path.basename(downloaded_file)))
    else:
        print(f'Downloaded file: {downloaded_file}')
        print('If this is not a WAV or ZIP, please adjust the code.')

# List downloaded WAVs
wav_files = glob.glob(os.path.join(input_audio_dir, '**', '*.[wW][aA][vV]'), recursive=True)
print(f'WAV files found: {len(wav_files)}')
for f in wav_files[:5]:
    print(' ', f)

print('\ninput_audio_dir:', input_audio_dir)
print('output_csv_filepath:', output_csv_filepath)

In [None]:
# @title Configuration (LOCAL ONLY)

# Run this cell only if you're running locally.

# -----------------------------
# PATHS (Local)
# -----------------------------
input_audio_dir = "/path/to/your/wav/folder"
output_csv_filepath = "/path/to/your/wav/folder/explosion_results.csv"

print('input_audio_dir:', input_audio_dir)
print('output_csv_filepath:', output_csv_filepath)

In [None]:
# @title Configuration (COMMON — ALWAYS RUN)

# Run this cell after you run ONE of the configuration cells above.

# Configuration dictionary for all tunable thresholds
DETECTOR_CONFIG = {
    # Tier 1: Initial Event Detection
    "energy_threshold_db": 20,  # dB above ambient noise level
    "broadband_check_bands": [(100, 1000), (1000, 5000), (5000, 15000)],  # Hz frequency bands
    "broadband_energy_threshold": 0.15,  # Min energy percentage per band
    
    # Tier 2: Primary Classification  
    "shockwave_rise_time_ms": 5,  # Max rise time in milliseconds
    "bubble_pulse_search_window_ms": 500,  # Search window after shockwave (ms)
    "bubble_pulse_min_period_ms": 20,  # Min time between shockwave and bubble pulse
    "bubble_pulse_max_period_ms": 200,  # Max time between shockwave and bubble pulse
    "bubble_pulse_amplitude_ratio": 0.1,  # Min amplitude ratio (bubble/shockwave)
    "peak_prominence_factor": 0.3,  # For secondary peak detection
    
    # Tier 3: Contextual Filtering (for future use)
    "repetition_interval_s": 30,  # Time window for repetition checks
    
    # General parameters
    "sample_rate": 22050,  # Target sample rate for analysis
    "frame_length_ms": 50,  # Frame length for RMS analysis
    "min_event_duration_ms": 10,  # Minimum event duration
}

# Dataset settings
dataset_fileglob = '*.[wW][aA][vV]'

# -----------------------------
# FILE RANGE SELECTION
# -----------------------------
# Use these to process files in batches.
# Examples:
#   file_start = 0,    file_end = 1000   -> files 1 to 1000
#   file_start = 1000, file_end = 2000   -> files 1001 to 2000
#   file_start = 0,    file_end = None   -> ALL files
file_start = 0     # 0-indexed start (inclusive)
file_end = None    # 0-indexed end (exclusive). Set to None to process all remaining files.

# -----------------------------
# VALIDATE REQUIRED PATHS
# -----------------------------
_required = ['input_audio_dir', 'output_csv_filepath']
_missing = [k for k in _required if k not in globals() or not globals()[k]]
if _missing:
    raise RuntimeError(
        'Missing required path variables from one of the configuration cells: '
        + ', '.join(_missing)
    )

Path(output_csv_filepath).parent.mkdir(parents=True, exist_ok=True)

# Discover WAV files
wav_files_all = sorted(glob.glob(os.path.join(input_audio_dir, '**', dataset_fileglob), recursive=True))
print(f'input_audio_dir: {input_audio_dir}')
print(f'output_csv_filepath: {output_csv_filepath}')
print(f'Total WAV files found: {len(wav_files_all)}')

# Apply file range
wav_files = wav_files_all[file_start:file_end]
print(f'File range: [{file_start} : {file_end if file_end is not None else "end"}]  ->  {len(wav_files)} files selected')

if wav_files:
    print('First file:', os.path.basename(wav_files[0]))
    print('Last file: ', os.path.basename(wav_files[-1]))
if len(wav_files) == 0:
    raise RuntimeError(
        f'No WAV files in the selected range [{file_start}:{file_end}]. '
        f'Total files available: {len(wav_files_all)}. '
        'Check input_audio_dir, dataset_fileglob, or file_start/file_end.'
    )

In [None]:
# @title Detection Logic (all tiers)

def _tier1_energy_threshold(audio_data, sr, config):
    """
    Tier 1, Rule 1: Energy Threshold Detection
    Finds segments where acoustic energy dramatically exceeds baseline.
    """
    
    # Calculate short-term RMS energy
    frame_length = int(config["frame_length_ms"] * sr / 1000)
    rms = librosa.feature.rms(y=audio_data, frame_length=frame_length, hop_length=frame_length//2)[0]
    
    # Convert to dB
    rms_db = 20 * np.log10(rms + 1e-8)  # Add small value to avoid log(0)
    
    # Calculate baseline (ambient noise level)
    baseline_db = np.median(rms_db)
    
    # Find high-energy segments
    threshold_db = baseline_db + config["energy_threshold_db"]
    high_energy_frames = np.where(rms_db > threshold_db)[0]
    
    if len(high_energy_frames) == 0:
        return {
            'passed': False,
            'reason': 'No high-energy transient found',
            'baseline_db': baseline_db,
            'threshold_db': threshold_db
        }
    
    # Group consecutive frames into candidates
    candidates = []
    hop_length = frame_length // 2
    
    # Find continuous regions
    frame_groups = []
    current_group = [high_energy_frames[0]]
    
    for frame in high_energy_frames[1:]:
        if frame == current_group[-1] + 1:
            current_group.append(frame)
        else:
            frame_groups.append(current_group)
            current_group = [frame]
    frame_groups.append(current_group)
    
    # Convert frame groups to sample indices
    for group in frame_groups:
        start_frame = group[0]
        end_frame = group[-1]
        
        start_idx = max(0, start_frame * hop_length - frame_length)
        end_idx = min(len(audio_data), (end_frame + 1) * hop_length + frame_length)
        
        # Check minimum duration
        duration_ms = (end_idx - start_idx) / sr * 1000
        if duration_ms >= config["min_event_duration_ms"]:
            candidates.append({
                'start_idx': start_idx,
                'end_idx': end_idx,
                'duration_ms': duration_ms,
                'max_energy_db': np.max(rms_db[start_frame:end_frame+1])
            })
    
    return {
        'passed': len(candidates) > 0,
        'candidates': candidates,
        'baseline_db': baseline_db,
        'threshold_db': threshold_db,
        'reason': f'Found {len(candidates)} energy candidates' if candidates else 'No candidates meet minimum duration'
    }


def _tier1_broadband_check(audio_data, sr, candidate, config):
    """
    Tier 1, Rule 2: Broadband Check
    Verifies that significant energy is present across multiple frequency bands.
    """
    
    # Extract candidate segment
    segment = audio_data[candidate['start_idx']:candidate['end_idx']]
    
    # Compute FFT
    fft = np.fft.fft(segment)
    freqs = np.fft.fftfreq(len(segment), 1/sr)
    magnitude = np.abs(fft)
    
    # Only consider positive frequencies
    positive_freq_mask = freqs >= 0
    freqs = freqs[positive_freq_mask]
    magnitude = magnitude[positive_freq_mask]
    
    total_energy = np.sum(magnitude**2)
    
    # Check energy in each frequency band
    bands_passed = 0
    for low_freq, high_freq in config["broadband_check_bands"]:
        band_mask = (freqs >= low_freq) & (freqs <= high_freq)
        band_energy = np.sum(magnitude[band_mask]**2)
        band_energy_ratio = band_energy / total_energy
        
        if band_energy_ratio >= config["broadband_energy_threshold"]:
            bands_passed += 1
    
    # Require energy in at least 2 out of 3 frequency bands
    return bands_passed >= 2


def _tier2_shockwave_detection(segment, sr, config):
    """
    Tier 2, Rule 3: Shockwave Detection
    Analyzes the segment for extremely short rise time characteristic of shockwaves.
    """
    
    # Find the absolute peak (potential shockwave)
    peak_idx = np.argmax(np.abs(segment))
    peak_amplitude = np.abs(segment[peak_idx])
    
    # Analyze rise time leading up to the peak
    # Look backwards from peak to find 10% and 90% amplitude points
    search_start = max(0, peak_idx - int(0.1 * sr))  # Search up to 100ms before peak
    search_segment = segment[search_start:peak_idx+1]
    search_amplitudes = np.abs(search_segment)
    
    # Find 10% and 90% amplitude points
    amp_10_percent = 0.1 * peak_amplitude
    amp_90_percent = 0.9 * peak_amplitude
    
    # Find last point below 10% threshold
    below_10_mask = search_amplitudes < amp_10_percent
    if not np.any(below_10_mask):
        return {
            'valid': False,
            'reason': 'Cannot find 10% amplitude point for rise time calculation'
        }
    
    idx_10_percent = np.where(below_10_mask)[0][-1]  # Last point below 10%
    
    # Find first point above 90% threshold after the 10% point
    above_90_mask = search_amplitudes[idx_10_percent:] > amp_90_percent
    if not np.any(above_90_mask):
        return {
            'valid': False,
            'reason': 'Cannot find 90% amplitude point for rise time calculation'
        }
    
    idx_90_percent = idx_10_percent + np.where(above_90_mask)[0][0]
    
    # Calculate rise time
    rise_time_samples = idx_90_percent - idx_10_percent
    rise_time_ms = (rise_time_samples / sr) * 1000
    
    # Check if rise time is sufficiently short for a shockwave
    if rise_time_ms > config["shockwave_rise_time_ms"]:
        return {
            'valid': False,
            'reason': f'Rise time too slow ({rise_time_ms:.1f}ms > {config["shockwave_rise_time_ms"]}ms threshold)'
        }
    
    return {
        'valid': True,
        'peak_idx': peak_idx,
        'peak_amplitude': peak_amplitude,
        'rise_time_ms': rise_time_ms,
        'reason': f'Valid shockwave detected (rise time: {rise_time_ms:.1f}ms)'
    }


def _tier2_bubble_pulse_search(segment, sr, shockwave_result, config):
    """
    Tier 2, Rule 4: Bubble Pulse Search
    Searches for secondary peak (bubble pulse) after the shockwave.
    """
    
    shockwave_idx = shockwave_result['peak_idx']
    shockwave_amplitude = shockwave_result['peak_amplitude']
    
    # Define search window after shockwave
    search_start_idx = shockwave_idx + int(config["bubble_pulse_min_period_ms"] * sr / 1000)
    search_end_idx = min(len(segment), 
                        shockwave_idx + int(config["bubble_pulse_search_window_ms"] * sr / 1000))
    
    if search_start_idx >= search_end_idx or search_start_idx >= len(segment):
        return {
            'found': False,
            'reason': 'Search window too short or extends beyond segment'
        }
    
    search_segment = segment[search_start_idx:search_end_idx]
    
    # Find peaks in the search region
    min_height = config["bubble_pulse_amplitude_ratio"] * shockwave_amplitude
    prominence = config["peak_prominence_factor"] * min_height
    
    peaks, properties = scipy.signal.find_peaks(
        np.abs(search_segment), 
        height=min_height,
        prominence=prominence
    )
    
    if len(peaks) == 0:
        return {
            'found': False,
            'reason': f'No secondary peaks found above {config["bubble_pulse_amplitude_ratio"]*100:.0f}% of shockwave amplitude'
        }
    
    # Select the most prominent peak as bubble pulse
    peak_heights = properties['peak_heights']
    most_prominent_idx = np.argmax(peak_heights)
    bubble_peak_idx = peaks[most_prominent_idx]
    bubble_amplitude = peak_heights[most_prominent_idx]
    
    # Convert to absolute index in segment
    bubble_idx_absolute = search_start_idx + bubble_peak_idx
    
    # Calculate bubble period
    period_samples = bubble_idx_absolute - shockwave_idx
    period_ms = (period_samples / sr) * 1000
    
    return {
        'found': True,
        'bubble_idx': bubble_idx_absolute,
        'bubble_amplitude': bubble_amplitude,
        'period_ms': period_ms,
        'amplitude_ratio': bubble_amplitude / shockwave_amplitude,
        'reason': f'Bubble pulse found at {period_ms:.1f}ms after shockwave'
    }


def _tier2_physical_plausibility(bubble_result, config):
    """
    Tier 2, Rule 5: Physical Plausibility Check
    Verifies that bubble period falls within physically realistic range.
    """
    
    period_ms = bubble_result['period_ms']
    min_period = config["bubble_pulse_min_period_ms"]
    max_period = config["bubble_pulse_max_period_ms"]
    
    return min_period <= period_ms <= max_period


def analyze_audio_for_explosion(audio_path, config=None):
    """
    Analyzes an audio file for dynamite explosion signatures based on the report.
    
    Args:
        audio_path (str): Path to the audio file
        config (dict): Configuration parameters
        
    Returns:
        dict: Analysis results with detection status and metrics
    """
    if config is None:
        config = DETECTOR_CONFIG
    
    try:
        # Load audio file
        audio_data, sr = librosa.load(audio_path, sr=config["sample_rate"])
        
        if len(audio_data) == 0:
            return {
                'is_explosion': False,
                'reason': 'Empty audio file',
                'file_path': audio_path,
                'sample_rate': sr
            }
            
        print(f"Loaded audio: {len(audio_data)/sr:.2f}s at {sr}Hz")
        
        # ==================== TIER 1: INITIAL EVENT DETECTION ====================
        print("\n--- TIER 1: Initial Event Detection ---")
        
        # Rule 1: Energy Threshold Detection
        tier1_result = _tier1_energy_threshold(audio_data, sr, config)
        if not tier1_result['passed']:
            return {
                'is_explosion': False,
                'reason': tier1_result['reason'],
                'file_path': audio_path,
                'sample_rate': sr,
                'tier1_result': tier1_result
            }
        
        candidates = tier1_result['candidates']
        print(f"Found {len(candidates)} high-energy candidates")
        
        # Rule 2: Broadband Check
        broadband_candidates = []
        for candidate in candidates:
            if _tier1_broadband_check(audio_data, sr, candidate, config):
                broadband_candidates.append(candidate)
        
        if not broadband_candidates:
            return {
                'is_explosion': False,
                'reason': 'No broadband high-energy transients found',
                'file_path': audio_path,
                'sample_rate': sr,
                'tier1_result': tier1_result
            }
            
        print(f"Found {len(broadband_candidates)} broadband candidates")
        
        # ==================== TIER 2: PRIMARY CLASSIFICATION ====================
        print("\n--- TIER 2: Primary Classification ---")
        
        for i, candidate in enumerate(broadband_candidates):
            print(f"\nAnalyzing candidate {i+1}/{len(broadband_candidates)}")
            
            # Extract candidate segment
            start_idx = candidate['start_idx']
            end_idx = candidate['end_idx']
            segment = audio_data[start_idx:end_idx]
            
            # Rule 3: Shockwave Detection
            shockwave_result = _tier2_shockwave_detection(segment, sr, config)
            if not shockwave_result['valid']:
                print(f"  X Candidate {i+1}: {shockwave_result['reason']}")
                continue
                
            print(f"  V Valid shockwave detected (rise time: {shockwave_result['rise_time_ms']:.1f}ms)")
            
            # Rule 4: Bubble Pulse Search
            bubble_result = _tier2_bubble_pulse_search(segment, sr, shockwave_result, config)
            if not bubble_result['found']:
                print(f"  X Candidate {i+1}: {bubble_result['reason']}")
                continue
                
            print(f"  V Bubble pulse found (period: {bubble_result['period_ms']:.1f}ms)")
            
            # Rule 5: Physical Plausibility Check
            if not _tier2_physical_plausibility(bubble_result, config):
                print(f"  X Candidate {i+1}: Bubble period outside physically plausible range")
                continue
            
            print(f"  V Physical plausibility confirmed")
            
            # If we reach here, we have a valid explosion detection!
            explosion_time = (start_idx + shockwave_result['peak_idx']) / sr
            
            return {
                'is_explosion': True,
                'reason': 'Shockwave and plausible bubble pulse detected',
                'file_path': audio_path,
                'sample_rate': sr,
                'explosion_time_s': explosion_time,
                'metrics': {
                    'shockwave_time_s': explosion_time,
                    'bubble_period_ms': bubble_result['period_ms'],
                    'rise_time_ms': shockwave_result['rise_time_ms'],
                    'shockwave_amplitude': shockwave_result['peak_amplitude'],
                    'bubble_amplitude': bubble_result['bubble_amplitude'],
                    'amplitude_ratio': bubble_result['amplitude_ratio']
                },
                'detection_data': {
                    'audio_segment': segment,
                    'segment_start_time': start_idx / sr,
                    'shockwave_idx': shockwave_result['peak_idx'],
                    'bubble_idx': bubble_result['bubble_idx'],
                    'sample_rate': sr
                }
            }
        
        # No valid explosions found
        return {
            'is_explosion': False,
            'reason': 'No candidates passed all explosion verification tests',
            'file_path': audio_path,
            'sample_rate': sr,
            'candidates_analyzed': len(broadband_candidates)
        }
        
    except Exception as e:
        return {
            'is_explosion': False,
            'reason': f'Error analyzing audio: {str(e)}',
            'file_path': audio_path
        }


def plot_detection_details(audio_data, sr, metrics):
    """
    Creates a two-panel visualization showing the detected explosion signature.
    
    Args:
        audio_data (np.array): Audio data
        sr (int): Sample rate  
        metrics (dict): Detection metrics and data
    """
    
    detection_data = metrics['detection_data']
    segment = detection_data['audio_segment']
    shockwave_idx = detection_data['shockwave_idx'] 
    bubble_idx = detection_data['bubble_idx']
    segment_start_time = detection_data['segment_start_time']
    
    # Create time arrays
    segment_time = np.arange(len(segment)) / sr + segment_start_time
    full_time = np.arange(len(audio_data)) / sr
    
    # Create figure with 2 subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
    fig.suptitle('Dynamite Explosion Detection Analysis', fontsize=16, fontweight='bold')
    
    # ==================== WAVEFORM PLOT ====================
    ax1.plot(segment_time, segment, 'b-', linewidth=1, alpha=0.7, label='Audio Waveform')
    
    # Mark shockwave peak
    shockwave_time = segment_time[shockwave_idx]
    ax1.axvline(shockwave_time, color='red', linewidth=2, linestyle='--', alpha=0.8)
    ax1.plot(shockwave_time, segment[shockwave_idx], 'ro', markersize=8, label='Shockwave Peak')
    ax1.text(shockwave_time, segment[shockwave_idx] + 0.1*np.max(np.abs(segment)), 
             'Shockwave Peak', ha='center', va='bottom', fontweight='bold', color='red')
    
    # Mark bubble pulse peak
    bubble_time = segment_time[bubble_idx]
    ax1.axvline(bubble_time, color='orange', linewidth=2, linestyle='--', alpha=0.8)
    ax1.plot(bubble_time, segment[bubble_idx], 'o', color='orange', markersize=8, label='Bubble Pulse')
    ax1.text(bubble_time, segment[bubble_idx] + 0.1*np.max(np.abs(segment)), 
             'First Bubble Pulse', ha='center', va='bottom', fontweight='bold', color='orange')
    
    # Add double-headed arrow for bubble period
    arrow_y = 0.8 * np.max(np.abs(segment))
    ax1.annotate('', xy=(bubble_time, arrow_y), xytext=(shockwave_time, arrow_y),
                arrowprops=dict(arrowstyle='<->', color='purple', lw=2))
    bubble_period_ms = metrics["metrics"]["bubble_period_ms"]
    ax1.text((shockwave_time + bubble_time) / 2, arrow_y + 0.05*np.max(np.abs(segment)),
             f'Bubble Period (tpuls)\n{bubble_period_ms:.1f}ms', 
             ha='center', va='bottom', fontweight='bold', color='purple',
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
    ax1.set_xlabel('Time (seconds)', fontweight='bold')
    ax1.set_ylabel('Amplitude', fontweight='bold')
    ax1.set_title('Annotated Waveform - Explosion Event Detection', fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # ==================== SPECTROGRAM PLOT ====================
    # Compute spectrogram
    D = librosa.stft(segment)
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    
    # Plot spectrogram
    img = librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='hz', 
                                   ax=ax2, cmap='viridis')
    
    # Adjust time axis to match segment timing
    ax2_xlim = ax2.get_xlim()
    time_offset = segment_start_time
    ax2.set_xlim([ax2_xlim[0] + time_offset, ax2_xlim[1] + time_offset])
    
    # Mark explosion time
    ax2.axvline(shockwave_time, color='red', linewidth=3, linestyle='-', alpha=0.9)
    ax2.text(shockwave_time + 0.01, ax2.get_ylim()[1] * 0.9, 
             'Broadband Shockwave Event', rotation=90, ha='left', va='top', 
             fontweight='bold', color='red',
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
    # Add bracket for reverberant tail
    tail_start = bubble_time
    tail_end = segment_time[-1]
    tail_y = ax2.get_ylim()[1] * 0.7
    
    # Draw bracket
    bracket_height = ax2.get_ylim()[1] * 0.05
    ax2.plot([tail_start, tail_start], [tail_y - bracket_height, tail_y + bracket_height], 
             'k-', linewidth=2)
    ax2.plot([tail_end, tail_end], [tail_y - bracket_height, tail_y + bracket_height], 
             'k-', linewidth=2)  
    ax2.plot([tail_start, tail_end], [tail_y, tail_y], 'k-', linewidth=2)
    
    ax2.text((tail_start + tail_end) / 2, tail_y + bracket_height * 2,
             'Reverberant Tail', ha='center', va='bottom', fontweight='bold',
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
    ax2.set_title('Annotated Spectrogram - Frequency Domain Analysis', fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(img, ax=ax2, format='%+2.0f dB')
    cbar.set_label('Magnitude (dB)', fontweight='bold')
    
    # Add detection summary text box
    summary_text = (f"Detection Summary:\n"
                    f"  Explosion detected at {metrics['explosion_time_s']:.2f}s\n"
                    f"  Shockwave rise time: {metrics['metrics']['rise_time_ms']:.1f}ms\n"
                    f"  Bubble period: {metrics['metrics']['bubble_period_ms']:.1f}ms\n"
                    f"  Amplitude ratio: {metrics['metrics']['amplitude_ratio']:.2f}\n"
                    f"  Classification: POSITIVE DETECTION")
    
    props = dict(boxstyle='round', facecolor='lightgreen', alpha=0.8)
    ax1.text(0.02, 0.98, summary_text, transform=ax1.transAxes, fontsize=10,
             verticalalignment='top', bbox=props, fontweight='bold')
    
    plt.tight_layout()
    return fig


print('Detection logic loaded.')

In [None]:
# @title Run detector on all WAV files and write inference CSV

from pathlib import Path

# Sanity checks
audio_root = Path(str(input_audio_dir))
print('input_audio_dir exists:', audio_root.exists(), 'is_dir:', audio_root.is_dir())

if not audio_root.exists() or not audio_root.is_dir():
    raise RuntimeError(f'input_audio_dir is not a readable directory: {input_audio_dir}')

print(f'WAV files to analyze: {len(wav_files)}  (range [{file_start}:{file_end if file_end is not None else "end"}])')
if len(wav_files) == 0:
    raise RuntimeError('No WAV files found. Check input_audio_dir and dataset_fileglob.')

# Build batch-aware output CSV path so batches don't overwrite each other
_end_label = file_end if file_end is not None else len(wav_files_all)
batch_csv_filepath = str(
    Path(output_csv_filepath).with_name(
        Path(output_csv_filepath).stem + f'_{file_start}_{_end_label}.csv'
    )
)
print(f'Batch CSV: {batch_csv_filepath}')

# Run detection on every file
all_results = []
detections = []

for i, wav_path in enumerate(wav_files):
    global_idx = file_start + i
    print(f'\n[{global_idx + 1}/{len(wav_files_all)}] {os.path.basename(wav_path)}')
    result = analyze_audio_for_explosion(wav_path, DETECTOR_CONFIG)
    all_results.append(result)
    
    status = 'EXPLOSION' if result['is_explosion'] else 'no detection'
    print(f'  -> {status}: {result["reason"]}')
    
    if result['is_explosion']:
        detections.append(result)

# Build results DataFrame
rows = []
for r in all_results:
    row = {
        'filename': os.path.basename(r['file_path']),
        'filepath': r['file_path'],
        'is_explosion': r['is_explosion'],
        'reason': r['reason'],
    }
    if r['is_explosion']:
        row['explosion_time_s'] = r.get('explosion_time_s')
        row['rise_time_ms'] = r['metrics']['rise_time_ms']
        row['bubble_period_ms'] = r['metrics']['bubble_period_ms']
        row['shockwave_amplitude'] = r['metrics']['shockwave_amplitude']
        row['bubble_amplitude'] = r['metrics']['bubble_amplitude']
        row['amplitude_ratio'] = r['metrics']['amplitude_ratio']
    rows.append(row)

results_df = pd.DataFrame(rows)

# Write batch CSV
results_df.to_csv(batch_csv_filepath, index=False)

# Also append to the main CSV (create if first batch, append otherwise)
if os.path.exists(output_csv_filepath):
    results_df.to_csv(output_csv_filepath, mode='a', header=False, index=False)
    print(f'Appended {len(results_df)} rows to: {output_csv_filepath}')
else:
    results_df.to_csv(output_csv_filepath, index=False)
    print(f'Created: {output_csv_filepath}')

print('\n' + '=' * 60)
print('DETECTION COMPLETE')
print('=' * 60)
print(f'Batch range: [{file_start} : {file_end if file_end is not None else "end"}]')
print(f'Files analyzed: {len(all_results)}')
print(f'Explosions detected: {len(detections)}')
print(f'Batch CSV: {batch_csv_filepath}')
print(f'Main CSV:  {output_csv_filepath}')

# Optionally copy CSV back to Drive (only if configured)
if 'output_csv_drive' in globals() and output_csv_drive:
    import subprocess
    print('Copying results CSV back to Drive...')
    subprocess.run(['cp', '-f', output_csv_filepath, output_csv_drive], check=True)
    print('Done. Wrote:', output_csv_drive)

In [None]:
# @title Plot detection details for each positive detection

if not detections:
    print('No explosions detected. Nothing to plot.')
else:
    for i, result in enumerate(detections):
        print(f'\nPlotting detection {i+1}/{len(detections)}: {os.path.basename(result["file_path"])}')
        audio_data, sr = librosa.load(result['file_path'], sr=DETECTOR_CONFIG['sample_rate'])
        fig = plot_detection_details(audio_data, sr, result)
        plt.show()
        plt.close(fig)

In [None]:
# @title Plot detections over time (detections/hour)

import re

csv_path = output_csv_filepath
df = pd.read_csv(csv_path)
print('rows:', len(df))
print('columns:', list(df.columns))

# Only keep positive detections
df = df[df['is_explosion'] == True].copy()

if df.empty:
    print('No positive detections to plot.')
else:
    # Parse datetime from filename like YYYYMMDD_HHMMSS
    _dt_re = re.compile(r'(\d{8})_(\d{6})')

    def extract_dt(fname: str):
        m = _dt_re.search(str(fname))
        if not m:
            return pd.NaT
        return pd.to_datetime(m.group(1) + m.group(2), format='%Y%m%d%H%M%S', errors='coerce')

    df['dt'] = df['filename'].apply(extract_dt)
    plot_df = df.dropna(subset=['dt']).copy()

    if plot_df.empty:
        print('No rows had a parseable datetime in filename. Adjust extract_dt() regex/format.')
    else:
        plot_df = plot_df.set_index('dt').sort_index()

        detections_per_hour = plot_df['filename'].resample('1h').count()

        plt.figure(figsize=(12, 4))
        detections_per_hour.plot()
        plt.title('Explosion Detections per Hour')
        plt.xlabel('Time')
        plt.ylabel('Detections / hour')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

        # Daily totals
        daily = plot_df['filename'].resample('1D').count()
        plt.figure(figsize=(12, 4))
        daily.plot()
        plt.title('Explosion Detections per Day')
        plt.xlabel('Date')
        plt.ylabel('Detections / day')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()