In [2]:
import os
import gc
import warnings
import logging
import time
import math
import cv2
from pathlib import Path
import numpy as np
import pandas as pd
import librosa
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm.auto import tqdm
from glob import glob
import random
import itertools
from typing import Union
import concurrent.futures

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

# ==========================================
# 1. Configuration
# ==========================================
class CFG:
    seed = 42
    num_workers = 2
    
    # 路径设置
    train_datadir = '/kaggle/input/birdclef-2025/train_audio'
    train_csv = '/kaggle/input/birdclef-2025/train.csv'
    test_soundscapes = '/kaggle/input/birdclef-2025/test_soundscapes'
    submission_csv = '/kaggle/input/birdclef-2025/sample_submission.csv'
    taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'
    
    # [关键] 修改这里为你上传的模型路径
    # 假设你把 best_model.pth 上传到了一个叫 birdclef-my-model 的数据集里
    model_files = [
        '/kaggle/input/sed-baseline/pytorch/default/1/best_model.pth' 
    ]
 
    # 模型参数 (必须与训练时一致)
    model_name = 'efficientnet_b0'  
    pretrained = False
    in_channels = 1
    
    # 音频参数 (必须与训练时一致)
    SR = 32000
    target_duration = 5 # 5秒切片
    
    # MelSpectrogram 参数
    n_fft = 1024
    hop_length = 512
    n_mels = 128
    f_min = 50
    f_max = 14000
    target_shape = (256, 256)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

cfg = CFG()

print(f"Using device: {cfg.device}")
print(f"Loading taxonomy data...")
# 如果本地没有 taxonomy (比如在测试只有 sample_submission 的时候)，做个兼容
if os.path.exists(cfg.taxonomy_csv):
    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
    species_ids = taxonomy_df['primary_label'].tolist()
else:
    # Fallback: read from sample_submission
    ss = pd.read_csv(cfg.submission_csv)
    species_ids = [c for c in ss.columns if c != 'row_id']

num_classes = len(species_ids)
print(f"Number of classes: {num_classes}")

# ==========================================
# 2. Utilities
# ==========================================
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

# ==========================================
# 3. Model Definition (Matching your training)
# ==========================================
class BirdCLEFModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        # 骨干网络
        self.backbone = timm.create_model(
            cfg.model_name,
            pretrained=False, # 推理时无法联网，设为 False
            in_chans=cfg.in_channels
        )

        # 替换分类头
        if 'efficientnet' in cfg.model_name:
            backbone_out = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        else:
            backbone_out = self.backbone.num_features
            self.backbone.reset_classifier(0, '')
            
        self.pooling = nn.AdaptiveAvgPool2d(1)
        
        # 这里的 num_classes 需要从全局变量或 cfg 获取
        global num_classes
        self.classifier = nn.Linear(backbone_out, num_classes)

    def forward(self, x):
        # x: (batch, 1, freq, time)
        x = self.backbone(x)
        
        # 处理 timm 输出可能的 dict
        if isinstance(x, dict):
            x = x['features']
            
        # Global Average Pooling
        if len(x.shape) == 4:
            x = self.pooling(x)
            x = x.view(x.size(0), -1)
            
        logits = self.classifier(x)
        return logits

# ==========================================
# 4. Feature Extraction (Librosa)
# ==========================================
class LogMelFeatureExtractor:
    def __init__(self, cfg):
        self.cfg = cfg
    
    def __call__(self, audio_data):
        # audio_data: numpy array (samples,)
        mel_spec = librosa.feature.melspectrogram(
            y=audio_data, 
            sr=self.cfg.SR, 
            n_fft=self.cfg.n_fft, 
            hop_length=self.cfg.hop_length, 
            n_mels=self.cfg.n_mels, 
            fmin=self.cfg.f_min, 
            fmax=self.cfg.f_max, 
            power=2.0
        )
        # Log Scale
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Min-Max Normalization to [0, 1]
        mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)
        
        # Resize to fixed shape (256, 256)
        if mel_spec_norm.shape != self.cfg.target_shape:
            mel_spec_norm = cv2.resize(mel_spec_norm, self.cfg.target_shape, interpolation=cv2.INTER_LINEAR)
            
        return mel_spec_norm.astype(np.float32)

