# IndianBatsModel - Real-World Inference Pipeline

This notebook runs the trained model on **raw, uncurated audio files**.
Unlike the test notebook, this pipeline handles:
1.  **Detection**: Finding potential bat calls in long recordings (ignoring silence/noise).
2.  **Classification**: Predicting the species for each detected call.
3.  **Filtering**: Ignoring low-confidence predictions.

**Use this for:** Field recordings or files where you don't know if/where the bats are.

In [None]:
# 1. Setup Environment
!git clone https://github.com/Quarkisinproton/IndianBatsModel.git
!pip install librosa pyyaml pandas matplotlib scikit-learn tqdm python-docx scipy

In [None]:
# 2. Import Modules
import sys
import os
import torch
import numpy as np
import pandas as pd
import librosa
import librosa.display
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy.signal import butter, filtfilt
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from docx import Document
from docx.shared import Inches
from datetime import timedelta

# Define working directory
WORK_DIR = '/kaggle/working'


# Add repo path
REPO_DIR = os.path.join(WORK_DIR, 'IndianBatsModel')
SRC_DIR = os.path.join(REPO_DIR, 'src')
if REPO_DIR not in sys.path: sys.path.append(REPO_DIR)
if SRC_DIR not in sys.path: sys.path.append(SRC_DIR)

# Import project modules
try:
    from MainShitz.models.cnn_with_features import CNNWithFeatures
    from MainShitz.data_prep.wombat_to_spectrograms import make_mel_spectrogram
    from MainShitz.data_prep.extract_end_frequency import compute_end_frequency
    print("Imports successful!")
except ImportError as e:
    print(f"Import Error: {e}")

In [None]:
# 3. Configuration

# --- INPUTS ---
# Path to your trained model
MODEL_PATH = '/kaggle/working/models/bat_fused_best.pth'

# Path to folder containing RAW .wav files
INPUT_AUDIO_DIR = '/kaggle/input/your-test-audio-folder' 

# --- DETECTION SETTINGS ---
SAMPLE_RATE = 250000  # Typical for bat recorders (adjust if needed)
MIN_FREQ = 15000      # High-pass filter cutoff (15kHz) to remove wind/noise
ENERGY_THRESH = 0.02  # RMS energy threshold to trigger detection
MIN_DURATION = 0.01   # Minimum call duration (seconds)
PAD_DURATION = 0.05   # Padding around detected call (seconds)
CONFIDENCE_THRESH = 70.0 # Only report predictions > 70% confidence

# --- MODEL SETTINGS ---
NUM_CLASSES = 2       # Must match your trained model
CLASS_NAMES = ['pip-ceylonicusbat-species', 'pip-tenuisbat-species'] # Ensure correct order!

# Auto-detect model in /kaggle/input
found_models = []
for root, dirs, files in os.walk('/kaggle/input'):
    for file in files:
        if file.endswith('.pth'):
            found_models.append(os.path.join(root, file))
if found_models:
    MODEL_PATH = found_models[0]
    print(f"Auto-detected model: {MODEL_PATH}")

In [None]:
# 4. Define Detection & Processing Functions

def bandpass_filter(y, sr, low=15000, high=120000, order=4):
    """Lightweight Butterworth band-pass to remove low rumble and out-of-band noise."""
    if y is None or len(y) == 0:
        return y
    nyq = 0.5 * sr
    low = max(1.0, min(low, nyq * 0.9))
    high = min(high, nyq * 0.99)
    if high <= low:
        return y
    b, a = butter(order, [low / nyq, high / nyq], btype='band')
    return filtfilt(b, a, y)


def save_spectrogram_like_example(
    y_seg,
    sr,
    out_path,
    *,
    fmin=15000,
    fmax=120000,
    n_fft=2048,
    hop_length=256,
    n_mels=256,
    cmap='viridis',
    db_floor=-80.0,
    dpi=150,
):
    """Save a clean, axis-less spectrogram PNG similar to the provided example."""
    if y_seg is None or len(y_seg) < n_fft:
        return False

    # cap fmax to Nyquist
    nyq = (sr * 0.5) - 1
    fmax = min(float(fmax), float(nyq))
    if fmax <= fmin:
        fmax = float(nyq)

    S = librosa.feature.melspectrogram(
        y=y_seg,
        sr=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
        power=2.0,
    )
    S_db = librosa.power_to_db(S, ref=np.max)
    S_db = np.clip(S_db, db_floor, 0.0)

    plt.figure(figsize=(10, 3))
    plt.imshow(S_db, origin='lower', aspect='auto', cmap=cmap, vmin=db_floor, vmax=0.0, interpolation='nearest')
    plt.axis('off')
    plt.margins(0)
    plt.savefig(out_path, bbox_inches='tight', pad_inches=0, dpi=dpi)
    plt.close()
    return True


