In [8]:
import os
import time
import logging
from pathlib import Path
from dataclasses import dataclass
import ast # For safely evaluating string representations of lists

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, train_test_split # Added train_test_split
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR, OneCycleLR
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import timm
import torchvision.transforms as T

In [9]:
@dataclass
class Config:
    train_csv: str = '/kaggle/input/birdclef-2025/train.csv'
    taxonomy_csv: str = '/kaggle/input/birdclef-2025/taxonomy.csv'
    spectrogram_npy: str = '/kaggle/input/falcon-birdclef-cnn-preprocessed-dataset/falcon_birdclef_cnn_preprocessed_dataset.npy'
    train_datadir: str = '/kaggle/input/birdclef-2025/train_audio'
    LOAD_DATA: bool = True
    
    val_split_ratio: float = 0.2 # Ratio of data to use for validation (e.g., 0.2 for 20%)
    
    seed: int = 42
    debug: bool = False
    batch_size: int = 32 
    num_workers: int = 0
    epochs: int = 10 
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    model_name: str = 'vit_base_patch16_224'
    img_size: int = 224
    pretrained: bool = True
    in_channels: int = 1
    
    optimizer: str = 'AdamW'
    lr: float = 1e-4 
    scheduler: str = 'CosineAnnealingLR'
    T_max: int = 10 
    min_lr: float = 1e-6
    weight_decay: float = 1e-5

    criterion: str = 'BCEWithLogitsLoss'

cfg = Config()
if cfg.debug:
    cfg.epochs = 2
    # For debug, maybe use a smaller val_split_ratio if the dataset subset is very small
    # cfg.val_split_ratio = 0.5 # Example for tiny debug dataset
cfg.T_max = cfg.epochs
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [10]:
def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

if cfg.LOAD_DATA:
    logging.info("Loading spectrograms...")
    spec_path = Path(cfg.spectrogram_npy)
    if spec_path.exists():
        spectrograms = np.load(spec_path, allow_pickle=True).item()
        logging.info(f"Loaded {len(spectrograms)} spectrograms from {spec_path}")
    else:
        logging.error(f"Spectrogram file not found: {spec_path}")
        spectrograms = {}
else:
    logging.info("LOAD_DATA is False. Skipping spectrogram loading.")
    spectrograms = {}