# ==========================================
# 5. Audio Loading & Slicing
# ==========================================
def load_and_slice_audio(path, cfg):
    """
    读取音频并切分为 5秒的片段
    """
    try:
        audio, _ = librosa.load(path, sr=cfg.SR)
    except Exception as e:
        print(f"Error loading {path}: {e}")
        return [], []

    # 计算需要的切片数
    chunk_len = int(cfg.target_duration * cfg.SR)
    # math.ceil 向上取整
    num_chunks = math.ceil(len(audio) / chunk_len)
    
    # 补齐长度
    target_len = num_chunks * chunk_len
    if len(audio) < target_len:
        audio = np.pad(audio, (0, target_len - len(audio)))
        
    segments = []
    end_seconds = []
    
    for i in range(num_chunks):
        seg = audio[i*chunk_len : (i+1)*chunk_len]
        segments.append(seg)
        end_seconds.append((i+1) * cfg.target_duration)
        
    return segments, end_seconds

# ==========================================
# 6. Model Loading
# ==========================================
def load_models(cfg):
    models = []
    model_files = cfg.model_files
    
    if not model_files:
        print(f"Warning: No model files found!")
        return models
    
    print(f"Found a total of {len(model_files)} model files.")
    
    for model_path in model_files:
        if not os.path.exists(model_path):
            print(f"Path does not exist: {model_path}")
            continue
            
        try:
            print(f"Loading model: {model_path}")
            # [关键] weights_only=False 避免 CFG 类报错
            checkpoint = torch.load(model_path, map_location=torch.device(cfg.device), weights_only=False)
            
            # 初始化模型
            model = BirdCLEFModel(cfg)
            
            # 加载权重
            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            else:
                state_dict = checkpoint
            
            model.load_state_dict(state_dict)
            model = model.to(cfg.device)
            model.eval()
            
            # 半精度加速 (可选，如果 GPU 支持)
            # model.half() 
            
            models.append(model)
        except Exception as e:
            print(f"Error loading model {model_path}: {e}")
            import traceback
            traceback.print_exc()
    
    return models

# ==========================================
# 7. Inference Function (Per File)
# ==========================================
def predict_on_file(audio_path, models, cfg, feature_extractor):
    """
    处理单个音频文件
    """
    audio_path = str(audio_path)
    row_ids = []
    predictions = []
    soundscape_id = Path(audio_path).stem.split('.')[0] # 只要文件名部分

    # 1. 加载并切片音频
    segments, seconds = load_and_slice_audio(audio_path, cfg)
    
    if len(segments) == 0:
        return [], []

    # 2. 预处理 (Feature Extraction)
    # 批量处理以加速 (但在 CPU 上可能还是串行比较稳妥)
    batch_imgs = []
    for seg in segments:
        spec = feature_extractor(seg) # (256, 256)
        batch_imgs.append(spec)
    
    # 转为 Tensor: (Batch, 1, 256, 256)
    batch_tensor = torch.tensor(np.array(batch_imgs), dtype=torch.float32).unsqueeze(1).to(cfg.device)
    
    # 3. 推理
    file_preds = []
    
    # 这里的 models 是一个列表（哪怕只有一个模型），方便以后做 Ensemble
    for model in models:
        with torch.no_grad():
            # (Batch, Num_Classes)
            logits = model(batch_tensor)
            probs = torch.sigmoid(logits).cpu().numpy()
            file_preds.append(probs)
    
    # 如果有多个模型，取平均
    final_preds = np.mean(file_preds, axis=0) # (Batch, Num_Classes)
    
    # 4. 生成 Row IDs
    for i, sec in enumerate(seconds):
        row_id = f"{soundscape_id}_{int(sec)}"
        row_ids.append(row_id)
        predictions.append(final_preds[i])
        
    return row_ids, predictions

