In [1]:
import os
import logging
import random
import gc
import time
import cv2
import math
import warnings
from pathlib import Path
import glob
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import librosa
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import torchaudio.transforms as AT
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from scipy import signal
import timm

import time
from datetime import datetime
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

In [2]:
class CFG:
    # Basic paths and metadata
    test_soundscapes = '/kaggle/input/birdclef-2025/test_soundscapes'
    train_soundscapes = '/kaggle/input/birdclef-2025/train_soundscapes'
    submission_csv = '/kaggle/input/birdclef-2025/sample_submission.csv'
    taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'
    model_path = '/kaggle/input/bestmodel/pytorch/default/1'
    pretrained = True

    # Debug settings
    debug = False
    debug_start_num = 0  # Start index for debug sample subset
    debug_num = 8        # Number of samples to process in debug mode

    # Audio parameters
    FS = 32000           # Sampling rate (Hz)
    WINDOW_SIZE = 5      # Window size in seconds

    # Mel spectrogram parameters
    N_FFT = 1034
    HOP_LENGTH = 64
    N_MELS = 136
    FMIN = 20
    FMAX = 16000
    TARGET_SHAPE = (256, 256)  # Final spectrogram shape (HxW)

    # Model parameters
    model_name = 'efficientnet_b0'
    in_channels = 1
    device = 'cpu'  # Force to run on CPU

    # Inference parameters
    batch_size = 16        # Smaller batch size for CPU usage
    use_tta = False        # Disable test-time augmentation for speed
    tta_count = 3          # Number of TTA repetitions
    threshold = 0.7        # Default decision threshold

    # Fold selection
    use_specific_folds = False  # Use all available model folds by default
    folds = [0, 1]              # Only used if use_specific_folds = True

    # Prediction smoothing
    apply_smoothing = True
    smoothing_window = 5
    smoothing_weights = [0.15, 0.2, 0.3, 0.2, 0.15]  # Symmetric smoothing kernel

    # Audio preprocessing options
    apply_noise_reduction = True
    apply_normalization = True
    noise_reduction_strength = 0.1

    # Memory management
    clear_cache_frequency = 5  # Clear memory every N files to prevent memory leaks

    # SpecAugment options
    use_spec_augment = True
    time_mask_param = 30       # Max width of time mask
    freq_mask_param = 20       # Max height of frequency mask
    time_mask_count = 1        # Number of time masks to apply
    freq_mask_count = 1        # Number of frequency masks to apply

    # Spectrogram contrast enhancement
    apply_spec_contrast = True
    contrast_factor = 0.15     # Intensity of contrast boost

    # Threshold adjustment
    class_thresholds = None    # Optional per-class thresholding (set at runtime)

    # Prediction blending (temporal smoothing with neighbors)
    prediction_blend = [0.7, 0.3]  # 70% current, 30% neighboring predictions

    # Additional smoothing from reference implementation
    apply_secondary_smoothing = True
    secondary_smoothing_weights = [0.2, 0.6, 0.2]  # Weights: [prev, current, next]


In [3]:
##############################################################################
# Memory Management Function
def clear_memory():
    """
    Force garbage collection and clear PyTorch CUDA cache if available.
    Helps reduce memory usage during large-scale processing or inference.
    """
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
##############################################################################
# Load and Process Label Information
print("Loading label data...")

# Extract primary label names from sample submission file (excluding 'row_id')
primary_labels = pd.read_csv('/kaggle/input/birdclef-2025/sample_submission.csv').columns[1:].to_list()

# Load taxonomy metadata which maps species labels to broader class categories
taxonomy = pd.read_csv(CFG.taxonomy_csv)


Loading label data...