In [11]:
class SpectrogramDataset(Dataset):
    def __init__(self, df: pd.DataFrame, cfg_obj: Config, specs: dict, mode: str = 'train'):
        self.df = df.copy()
        self.specs = specs
        self.cfg = cfg_obj
        self.mode = mode
        self.missing_specs_count = 0

        if 'sample_key' not in self.df.columns:
            self.df['sample_key'] = (
                self.df.filename
                .str.replace('/', '_', regex=False)
                .str.replace('.wav', '', regex=False)
            )

        try:
            taxonomy = pd.read_csv(self.cfg.taxonomy_csv)
        except FileNotFoundError:
            logging.error(f"Taxonomy CSV not found at {self.cfg.taxonomy_csv}.")
            self.label_to_idx = {}
            self.num_classes = 0
            raise

        labels = taxonomy['primary_label'].unique().tolist()
        self.label_to_idx = {lbl: idx for idx, lbl in enumerate(labels)}
        self.num_classes = len(labels)
        if self.num_classes == 0:
            logging.warning("No classes found in taxonomy.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        key = row['sample_key']
        spec_data = self.specs.get(key)

        if spec_data is None:
            spec_data = np.zeros((self.cfg.in_channels, self.cfg.img_size, self.cfg.img_size), dtype=np.float32)
            self.missing_specs_count +=1
            if self.missing_specs_count < 10 or self.missing_specs_count % 100 == 0:
                 logging.warning(f"Spectrogram for key '{key}' not found. Using zero tensor. Total missing: {self.missing_specs_count}")
        else:
            if spec_data.ndim == 2:
                spec_data = np.expand_dims(spec_data, axis=0)

        spec = torch.tensor(spec_data, dtype=torch.float32)

        if spec.shape[1:] != (self.cfg.img_size, self.cfg.img_size):
            spec = T.Resize([self.cfg.img_size, self.cfg.img_size], antialias=True)(spec)
        
        target = np.zeros(self.num_classes, dtype=np.float32)
        primary_label = row['primary_label']
        if primary_label in self.label_to_idx:
            target[self.label_to_idx[primary_label]] = 1.0

        sec_labels_str = row.get('secondary_labels', '[]')
        if isinstance(sec_labels_str, str) and sec_labels_str and sec_labels_str != '[]':
            try:
                secondary_labels_list = ast.literal_eval(sec_labels_str)
                for s_label in secondary_labels_list:
                    if s_label in self.label_to_idx:
                        target[self.label_to_idx[s_label]] = 1.0
            except (ValueError, SyntaxError) as e:
                logging.error(f"Error parsing secondary_labels '{sec_labels_str}': {e}")

        target = torch.tensor(target, dtype=torch.float32)
        return spec, target

def collate_specs(batch):
    specs, targets = zip(*batch)
    specs = torch.stack(specs)
    targets = torch.stack(targets)
    return specs, targets

In [12]:
class CLEFClassifier(nn.Module):
    def __init__(self, cfg_obj: Config, num_classes: int):
        super().__init__()
        self.cfg = cfg_obj
        self.num_classes = num_classes

        model_kwargs = {
            'pretrained': self.cfg.pretrained,
            'in_chans': self.cfg.in_channels,
            'num_classes': 0
        }
        
        if 'vit' in self.cfg.model_name.lower() or \
           'swin' in self.cfg.model_name.lower() or \
           'convnext' in self.cfg.model_name.lower():
            model_kwargs['img_size'] = self.cfg.img_size

        self.encoder = timm.create_model(
            self.cfg.model_name,
            **model_kwargs
        )
        
        self.pool = nn.AdaptiveAvgPool2d(1) 
        feat_dim = self.encoder.num_features
        self.head = nn.Linear(feat_dim, self.num_classes)

    def forward(self, x):
        feats = self.encoder(x) 
        if feats.dim() == 4:
            feats = self.pool(feats)
            feats = feats.view(feats.size(0), -1)
        return self.head(feats)

In [13]:
def make_optimizer(model, cfg_obj: Config):
    if cfg_obj.optimizer == 'AdamW':
        return AdamW(model.parameters(), lr=cfg_obj.lr, weight_decay=cfg_obj.weight_decay)
    elif cfg_obj.optimizer == 'Adam':
        return Adam(model.parameters(), lr=cfg_obj.lr, weight_decay=cfg_obj.weight_decay)
    elif cfg_obj.optimizer == 'SGD':
        return SGD(model.parameters(), lr=cfg_obj.lr, weight_decay=cfg_obj.weight_decay, momentum=0.9)
    else:
        raise ValueError(f'Unknown optimizer: {cfg_obj.optimizer}')

def make_scheduler(opt, cfg_obj: Config, steps_per_epoch_for_onecycle: int = 0):
    if cfg_obj.scheduler == 'CosineAnnealingLR':
        return CosineAnnealingLR(opt, T_max=cfg_obj.T_max, eta_min=cfg_obj.min_lr)
    elif cfg_obj.scheduler == 'ReduceLROnPlateau':
        return ReduceLROnPlateau(opt, mode='max', factor=0.1, patience=2, min_lr=cfg_obj.min_lr, verbose=True)
    elif cfg_obj.scheduler == 'StepLR':
        return StepLR(opt, step_size=cfg_obj.epochs//3, gamma=0.1)
    elif cfg_obj.scheduler == 'OneCycleLR':
        if steps_per_epoch_for_onecycle == 0:
            raise ValueError("steps_per_epoch_for_onecycle must be provided for OneCycleLR")
        return OneCycleLR(opt, max_lr=cfg_obj.lr,
                          steps_per_epoch=steps_per_epoch_for_onecycle,
                          epochs=cfg_obj.epochs,
                          pct_start=0.3, div_factor=25, final_div_factor=1e4)
    elif cfg_obj.scheduler is None or cfg_obj.scheduler.lower() == 'none':
        return None
    else:
        raise ValueError(f'Unknown scheduler: {cfg_obj.scheduler}')

def make_loss(cfg_obj: Config):
    if cfg_obj.criterion == 'BCEWithLogitsLoss':
        return nn.BCEWithLogitsLoss()
    else:
        raise ValueError(f'Unknown criterion: {cfg_obj.criterion}')

def epoch_step(model, loader, opt, loss_fn, device, scheduler=None, train=True, cfg_obj: Config = None):
    model.train() if train else model.eval()
    epoch_losses, all_targets, all_predictions = [], [], []

    progress_bar_desc = 'Train' if train else 'Valid'
    if loader is None or len(loader) == 0: # Handle cases with no validation loader
        if not train: # Only return if it's a validation step with no loader
            return 0.0, 0.0 # Return dummy values for loss and AUC
        # If it's a training step with no loader, something is wrong, but proceed cautiously
        # Or raise an error: raise ValueError("Training loader cannot be None or empty")

    loader_pbar = tqdm(loader, desc=progress_bar_desc, leave=False) if loader else []


    for batch in loader_pbar:
        specs, targets_batch = batch
        specs, targets_batch = specs.to(device), targets_batch.to(device)

        if train:
            opt.zero_grad()
            preds = model(specs)
            loss = loss_fn(preds, targets_batch)
            loss.backward()
            opt.step()
            if scheduler and isinstance(scheduler, OneCycleLR):
                scheduler.step()
        else:
            with torch.no_grad():
                preds = model(specs)
                loss = loss_fn(preds, targets_batch)

        epoch_losses.append(loss.item())
        all_targets.append(targets_batch.detach().cpu().numpy())
        all_predictions.append(torch.sigmoid(preds).detach().cpu().numpy())
        if hasattr(loader_pbar, 'set_postfix'): # Check if loader_pbar is a tqdm object
             loader_pbar.set_postfix(loss=loss.item())
    
    if not epoch_losses: # If no batches were processed (empty loader)
        return 0.0, 0.0


    all_targets_np = np.vstack(all_targets)
    all_predictions_np = np.vstack(all_predictions)
    avg_epoch_loss = np.mean(epoch_losses)
    class_auc_scores = []

    for i in range(all_targets_np.shape[1]):
        if np.sum(all_targets_np[:, i]) > 0 and len(np.unique(all_targets_np[:, i])) > 1:
            try:
                class_auc = roc_auc_score(all_targets_np[:, i], all_predictions_np[:, i])
                class_auc_scores.append(class_auc)
            except ValueError:
                class_auc_scores.append(np.nan)
        else:
            class_auc_scores.append(np.nan)
            
    mean_auc = np.nanmean(class_auc_scores) if len(class_auc_scores) > 0 else 0.0
    if np.isnan(mean_auc): mean_auc = 0.0
    return avg_epoch_loss, mean_auc

def train_final_model(df: pd.DataFrame, cfg_obj: Config, loaded_spectrograms: dict):
    logging.info("Starting training on a single train/validation split.")

    if 'primary_label' not in df.columns:
        logging.error("Column 'primary_label' not found for train/val split stratification.")
        return None
    if not Path(cfg_obj.taxonomy_csv).exists():
        logging.error(f"Taxonomy CSV {cfg_obj.taxonomy_csv} not found.")
        return None
    try:
        temp_taxonomy = pd.read_csv(cfg_obj.taxonomy_csv)
        num_classes = len(temp_taxonomy['primary_label'].unique())
        if num_classes == 0:
            logging.error("No classes in taxonomy.")
            return None
    except Exception as e:
        logging.error(f"Error reading taxonomy: {e}.")
        return None

    # Create train/validation split
    train_df, val_df = train_test_split(
        df,
        test_size=cfg_obj.val_split_ratio,
        random_state=cfg_obj.seed,
        stratify=df['primary_label'] if 'primary_label' in df else None # Stratify if possible
    )
    logging.info(f"Training data: {len(train_df)} samples, Validation data: {len(val_df)} samples")

    train_ds = SpectrogramDataset(train_df, cfg_obj, loaded_spectrograms, 'train')
    val_ds = SpectrogramDataset(val_df, cfg_obj, loaded_spectrograms, 'valid')

    if train_ds.missing_specs_count > 0:
        logging.warning(f"Train DS: {train_ds.missing_specs_count}/{len(train_ds)} missing specs.")
    if val_ds.missing_specs_count > 0:
        logging.warning(f"Valid DS: {val_ds.missing_specs_count}/{len(val_ds)} missing specs.")
    if train_ds.num_classes == 0:
        logging.error("Train Dataset has 0 classes. Aborting.")
        return None

    tr_loader = DataLoader(train_ds, batch_size=cfg_obj.batch_size, shuffle=True, num_workers=cfg_obj.num_workers, collate_fn=collate_specs, pin_memory=(cfg_obj.device.startswith('cuda')), persistent_workers=(cfg_obj.num_workers > 0))
    v_loader  = DataLoader(val_ds, batch_size=cfg_obj.batch_size * 2, shuffle=False, num_workers=cfg_obj.num_workers, collate_fn=collate_specs, pin_memory=(cfg_obj.device.startswith('cuda')), persistent_workers=(cfg_obj.num_workers > 0))

    model = CLEFClassifier(cfg_obj, num_classes=train_ds.num_classes).to(cfg_obj.device)
    
    if torch.cuda.device_count() > 1 and cfg_obj.device.startswith('cuda'):
        logging.info(f"Using {torch.cuda.device_count()} GPUs via nn.DataParallel.")
        model = nn.DataParallel(model)

    opt = make_optimizer(model, cfg_obj)
    sch = make_scheduler(opt, cfg_obj, steps_per_epoch_for_onecycle=len(tr_loader))
    loss_fn = make_loss(cfg_obj)
    
    best_val_auc = 0.0
    best_epoch = -1
    model_save_path = "final_model_best_auc.pt"
    
    epochs_pbar = tqdm(range(cfg_obj.epochs), desc="Training Epochs")

    for epoch in epochs_pbar:
        train_loss, train_auc = epoch_step(model, tr_loader, opt, loss_fn, cfg_obj.device, scheduler=(sch if cfg_obj.scheduler == 'OneCycleLR' else None), train=True, cfg_obj=cfg_obj)
        valid_loss, valid_auc = epoch_step(model, v_loader, None, loss_fn, cfg_obj.device, train=False, cfg_obj=cfg_obj)
        epochs_pbar.set_postfix(TrainLoss=f"{train_loss:.4f}", TrainAUC=f"{train_auc:.4f}", ValidLoss=f"{valid_loss:.4f}", ValidAUC=f"{valid_auc:.4f}")

        if sch and not isinstance(sch, OneCycleLR):
            if isinstance(sch, ReduceLROnPlateau): sch.step(valid_auc)
            else: sch.step()
        
        if valid_auc > best_val_auc:
            best_val_auc = valid_auc
            best_epoch = epoch + 1
            current_model_state_dict = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(current_model_state_dict, model_save_path)
            logging.info(f"  Epoch {epoch+1}: New best Valid AUC: {valid_auc:.4f}. Model saved to {model_save_path}")
    
    logging.info(f"Training complete. Best Valid AUC: {best_val_auc:.4f} at epoch {best_epoch}")
    
    best_model_info = {
        'model_path': model_save_path,
        'best_val_auc': best_val_auc,
        'best_epoch': best_epoch
    }
    
    del model, opt, sch, tr_loader, v_loader, train_ds, val_ds
    if cfg_obj.device == 'cuda': torch.cuda.empty_cache()
        
    return best_model_info

In [None]:
if __name__ == '__main__':
    start_time = time.time()
    logging.info(f"Using device: {cfg.device}")
    if cfg.device.startswith("cuda"):
        logging.info(f"CUDA available. Device count: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            logging.info(f"  GPU {i}: {torch.cuda.get_device_name(i)}")

    try:
        main_df = pd.read_csv(cfg.train_csv)
        logging.info(f"Loaded train_csv: {len(main_df)} rows.")
    except FileNotFoundError:
        logging.error(f"Train CSV {cfg.train_csv} not found.")
        main_df = None

    if main_df is not None and not main_df.empty and (spectrograms if 'spectrograms' in globals() else True) :
        if not spectrograms and cfg.LOAD_DATA:
             logging.error("Spectrograms dictionary is empty despite LOAD_DATA=True. Aborting.")
             exit()

        if cfg.debug:
            logging.info("DEBUG mode: Using small subset.")
            # Ensure enough samples for train_test_split, especially if stratifying
            min_samples_for_split = 2 
            if 'primary_label' in main_df.columns:
                 # Check if stratification is possible with debug sample size
                 label_counts = main_df['primary_label'].value_counts()
                 if any(label_counts < min_samples_for_split): # Need at least 2 samples of each class for stratification if val_split_ratio < 0.5
                     logging.warning("DEBUG: Some classes have less than 2 samples. Stratification in train_test_split might be problematic or disabled.")

            current_len = len(main_df)
            debug_sample_size = min(1000, current_len)
            if current_len > debug_sample_size :
                 # Simplified sampling for debug to avoid issues with very small classes for train_test_split
                main_df = main_df.sample(n=debug_sample_size, random_state=cfg.seed, replace=False if current_len >= debug_sample_size else True).reset_index(drop=True)

            logging.info(f"Debug DataFrame size: {len(main_df)}.")

        # --- Call the new training function ---
        trained_model_info = train_final_model(main_df, cfg, spectrograms)
        # ------------------------------------

        if trained_model_info and Path(trained_model_info['model_path']).exists():
            logging.info(f"\n--- Final Model Preparation ---")
            best_model_path = trained_model_info['model_path']
            logging.info(f"Preparing final model from: {best_model_path}")
            
            try:
                taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
                num_classes_final = len(taxonomy_df['primary_label'].unique())
            except Exception as e:
                logging.error(f"Cannot read taxonomy for final model: {e}")
                num_classes_final = 0

            if num_classes_final > 0:
                final_model_base = CLEFClassifier(cfg, num_classes=num_classes_final)
                try:
                    state_dict_to_load = torch.load(best_model_path, map_location='cpu')
                    final_model_base.load_state_dict(state_dict_to_load)
                    final_model_to_save = final_model_base.to(cfg.device) 
                    
                    final_model_save_path = 'trained_model_final.pt' # New name for the final model
                    torch.save(final_model_to_save.state_dict(), final_model_save_path)
                    logging.info(f"Successfully saved final trained model to {final_model_save_path} (from epoch {trained_model_info['best_epoch']}, Val AUC: {trained_model_info['best_val_auc']:.4f})")
                except Exception as e:
                    logging.error(f"Error loading/saving the final model: {e}")
            else:
                logging.error("Cannot instantiate final model (num_classes=0).")
        else:
            logging.warning("Training did not yield a saved model or model_info is missing.")
            
    elif main_df is None or main_df.empty:
        logging.error("Main DataFrame empty. Training aborted.")
    elif not spectrograms and cfg.LOAD_DATA :
        logging.error("Spectrograms empty. Training aborted.")
        
    end_time = time.time()
    logging.info(f"Total execution time: {(end_time - start_time)/60:.2f} minutes")

Training Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Train:   0%|          | 0/715 [00:00<?, ?it/s]

Valid:   0%|          | 0/90 [00:00<?, ?it/s]

Train:   0%|          | 0/715 [00:00<?, ?it/s]

Valid:   0%|          | 0/90 [00:00<?, ?it/s]

Train:   0%|          | 0/715 [00:00<?, ?it/s]