def detect_events(y, sr, min_freq=15000, threshold=0.01, min_dur=0.01, pad=0.05):
    """Finds segments of interest in the audio based on energy in high frequencies."""
    # Denoise: band-pass around bat band
    high_cut = min(120000, (sr * 0.5) - 1000) if sr else 120000
    y = bandpass_filter(y, sr, low=min_freq, high=high_cut)

    # 1. High-pass filter (legacy STFT masking for event detection)
    S = librosa.stft(y, n_fft=2048, hop_length=512)
    freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)

    # Zero out frequencies below min_freq
    S_filtered = S.copy()
    mask = freqs < min_freq
    S_filtered[mask, :] = 0

    # 2. Calculate RMS energy profile of filtered signal
    rms = librosa.feature.rms(S=S_filtered, frame_length=2048, hop_length=512)[0]
    times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=512)

    # 3. Thresholding
    is_active = rms > threshold

    # 4. Group into segments
    events = []
    start = None
    for i, active in enumerate(is_active):
        if active and start is None:
            start = times[i]
        elif not active and start is not None:
            end = times[i]
            if (end - start) >= min_dur:
                events.append((max(0, start - pad), min(librosa.get_duration(y=y, sr=sr), end + pad)))
            start = None

    # Handle case where file ends while active
    if start is not None:
        end = times[-1]
        if (end - start) >= min_dur:
            events.append((max(0, start - pad), end))

    # Merge overlapping segments
    if not events:
        return []

    merged = []
    curr_start, curr_end = events[0]
    for next_start, next_end in events[1:]:
        if next_start <= curr_end:
            curr_end = max(curr_end, next_end)
        else:
            merged.append((curr_start, curr_end))
            curr_start, curr_end = next_start, next_end
    merged.append((curr_start, curr_end))

    return merged


def prepare_input(y, sr, start, end):
    """Extracts segment, generates spectrogram and features for the model."""
    # Extract audio
    start_sample = int(start * sr)
    end_sample = int(end * sr)
    y_seg = y[start_sample:end_sample]

    if len(y_seg) < 512:
        return None, None, None, None, None  # Too short

    # Denoise segment
    high_cut = min(120000, (sr * 0.5) - 1000) if sr else 120000
    y_seg = bandpass_filter(y_seg, sr, low=15000, high=high_cut)

    # 1. Spectrogram
    S_db = make_mel_spectrogram(y_seg, sr)
    energy_db_max = float(S_db.max())

    # Convert to Image (normalize to 0-255 like training)
    S_min, S_max = S_db.min(), S_db.max()
    S_norm = (S_db - S_min) / (S_max - S_min + 1e-8)

    # Apply Magma Colormap (matches training data)
    S_colored = plt.get_cmap('magma')(S_norm)

    # Convert to uint8 RGB (drop alpha)
    S_img = (S_colored[:, :, :3] * 255).astype(np.uint8)

    # Flip vertically so low frequency is at the bottom
    S_img = np.flipud(S_img)

    img = Image.fromarray(S_img)

    # Transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    img_tensor = transform(img)

    # 2. Features (End Frequency)
    end_freq = compute_end_frequency(y, sr, start, end)
    if np.isnan(end_freq):
        end_freq = 0.0
    feat_tensor = torch.tensor([end_freq], dtype=torch.float32)

    return img_tensor, feat_tensor, img, y_seg, energy_db_max

In [None]:
# 5. Load Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

try:
    # Initialize model structure
    model = CNNWithFeatures(num_classes=NUM_CLASSES, numeric_feat_dim=1, pretrained=False)
    
    if os.path.exists(MODEL_PATH):
        print(f"Loading {MODEL_PATH}...")
        checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
        
        if isinstance(checkpoint, torch.nn.Module):
            model = checkpoint
        elif isinstance(checkpoint, dict):
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)
        
        model.to(device)
        model.eval()
        print("Model loaded!")
    else:
        print("Model file not found!")