In [4]:
class BirdCLEFModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
        cfg.num_classes = len(taxonomy_df)
        
        self.backbone = timm.create_model(
            cfg.model_name,
            pretrained=cfg.pretrained,
            in_chans=cfg.in_channels,
            drop_rate=0.2,
            drop_path_rate=0.2
        )
        
        if 'efficientnet' in cfg.model_name:
            backbone_out = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        elif 'resnet' in cfg.model_name:
            backbone_out = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            backbone_out = self.backbone.get_classifier().in_features
            self.backbone.reset_classifier(0, '')
        
        self.pooling = nn.AdaptiveAvgPool2d(1)
            
        self.feat_dim = backbone_out
        
        self.classifier = nn.Linear(backbone_out, cfg.num_classes)
        
        self.mixup_enabled = hasattr(cfg, 'mixup_alpha') and cfg.mixup_alpha > 0
        if self.mixup_enabled:
            self.mixup_alpha = cfg.mixup_alpha
            
    def forward(self, x, targets=None):
    
        if self.training and self.mixup_enabled and targets is not None:
            mixed_x, targets_a, targets_b, lam = self.mixup_data(x, targets)
            x = mixed_x
        else:
            targets_a, targets_b, lam = None, None, None
        
        features = self.backbone(x)
        
        if isinstance(features, dict):
            features = features['features']
            
        if len(features.shape) == 4:
            features = self.pooling(features)
            features = features.view(features.size(0), -1)
        
        logits = self.classifier(features)
        
        if self.training and self.mixup_enabled and targets is not None:
            loss = self.mixup_criterion(F.binary_cross_entropy_with_logits, 
                                       logits, targets_a, targets_b, lam)
            return logits, loss
            
        return logits

In [5]:
##############################################################################
# Estimate Class Occurrence Frequency and Set Dynamic Thresholds
def estimate_class_thresholds():
    """
    Estimate optimal prediction thresholds for each bird species class,
    based on taxonomy and observed frequency.

    Rare species usually require higher thresholds to reduce false positives,
    while more common species may allow for lower thresholds to ensure recall.

    Returns:
        thresholds (dict): A dictionary mapping each species to its threshold.
    """
    print("Estimating optimal thresholds for each class...")

    # Default threshold defined in config
    default_threshold = CFG.threshold
    
    # Dictionary to hold species-specific thresholds
    thresholds = {}

    # If taxonomy data is available and contains valid species labels
    if taxonomy is not None and 'primary_label' in taxonomy.columns:
        for species in primary_labels:
            species_info = taxonomy[taxonomy['primary_label'] == species]

            if not species_info.empty:
                # Extract taxonomic details
                family = species_info['family'].iloc[0] if 'family' in species_info.columns else None
                genus = species_info['genus'].iloc[0] if 'genus' in species_info.columns else None

                # Example logic:
                # Certain families are more common (e.g., hummingbirds), allow lower threshold.
                # Others may be rare or have similar calls, requiring higher threshold.
                if family in ['Trochilidae', 'Tyrannidae']:  # Hummingbirds, Tyrant Flycatchers
                    thresholds[species] = default_threshold - 0.05  # Slightly lower threshold for recall
                elif family in ['Thraupidae', 'Parulidae']:  # Tanagers, Wood-warblers
                    thresholds[species] = default_threshold + 0.05  # Slightly higher threshold to avoid confusion
                else:
                    thresholds[species] = default_threshold
            else:
                # Species not found in taxonomy, fallback to default
                thresholds[species] = default_threshold
    else:
        # If no taxonomy data is available, apply default threshold for all species
        for species in primary_labels:
            thresholds[species] = default_threshold

    return thresholds


In [6]:
##############################################################################
# Audio File Loading Utility
def get_audio_files():
    """
    Retrieves the list of audio files to process based on mode (test or debug).

    If test files are found in the test_soundscapes directory, the function switches
    to full evaluation mode. If not, it defaults to debug mode and uses a subset
    of training soundscapes instead.

    Returns:
        tuple: (audio_paths, file_ids)
            - audio_paths: Full file paths to audio files.
            - file_ids: Unique identifiers (filenames without extensions).
    """
    if os.path.exists(CFG.test_soundscapes) and len(glob.glob(f'{CFG.test_soundscapes}/*.ogg')) > 0:
        # Found official test soundscapes, switch to full inference mode
        CFG.debug = False
        audio_dir = CFG.test_soundscapes
        audio_paths = sorted(glob.glob(f'{audio_dir}/*.ogg'))
    else:
        # No test files found — fallback to debug mode using train soundscapes
        print("No test files found. Using train soundscapes for debugging.")
        CFG.debug = True
        audio_dir = CFG.train_soundscapes
        all_audio_paths = sorted(glob.glob(f'{audio_dir}/*.ogg'))

        # Only take a limited subset of audio files for quick debugging
        audio_paths = all_audio_paths[CFG.debug_start_num:CFG.debug_start_num + CFG.debug_num]

    # Extract file IDs from file paths (remove directory and extension)
    file_ids = [os.path.splitext(os.path.basename(path))[0] for path in audio_paths]

    # Summary output
    print(f'Debug mode: {CFG.debug}')
    print(f'Number of soundscapes: {len(audio_paths)}')

    return audio_paths, file_ids


