# **Imports**

In [None]:
import os
import logging
import random
import gc
import time
import cv2
import math
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
import librosa

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import timm

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

# **Checkpointss**

In [None]:
def save_checkpoint(model, optimizer, epoch, loss, best_auc, path="checkpoint.pth"):
    torch.save({
        'epoch':       epoch,
        'model_state_dict':    model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss':        loss,
        'best_auc':    best_auc
    }, path)

    
def load_checkpoint(model, optimizer, path):
    ck = torch.load(path, weights_only=False)
    model.load_state_dict(ck['model_state_dict'])
    optimizer.load_state_dict(ck['optimizer_state_dict'])
    start_epoch = ck['epoch'] + 1
    loss        = ck['loss']
    best_auc    = ck.get('best_auc', 0.0)
    return model, optimizer, start_epoch, loss, best_auc

# **Configs**

In [None]:
class CFG:
    
    seed = 42
    debug = False # changed this to false  
    apex = False
    print_freq = 100
    num_workers = 2
    
    OUTPUT_DIR = '/home/huynhw/koa_scratch/kaggle635/kaggleoutput'

    train_datadir = '/kaggle/input/birdclef-2025/train_audio'
    train_csv = '/home/huynhw/koa_scratch/kaggle635/train.csv'
    test_soundscapes = '/home/huynhw/koa_scratch/kaggle635/test_soundscapes'
    submission_csv = '/home/huynhw/koa_scratch/kaggle635/sample_submission.csv'
    taxonomy_csv = '/home/huynhw/koa_scratch/kaggle635/taxonomy.csv'

    spectrogram_npy = '/home/huynhw/koa_scratch/kaggle635/birdclef2025_melspec_5sec_256_256.npy'
 
    model_name = 'BiLSTM'  
    pretrained = True
    in_channels = 1

    resume = True
    checkpoint_path = "checkpoint.pth"
    
    LOAD_DATA = True  
    FS = 32000
    TARGET_DURATION = 5.0
    TARGET_SHAPE = (256, 256)
    
    N_FFT = 1280
    HOP_LENGTH = 640
    N_MELS = 128
    FMIN = 50
    FMAX = 14000
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    epochs = 10  
    batch_size = 16  
    criterion = 'BCEWithLogitsLoss'

    n_fold = 5
    selected_folds = [0, 1, 2, 3, 4]   

    optimizer = 'AdamW'
    lr = 5e-4 
    weight_decay = 1e-5
  
    scheduler = 'CosineAnnealingLR'
    min_lr = 1e-6
    T_max = epochs

    aug_prob = 0.5  
    mixup_alpha = 0.5  
    input_dim = 256
    
    def update_debug_settings(self):
        if self.debug:
            self.epochs = 2
            self.selected_folds = [0]

cfg = CFG()

# **Utilities**