except Exception as e:
    print(f"Error loading model: {e}")

In [None]:
# 6. Run Inference Loop
from IPython.display import display

# Output directory for spectrograms
SPECTROGRAM_DIR = os.path.join(WORK_DIR, 'inference_spectrograms')
os.makedirs(SPECTROGRAM_DIR, exist_ok=True)
print(f'Saving spectrograms to {SPECTROGRAM_DIR}...')

# Thresholds
LOW_ENERGY_THRESH_DB = -55.0

# Spectrogram style settings (matches your example)
SPEC_CMAP = 'viridis'
SPEC_DB_FLOOR = -80.0
SPEC_FMIN = 15000
SPEC_FMAX = 120000

# Word report
doc = Document()
doc.add_heading('Bat Species Inference Report', 0)

results = []

def format_time(seconds):
    return str(timedelta(seconds=float(seconds)))

# Find audio files
audio_files = []
if os.path.exists(INPUT_AUDIO_DIR):
    for root, dirs, files in os.walk(INPUT_AUDIO_DIR):
        for f in files:
            if f.lower().endswith(('.wav', '.mp3', '.flac')):
                audio_files.append(os.path.join(root, f))
else:
    print(f"Input directory {INPUT_AUDIO_DIR} does not exist.")

print(f"Found {len(audio_files)} files to process.")

for audio_path in tqdm(audio_files):
    try:
        # Load audio
        y, sr = librosa.load(audio_path, sr=None, mono=True)
        
        # Detect events
        events = detect_events(y, sr, min_freq=MIN_FREQ, threshold=ENERGY_THRESH, min_dur=MIN_DURATION, pad=PAD_DURATION)
        
        if not events:
            results.append({
                'filename': os.path.basename(audio_path),
                'start': '-', 'end': '-',
                'prediction': 'No Detection',
                'confidence': 0.0
            })
            continue
            
        # Process each event
        for start, end in events:
            img_t, feat_t, pil_img, y_seg, energy_db_max = prepare_input(y, sr, start, end)
            if img_t is None:
                print(f"  Skipping too-short segment [{format_time(start)} - {format_time(end)}]")
                continue
            
            # Inference
            with torch.no_grad():
                img_batch = img_t.unsqueeze(0).to(device)
                feat_batch = feat_t.unsqueeze(0).to(device)
                
                output = model(img_batch, feat_batch)
                probs = torch.nn.functional.softmax(output, dim=1)
                conf, pred_idx = torch.max(probs, 1)
            confidence_pct = conf.item() * 100
            pred_class = CLASS_NAMES[pred_idx.item()]
            
            # Reject obvious noise by low energy or low confidence
            if energy_db_max < LOW_ENERGY_THRESH_DB:
                final_pred = "Noise/LowEnergy"
                confidence_pct = 0.0
            else:
                final_pred = pred_class if confidence_pct >= CONFIDENCE_THRESH else "Uncertain"

            # Save styled spectrogram (looks like your example)
            spec_filename = f"{Path(audio_path).stem}_{start:.2f}_{end:.2f}_{final_pred}.png"
            spec_path = os.path.join(SPECTROGRAM_DIR, spec_filename)
            _ok = save_spectrogram_like_example(
                y_seg,
                sr,
                spec_path,
                fmin=SPEC_FMIN,
                fmax=SPEC_FMAX,
                cmap=SPEC_CMAP,
                db_floor=SPEC_DB_FLOOR,
            )

            # Display inline if confident or marked noise
            if final_pred != "Uncertain":
                print(f"\nDetected: {final_pred} ({confidence_pct:.1f}%) at {start:.2f}-{end:.2f}s in {os.path.basename(audio_path)}")
                if _ok:
                    display(Image.open(spec_path))

            results.append({
                'filename': os.path.basename(audio_path),
                'start': f"{start:.2f}",
                'end': f"{end:.2f}",
                'prediction': final_pred,
                'confidence': f"{confidence_pct:.1f}",
                'energy_db_max': f"{energy_db_max:.1f}"
            })

            # Add to Word report
            p = doc.add_paragraph()
            p.add_run(f"File: {os.path.basename(audio_path)}\n").bold = True
            p.add_run(f"Segment: {format_time(start)} - {format_time(end)}\n")
            p.add_run(f"Prediction: {final_pred}\n")
            p.add_run(f"Confidence: {confidence_pct:.1f}%\n")
            p.add_run(f"Max Energy (dB): {energy_db_max:.1f}\n")
            if _ok:
                doc.add_picture(spec_path, width=Inches(6))
            doc.add_paragraph('-' * 50)
            
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")