In [7]:
##############################################################################
# Model Discovery and Loading
def find_model_files():
    """
    Finds and loads all .pth model files in the specified model directory.

    This function searches recursively inside the directory defined by `CFG.model_path`,
    collecting all PyTorch model checkpoint files (with .pth extension).

    Returns:
        list: A list of full paths to discovered model files.
    """
    model_files = []
    model_dir = Path(CFG.model_path)
    
    # Recursively search for all .pth files in the model directory
    for path in model_dir.glob('**/*.pth'):
        model_files.append(str(path))
    
    return model_files


In [8]:
def load_models():
    """
    Loads all discovered model files and prepares them for ensemble inference.
    
    This function:
    - Searches for all .pth model files under the configured model directory
    - Optionally filters by specified folds if CFG.use_specific_folds is True
    - Loads the model state dictionaries
    - Initializes BirdCLEFModel instances with loaded weights
    - Moves models to the appropriate device (CPU/GPU) and sets them to eval mode
    
    Returns:
        list: A list of loaded PyTorch model instances ready for inference.
    """
    models = []
    model_files = find_model_files()

    if not model_files:
        print(f"Warning: No model files found under {CFG.model_path}!")
        return models

    print(f"Found a total of {len(model_files)} model files.")

    # Optionally filter by specific folds (e.g., only fold0 and fold1)
    if CFG.use_specific_folds:
        filtered_files = []
        for fold in CFG.folds:
            fold_files = [f for f in model_files if f"fold{fold}" in f]
            filtered_files.extend(fold_files)
        model_files = filtered_files
        print(f"Using {len(model_files)} model files for the specified folds ({CFG.folds}).")
    
    # Load each model from its checkpoint
    for model_path in model_files:
        try:
            print(f"Loading model: {model_path}")
            checkpoint = torch.load(model_path, map_location='cpu')  # Always load on CPU first
            model = BirdCLEFModel(CFG)
            model.load_state_dict(checkpoint['model_state_dict'])
            model = model.to(CFG.device)  # Move to configured device
            model.eval()  # Set to evaluation mode
            models.append(model)
        except Exception as e:
            print(f"Error loading model {model_path}: {e}")
    
    return models


In [9]:
##############################################################################
# Audio Processing Functions

def reduce_noise(audio_data):
    """
    Apply noise reduction to raw audio using median filtering and signal blending.

    Args:
        audio_data (np.ndarray): Input 1D audio waveform.

    Returns:
        np.ndarray: Denoised audio signal.
    """
    if not CFG.apply_noise_reduction:
        return audio_data

    # Apply median filter for noise suppression
    window_size = 5
    audio_denoised = signal.medfilt(audio_data, window_size)

    # Blend original signal with denoised version
    return (1 - CFG.noise_reduction_strength) * audio_data + CFG.noise_reduction_strength * audio_denoised


def normalize_audio(audio_data):
    """
    Normalize the audio waveform by removing DC offset and scaling amplitude.

    Args:
        audio_data (np.ndarray): Input 1D audio waveform.

    Returns:
        np.ndarray: Normalized audio signal.
    """
    if not CFG.apply_normalization:
        return audio_data

    # Remove DC offset
    audio_data = audio_data - np.mean(audio_data)

    # Normalize amplitude to [-1, 1]
    max_amplitude = np.max(np.abs(audio_data))
    if max_amplitude > 0:
        audio_data = audio_data / max_amplitude

    return audio_data


