In [1]:
import os
import logging
import random
import gc
import time
import cv2
import math
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import librosa

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from scipy import signal
import timm

import time
from datetime import datetime
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

In [2]:
class CFG:
    seed = 42  # Random seed for reproducibility
    debug = False  # Debug mode for faster iteration
    apex = False  # Use NVIDIA Apex for mixed precision training
    print_freq = 100  # Logging frequency (batches)
    num_workers = 2  # Number of dataloader subprocesses

    # File and directory paths
    OUTPUT_DIR = '/kaggle/working/'  # Directory to save outputs and models
    train_datadir = '/kaggle/input/birdclef-2025/train_audio'  # Path to training audio
    train_csv = '/kaggle/input/birdclef-2025/train.csv'  # Training metadata CSV
    test_soundscapes = '/kaggle/input/birdclef-2025/test_soundscapes'  # Test soundscape audio
    submission_csv = '/kaggle/input/birdclef-2025/sample_submission.csv'  # Sample submission format
    taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'  # Bird species taxonomy information

    spectrogram_npy = '/kaggle/input/1111111/birdclef2025_melspec_5sec_256_256.npy'  # Cached mel spectrograms (optional)

    # Model configuration
    model_name = 'efficientnet_b0'  # Backbone model architecture
    model_path = '/kaggle/working/'  # Path to save/load model checkpoints
    pretrained = True  # Use pretrained weights
    in_channels = 1  # Number of input channels (e.g., 1 for mel-spectrograms)

    # Audio preprocessing
    LOAD_DATA = True  # Whether to load and preprocess audio data
    FS = 32000  # Sampling rate in Hz
    TARGET_DURATION = 5.0  # Target duration of each audio clip (in seconds)
    TARGET_SHAPE = (256, 256)  # Target shape for input spectrograms (H, W)

    # Mel spectrogram parameters
    N_FFT = 1024  # FFT window size
    HOP_LENGTH = 512  # Hop length for STFT
    N_MELS = 128  # Number of mel bands
    FMIN = 50  # Minimum frequency for mel filter bank
    FMAX = 14000  # Maximum frequency for mel filter bank
    WINDOW_SIZE = 5  # Median filter window size for noise reduction
    N_MAX = 50 if debug else None  # Optional limit on number of audio samples (for debug mode)

    # Preprocessing flags
    apply_noise_reduction = True  # Apply median filter-based noise reduction
    apply_normalization = True  # Normalize waveform amplitude
    noise_reduction_strength = 0.1  # Blending factor for noise reduction
    apply_spec_contrast = True  # Apply contrast enhancement to mel-spectrogram
    contrast_factor = 0.15  # Contrast enhancement factor

    # Training configuration
    device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Use GPU if available
    epochs = 10  # Total number of training epochs
    batch_size = 32  # Batch size for training
    criterion = 'BCEWithLogitsLoss'  # Loss function (binary classification)

    # Cross-validation settings
    n_fold = 5  # Total number of folds
    selected_folds = [0, 1, 2, 3, 4]  # Folds to train on

    # Optimizer and scheduler
    optimizer = 'AdamW'  # Optimizer choice
    lr = 5e-4  # Learning rate
    weight_decay = 1e-5  # Weight decay for regularization

    scheduler = 'CosineAnnealingLR'  # Learning rate scheduler
    min_lr = 1e-6  # Minimum learning rate
    T_max = epochs  # Max number of iterations for scheduler

    # Data augmentation
    aug_prob = 0.5  # Probability of applying augmentations
    mixup_alpha = 0.5  # Alpha value for mixup augmentation

    def update_debug_settings(self):
        """
        Update training settings for debug mode to speed up experiments.
        """
        if self.debug:
            self.epochs = 2  # Fewer epochs for faster debugging
            self.selected_folds = [0]  # Use only one fold during debugging