# Display Results and save report
df = pd.DataFrame(results)
print("\nInference Results:")
print(df.to_string())

report_path = os.path.join(WORK_DIR, 'Inference_Report.docx')
doc.save(report_path)
print(f"Report saved to {report_path}")

In [None]:
# 7. Save Results for Manual Verification
# Since we don't know the true species for these files, we save the predictions
# to a CSV file. You can give this file to an expert for verification.

if not df.empty:
    output_csv_path = os.path.join(WORK_DIR, 'inference_results.csv')
    df.to_csv(output_csv_path, index=False)
    
    print(f"\nSUCCESS: Results saved to: {output_csv_path}")
    print("Columns: filename, start, end, prediction, confidence, energy_db_max")
    print("-" * 50)
    print("Download this file from the 'Output' section of your Kaggle notebook.")
    print("-" * 50)
    
    # Preview again
    print(df.head())
else:
    print("No detections found. No CSV generated.")

In [None]:
# 8. Evaluate Accuracy against Ground Truth
# PERU3.txt does NOT include a filename column.
# So evaluation only makes sense if:
#   - you ran inference on the matching audio file (e.g. PERU3.wav), OR
#   - you select which `df['filename']` to evaluate.

from pathlib import Path

# ---- Choose which audio file in df to evaluate ----
# If you ran inference on multiple files, set this explicitly (exact basename):
EVAL_AUDIO_FILENAME = None  # e.g. "PERU3.wav"

# ---- Locate ground truth file ----
GT_CANDIDATES = [
    os.path.join(REPO_DIR, 'data', 'PERU3.txt'),
    os.path.join(WORK_DIR, 'PERU3.txt'),
    '/kaggle/input/indian-bats-data/PERU3.txt',
    '../data/PERU3.txt',
]

GT_FILE_PATH = next((p for p in GT_CANDIDATES if os.path.exists(p)), None)

if GT_FILE_PATH is None:
    print("Ground truth file (PERU3.txt) not found. Skipping evaluation.")
    print(f"Checked locations: {GT_CANDIDATES}")