def audio2melspec(audio_data, cfg=CFG):
    """
    Convert raw audio to a normalized Mel-spectrogram with optional enhancements.

    This includes:
    - Noise reduction
    - Normalization
    - Spectrogram contrast enhancement
    - Resizing to target shape

    Args:
        audio_data (np.ndarray): Raw audio waveform.
        cfg (CFG): Configuration object containing processing parameters.

    Returns:
        np.ndarray: Preprocessed Mel-spectrogram (float32).
    """
    # Replace NaNs with mean value (safety check based on reference code)
    if np.isnan(audio_data).any():
        mean_signal = np.nanmean(audio_data)
        audio_data = np.nan_to_num(audio_data, nan=mean_signal)

    # Pad short audio to required window size
    required_length = cfg.FS * cfg.WINDOW_SIZE
    if len(audio_data) < required_length:
        audio_data = np.pad(
            audio_data,
            (0, required_length - len(audio_data)),
            mode='constant'
        )

    # Apply noise reduction and normalization
    audio_data = reduce_noise(audio_data)
    audio_data = normalize_audio(audio_data)

    # Compute Mel-spectrogram
    mel_spec = librosa.feature.melspectrogram(
        y=audio_data,
        sr=cfg.FS,
        n_fft=cfg.N_FFT,
        hop_length=cfg.HOP_LENGTH,
        n_mels=cfg.N_MELS,
        fmin=cfg.FMIN,
        fmax=cfg.FMAX,
        power=2.0  # Use power spectrogram
    )

    # Convert to log scale (dB)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

    # Normalize to [0, 1]
    mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)

    # Optional: Apply contrast enhancement to highlight features
    if cfg.apply_spec_contrast:
        mel_spec_norm = enhance_spectrogram_contrast(mel_spec_norm, cfg.contrast_factor)

    # Resize to fixed target shape (e.g., 256x256)
    if mel_spec_norm.shape != cfg.TARGET_SHAPE:
        mel_spec_norm = cv2.resize(mel_spec_norm, cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)

    return mel_spec_norm.astype(np.float32)

# The `factor` parameter controls the strength of contrast enhancement; try values between 0.05 and 0.2
def enhance_spectrogram_contrast(spec, factor=0.15):
    """
    Enhance the contrast of a mel-spectrogram to make features more distinguishable.
    
    :param spec: Input spectrogram (2D array).
    :param factor: Contrast enhancement factor.
    :return: Contrast-enhanced spectrogram, clipped to [0, 1].
    """
    mean = np.mean(spec)
    enhanced = mean + (spec - mean) * (1 + factor)
    return np.clip(enhanced, 0, 1)

def apply_spec_augment(spec):
    """
    Apply SpecAugment to a mel-spectrogram by adding time and frequency masks.
    
    :param spec: Input mel-spectrogram of shape [frequency, time].
    :return: Augmented spectrogram.
    """
    if not CFG.use_spec_augment:
        return spec
    
    augmented = spec.copy()
    
    # Apply frequency masking
    for _ in range(CFG.freq_mask_count):
        f = np.random.randint(0, CFG.freq_mask_param)
        f0 = np.random.randint(0, augmented.shape[0] - f)
        augmented[f0:f0 + f, :] = 0  # Zero out selected frequency range
    
    # Apply time masking
    for _ in range(CFG.time_mask_count):
        t = np.random.randint(0, CFG.time_mask_param)
        t0 = np.random.randint(0, augmented.shape[1] - t)
        augmented[:, t0:t0 + t] = 0  # Zero out selected time range
    
    return augmented

def apply_tta(spec, tta_idx):
    """
    Apply Test-Time Augmentation (TTA) to a mel-spectrogram.
    
    :param spec: Input mel-spectrogram.
    :param tta_idx: Index indicating which TTA method to apply.
    :return: Augmented spectrogram.
    """
    result = spec.copy()
    
    if tta_idx == 0:
        # No augmentation
        return result
    elif tta_idx == 1:
        # Horizontal flip (time reversal)
        return np.flip(result, axis=1).copy()
    elif tta_idx == 2:
        # Vertical flip (frequency inversion)
        return np.flip(result, axis=0).copy()
    elif tta_idx == 3:
        # Both horizontal and vertical flips
        return np.flip(np.flip(result, axis=1), axis=0).copy()
    elif tta_idx == 4:
        # Slight upward pitch shift (frequency axis roll)
        return np.roll(result, shift=3, axis=0)
    elif tta_idx == 5:
        # Slight downward pitch shift
        return np.roll(result, shift=-3, axis=0)
    else:
        return result