# ==========================================
# 8. Main Inference Loop
# ==========================================
def run_inference(cfg, models):
    # 1. 查找测试文件
    test_files = sorted(list(Path(cfg.test_soundscapes).glob('*.ogg')))
    
    # Commit 阶段如果没有文件，用训练集的前几个充数，或者直接返回
    if len(test_files) == 0:
        print("No test files found (Commit Stage). Using dummy files if available or skipping.")
        # 这里为了防止空报错，通常会生成 dummy submission
        return [], []
    
    print(f"Found {len(test_files)} test soundscapes")

    all_row_ids = []
    all_predictions = []
    
    feature_extractor = LogMelFeatureExtractor(cfg)

    # 2. 并行处理
    # max_workers 根据 CPU 核心数调整，音频处理是 IO 密集型 + CPU 密集型
    with concurrent.futures.ThreadPoolExecutor(max_workers=cfg.num_workers) as executor:
        results = list(
            tqdm(
                executor.map(
                    predict_on_file,
                    test_files,
                    itertools.repeat(models),
                    itertools.repeat(cfg),
                    itertools.repeat(feature_extractor)
                ),
                total=len(test_files),
                desc="Inferencing"
            )
        )

    # 3. 汇总结果
    for rids, preds in results:
        all_row_ids.extend(rids)
        all_predictions.extend(preds)
    
    return all_row_ids, all_predictions

# ==========================================
# 9. Submission & Smoothing
# ==========================================
def create_submission(row_ids, predictions, species_ids, cfg):
    print("Creating submission dataframe...")
    
    if len(row_ids) == 0:
        # Dummy submission creation
        print("Generating dummy submission structure.")
        sample_sub = pd.read_csv(cfg.submission_csv)
        sample_sub.to_csv("submission.csv", index=False)
        return

    # 构建 DataFrame
    # predictions 是一个 list of arrays
    preds_np = np.array(predictions)
    
    submission_df = pd.DataFrame(preds_np, columns=species_ids)
    submission_df.insert(0, 'row_id', row_ids)

    # 对齐列名 (以防万一)
    if os.path.exists(cfg.submission_csv):
        sample_sub = pd.read_csv(cfg.submission_csv)
        # 补全缺失列
        missing_cols = set(sample_sub.columns) - set(submission_df.columns)
        for col in missing_cols:
            submission_df[col] = 0.0
        # 排序
        submission_df = submission_df[sample_sub.columns]
    
    return submission_df

def smooth_submission(submission_path):
    """
    后处理平滑：如果前一秒和后一秒都有鸟叫，中间大概率也是。
    (0.8 * curr) + (0.2 * neighbor)
    """
    print("Smoothing submission predictions...")
    if not os.path.exists(submission_path): return

    sub = pd.read_csv(submission_path)
    if len(sub) == 0: return

    cols = sub.columns[1:]
    # 提取 soundscape 文件名作为 group key
    groups = sub['row_id'].str.rsplit('_', n=1).str[0].values
    unique_groups = np.unique(groups)
    
    for group in unique_groups:
        idx = np.where(groups == group)[0]
        # 如果该文件只有一个切片，无法平滑
        if len(idx) <= 1: continue

        sub_group = sub.iloc[idx].copy()
        predictions = sub_group[cols].values.astype(float)
        new_predictions = predictions.copy()
        
        # 首尾特殊处理
        new_predictions[0] = (predictions[0] * 0.8) + (predictions[1] * 0.2)
        new_predictions[-1] = (predictions[-1] * 0.8) + (predictions[-2] * 0.2)
        
        # 中间部分处理
        # curr * 0.6 + prev * 0.2 + next * 0.2
        for i in range(1, predictions.shape[0]-1):
            new_predictions[i] = (predictions[i-1] * 0.2) + (predictions[i] * 0.6) + (predictions[i+1] * 0.2)
        
        # 回写
        sub.iloc[idx, 1:] = new_predictions
    
    sub.to_csv(submission_path, index=False)
    print(f"Smoothed submission saved to {submission_path}")