else:
    print(f"Loading ground truth from: {GT_FILE_PATH}")

    if 'df' not in globals() or df is None or df.empty:
        print("Results dataframe `df` is empty/not found. Run inference first.")
    else:
        # ---- Decide which filename to evaluate ----
        df2 = df.copy()
        
        # Keep only rows that look like detections (start/end numeric)
        df2['start_s'] = pd.to_numeric(df2.get('start', np.nan), errors='coerce')
        df2['end_s'] = pd.to_numeric(df2.get('end', np.nan), errors='coerce')
        df2 = df2.dropna(subset=['start_s', 'end_s'])

        if df2.empty:
            print("No numeric detections in `df` to evaluate.")
        else:
            unique_files = sorted(df2['filename'].astype(str).unique().tolist())

            gt_stem = Path(GT_FILE_PATH).stem  # "PERU3"
            
            selected_file = None
            if EVAL_AUDIO_FILENAME is not None:
                # user override
                if EVAL_AUDIO_FILENAME in unique_files:
                    selected_file = EVAL_AUDIO_FILENAME
                else:
                    print(f"EVAL_AUDIO_FILENAME={EVAL_AUDIO_FILENAME!r} not found in df filenames:")
                    print(unique_files)
            else:
                # auto-pick by matching stem
                stem_matches = [f for f in unique_files if gt_stem.lower() in Path(f).stem.lower()]
                if len(stem_matches) == 1:
                    selected_file = stem_matches[0]
                elif len(unique_files) == 1:
                    selected_file = unique_files[0]
                else:
                    print("Multiple audio files detected in results, and GT has no filename column.")
                    print(f"Ground truth stem: {gt_stem!r}")
                    print("Found these result filenames:")
                    print(unique_files)
                    print("\nSet `EVAL_AUDIO_FILENAME = 'PERU3.wav'` (or your filename) in this cell and rerun.")

            if selected_file is None:
                pass
            else:
                print(f"\nEvaluating only this file: {selected_file}")
                det_df = df2[df2['filename'].astype(str) == selected_file].copy()

                # ---- Load GT ----
                gt_df = pd.read_csv(GT_FILE_PATH, sep='\t')

                # Map GT species -> model class labels
                SPECIES_MAP = {
                    'Pipistrellus tenuis': 'pip-tenuisbat-species',
                    'Pipistrellus ceylonicus': 'pip-ceylonicusbat-species',
                }

                if 'Species' not in gt_df.columns:
                    raise ValueError("'Species' column not found in ground truth file")
                if 'Begin Time (s)' not in gt_df.columns or 'End Time (s)' not in gt_df.columns:
                    raise ValueError("Ground truth must have 'Begin Time (s)' and 'End Time (s)'")

                gt_df['Species_Mapped'] = gt_df['Species'].astype(str).str.strip().map(SPECIES_MAP)
                gt_valid = gt_df.dropna(subset=['Species_Mapped']).copy()
                gt_valid['gt_start_s'] = pd.to_numeric(gt_valid['Begin Time (s)'], errors='coerce')
                gt_valid['gt_end_s'] = pd.to_numeric(gt_valid['End Time (s)'], errors='coerce')
                gt_valid = gt_valid.dropna(subset=['gt_start_s', 'gt_end_s'])

                print(f"GT rows (mapped): {len(gt_valid)}")
                print(f"Detections rows:   {len(det_df)}")

                # ---- Overlap match helper ----
                def overlaps(a_start, a_end, b_start, b_end):
                    return max(a_start, b_start) < min(a_end, b_end)

                # Metrics
                tp = 0
                fn = 0
                misclassified = 0
                fp = 0

                matched_det_indices = set()

                # Recall side: for each GT event, did we predict the right species?
                for _, gt_row in gt_valid.iterrows():
                    gt_start = float(gt_row['gt_start_s'])
                    gt_end = float(gt_row['gt_end_s'])
                    gt_label = gt_row['Species_Mapped']

                    candidates = []
                    for i, det_row in det_df.iterrows():
                        det_start = float(det_row['start_s'])
                        det_end = float(det_row['end_s'])
                        if overlaps(gt_start, gt_end, det_start, det_end):
                            candidates.append((i, det_row))

                    if not candidates:
                        fn += 1
                        continue

                    # Ignore non-claims for species
                    species_preds = [(i, r) for (i, r) in candidates if r['prediction'] not in ['No Detection', 'Noise/LowEnergy', 'Uncertain']]
                    if not species_preds:
                        fn += 1
                        continue

                    # If any overlapping prediction equals GT label -> TP else Misclassified
                    if any(r['prediction'] == gt_label for _, r in species_preds):
                        tp += 1
                    else:
                        misclassified += 1

                    for i, _ in species_preds:
                        matched_det_indices.add(i)

                # Precision side: any species claim not matched to any GT overlap is FP
                for i, det_row in det_df.iterrows():
                    if det_row['prediction'] in ['No Detection', 'Noise/LowEnergy', 'Uncertain']:
                        continue

                    # If this detection overlaps ANY GT event, we don't count as FP (it was "about" some GT)
                    det_start = float(det_row['start_s'])
                    det_end = float(det_row['end_s'])

                    has_any_gt_overlap = False
                    for _, gt_row in gt_valid.iterrows():
                        if overlaps(float(gt_row['gt_start_s']), float(gt_row['gt_end_s']), det_start, det_end):
                            has_any_gt_overlap = True
                            break

                    if not has_any_gt_overlap:
                        fp += 1

                # Report
                precision = tp / (tp + fp + misclassified) if (tp + fp + misclassified) > 0 else 0.0
                recall = tp / (tp + fn + misclassified) if (tp + fn + misclassified) > 0 else 0.0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

                print("\n" + "=" * 40)
                print("EVALUATION RESULTS")
                print("=" * 40)
                print(f"File evaluated: {selected_file}")
                print(f"True Positives (Correct): {tp}")
                print(f"False Negatives (Missed): {fn}")
                print(f"Misclassified:            {misclassified}")
                print(f"False Positives:          {fp}")
                print("-" * 40)
                print(f"Precision: {precision:.2%}")
                print(f"Recall:    {recall:.2%}")
                print(f"F1 Score:  {f1:.2%}")
                print("=" * 40)