##############################################################################
# Main inference function with enhancement and ensemble support
def predict_on_audio(audio_path, models):
    """
    Perform inference on a single audio file. The audio is split into 5-second segments,
    and the presence of bird species is predicted for each segment.
    
    :param audio_path: Path to the audio file.
    :param models: List of PyTorch models used for ensemble prediction.
    :return: Tuple (row_ids, predictions) for each segment.
    """
    predictions = []
    row_ids = []
    soundscape_id = os.path.splitext(os.path.basename(audio_path))[0]
    
    try:
        print(f"Processing {soundscape_id}")
        audio_data, _ = librosa.load(audio_path, sr=CFG.FS)
        total_segments = int(len(audio_data) / (CFG.FS * CFG.WINDOW_SIZE))
        
        segment_predictions = []  # Store predictions for each segment before post-processing
        
        for segment_idx in range(total_segments):
            if time.time() > TERMINATE_TIME:
                print("Time limit reached, stopping processing early")
                return row_ids, predictions
                
            start_sample = segment_idx * CFG.FS * CFG.WINDOW_SIZE
            end_sample = start_sample + CFG.FS * CFG.WINDOW_SIZE
            segment_audio = audio_data[start_sample:end_sample]
            
            end_time_sec = (segment_idx + 1) * CFG.WINDOW_SIZE
            row_id = f"{soundscape_id}_{end_time_sec}"
            row_ids.append(row_id)

            try:
                # Apply TTA if enabled
                if CFG.use_tta:
                    all_preds = []
                    for tta_idx in range(CFG.tta_count):
                        mel_spec = audio2melspec(segment_audio)
                        
                        # Optionally apply SpecAugment during inference
                        if np.random.random() < 0.5 and CFG.use_spec_augment:
                            mel_spec = apply_spec_augment(mel_spec)
                        
                        mel_spec = apply_tta(mel_spec, tta_idx)
                        mel_spec_tensor = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
                        mel_spec_tensor = mel_spec_tensor.to(CFG.device)
                        
                        segment_preds = []
                        for model in models:
                            with torch.no_grad():
                                outputs = model(mel_spec_tensor)
                                probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                                segment_preds.append(probs)
                        
                        # Ensemble prediction via weighted average
                        if len(models) > 1:
                            weights = np.linspace(0.8, 1.2, len(models))
                            weights = weights / weights.sum()
                            avg_preds = np.average(segment_preds, axis=0, weights=weights)
                        else:
                            avg_preds = segment_preds[0]
                        
                        all_preds.append(avg_preds)
                    
                    # Average predictions across all TTA outputs
                    final_preds = np.mean(all_preds, axis=0)
                else:
                    # Regular inference without TTA
                    mel_spec = audio2melspec(segment_audio)
                    mel_spec_tensor = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
                    mel_spec_tensor = mel_spec_tensor.to(CFG.device)
                    
                    segment_preds = []
                    for model in models:
                        with torch.no_grad():
                            outputs = model(mel_spec_tensor)
                            probs = torch.sigmoid(outputs).cpu().numpy().squeeze()
                            segment_preds.append(probs)
                    
                    if len(models) > 1:
                        weights = np.linspace(0.8, 1.2, len(models))
                        weights = weights / weights.sum()
                        final_preds = np.average(segment_preds, axis=0, weights=weights)
                    else:
                        final_preds = segment_preds[0]
            except Exception as e:
                print(f"Error processing segment {segment_idx}: {e}")
                # Fallback: use zeros or previous prediction
                final_preds = np.zeros(len(primary_labels)) if len(predictions) == 0 else predictions[-1]
            
            segment_predictions.append(final_preds)
        
        # Post-process predictions by smoothing with adjacent segments
        for i in range(len(segment_predictions)):
            if i > 0 and i < len(segment_predictions) - 1:
                # Smooth current prediction using adjacent segments
                blended_pred = (
                    CFG.prediction_blend[0] * segment_predictions[i] + 
                    CFG.prediction_blend[1] * 0.5 * (segment_predictions[i-1] + segment_predictions[i+1])
                )
                predictions.append(blended_pred)
            else:
                # Use raw predictions for first and last segments
                predictions.append(segment_predictions[i])
                
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        row_ids = []
        predictions = []
    
    return row_ids, predictions