In [None]:
def set_seed(seed=42):
    """
    Set seed for reproducibility
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(cfg.seed)

# **Preprocessing**


In [None]:
def audio2melspec(audio_data, cfg):
    """Convert audio data to mel spectrogram"""
    if np.isnan(audio_data).any():
        mean_signal = np.nanmean(audio_data)
        audio_data = np.nan_to_num(audio_data, nan=mean_signal)

    mel_spec = librosa.feature.melspectrogram(
        y=audio_data,
        sr=cfg.FS,
        n_fft=cfg.N_FFT,
        hop_length=cfg.HOP_LENGTH,
        n_mels=cfg.N_MELS,
        fmin=cfg.FMIN,
        fmax=cfg.FMAX,
        power=2.0
    )

    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)
    
    return mel_spec_norm

def process_audio_file(audio_path, cfg):
    """Process a single audio file to get the mel spectrogram"""
    try:
        audio_data, _ = librosa.load(audio_path, sr=cfg.FS)

        target_samples = int(cfg.TARGET_DURATION * cfg.FS)

        if len(audio_data) < target_samples:
            n_copy = math.ceil(target_samples / len(audio_data))
            if n_copy > 1:
                audio_data = np.concatenate([audio_data] * n_copy)

        # Extract center 5 seconds
        start_idx = max(0, int(len(audio_data) / 2 - target_samples / 2))
        end_idx = min(len(audio_data), start_idx + target_samples)
        center_audio = audio_data[start_idx:end_idx]

        if len(center_audio) < target_samples:
            center_audio = np.pad(center_audio, 
                                 (0, target_samples - len(center_audio)), 
                                 mode='constant')

        mel_spec = audio2melspec(center_audio, cfg)
        
        if mel_spec.shape != cfg.TARGET_SHAPE:
            mel_spec = cv2.resize(mel_spec, cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)

        return mel_spec.astype(np.float32)
        
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

def generate_spectrograms(df, cfg):
    """Generate spectrograms from audio files"""
    print("Generating mel spectrograms from audio files...")
    start_time = time.time()

    all_bird_data = {}
    errors = []

    for i, row in tqdm(df.iterrows(), total=len(df)):
        if cfg.debug and i >= 1000:
            break
        
        try:
            samplename = row['samplename']
            filepath = row['filepath']
            
            mel_spec = process_audio_file(filepath, cfg)
            
            if mel_spec is not None:
                all_bird_data[samplename] = mel_spec
            
        except Exception as e:
            print(f"Error processing {row.filepath}: {e}")
            errors.append((row.filepath, str(e)))

    end_time = time.time()
    print(f"Processing completed in {end_time - start_time:.2f} seconds")
    print(f"Successfully processed {len(all_bird_data)} files out of {len(df)}")
    print(f"Failed to process {len(errors)} files")
    
    return all_bird_data

# **Dataset Prep and Data Augmentations**

In [None]:
class BirdCLEFDatasetFromNPY(Dataset):
    def __init__(self, df, cfg, spectrograms=None, mode="train"):
        self.df = df
        self.cfg = cfg
        self.mode = mode

        self.spectrograms = spectrograms
        
        taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)
        self.species_ids = taxonomy_df['primary_label'].tolist()
        self.num_classes = len(self.species_ids)
        self.label_to_idx = {label: idx for idx, label in enumerate(self.species_ids)}

        if 'filepath' not in self.df.columns:
            self.df['filepath'] = self.cfg.train_datadir + '/' + self.df.filename
        
        if 'samplename' not in self.df.columns:
            self.df['samplename'] = self.df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])

        sample_names = set(self.df['samplename'])
        if self.spectrograms:
            found_samples = sum(1 for name in sample_names if name in self.spectrograms)
            print(f"Found {found_samples} matching spectrograms for {mode} dataset out of {len(self.df)} samples")
        
        if cfg.debug:
            self.df = self.df.sample(min(1000, len(self.df)), random_state=cfg.seed).reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        samplename = row['samplename']
        spec = None

        if self.spectrograms and samplename in self.spectrograms:
            spec = self.spectrograms[samplename]
        elif not self.cfg.LOAD_DATA:
            spec = process_audio_file(row['filepath'], self.cfg)

        if spec is None:
            spec = np.zeros(self.cfg.TARGET_SHAPE, dtype=np.float32)
            if self.mode == "train":  # Only print warning during training
                print(f"Warning: Spectrogram for {samplename} not found and could not be generated")

        spec = torch.tensor(spec, dtype=torch.float32).unsqueeze(0)  # Add channel dimension

        if self.mode == "train" and random.random() < self.cfg.aug_prob:
            spec = self.apply_spec_augmentations(spec)
        
        target = self.encode_label(row['primary_label'])
        
        if 'secondary_labels' in row and row['secondary_labels'] not in [[''], None, np.nan]:
            if isinstance(row['secondary_labels'], str):
                secondary_labels = eval(row['secondary_labels'])
            else:
                secondary_labels = row['secondary_labels']
            
            for label in secondary_labels:
                if label in self.label_to_idx:
                    target[self.label_to_idx[label]] = 1.0
        
        return {
            'melspec': spec, 
            'target': torch.tensor(target, dtype=torch.float32),
            'filename': row['filename']
        }
    
    def apply_spec_augmentations(self, spec):
        """Apply augmentations to spectrogram"""
    
        # Time masking (horizontal stripes)
        if random.random() < 0.5:
            num_masks = random.randint(1, 3)
            for _ in range(num_masks):
                width = random.randint(5, 20)
                start = random.randint(0, spec.shape[2] - width)
                spec[0, :, start:start+width] = 0
        
        # Frequency masking (vertical stripes)
        if random.random() < 0.5:
            num_masks = random.randint(1, 3)
            for _ in range(num_masks):
                height = random.randint(5, 20)
                start = random.randint(0, spec.shape[1] - height)
                spec[0, start:start+height, :] = 0
        
        # Random brightness/contrast
        if random.random() < 0.5:
            gain = random.uniform(0.8, 1.2)
            bias = random.uniform(-0.1, 0.1)
            spec = spec * gain + bias
            spec = torch.clamp(spec, 0, 1) 
            
        return spec
    
    def encode_label(self, label):
        """Encode label to one-hot vector"""
        target = np.zeros(self.num_classes)
        if label in self.label_to_idx:
            target[self.label_to_idx[label]] = 1.0
        return target

In [None]:
def collate_fn(batch):
    """Custom collate function to handle different sized spectrograms"""
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return {}
        
    result = {key: [] for key in batch[0].keys()}
    
    for item in batch:
        for key, value in item.items():
            result[key].append(value)
    
    for key in result:
        if key == 'target' and isinstance(result[key][0], torch.Tensor):
            result[key] = torch.stack(result[key])
        elif key == 'melspec' and isinstance(result[key][0], torch.Tensor):
            shapes = [t.shape for t in result[key]]
            if len(set(str(s) for s in shapes)) == 1:
                result[key] = torch.stack(result[key])
    
    return result

# **BiLSTM Model**

In [None]:
class GlobalAttentionPool(nn.Module):
    """
    Learnable attention pooling:
     - projects spatial features to K,V
     - uses a single learnable query to attend over H×W tokens
    """
    def __init__(self, input_dim, num_heads=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(input_dim, num_heads, batch_first=True)
        self.query = nn.Parameter(torch.randn(1, 1, input_dim))
        
    def forward(self, x):
        # x shape: [batch_size, seq_len, input_dim]
        batch_size = x.size(0)
        query = self.query.expand(batch_size, -1, -1)  # [batch_size, 1, input_dim]
        
        # Apply attention
        attn_output, _ = self.attention(query, x, x)  # [batch_size, 1, input_dim]
        
        # Squeeze the sequence dimension
        return attn_output.squeeze(1)  # [batch_size, input_dim]
class BiLSTM(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        # load label count
        taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
        cfg.num_classes = len(taxonomy_df)
        
        # Define input feature dimension
        self.input_dim = cfg.input_dim  # Add this to your config
        
        # BiLSTM layers
        self.lstm_hidden_size = getattr(cfg, 'lstm_hidden_size', 640)
        self.lstm_num_layers = getattr(cfg, 'lstm_num_layers', 3)
        self.lstm_dropout = getattr(cfg, 'lstm_dropout', 0.4)
        
        self.bilstm = nn.LSTM(
            input_size=self.input_dim,
            hidden_size=self.lstm_hidden_size,
            num_layers=self.lstm_num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=self.lstm_dropout if self.lstm_num_layers > 1 else 0
        )
        
        # Feature dimension after BiLSTM
        bilstm_output_dim = self.lstm_hidden_size * 2  # *2 because bidirectional
        self.feat_dim = bilstm_output_dim
        
        # Attention pooling
        self.pool = GlobalAttentionPool(bilstm_output_dim, num_heads=8)
        
        # Projection head
        hidden_dim = bilstm_output_dim // 2
        self.proj_head = nn.Sequential(
            nn.Linear(bilstm_output_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(p=getattr(cfg, 'dropout_rate', 0.5)),
        )
        
        # Classifier
        self.classifier = nn.Linear(hidden_dim, cfg.num_classes)
        
        # Mixup config
        self.mixup_enabled = getattr(cfg, 'mixup_alpha', 0) > 0
        if self.mixup_enabled:
            self.mixup_alpha = cfg.mixup_alpha
            
    def forward(self, x, targets=None):
        batch_size = x.size(0)
        
        # For BiLSTM, x should be [batch_size, sequence_length, features]
        # Check if reshaping is needed based on input dimensions
        if len(x.shape) == 4:  # [B, C, H, W] format (like a spectrogram)
            # Reshape for LSTM: [batch_size, seq_len, features]
            # Assuming x is [B, C, H, W], reshape to [B, H, W*C] or similar
            # This depends on how your data is structured
            x = x.permute(0, 2, 1, 3).contiguous()  # [B, H, C, W]
            x = x.view(batch_size, x.size(1), -1)  # [B, H, C*W]
        
        # Apply BiLSTM
        lstm_out, _ = self.bilstm(x)  # [B, seq_len, hidden_size*2]
        
        # Apply attention pooling
        pooled = self.pool(lstm_out)  # [B, hidden_size*2]
        
        # Projection head
        proj = self.proj_head(pooled)
        
        # Mixup logic if needed
        if self.training and self.mixup_enabled and targets is not None:
            lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
            index = torch.randperm(batch_size).to(x.device)
            
            mixed_proj = lam * proj + (1 - lam) * proj[index, :]
            logits = self.classifier(mixed_proj)
            
            return logits, lam, index
        
        # Standard forward pass
        logits = self.classifier(proj)
        
        if targets is not None:
            return logits, None, None
        else:
            return logits

# **Training Tools**

In [None]:
def get_optimizer(model, cfg):
  
    if cfg.optimizer == 'Adam':
        optimizer = optim.Adam(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay
        )
    elif cfg.optimizer == 'AdamW':
        optimizer = optim.AdamW(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.weight_decay
        )
    elif cfg.optimizer == 'SGD':
        optimizer = optim.SGD(
            model.parameters(),
            lr=cfg.lr,
            momentum=0.9,
            weight_decay=cfg.weight_decay
        )
    else:
        raise NotImplementedError(f"Optimizer {cfg.optimizer} not implemented")
        
    return optimizer

def get_scheduler(optimizer, cfg):
   
    if cfg.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=cfg.T_max,
            eta_min=cfg.min_lr
        )
    elif cfg.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=2,
            min_lr=cfg.min_lr,
            verbose=True
        )
    elif cfg.scheduler == 'StepLR':
        scheduler = lr_scheduler.StepLR(
            optimizer,
            step_size=cfg.epochs // 3,
            gamma=0.5
        )
    elif cfg.scheduler == 'OneCycleLR':
        scheduler = None  
    else:
        scheduler = None
        
    return scheduler

def get_criterion(cfg):
 
    if cfg.criterion == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss()
    else:
        raise NotImplementedError(f"Criterion {cfg.criterion} not implemented")
        
    return criterion

# **Training Loop and Training**

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device, scheduler=None):
    
    model.train()
    losses = []
    all_targets = []
    all_outputs = []
    
    pbar = tqdm(enumerate(loader), total=len(loader), desc="Training")
    
    for step, batch in pbar:
    
        if isinstance(batch['melspec'], list):
            batch_outputs = []
            batch_losses = []
            
            for i in range(len(batch['melspec'])):
                inputs = batch['melspec'][i].unsqueeze(0).to(device)
                target = batch['target'][i].unsqueeze(0).to(device)
                
                optimizer.zero_grad()
                output = model(inputs)
                loss = criterion(output, target)
                loss.backward()
                
                batch_outputs.append(output.detach().cpu())
                batch_losses.append(loss.item())
            
            optimizer.step()
            outputs = torch.cat(batch_outputs, dim=0).numpy()
            loss = np.mean(batch_losses)
            targets = batch['target'].numpy()
            
        else:
            inputs = batch['melspec'].to(device)
            targets = batch['target'].to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            if isinstance(outputs, tuple):
                outputs, loss = outputs  
            else:
                loss = criterion(outputs, targets)
                
            loss.backward()
            optimizer.step()
            
            outputs = outputs.detach().cpu().numpy()
            targets = targets.detach().cpu().numpy()
        
        if scheduler is not None and isinstance(scheduler, lr_scheduler.OneCycleLR):
            scheduler.step()
            
        all_outputs.append(outputs)
        all_targets.append(targets)
        losses.append(loss if isinstance(loss, float) else loss.item())
        
        pbar.set_postfix({
            'train_loss': np.mean(losses[-10:]) if losses else 0,
            'lr': optimizer.param_groups[0]['lr']
        })
    
    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)
    auc = calculate_auc(all_targets, all_outputs)
    avg_loss = np.mean(losses)
    
    return avg_loss, auc

def validate(model, loader, criterion, device):
   
    model.eval()
    losses = []
    all_targets = []
    all_outputs = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            if isinstance(batch['melspec'], list):
                batch_outputs = []
                batch_losses = []
                
                for i in range(len(batch['melspec'])):
                    inputs = batch['melspec'][i].unsqueeze(0).to(device)
                    target = batch['target'][i].unsqueeze(0).to(device)
                    
                    output = model(inputs)
                    loss = criterion(output, target)
                    
                    batch_outputs.append(output.detach().cpu())
                    batch_losses.append(loss.item())
                
                outputs = torch.cat(batch_outputs, dim=0).numpy()
                loss = np.mean(batch_losses)
                targets = batch['target'].numpy()
                
            else:
                inputs = batch['melspec'].to(device)
                targets = batch['target'].to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                outputs = outputs.detach().cpu().numpy()
                targets = targets.detach().cpu().numpy()
            
            all_outputs.append(outputs)
            all_targets.append(targets)
            losses.append(loss if isinstance(loss, float) else loss.item())
    
    all_outputs = np.concatenate(all_outputs)
    all_targets = np.concatenate(all_targets)
    
    auc = calculate_auc(all_targets, all_outputs)
    avg_loss = np.mean(losses)
    
    return avg_loss, auc

def calculate_auc(targets, outputs):
  
    num_classes = targets.shape[1]
    aucs = []
    
    probs = 1 / (1 + np.exp(-outputs))
    
    for i in range(num_classes):
        
        if np.sum(targets[:, i]) > 0:
            class_auc = roc_auc_score(targets[:, i], probs[:, i])
            aucs.append(class_auc)
    
    return np.mean(aucs) if aucs else 0.0

In [None]:
def run_training(df, cfg):
    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)
    cfg.num_classes = taxonomy_df.shape[0]

    if cfg.debug:
        cfg.update_debug_settings()

    # load precomputed features or prepare on‑the‑fly
    spectrograms = None
    if cfg.LOAD_DATA:
        try:
            spectrograms = np.load(cfg.spectrogram_npy, allow_pickle=True).item()
            print(f"Loaded {len(spectrograms)} pre-computed mel spectrograms")
        except Exception as e:
            print(f"Error loading NPY: {e}\nSwitching to on‑the‑fly.")
            cfg.LOAD_DATA = False

    if not cfg.LOAD_DATA:
        if 'filepath' not in df.columns:
            df['filepath'] = cfg.train_datadir + '/' + df.filename
        if 'samplename' not in df.columns:
            df['samplename'] = df.filename.map(
                lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0]
            )

    skf = StratifiedKFold(n_splits=cfg.n_fold, shuffle=True, random_state=cfg.seed)
    best_scores = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['primary_label'])):
        if fold not in cfg.selected_folds:
            continue

        print(f'\n{"="*20} Fold {fold} {"="*20}')
        train_df, val_df = df.iloc[train_idx], df.iloc[val_idx]

        # dataloaders
        train_loader = DataLoader(
            BirdCLEFDatasetFromNPY(train_df, cfg, spectrograms, mode='train'),
            batch_size=cfg.batch_size, shuffle=True,
            num_workers=cfg.num_workers, pin_memory=True,
            collate_fn=collate_fn, drop_last=True
        )
        val_loader = DataLoader(
            BirdCLEFDatasetFromNPY(val_df, cfg, spectrograms, mode='valid'),
            batch_size=cfg.batch_size, shuffle=False,
            num_workers=cfg.num_workers, pin_memory=True,
            collate_fn=collate_fn
        )

        # model / optimizer / criterion / scheduler
        model     = BiLSTM(cfg).to(cfg.device)
        optimizer = get_optimizer(model, cfg)
        criterion = get_criterion(cfg)
        if cfg.scheduler == 'OneCycleLR':
            scheduler = lr_scheduler.OneCycleLR(
                optimizer, max_lr=cfg.lr,
                steps_per_epoch=len(train_loader),
                epochs=cfg.epochs, pct_start=0.1
            )
        else:
            scheduler = get_scheduler(optimizer, cfg)

        # decide checkpoint path
        ckpt_path = getattr(cfg, 'checkpoint_path', f"checkpoint_fold{fold}.pth")

        # resume logic: load epoch, optimizer & optionally best_auc
        if cfg.resume and os.path.isfile(ckpt_path):
            # make load_checkpoint return best_auc as well:
            model, optimizer, start_epoch, _, best_auc = load_checkpoint(model, optimizer, ckpt_path)
            print(f"→ Resumed Fold {fold} from epoch {start_epoch}")
        else:
            start_epoch = 0
            best_auc = 0.0

        # if we have no saved best_auc, do an initial validate pass
        if start_epoch >= cfg.epochs or best_auc == 0.0:
            # run one validation so best_auc isn't zero
            _, best_auc = validate(model, val_loader, criterion, cfg.device)
            print(f"→ Starting best AUC = {best_auc:.4f}")

        best_epoch = start_epoch
        
        for epoch in range(start_epoch, cfg.epochs):
            print(f"\nEpoch {epoch+1}/{cfg.epochs}")
            train_loss, train_auc = train_one_epoch(
                model, train_loader, optimizer, criterion,
                cfg.device,
                scheduler if isinstance(scheduler, lr_scheduler.OneCycleLR) else None
            )
            val_loss, val_auc = validate(model, val_loader, criterion, cfg.device)

            # step non‑OneCycle schedulers
            if scheduler is not None and not isinstance(scheduler, lr_scheduler.OneCycleLR):
                if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_loss)
                else:
                    scheduler.step()

            print(f"Train Loss: {train_loss:.4f}, Train AUC: {train_auc:.4f}")
            print(f" Val  Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}")

            # save best
            if val_auc > best_auc:
                best_auc, best_epoch = val_auc, epoch + 1
                print(f"New best AUC: {best_auc:.4f} at epoch {best_epoch}")
                # 1) torch.save your full “best model” bundle
                torch.save({
                    'model_state_dict':   model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': getattr(scheduler, 'state_dict', lambda: None)(),
                    'epoch':               epoch,
                    'val_auc':             val_auc,
                    'train_auc':           train_auc,
                    'cfg':                 cfg
                }, f"model_fold{fold}.pth")
                # 2) save lightweight checkpoint for resume
                save_checkpoint(model, optimizer, epoch, val_loss, best_auc, path=ckpt_path)

        best_scores.append(best_auc)
        print(f"Fold {fold} best → AUC {best_auc:.4f} @ epoch {best_epoch}")

        # cleanup
        del model, optimizer, scheduler, train_loader, val_loader
        torch.cuda.empty_cache()
        gc.collect()
    
    print("\n" + "="*60)
    print("Cross-Validation Results:")
    for fold, score in enumerate(best_scores):
        print(f"Fold {cfg.selected_folds[fold]}: {score:.4f}")
    print(f"Mean AUC: {np.mean(best_scores):.4f}")
    print("="*60)

In [None]:
if __name__ == "__main__":
    import time
    
    print("\nLoading training data...")
    train_df = pd.read_csv(cfg.train_csv)
    taxonomy_df = pd.read_csv(cfg.taxonomy_csv)

    print("\nStarting training...")
    print(f"LOAD_DATA is set to {cfg.LOAD_DATA}")
    if cfg.LOAD_DATA:
        print("Using pre-computed mel spectrograms from NPY file")
    else:
        print("Will generate spectrograms on-the-fly during training")
        
    cfg.resume          = True
    cfg.checkpoint_path = "checkpoint.pth"
    print(f"Resuming from checkpoint? {cfg.resume}")
    run_training(train_df, cfg)
    
    print("\nTraining complete!")


Loading training data...

Starting training...
LOAD_DATA is set to True
Using pre-computed mel spectrograms from NPY file
Resuming from checkpoint? True
Loaded 28564 pre-computed mel spectrograms

Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples


Validation:   0%|          | 0/358 [00:00<?, ?it/s]

→ Starting best AUC = 0.4829

Epoch 1/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0808, Train AUC: 0.5039
 Val  Loss: 0.0310, Val AUC: 0.5313
New best AUC: 0.5313 at epoch 1

Epoch 2/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0315, Train AUC: 0.5145
 Val  Loss: 0.0310, Val AUC: 0.5424
New best AUC: 0.5424 at epoch 2

Epoch 3/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1459104899e0>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x1459104899e0>Traceback (most recent call last):
  File "/home/huynhw/.local/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()

Traceback (most recent call last):
  File "/home/huynhw/.local/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
  File "/home/huynhw/.local/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    if w.is_alive():
     self._shutdown_workers()
  File "/home/huynhw/.local/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^ ^^^^^
  File "/opt/apps/software/lang/Anaconda3/2024.02-1/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid()

Train Loss: 0.0313, Train AUC: 0.5094
 Val  Loss: 0.0308, Val AUC: 0.5435
New best AUC: 0.5435 at epoch 3

Epoch 4/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0311, Train AUC: 0.5094
 Val  Loss: 0.0308, Val AUC: 0.5577
New best AUC: 0.5577 at epoch 4

Epoch 5/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0310, Train AUC: 0.5153
 Val  Loss: 0.0307, Val AUC: 0.5582
New best AUC: 0.5582 at epoch 5

Epoch 6/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0310, Train AUC: 0.5225
 Val  Loss: 0.0307, Val AUC: 0.5612
New best AUC: 0.5612 at epoch 6

Epoch 7/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0309, Train AUC: 0.5297
 Val  Loss: 0.0306, Val AUC: 0.5811
New best AUC: 0.5811 at epoch 7

Epoch 8/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0309, Train AUC: 0.5448
 Val  Loss: 0.0306, Val AUC: 0.5707

Epoch 9/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1459104899e0>
Traceback (most recent call last):
  File "/home/huynhw/.local/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/huynhw/.local/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/opt/apps/software/lang/Anaconda3/2024.02-1/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


Train Loss: 0.0308, Train AUC: 0.5476
 Val  Loss: 0.0306, Val AUC: 0.5681

Epoch 10/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0308, Train AUC: 0.5427
 Val  Loss: 0.0306, Val AUC: 0.5659
Fold 0 best → AUC 0.5811 @ epoch 7

Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples
→ Resumed Fold 1 from epoch 7

Epoch 8/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0309, Train AUC: 0.5383
 Val  Loss: 0.0304, Val AUC: 0.5845
New best AUC: 0.5845 at epoch 8

Epoch 9/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0309, Train AUC: 0.5318
 Val  Loss: 0.0304, Val AUC: 0.5869
New best AUC: 0.5869 at epoch 9

Epoch 10/10


Training:   0%|          | 0/1428 [00:00<?, ?it/s]

Validation:   0%|          | 0/358 [00:00<?, ?it/s]

Train Loss: 0.0308, Train AUC: 0.5368
 Val  Loss: 0.0304, Val AUC: 0.5941
New best AUC: 0.5941 at epoch 10
Fold 1 best → AUC 0.5941 @ epoch 10

Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples
→ Resumed Fold 2 from epoch 10


Validation:   0%|          | 0/358 [00:00<?, ?it/s]

→ Starting best AUC = 0.6023
Fold 2 best → AUC 0.6023 @ epoch 10

Found 22851 matching spectrograms for train dataset out of 22851 samples
Found 5713 matching spectrograms for valid dataset out of 5713 samples
→ Resumed Fold 3 from epoch 10


Validation:   0%|          | 0/358 [00:00<?, ?it/s]

→ Starting best AUC = 0.6097
Fold 3 best → AUC 0.6097 @ epoch 10

Found 22852 matching spectrograms for train dataset out of 22852 samples
Found 5712 matching spectrograms for valid dataset out of 5712 samples
→ Resumed Fold 4 from epoch 10


Validation:   0%|          | 0/357 [00:00<?, ?it/s]

→ Starting best AUC = 0.5957
Fold 4 best → AUC 0.5957 @ epoch 10

Cross-Validation Results:
Fold 0: 0.5811
Fold 1: 0.5941
Fold 2: 0.6023
Fold 3: 0.6097
Fold 4: 0.5957
Mean AUC: 0.5966

Training complete!