In [3]:
def set_seed(seed=42):
    """
    Set seed for reproducibility
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
def process_audio_file(audio_path, cfg):
    """Process a single audio file to get the mel spectrogram"""
    try:
        audio_data, _ = librosa.load(audio_path, sr=cfg.FS)

        target_samples = int(cfg.TARGET_DURATION * cfg.FS)

        if len(audio_data) < target_samples:
            n_copy = math.ceil(target_samples / len(audio_data))
            if n_copy > 1:
                audio_data = np.concatenate([audio_data] * n_copy)

        # Extract center 5 seconds
        start_idx = max(0, int(len(audio_data) / 2 - target_samples / 2))
        end_idx = min(len(audio_data), start_idx + target_samples)
        center_audio = audio_data[start_idx:end_idx]

        if len(center_audio) < target_samples:
            center_audio = np.pad(center_audio, 
                                 (0, target_samples - len(center_audio)), 
                                 mode='constant')

        mel_spec = audio2melspec(center_audio, cfg)
        
        if mel_spec.shape != cfg.TARGET_SHAPE:
            mel_spec = cv2.resize(mel_spec, cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)

        return mel_spec.astype(np.float32)
        
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None
# Audio processing functions
def reduce_noise(audio_data):
    """
    Apply noise reduction to audio data.

    :param audio_data: Raw audio waveform
    :return: Denoised audio waveform
    """
    if not CFG.apply_noise_reduction:
        return audio_data

    # Apply median filter for basic noise suppression
    window_size = 5
    audio_denoised = signal.medfilt(audio_data, window_size)

    # Blend original and denoised signal based on reduction strength
    return (1 - CFG.noise_reduction_strength) * audio_data + CFG.noise_reduction_strength * audio_denoised


def normalize_audio(audio_data):
    """
    Normalize the audio waveform.

    :param audio_data: Raw audio waveform
    :return: Normalized audio waveform
    """
    if not CFG.apply_normalization:
        return audio_data

    # Remove DC offset
    audio_data = audio_data - np.mean(audio_data)

    # Normalize to max absolute amplitude of 1
    max_amplitude = np.max(np.abs(audio_data))
    if max_amplitude > 0:
        audio_data = audio_data / max_amplitude

    return audio_data


def enhance_spectrogram_contrast(spec, factor=0.15):
    """
    Enhance the contrast of the spectrogram to highlight features.

    :param spec: Input spectrogram (normalized)
    :param factor: Contrast enhancement factor (typically in range 0.05~0.2)
    :return: Contrast-enhanced spectrogram
    """
    mean = np.mean(spec)
    enhanced = mean + (spec - mean) * (1 + factor)
    return np.clip(enhanced, 0, 1)  # Ensure values remain in [0, 1]


#############################################################################
def audio2melspec(audio_data, cfg=CFG):
    """
    Convert raw audio data into a normalized and contrast-enhanced Mel spectrogram.

    :param audio_data: Raw 1D audio waveform
    :param cfg: Configuration object with processing parameters
    :return: Processed Mel spectrogram (float32, shape=cfg.TARGET_SHAPE)
    """

    # Handle potential NaN values (replacing with mean value)
    if np.isnan(audio_data).any():
        mean_signal = np.nanmean(audio_data)
        audio_data = np.nan_to_num(audio_data, nan=mean_signal)

    # Pad the signal if it's shorter than the required window length
    required_length = CFG.FS * CFG.WINDOW_SIZE
    if len(audio_data) < required_length:
        audio_data = np.pad(
            audio_data,
            (0, required_length - len(audio_data)),
            mode='constant'
        )

    # Apply noise reduction and normalization
    audio_data = reduce_noise(audio_data)
    audio_data = normalize_audio(audio_data)

    # Generate Mel spectrogram from the waveform
    mel_spec = librosa.feature.melspectrogram(
        y=audio_data,
        sr=cfg.FS,
        n_fft=cfg.N_FFT,
        hop_length=cfg.HOP_LENGTH,
        n_mels=cfg.N_MELS,
        fmin=cfg.FMIN,
        fmax=cfg.FMAX,
        power=2.0
    )

    # Convert to log-scale (dB)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

    # Normalize to [0, 1]
    mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)

    # Apply contrast enhancement if enabled
    if CFG.apply_spec_contrast:
        mel_spec_norm = enhance_spectrogram_contrast(mel_spec_norm, CFG.contrast_factor)

    # Resize spectrogram to target shape for model input
    if mel_spec_norm.shape != cfg.TARGET_SHAPE:
        mel_spec_norm = cv2.resize(mel_spec_norm, cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)

    return mel_spec_norm.astype(np.float32)

In [5]:
class BirdCLEFDatasetFromNPY(Dataset):
    def __init__(self, df, cfg, spectrograms=None, mode="train"):
        self.df = df
        self.cfg = cfg
        self.mode = mode

        self.spectrograms = spectrograms
        
        taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)
        self.species_ids = taxonomy_df['primary_label'].tolist()
        self.num_classes = len(self.species_ids)
        self.label_to_idx = {label: idx for idx, label in enumerate(self.species_ids)}

        if 'filepath' not in self.df.columns:
            self.df['filepath'] = self.cfg.train_datadir + '/' + self.df.filename
        
        if 'samplename' not in self.df.columns:
            self.df['samplename'] = self.df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])

        sample_names = set(self.df['samplename'])
        if self.spectrograms:
            found_samples = sum(1 for name in sample_names if name in self.spectrograms)
            print(f"Found {found_samples} matching spectrograms for {mode} dataset out of {len(self.df)} samples")
        
        if cfg.debug:
            self.df = self.df.sample(min(1000, len(self.df)), random_state=cfg.seed).reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        samplename = row['samplename']
        spec = None

        if self.spectrograms and samplename in self.spectrograms:
            spec = self.spectrograms[samplename]
        elif not self.cfg.LOAD_DATA:
            spec = process_audio_file(row['filepath'], self.cfg)

        if spec is None:
            spec = np.zeros(self.cfg.TARGET_SHAPE, dtype=np.float32)
            if self.mode == "train":  # Only print warning during training
                print(f"Warning: Spectrogram for {samplename} not found and could not be generated")

        spec = torch.tensor(spec, dtype=torch.float32).unsqueeze(0)  # Add channel dimension

        if self.mode == "train" and random.random() < self.cfg.aug_prob:
            spec = self.apply_spec_augmentations(spec)
        
        target = self.encode_label(row['primary_label'])
        
        if 'secondary_labels' in row and row['secondary_labels'] not in [[''], None, np.nan]:
            if isinstance(row['secondary_labels'], str):
                secondary_labels = eval(row['secondary_labels'])
            else:
                secondary_labels = row['secondary_labels']
            
            for label in secondary_labels:
                if label in self.label_to_idx:
                    target[self.label_to_idx[label]] = 1.0
        
        return {
            'melspec': spec, 
            'target': torch.tensor(target, dtype=torch.float32),
            'filename': row['filename']
        }
    
    def apply_spec_augmentations(self, spec):
        """Apply augmentations to spectrogram"""
    
        # Time masking (horizontal stripes)
        if random.random() < 0.5:
            num_masks = random.randint(1, 3)
            for _ in range(num_masks):
                width = random.randint(5, 20)
                start = random.randint(0, spec.shape[2] - width)
                spec[0, :, start:start+width] = 0
        
        # Frequency masking (vertical stripes)
        if random.random() < 0.5:
            num_masks = random.randint(1, 3)
            for _ in range(num_masks):
                height = random.randint(5, 20)
                start = random.randint(0, spec.shape[1] - height)
                spec[0, start:start+height, :] = 0
        
        # Random brightness/contrast
        if random.random() < 0.5:
            gain = random.uniform(0.8, 1.2)
            bias = random.uniform(-0.1, 0.1)
            spec = spec * gain + bias
            spec = torch.clamp(spec, 0, 1) 
            
        return spec
    
    def encode_label(self, label):
        """Encode label to one-hot vector"""
        target = np.zeros(self.num_classes)
        if label in self.label_to_idx:
            target[self.label_to_idx[label]] = 1.0
        return target

In [6]:
def collate_fn(batch):
    """Custom collate function to handle different sized spectrograms"""
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return {}
        
    result = {key: [] for key in batch[0].keys()}
    
    for item in batch:
        for key, value in item.items():
            result[key].append(value)
    
    for key in result:
        if key == 'target' and isinstance(result[key][0], torch.Tensor):
            result[key] = torch.stack(result[key])
        elif key == 'melspec' and isinstance(result[key][0], torch.Tensor):
            shapes = [t.shape for t in result[key]]
            if len(set(str(s) for s in shapes)) == 1:
                result[key] = torch.stack(result[key])
    
    return result

In [7]:
class BirdCLEFModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
        cfg.num_classes = len(taxonomy_df)
        
        self.backbone = timm.create_model(
            cfg.model_name,
            pretrained=cfg.pretrained,
            in_chans=cfg.in_channels,
            drop_rate=0.2,
            drop_path_rate=0.2
        )
        
        if 'efficientnet' in cfg.model_name:
            backbone_out = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        elif 'resnet' in cfg.model_name:
            backbone_out = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        else:
            backbone_out = self.backbone.get_classifier().in_features
            self.backbone.reset_classifier(0, '')
        
        self.pooling = nn.AdaptiveAvgPool2d(1)
            
        self.feat_dim = backbone_out
        
        self.classifier = nn.Linear(backbone_out, cfg.num_classes)
        
        self.mixup_enabled = hasattr(cfg, 'mixup_alpha') and cfg.mixup_alpha > 0
        if self.mixup_enabled:
            self.mixup_alpha = cfg.mixup_alpha
            
    def forward(self, x, targets=None):
    
        if self.training and self.mixup_enabled and targets is not None:
            mixed_x, targets_a, targets_b, lam = self.mixup_data(x, targets)
            x = mixed_x
        else:
            targets_a, targets_b, lam = None, None, None
        
        features = self.backbone(x)
        
        if isinstance(features, dict):
            features = features['features']
            
        if len(features.shape) == 4:
            features = self.pooling(features)
            features = features.view(features.size(0), -1)
        
        logits = self.classifier(features)
        
        if self.training and self.mixup_enabled and targets is not None:
            loss = self.mixup_criterion(F.binary_cross_entropy_with_logits, 
                                       logits, targets_a, targets_b, lam)
            return logits, loss
            
        return logits
    
    def mixup_data(self, x, targets):
        """Applies mixup to the data batch"""
        batch_size = x.size(0)

        lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)

        indices = torch.randperm(batch_size).to(x.device)

        mixed_x = lam * x + (1 - lam) * x[indices]
        
        return mixed_x, targets, targets[indices], lam
    
    def mixup_criterion(self, criterion, pred, y_a, y_b, lam):
        """Applies mixup to the loss function"""
        return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [8]:
def get_optimizer(model, cfg):
  
    if cfg.optimizer == 'Adam':
        optimizer = optim.Adam(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay
        )
    elif cfg.optimizer == 'AdamW':
        optimizer = optim.AdamW(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay
        )
    elif cfg.optimizer == 'SGD':
        optimizer = optim.SGD(
            model.parameters(),
            lr=cfg.lr,
            momentum=0.9,
            weight_decay=cfg.weight_decay
        )
    else:
        raise NotImplementedError(f"Optimizer {cfg.optimizer} not implemented")
        
    return optimizer

def get_scheduler(optimizer, cfg):
   
    if cfg.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=cfg.T_max,
            eta_min=cfg.min_lr
        )
    elif cfg.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=2,
            min_lr=cfg.min_lr,
            verbose=True
        )
    elif cfg.scheduler == 'StepLR':
        scheduler = lr_scheduler.StepLR(
            optimizer,
            step_size=cfg.epochs // 3,
            gamma=0.5
        )
    elif cfg.scheduler == 'OneCycleLR':
        scheduler = None  
    else:
        scheduler = None
        
    return scheduler

def get_criterion(cfg):
 
    if cfg.criterion == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss()
    else:
        raise NotImplementedError(f"Criterion {cfg.criterion} not implemented")
        
    return criterion

In [9]:
def train_one_epoch(model, loader, optimizer, criterion, device, scheduler=None):
    
    model.train()
    losses = []
    all_targets = []
    all_outputs = []
    
    pbar = tqdm(enumerate(loader), total=len(loader), desc="Training")
    
    for step, batch in pbar:
    
        if isinstance(batch['melspec'], list):
            batch_outputs = []
            batch_losses = []
            
            for i in range(len(batch['melspec'])):
                inputs = batch['melspec'][i].unsqueeze(0).to(device)
                target = batch['target'][i].unsqueeze(0).to(device)
                
                optimizer.zero_grad()
                output = model(inputs)
                loss = criterion(output, target)
                loss.backward()
                
                batch_outputs.append(output.detach().cpu())
                batch_losses.append(loss.item())
            
            optimizer.step()
            outputs = torch.cat(batch_outputs, dim=0).numpy()
            loss = np.mean(batch_losses)
            targets = batch['target'].numpy()
            
        else:
            inputs = batch['melspec'].to(device)
            targets = batch['target'].to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            if isinstance(outputs, tuple):
                outputs, loss = outputs  
            else:
                loss = criterion(outputs, targets)
                
            loss.backward()
            optimizer.step()
            
            outputs = outputs.detach().cpu().numpy()
            targets = targets.detach().cpu().numpy()
        
        if scheduler is not None and isinstance(scheduler, lr_scheduler.OneCycleLR):
            scheduler.step()
            
        all_outputs.append(outputs)
        all_targets.append(targets)
        losses.append(loss if isinstance(loss, float) else loss.item())
        
        pbar.set_postfix({
            'train_loss': np.mean(losses[-10:]) if losses else 0,
            'lr': optimizer.param_groups[0]['lr']
        })
    
    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)
    auc = calculate_auc(all_targets, all_outputs)
    avg_loss = np.mean(losses)
    
    return avg_loss, auc

def validate(model, loader, criterion, device):
   
    model.eval()
    losses = []
    all_targets = []
    all_outputs = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            if isinstance(batch['melspec'], list):
                batch_outputs = []
                batch_losses = []
                
                for i in range(len(batch['melspec'])):
                    inputs = batch['melspec'][i].unsqueeze(0).to(device)
                    target = batch['target'][i].unsqueeze(0).to(device)
                    
                    output = model(inputs)
                    loss = criterion(output, target)
                    
                    batch_outputs.append(output.detach().cpu())
                    batch_losses.append(loss.item())
                
                outputs = torch.cat(batch_outputs, dim=0).numpy()
                loss = np.mean(batch_losses)
                targets = batch['target'].numpy()
                
            else:
                inputs = batch['melspec'].to(device)
                targets = batch['target'].to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                outputs = outputs.detach().cpu().numpy()
                targets = targets.detach().cpu().numpy()
            
            all_outputs.append(outputs)
            all_targets.append(targets)
            losses.append(loss if isinstance(loss, float) else loss.item())
    
    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)
    
    auc = calculate_auc(all_targets, all_outputs)
    avg_loss = np.mean(losses)
    
    return avg_loss, auc

def calculate_auc(targets, outputs):
  
    num_classes = targets.shape[1]
    aucs = []
    
    probs = 1 / (1 + np.exp(-outputs))
    
    for i in range(num_classes):
        
        if np.sum(targets[:, i]) > 0:
            class_auc = roc_auc_score(targets[:, i], probs[:, i])
            aucs.append(class_auc)
    
    return np.mean(aucs) if aucs else 0.0

In [10]:
def run_training(df, cfg):
    """Training function that can either use pre-computed spectrograms or generate them on-the-fly"""

    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
    species_ids = taxonomy_df['primary_label'].tolist()
    cfg.num_classes = len(species_ids)
    
    if cfg.debug:
        cfg.update_debug_settings()

    spectrograms = None
    if cfg.LOAD_DATA:
        print("Loading pre-computed mel spectrograms from NPY file...")
        try:
            spectrograms = np.load(cfg.spectrogram_npy, allow_pickle=True).item()
            print(f"Loaded {len(spectrograms)} pre-computed mel spectrograms")
        except Exception as e:
            print(f"Error loading pre-computed spectrograms: {e}")
            print("Will generate spectrograms on-the-fly instead.")
            cfg.LOAD_DATA = False
    
    if not cfg.LOAD_DATA:
        print("Will generate spectrograms on-the-fly during training.")
        if 'filepath' not in df.columns:
            df['filepath'] = cfg.train_datadir + '/' + df.filename
        if 'samplename' not in df.columns:
            df['samplename'] = df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])
        
    skf = StratifiedKFold(n_splits=cfg.n_fold, shuffle=True, random_state=cfg.seed)
    
    best_scores = []
    save_dir = cfg.model_path 
    os.makedirs(save_dir, exist_ok=True)
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['primary_label'])):
        if fold not in cfg.selected_folds:
            continue
            
        print(f'\n{"="*30} Fold {fold} {"="*30}')
        
        train_df = df.iloc[train_idx].reset_index(drop=True)
        val_df = df.iloc[val_idx].reset_index(drop=True)
        
        print(f'Training set: {len(train_df)} samples')
        print(f'Validation set: {len(val_df)} samples')
        
        train_dataset = BirdCLEFDatasetFromNPY(train_df, cfg, spectrograms=spectrograms, mode='train')
        val_dataset = BirdCLEFDatasetFromNPY(val_df, cfg, spectrograms=spectrograms, mode='valid')
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=cfg.batch_size, 
            shuffle=True, 
            num_workers=cfg.num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
            drop_last=True
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=cfg.batch_size, 
            shuffle=False, 
            num_workers=cfg.num_workers,
            pin_memory=True,
            collate_fn=collate_fn
        )
        
        model = BirdCLEFModel(cfg).to(cfg.device)
        optimizer = get_optimizer(model, cfg)
        criterion = get_criterion(cfg)
        
        if cfg.scheduler == 'OneCycleLR':
            scheduler = lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=cfg.lr,
                steps_per_epoch=len(train_loader),
                epochs=cfg.epochs,
                pct_start=0.1
            )
        else:
            scheduler = get_scheduler(optimizer, cfg)
        
        best_auc = 0
        best_epoch = 0
        
        for epoch in range(cfg.epochs):
            print(f"\nEpoch {epoch+1}/{cfg.epochs}")
            
            train_loss, train_auc = train_one_epoch(
                model, 
                train_loader, 
                optimizer, 
                criterion, 
                cfg.device,
                scheduler if isinstance(scheduler, lr_scheduler.OneCycleLR) else None
            )
            
            val_loss, val_auc = validate(model, val_loader, criterion, cfg.device)

            if scheduler is not None and not isinstance(scheduler, lr_scheduler.OneCycleLR):
                if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                else:
                    scheduler.step()

            print(f"Train Loss: {train_loss:.4f}, Train AUC: {train_auc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}")
            
            if val_auc > best_auc:
                best_auc = val_auc
                best_epoch = epoch + 1
                print(f"New best AUC: {best_auc:.4f} at epoch {best_epoch}")

                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                    'epoch': epoch,
                    'val_auc': val_auc,
                    'train_auc': train_auc,
                    'cfg': cfg
                }, f"model_fold{fold}.pth")
        
        best_scores.append(best_auc)
        print(f"\nBest AUC for fold {fold}: {best_auc:.4f} at epoch {best_epoch}")
        best_model_path = os.path.join(
            save_dir,
            f"model_fold{fold}_best_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
        )

        # 仅在验证AUC提升时保存
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'epoch': epoch,
            'val_auc': val_auc,
            'train_auc': train_auc,
            'cfg': cfg
        }, best_model_path)
        
        # Clear memory
        del model, optimizer, scheduler, train_loader, val_loader
        torch.cuda.empty_cache()
        gc.collect()
    
    print("\n" + "="*60)
    print("Cross-Validation Results:")
    for fold, score in enumerate(best_scores):
        print(f"Fold {cfg.selected_folds[fold]}: {score:.4f}")
    print(f"Mean AUC: {np.mean(best_scores):.4f}")
    print("="*60)

In [11]:
if __name__ == "__main__":
    cfg = CFG()
    set_seed(cfg.seed)
    print("\nLoading training data...")
    train_df = pd.read_csv(cfg.train_csv)
    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)

    print("\nStarting training...")
    print(f"LOAD_DATA is set to {cfg.LOAD_DATA}")
    if cfg.LOAD_DATA:
        print("Using pre-computed mel spectrograms from NPY file")
    else:
        print("Will generate spectrograms on-the-fly during training")
    
    run_training(train_df, cfg)
    
    print("\nTraining complete!")


Loading training data...

Starting training...
LOAD_DATA is set to True
Using pre-computed mel spectrograms from NPY file
Loading pre-computed mel spectrograms from NPY file...
Loaded 28564 pre-computed mel spectrograms

Training set: 22851 samples
Validation set: 5713 samples
Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]


Epoch 1/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0367, Train AUC: 0.5860
Val Loss: 0.0257, Val AUC: 0.8236
New best AUC: 0.8236 at epoch 1

Epoch 2/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0238, Train AUC: 0.8111
Val Loss: 0.0216, Val AUC: 0.8871
New best AUC: 0.8871 at epoch 2

Epoch 3/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0198, Train AUC: 0.8952
Val Loss: 0.0186, Val AUC: 0.9188
New best AUC: 0.9188 at epoch 3

Epoch 4/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0171, Train AUC: 0.9310
Val Loss: 0.0171, Val AUC: 0.9335
New best AUC: 0.9335 at epoch 4

Epoch 5/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0149, Train AUC: 0.9532
Val Loss: 0.0162, Val AUC: 0.9400
New best AUC: 0.9400 at epoch 5

Epoch 6/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0129, Train AUC: 0.9715
Val Loss: 0.0160, Val AUC: 0.9431
New best AUC: 0.9431 at epoch 6

Epoch 7/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0111, Train AUC: 0.9825
Val Loss: 0.0158, Val AUC: 0.9424

Epoch 8/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0095, Train AUC: 0.9888
Val Loss: 0.0157, Val AUC: 0.9435
New best AUC: 0.9435 at epoch 8

Epoch 9/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0083, Train AUC: 0.9923
Val Loss: 0.0158, Val AUC: 0.9446
New best AUC: 0.9446 at epoch 9

Epoch 10/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0078, Train AUC: 0.9938
Val Loss: 0.0159, Val AUC: 0.9442

Best AUC for fold 0: 0.9446 at epoch 9

Training set: 22851 samples
Validation set: 5713 samples
Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples

Epoch 1/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0361, Train AUC: 0.5963
Val Loss: 0.0248, Val AUC: 0.8206
New best AUC: 0.8206 at epoch 1

Epoch 2/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0232, Train AUC: 0.8244
Val Loss: 0.0197, Val AUC: 0.9025
New best AUC: 0.9025 at epoch 2

Epoch 3/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0194, Train AUC: 0.8993
Val Loss: 0.0176, Val AUC: 0.9241
New best AUC: 0.9241 at epoch 3

Epoch 4/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0169, Train AUC: 0.9380
Val Loss: 0.0164, Val AUC: 0.9365
New best AUC: 0.9365 at epoch 4

Epoch 5/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0147, Train AUC: 0.9571
Val Loss: 0.0158, Val AUC: 0.9449
New best AUC: 0.9449 at epoch 5

Epoch 6/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0128, Train AUC: 0.9725
Val Loss: 0.0152, Val AUC: 0.9484
New best AUC: 0.9484 at epoch 6

Epoch 7/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0110, Train AUC: 0.9817
Val Loss: 0.0152, Val AUC: 0.9472

Epoch 8/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0094, Train AUC: 0.9889
Val Loss: 0.0152, Val AUC: 0.9499
New best AUC: 0.9499 at epoch 8

Epoch 9/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0083, Train AUC: 0.9923
Val Loss: 0.0151, Val AUC: 0.9493

Epoch 10/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0077, Train AUC: 0.9934
Val Loss: 0.0153, Val AUC: 0.9486

Best AUC for fold 1: 0.9499 at epoch 8

Training set: 22851 samples
Validation set: 5713 samples
Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples

Epoch 1/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0357, Train AUC: 0.6115
Val Loss: 0.0242, Val AUC: 0.8335
New best AUC: 0.8335 at epoch 1

Epoch 2/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0224, Train AUC: 0.8364
Val Loss: 0.0203, Val AUC: 0.8975
New best AUC: 0.8975 at epoch 2

Epoch 3/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0185, Train AUC: 0.9067
Val Loss: 0.0172, Val AUC: 0.9277
New best AUC: 0.9277 at epoch 3

Epoch 4/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0159, Train AUC: 0.9439
Val Loss: 0.0162, Val AUC: 0.9321
New best AUC: 0.9321 at epoch 4

Epoch 5/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0137, Train AUC: 0.9627
Val Loss: 0.0157, Val AUC: 0.9380
New best AUC: 0.9380 at epoch 5

Epoch 6/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0117, Train AUC: 0.9774
Val Loss: 0.0154, Val AUC: 0.9426
New best AUC: 0.9426 at epoch 6

Epoch 7/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0098, Train AUC: 0.9869
Val Loss: 0.0155, Val AUC: 0.9440
New best AUC: 0.9440 at epoch 7

Epoch 8/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0081, Train AUC: 0.9924
Val Loss: 0.0156, Val AUC: 0.9426

Epoch 9/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0070, Train AUC: 0.9946
Val Loss: 0.0159, Val AUC: 0.9415

Epoch 10/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0065, Train AUC: 0.9958
Val Loss: 0.0159, Val AUC: 0.9422

Best AUC for fold 2: 0.9440 at epoch 7

Training set: 22851 samples
Validation set: 5713 samples
Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples

Epoch 1/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0368, Train AUC: 0.5615
Val Loss: 0.0270, Val AUC: 0.7809
New best AUC: 0.7809 at epoch 1

Epoch 2/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0239, Train AUC: 0.7975
Val Loss: 0.0204, Val AUC: 0.8916
New best AUC: 0.8916 at epoch 2

Epoch 3/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0194, Train AUC: 0.8948
Val Loss: 0.0180, Val AUC: 0.9212
New best AUC: 0.9212 at epoch 3

Epoch 4/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0165, Train AUC: 0.9359
Val Loss: 0.0165, Val AUC: 0.9372
New best AUC: 0.9372 at epoch 4

Epoch 5/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0142, Train AUC: 0.9555
Val Loss: 0.0158, Val AUC: 0.9472
New best AUC: 0.9472 at epoch 5

Epoch 6/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0122, Train AUC: 0.9716
Val Loss: 0.0155, Val AUC: 0.9457

Epoch 7/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0104, Train AUC: 0.9839
Val Loss: 0.0155, Val AUC: 0.9489
New best AUC: 0.9489 at epoch 7

Epoch 8/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0088, Train AUC: 0.9899
Val Loss: 0.0155, Val AUC: 0.9495
New best AUC: 0.9495 at epoch 8

Epoch 9/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0077, Train AUC: 0.9931
Val Loss: 0.0157, Val AUC: 0.9496
New best AUC: 0.9496 at epoch 9

Epoch 10/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0071, Train AUC: 0.9945
Val Loss: 0.0156, Val AUC: 0.9493

Best AUC for fold 3: 0.9496 at epoch 9

Training set: 22852 samples
Validation set: 5712 samples
Found 22852 matching spectrograms for train dataset out of 22852 samples
Found 5712 matching spectrograms for valid dataset out of 5712 samples

Epoch 1/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0368, Train AUC: 0.5775
Val Loss: 0.0263, Val AUC: 0.7830
New best AUC: 0.7830 at epoch 1

Epoch 2/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0232, Train AUC: 0.8210
Val Loss: 0.0195, Val AUC: 0.8921
New best AUC: 0.8921 at epoch 2

Epoch 3/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0186, Train AUC: 0.8984
Val Loss: 0.0169, Val AUC: 0.9335
New best AUC: 0.9335 at epoch 3

Epoch 4/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0159, Train AUC: 0.9406
Val Loss: 0.0159, Val AUC: 0.9407
New best AUC: 0.9407 at epoch 4

Epoch 5/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0136, Train AUC: 0.9646
Val Loss: 0.0153, Val AUC: 0.9436
New best AUC: 0.9436 at epoch 5

Epoch 6/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0116, Train AUC: 0.9781
Val Loss: 0.0151, Val AUC: 0.9490
New best AUC: 0.9490 at epoch 6

Epoch 7/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0097, Train AUC: 0.9874
Val Loss: 0.0152, Val AUC: 0.9465

Epoch 8/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0081, Train AUC: 0.9921
Val Loss: 0.0153, Val AUC: 0.9481

Epoch 9/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0070, Train AUC: 0.9950
Val Loss: 0.0155, Val AUC: 0.9472

Epoch 10/10


Training:   0%|          | 0/714 [00:00<?, ?it/s]

Validation:   0%|          | 0/179 [00:00<?, ?it/s]

Train Loss: 0.0065, Train AUC: 0.9960
Val Loss: 0.0155, Val AUC: 0.9469

Best AUC for fold 4: 0.9490 at epoch 6

Cross-Validation Results:
Fold 0: 0.9446
Fold 1: 0.9499
Fold 2: 0.9440
Fold 3: 0.9496
Fold 4: 0.9490
Mean AUC: 0.9474

Training complete!