##############################################################################
# Enhanced smoothing function
def smooth_predictions(submission_df):
    """
    Smooth predictions in the submission to enhance temporal consistency.
    
    :param submission_df: DataFrame containing raw predictions.
    :return: Smoothed predictions DataFrame.
    """
    print("Smoothing prediction results...")
    sub = submission_df.copy()
    cols = sub.columns[1:]
    sub['group'] = sub['row_id'].str.rsplit('_', n=1).str[0]  # Extract group ID by removing timestamp
    unique_groups = sub['group'].unique()
    
    for group in unique_groups:
        group_mask = sub['group'] == group
        sub_group = sub[group_mask].copy()
        predictions = sub_group[cols].values
        new_predictions = predictions.copy()
        
        if predictions.shape[0] > 1:
            # Enhanced smoothing with configurable window and weights
            window = CFG.smoothing_window
            weights = CFG.smoothing_weights
            half_window = window // 2
            
            # Handle edge cases
            for i in range(half_window):
                # Smooth the beginning predictions using a truncated window
                valid_window = i + half_window + 1
                valid_weights = weights[-valid_window:]
                valid_weights = valid_weights / np.sum(valid_weights)
                new_predictions[i] = np.average(predictions[:valid_window], axis=0, weights=valid_weights)
                
                # Smooth the end predictions similarly
                valid_window = i + half_window + 1
                valid_weights = weights[:valid_window]
                valid_weights = valid_weights / np.sum(valid_weights)
                new_predictions[-(i+1)] = np.average(predictions[-valid_window:], axis=0, weights=valid_weights)
            
            # Apply smoothing for central elements using sliding window
            for i in range(half_window, predictions.shape[0] - half_window):
                window_start = i - half_window
                window_end = i + half_window + 1
                new_predictions[i] = np.average(predictions[window_start:window_end], axis=0, weights=weights)
        
        sub.loc[group_mask, cols] = new_predictions

    sub.drop('group', axis=1, inplace=True)
    return sub

##############################################################################
# Apply thresholds to predictions
def apply_thresholds(submission_df, thresholds):
    """
    Binarize prediction results using the best threshold for each class.
    
    :param submission_df: DataFrame containing prediction probabilities.
    :param thresholds: Dictionary of optimal thresholds per class.
    :return: Thresholded binary prediction DataFrame.
    """
    print("Applying class-specific thresholds...")
    result_df = submission_df.copy()
    
    for species in primary_labels:
        threshold = thresholds.get(species, CFG.threshold)
        result_df[species] = (result_df[species] >= threshold).astype(float)
    
    return result_df