# ==========================================
# 10. Main Execution
# ==========================================
def main():
    start_time = time.time()
    print("Starting BirdCLEF-2025 inference...")

    # 1. 加载模型
    models = load_models(cfg)
    
    if not models:
        print("No models found! Please check model paths.")
        # 生成一个空的以防止报错
        create_submission([], [], species_ids, cfg)
        return
    
    print(f"Model usage: {'Single model' if len(models) == 1 else f'Ensemble of {len(models)} models'}")

    # 2. 运行推理
    row_ids, predictions = run_inference(cfg, models)

    # 3. 生成 CSV
    submission_df = create_submission(row_ids, predictions, species_ids, cfg)
    
    if submission_df is not None:
        submission_path = 'submission.csv'
        submission_df.to_csv(submission_path, index=False)
        print(f"Submission saved to {submission_path}")

        # 4. 平滑结果
        smooth_submission(submission_path)
    
    end_time = time.time()
    print(f"Inference completed in {(end_time - start_time)/60:.2f} minutes")
    
    # 打印前几行看看
    if os.path.exists("submission.csv"):
        print(pd.read_csv("submission.csv").head())

if __name__ == "__main__":
    main()

Using device: cpu
Loading taxonomy data...
Number of classes: 206
Starting BirdCLEF-2025 inference...
Found a total of 1 model files.
Loading model: /kaggle/input/bird2025-sed-ckpt/sedmodel.pth
Error loading model /kaggle/input/bird2025-sed-ckpt/sedmodel.pth: Error(s) in loading state_dict for BirdCLEFModel:
	Missing key(s) in state_dict: "backbone.conv_stem.weight", "backbone.blocks.0.0.conv_dw.weight", "backbone.blocks.0.0.bn1.weight", "backbone.blocks.0.0.bn1.bias", "backbone.blocks.0.0.bn1.running_mean", "backbone.blocks.0.0.bn1.running_var", "backbone.blocks.0.0.se.conv_reduce.weight", "backbone.blocks.0.0.se.conv_reduce.bias", "backbone.blocks.0.0.se.conv_expand.weight", "backbone.blocks.0.0.se.conv_expand.bias", "backbone.blocks.0.0.conv_pw.weight", "backbone.blocks.0.0.bn2.weight", "backbone.blocks.0.0.bn2.bias", "backbone.blocks.0.0.bn2.running_mean", "backbone.blocks.0.0.bn2.running_var", "backbone.blocks.1.0.conv_pw.weight", "backbone.blocks.1.0.bn1.weight", "backbone.blocks

Traceback (most recent call last):
  File "/tmp/ipykernel_48/2018984968.py", line 237, in load_models
    model.load_state_dict(state_dict)
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for BirdCLEFModel:
	Missing key(s) in state_dict: "backbone.conv_stem.weight", "backbone.blocks.0.0.conv_dw.weight", "backbone.blocks.0.0.bn1.weight", "backbone.blocks.0.0.bn1.bias", "backbone.blocks.0.0.bn1.running_mean", "backbone.blocks.0.0.bn1.running_var", "backbone.blocks.0.0.se.conv_reduce.weight", "backbone.blocks.0.0.se.conv_reduce.bias", "backbone.blocks.0.0.se.conv_expand.weight", "backbone.blocks.0.0.se.conv_expand.bias", "backbone.blocks.0.0.conv_pw.weight", "backbone.blocks.0.0.bn2.weight", "backbone.blocks.0.0.bn2.bias", "backbone.blocks.0.0.bn2.running_mean", "backbone.blocks.0.0.bn2.running_var", "backbone.blocks.1.0.conv_pw.weight", "backbone.blocks.1.0.bn1.