In [10]:
##############################################################################
# Main Inference Pipeline
def run_pipeline():
    """
    Execute the complete inference pipeline.
    """
    print(f"Device: {CFG.device}")
    print(f"TTA enabled: {CFG.use_tta} (variations: {CFG.tta_count if CFG.use_tta else 0})")
    
    # Estimate optimal thresholds for each class (if not already set)
    if CFG.class_thresholds is None:
        CFG.class_thresholds = estimate_class_thresholds()
    
    # Load model(s)
    models = load_models()
    if not models:
        print("No models found! Please check model paths.")
        return None
    
    print(f"Model usage: {'Single model' if len(models) == 1 else f'Ensemble of {len(models)} models'}")
    
    # Retrieve all audio files
    audio_paths, file_ids = get_audio_files()
    print(f"Found {len(audio_paths)} audio files")
    
    all_row_ids = []
    all_predictions = []
    
    # Process each audio file individually
    for i, audio_path in enumerate(tqdm(audio_paths)):
        row_ids, predictions = predict_on_audio(audio_path, models)
        
        # Add results only if both are valid and lengths match
        if len(row_ids) > 0 and len(predictions) > 0 and len(row_ids) == len(predictions):
            all_row_ids.extend(row_ids)
            all_predictions.extend(predictions)
        else:
            print(f"Skipping results for {audio_path} due to length mismatch or empty results")
        
        # Periodically clear memory to maintain efficiency
        if (i + 1) % CFG.clear_cache_frequency == 0:
            clear_memory()
    
    # Construct the submission DataFrame
    print("Creating submission dataframe...")
    submission_dict = {'row_id': all_row_ids}
    for i, species in enumerate(primary_labels):
        submission_dict[species] = [pred[i] for pred in all_predictions]
    
    # Ensure consistent length across all columns
    lengths = [len(v) for v in submission_dict.values()]
    if len(set(lengths)) > 1:
        print(f"Warning: Inconsistent lengths in submission_dict: {lengths}")
        min_length = min(lengths)
        for k in submission_dict:
            submission_dict[k] = submission_dict[k][:min_length]
    
    submission_df = pd.DataFrame(submission_dict)
    
    # Ensure all required columns are present (match sample submission)
    sample_sub = pd.read_csv(CFG.submission_csv)
    missing_cols = set(sample_sub.columns) - set(submission_df.columns)
    if missing_cols:
        print(f"Warning: Missing {len(missing_cols)} columns in submission")
        for col in missing_cols:
            submission_df[col] = 0.0
            
    # Reorder columns to match sample submission
    if 'row_id' in sample_sub.columns:
        submission_df = submission_df[sample_sub.columns]
    
    # Apply temporal smoothing (first stage)
    if CFG.apply_smoothing:
        submission_df = smooth_predictions(submission_df)
    
    # Optional second-stage smoothing from reference implementation
    if CFG.apply_secondary_smoothing:
        cols = submission_df.columns[1:]
        groups = submission_df['row_id'].str.rsplit('_', n=1).str[0].values
            
        for group in np.unique(groups):
            group_mask = (groups == group)
            sub_group = submission_df[group_mask]
            predictions = sub_group[cols].values
            new_predictions = predictions.copy()
            
            # Smooth internal entries using a 3-frame weighted average
            for i in range(1, predictions.shape[0]-1):
                new_predictions[i] = (predictions[i-1] * 0.2) + (predictions[i] * 0.6) + (predictions[i+1] * 0.2)
            
            # Handle the first and last entries
            if predictions.shape[0] > 1:
                new_predictions[0] = (predictions[0] * 0.8) + (predictions[1] * 0.2)
                new_predictions[-1] = (predictions[-1] * 0.8) + (predictions[-2] * 0.2)
            
            submission_df.loc[group_mask, cols] = new_predictions
    
    # Apply thresholds to convert probabilities to binary labels (optional)
    # Note: This is typically skipped for Kaggle since evaluation is on probabilities
    # final_df = apply_thresholds(submission_df, CFG.class_thresholds)
    final_df = submission_df
    
    # Save to CSV
    submission_path = 'submission.csv'
    final_df.to_csv(submission_path, index=False)
    print(f"Submission saved to {submission_path}")
    
    # Display preview of submission
    print("\nSubmission head:")
    print(final_df.head(5))
    
    print("\nSubmission tail:")
    print(final_df.tail(5))
    
    return final_df
##############################################################################
# Visualization Functions

def create_mel_transform():
    """
    Create a MelSpectrogram transform for visualization.
    
    :return: MelSpectrogram transform object.
    """
    return AT.MelSpectrogram(
        sample_rate=CFG.FS,
        n_fft=CFG.N_FFT,
        win_length=CFG.N_FFT,
        hop_length=CFG.HOP_LENGTH,
        center=True,
        f_min=CFG.FMIN,
        f_max=CFG.FMAX,
        pad_mode="reflect",
        power=2.0,
        norm='slaney',
        n_mels=CFG.N_MELS,
        mel_scale="htk",
    )

def audio_to_mel_debug(filepath):
    """
    Convert an audio file to its Mel spectrogram representation for debugging or visualization.
    
    :param filepath: Path to the audio file (.ogg, .wav, etc.)
    :return: Tensor representing the log-scaled Mel spectrogram.
    """
    waveform, _ = torchaudio.load(filepath, backend="soundfile")
    
    # Normalize waveform to [-1, 1]
    waveform = waveform / torch.max(torch.abs(waveform))
    
    # Generate mel spectrogram
    mel_transform = create_mel_transform()
    melspec = mel_transform(waveform)
    
    # Convert power to decibels (log scale), add epsilon to avoid log(0)
    melspec = 10 * torch.log10(melspec + 1e-10)
    
    return melspec

def plot_results(results, file_name, audio_dir):
    """
    Plot the Mel spectrogram and prediction heatmap for a given audio file.
    
    :param results: DataFrame containing prediction results for all files.
    :param file_name: Base name of the audio file (without extension).
    :param audio_dir: Directory containing the audio files.
    """
    path = os.path.join(audio_dir, file_name + ".ogg")
    
    # Compute Mel spectrogram
    specgram = audio_to_mel_debug(path)
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot the Mel spectrogram
    axes[0].set_title(file_name)
    im = axes[0].imshow(specgram[0], origin="lower", aspect="auto")
    axes[0].set_ylabel("Mel bin")
    axes[0].set_xlabel("Frame")
    fig.colorbar(im, ax=axes[0])
    
    # Filter prediction results for the current audio file
    file_results = results[results["row_id"].str.contains(file_name)]
    
    # Determine number of time segments (usually 60s audio → 12 segments)
    time_segments = file_results.shape[0]
    
    # Plot prediction heatmap
    heatmap = axes[1].pcolor(file_results.iloc[:, 1:].values.T, edgecolors='k',
                             linewidths=0.1, vmin=0, vmax=1, cmap='Blues')
    fig.colorbar(heatmap, ax=axes[1])
    
    # Set time ticks on x-axis (e.g., 0s, 5s, ..., 55s)
    axes[1].set_xticks(np.arange(0, time_segments, 1))
    axes[1].set_xticklabels(np.arange(0, time_segments * 5, 5))
    
    axes[1].set_ylabel("Species")
    axes[1].set_xlabel("Seconds")
    
    # Show species labels only when count is manageable
    if len(primary_labels) <= 30:
        axes[1].set_yticks(np.arange(0.5, len(primary_labels), 1))
        axes[1].set_yticklabels(primary_labels)
    
    fig.tight_layout()
    plt.savefig(f'{file_name}_prediction.png')
    plt.close()


In [11]:
##############################################################################
# Main Execution
if __name__ == "__main__":
    START = time.time()
    TERMINATE_TIME = START + 5300  # Execution limit: ~88 minutes

    try:
        # Run the full inference pipeline
        results = run_pipeline()
        
        # Clear memory after prediction to ensure sufficient memory for visualization
        clear_memory()
        
        # If in debug mode and results are available, generate visualizations
        if CFG.debug and results is not None:
            audio_paths, file_ids = get_audio_files()
            audio_dir = CFG.train_soundscapes if CFG.debug else CFG.test_soundscapes
            
            print("\nGenerating visualizations...")
            for file_id in file_ids:
                plot_results(results, file_id, audio_dir)
                print(f"Visualization saved for {file_id}")
        
        print(f"\nTotal execution time: {(time.time() - START) / 60:.2f} minutes")
    
    except Exception as e:
        print(f"Error during execution: {e}")
        import traceback
        traceback.print_exc()


Device: cpu
TTA enabled: False (variations: 0)
Estimating optimal thresholds for each class...
Found a total of 1 model files.
Loading model: /kaggle/input/bestmodel/pytorch/default/1/model_fold4_best_20250416_214622.pth


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

Model usage: Single model
No test files found. Using train soundscapes for debugging.
Debug mode: True
Number of soundscapes: 8
Found 8 audio files


  0%|          | 0/8 [00:00<?, ?it/s]

Processing H02_20230420_074000
Processing H02_20230420_112000
Processing H02_20230420_154500
Processing H02_20230420_164000
Processing H02_20230420_223500
Processing H02_20230421_093000
Processing H02_20230421_113500
Processing H02_20230421_170000
Creating submission dataframe...
Smoothing prediction results...
Submission saved to submission.csv

Submission head:
                   row_id   1139490   1192948   1194042    126247   1346504  \
0   H02_20230420_074000_5  0.000092  0.000076  0.000376  0.000064  0.000494   
1  H02_20230420_074000_10  0.000091  0.000080  0.000385  0.000060  0.000471   
2  H02_20230420_074000_15  0.000089  0.000085  0.000400  0.000058  0.000440   
3  H02_20230420_074000_20  0.000088  0.000085  0.000404  0.000057  0.000408   
4  H02_20230420_074000_25  0.000098  0.000086  0.000427  0.000064  0.000463   

     134933    135045   1462711   1462737  ...   yebfly1   yebsee1   yecspi2  \
0  0.004250  0.002491  0.000096  0.000053  ...  0.007989  0.005650  0.043839   