# BirdCLEF+ 2025: Audio Spectrogram Transformer (AST) Model

This notebook implements the Audio Transformer component of our ensemble approach for the BirdCLEF+ 2025 competition. We leverage the pretrained Audio Spectrogram Transformer (AST) model, which has shown excellent performance on audio classification tasks.

## Strategy
- Use the pretrained AST model as a foundation
- Fine-tune the model on the BirdCLEF+ 2025 dataset
- Implement cross-validation to ensure robust performance
- Add appropriate audio augmentations to handle limited data
- Generate predictions that can be ensembled with other models

## Setup Google Colab Environment

First, we'll set up the Colab environment by mounting Google Drive and installing required packages.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required packages
!pip install -q transformers timm torchaudio librosa matplotlib scikit-learn

# Set up paths for Google Drive
import os
DATA_DIR = '/content/drive/MyDrive/birdclef-2025-data'
MODEL_SAVE_DIR = '/content/drive/MyDrive/fp-561-models'
    
# Create model save directory if it doesn't exist
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
    
print(f"Data directory exists: {os.path.exists(DATA_DIR)}")
print(f"Model save directory: {MODEL_SAVE_DIR}")

## Import Required Libraries

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import librosa
import librosa.display
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import transformers
from transformers import ASTModel, ASTForAudioClassification, AutoFeatureExtractor
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.metrics import roc_auc_score
import warnings
import random
from tqdm.notebook import tqdm
import glob

# Ignore warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## Set Configuration Parameters

In [None]:
# Define paths
TRAIN_AUDIO_DIR = os.path.join(DATA_DIR, 'train_audio')
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train.csv')
TAXONOMY_PATH = os.path.join(DATA_DIR, 'taxonomy.csv')
SOUNDSCAPE_DIR = os.path.join(DATA_DIR, 'train_soundscapes')

# Configuration
CONFIG = {
    'seed': 42,
    'n_folds': 5,
    'model_name': 'MIT/ast-finetuned-audioset-10-10-0.4593',  # Pretrained AST model
    'num_epochs': 20,
    'batch_size': 8,  # Smaller batch size to fit in memory
    'learning_rate': 1e-5,  # Low LR for fine-tuning
    'weight_decay': 1e-4,
    'max_audio_len': 10,  # seconds
    'sample_rate': 32000,  # BirdCLEF dataset sample rate
    'early_stopping_patience': 5,
    'target_sample_rate': 16000,  # AST model expects 16kHz
    'use_mixup': True,
    'mixup_alpha': 0.2,
    'audio_max_length': 10,
    'model_save_dir': MODEL_SAVE_DIR  # Save models to Google Drive
}

# Set random seeds
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [None]:
# Check if the input files and directories exist
print(f"Train CSV exists: {os.path.exists(TRAIN_CSV_PATH)}")
print(f"Taxonomy CSV exists: {os.path.exists(TAXONOMY_PATH)}")
print(f"Train audio directory exists: {os.path.exists(TRAIN_AUDIO_DIR)}")

# Check how many audio files are available
if os.path.exists(TRAIN_AUDIO_DIR):
    audio_files = []
    for root, dirs, files in os.walk(TRAIN_AUDIO_DIR):
        audio_files.extend([os.path.join(root, file) for file in files if file.endswith('.ogg')])
    print(f"Found {len(audio_files)} audio files in {TRAIN_AUDIO_DIR}")

## Load and Process Data

In [None]:
# Load train data and taxonomy
df_train = pd.read_csv(TRAIN_CSV_PATH)
df_taxonomy = pd.read_csv(TAXONOMY_PATH)

# Display dataset info
print(f"Training data shape: {df_train.shape}")
print(f"Number of unique species: {df_train['primary_label'].nunique()}")

# Look at class distribution
class_counts = df_train['primary_label'].value_counts()
print(f"Class distribution - min: {class_counts.min()}, max: {class_counts.max()}, mean: {class_counts.mean():.2f}")

# View a few rows
df_train.head()

## Create Label Encoder

In [None]:
# Get all unique species labels (primary + secondary)
def get_all_species_labels():
    primary_labels = df_train['primary_label'].unique().tolist()
    
    # Extract secondary labels
    secondary_labels = []
    for labels in df_train['secondary_labels'].fillna('[]'):
        try:
            secondary_labels.extend(eval(labels))
        except:
            pass
    
    # Combine all unique labels
    all_labels = sorted(set(primary_labels + secondary_labels))
    
    # Create label encoders
    label_to_idx = {label: idx for idx, label in enumerate(all_labels)}
    idx_to_label = {idx: label for label, idx in label_to_idx.items()}
    
    return all_labels, label_to_idx, idx_to_label

all_labels, label_to_idx, idx_to_label = get_all_species_labels()
num_classes = len(all_labels)

print(f"Total number of unique species: {num_classes}")

## Audio Processing and Dataset

In [None]:
# Load the feature extractor for the AST model
feature_extractor = AutoFeatureExtractor.from_pretrained(CONFIG['model_name'])

# Audio augmentation functions
def random_power(audio, power_min=0.8, power_max=1.2):
    # Apply random power to audio (volume augmentation)
    power = torch.tensor(random.uniform(power_min, power_max))
    return audio ** power

def add_white_noise(audio, noise_factor_min=0.001, noise_factor_max=0.01):
    # Add random white noise
    noise_factor = random.uniform(noise_factor_min, noise_factor_max)
    noise = torch.randn_like(audio) * noise_factor
    return audio + noise

def random_time_shift(audio, shift_factor=0.2):
    # Randomly shift audio in time (cyclic)
    shift = int(random.uniform(-shift_factor, shift_factor) * len(audio))
    return torch.roll(audio, shift)

def mixup(audio1, audio2, labels1, labels2, alpha=0.2):
    # Mixup augmentation - blend two audio samples
    lambda_param = np.random.beta(alpha, alpha)
    lambda_param = max(lambda_param, 1-lambda_param)
    
    mixed_audio = lambda_param * audio1 + (1 - lambda_param) * audio2
    mixed_labels = lambda_param * labels1 + (1 - lambda_param) * labels2
    
    return mixed_audio, mixed_labels

# Custom dataset class for audio
class BirdCLEFDataset(Dataset):
    def __init__(self, df, label_to_idx, audio_dir, is_train=True):
        self.df = df.reset_index(drop=True)
        self.audio_dir = audio_dir
        self.label_to_idx = label_to_idx
        self.is_train = is_train
        self.target_sample_rate = CONFIG['target_sample_rate']
        self.max_audio_len = CONFIG['audio_max_length']
        self.num_classes = len(label_to_idx)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        audio_path = os.path.join(self.audio_dir, row['filename'])
        
        # Load audio
        try:
            audio, sample_rate = librosa.load(audio_path, sr=self.target_sample_rate, mono=True)
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            # Use empty audio as fallback
            audio = np.zeros(self.target_sample_rate * self.max_audio_len)
        
        # Process the audio to fixed length
        max_samples = int(self.target_sample_rate * self.max_audio_len)
        
        if len(audio) > max_samples:
            # For training, take a random segment
            if self.is_train:
                start = np.random.randint(0, len(audio) - max_samples)
                audio = audio[start:start + max_samples]
            else:
                # For validation/test, take the beginning
                audio = audio[:max_samples]
        else:
            # If audio is too short, pad with zeros
            padding = max_samples - len(audio)
            audio = np.pad(audio, (0, padding), mode='constant')
        
        # Convert to tensor
        audio_tensor = torch.tensor(audio, dtype=torch.float32)
        
        # Apply augmentations during training
        if self.is_train:
            if random.random() > 0.5:
                audio_tensor = random_power(audio_tensor)
            if random.random() > 0.7:
                audio_tensor = add_white_noise(audio_tensor)
            if random.random() > 0.5:
                audio_tensor = random_time_shift(audio_tensor)
        
        # Create one-hot encoded target
        primary_label = row['primary_label']
        primary_idx = self.label_to_idx.get(primary_label, 0)
        
        target = torch.zeros(self.num_classes)
        target[primary_idx] = 1.0
        
        # Add secondary labels if available
        if 'secondary_labels' in row and isinstance(row['secondary_labels'], str) and len(row['secondary_labels']) > 2:
            try:
                secondary_labels = eval(row['secondary_labels'])
                for label in secondary_labels:
                    if label in self.label_to_idx:
                        target[self.label_to_idx[label]] = 1.0
            except:
                pass
        
        # Process with AST feature extractor
        audio_inputs = feature_extractor(
            audio_tensor, 
            sampling_rate=self.target_sample_rate, 
            return_tensors="pt",
            max_length=feature_extractor.max_length if hasattr(feature_extractor, 'max_length') else None,
            padding="max_length" if hasattr(feature_extractor, 'max_length') else "do_not_pad",
            truncation=True
        )
        
        return {
            'input_values': audio_inputs.input_values.squeeze(),
            'attention_mask': audio_inputs.attention_mask.squeeze() if hasattr(audio_inputs, 'attention_mask') else None,
            'target': target,
            'primary_idx': primary_idx
        }

## Data Loaders & Sampling

In [None]:
# Create collate function that handles batching
def collate_fn(batch):
    input_values = torch.stack([item['input_values'] for item in batch])
    targets = torch.stack([item['target'] for item in batch])
    primary_idx = torch.tensor([item['primary_idx'] for item in batch])
    
    # Handle attention_mask (may be None for some models)
    if batch[0]['attention_mask'] is not None:
        attention_mask = torch.stack([item['attention_mask'] for item in batch])
    else:
        attention_mask = None
    
    return {
        'input_values': input_values,
        'attention_mask': attention_mask,
        'target': targets,
        'primary_idx': primary_idx
    }

# Data loader with mixup
class MixupDataLoader:
    def __init__(self, dataset, batch_size, shuffle=True, num_workers=2, mixup_alpha=0.2):
        self.dataloader = DataLoader(
            dataset, batch_size=batch_size, shuffle=shuffle, 
            num_workers=num_workers, collate_fn=collate_fn
        )
        self.mixup_alpha = mixup_alpha
        
    def __iter__(self):
        for batch in self.dataloader:
            if CONFIG['use_mixup']:
                # Apply mixup to half of the batches
                if np.random.random() > 0.5:
                    # Create a shuffled index
                    batch_size = len(batch['target'])
                    shuffled_idx = torch.randperm(batch_size)
                    
                    # Get lambda parameter for mixup
                    lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
                    lam = max(lam, 1-lam)  # Ensure the stronger signal has a higher weight
                    
                    # Mix inputs
                    mixed_input = lam * batch['input_values'] + (1 - lam) * batch['input_values'][shuffled_idx]
                    batch['input_values'] = mixed_input
                    
                    # Mix targets
                    mixed_target = lam * batch['target'] + (1 - lam) * batch['target'][shuffled_idx]
                    batch['target'] = mixed_target
                    
            yield batch
    
    def __len__(self):
        return len(self.dataloader)

## Define the Model

In [None]:
class ASTClassifier(nn.Module):
    def __init__(self, model_name, num_classes, freeze_base=True):
        super().__init__()
        
        # Load the pretrained AST model
        self.base_model = ASTModel.from_pretrained(model_name)
        
        # Get the hidden size
        hidden_size = self.base_model.config.hidden_size
        
        # Freeze the base model if specified
        if freeze_base:
            for param in self.base_model.parameters():
                param.requires_grad = False
                
        # Create new classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size // 2, num_classes)
        )
        
    def forward(self, input_values, attention_mask=None):
        outputs = self.base_model(
            input_values=input_values,
            attention_mask=attention_mask
        )
        
        # Get the [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]
        
        # Pass through the classifier
        logits = self.classifier(pooled_output)
        
        return logits

## Training Functions

In [None]:
# Loss function
def get_loss_fn():
    return nn.BCEWithLogitsLoss(reduction='mean')

# Optimizer
def get_optimizer(model):
    return torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )

# Learning rate scheduler
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    return transformers.get_cosine_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

# Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    
    all_targets = []
    all_preds = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            input_values = batch['input_values'].to(device)
            attention_mask = batch['attention_mask'].to(device) if batch['attention_mask'] is not None else None
            targets = batch['target'].numpy()
            
            outputs = model(input_values, attention_mask)
            preds = torch.sigmoid(outputs).cpu().numpy()
            
            all_targets.append(targets)
            all_preds.append(preds)
    
    all_targets = np.vstack(all_targets)
    all_preds = np.vstack(all_preds)
    
    # Compute ROC-AUC per class and then macro average
    aucs = []
    for i in range(all_targets.shape[1]):
        # Skip if no positive samples
        if np.sum(all_targets[:, i]) > 0:
            aucs.append(roc_auc_score(all_targets[:, i], all_preds[:, i]))
    
    mean_auc = np.mean(aucs) if len(aucs) > 0 else 0
    return mean_auc, all_preds, all_targets

# Training function
def train_one_epoch(model, dataloader, optimizer, scheduler, loss_fn, device):
    model.train()
    
    running_loss = 0.0
    
    for batch in tqdm(dataloader, desc="Training", leave=False):
        optimizer.zero_grad()
        
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device) if batch['attention_mask'] is not None else None
        targets = batch['target'].to(device)
        
        outputs = model(input_values, attention_mask)
        loss = loss_fn(outputs, targets)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()
    
    return running_loss / len(dataloader)

## Cross-Validation and Model Training

In [None]:
# Create stratified folds
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['seed'])
df_train['fold'] = -1

# Assign folds
for fold, (train_idx, val_idx) in enumerate(skf.split(df_train, df_train['primary_label'])):
    df_train.loc[val_idx, 'fold'] = fold

print(f"Fold distribution:\n{df_train['fold'].value_counts()}")

## Train Model with Cross-Validation

In [None]:
def train_and_validate(fold=0):
    # Create train and validation datasets
    train_df = df_train[df_train['fold'] != fold].reset_index(drop=True)
    valid_df = df_train[df_train['fold'] == fold].reset_index(drop=True)
    
    print(f"Fold {fold}: Train size: {len(train_df)}, Valid size: {len(valid_df)}")
    
    # Create datasets and dataloaders
    train_dataset = BirdCLEFDataset(train_df, label_to_idx, TRAIN_AUDIO_DIR, is_train=True)
    valid_dataset = BirdCLEFDataset(valid_df, label_to_idx, TRAIN_AUDIO_DIR, is_train=False)
    
    train_loader = MixupDataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=True, 
        num_workers=2,
        mixup_alpha=CONFIG['mixup_alpha']
    )
    
    valid_loader = DataLoader(
        valid_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=False, 
        num_workers=2,
        collate_fn=collate_fn
    )
    
    # Create model
    model = ASTClassifier(
        model_name=CONFIG['model_name'], 
        num_classes=num_classes,
        freeze_base=False  # Start with full fine-tuning
    ).to(DEVICE)
    
    # Create optimizer and loss function
    optimizer = get_optimizer(model)
    loss_fn = get_loss_fn()
    
    # Learning rate scheduler
    num_training_steps = len(train_loader) * CONFIG['num_epochs']
    num_warmup_steps = int(0.1 * num_training_steps)
    scheduler = get_scheduler(optimizer, num_warmup_steps, num_training_steps)
    
    # Training loop
    best_auc = 0.0
    best_epoch = 0
    patience_counter = 0
    model_path = os.path.join(CONFIG['model_save_dir'], f"ast_fold_{fold}.pt")
    
    for epoch in range(CONFIG['num_epochs']):
        print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
        
        # Train
        train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, loss_fn, DEVICE)
        
        # Evaluate
        val_auc, val_preds, val_targets = evaluate(model, valid_loader, DEVICE)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val ROC-AUC: {val_auc:.4f}")
        
        # Save best model
        if val_auc > best_auc:
            best_auc = val_auc
            best_epoch = epoch
            patience_counter = 0
            torch.save(model.state_dict(), model_path)
            print(f"Saved best model with ROC-AUC: {best_auc:.4f} to {model_path}")
        else:
            patience_counter += 1
            print(f"No improvement for {patience_counter} epochs. Best ROC-AUC: {best_auc:.4f} at epoch {best_epoch+1}")
        
        # Early stopping
        if patience_counter >= CONFIG['early_stopping_patience']:
            print(f"Early stopping at epoch {epoch+1}")
            break
            
        # Save checkpoint every 5 epochs to prevent loss of progress in case of Colab disconnection
        if epoch > 0 and (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(CONFIG['model_save_dir'], f"ast_fold_{fold}_checkpoint_epoch_{epoch+1}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_auc': best_auc
            }, checkpoint_path)
            print(f"Saved checkpoint at epoch {epoch+1}")
    
    # Load best model
    model.load_state_dict(torch.load(model_path))
    
    return model, best_auc

## Train Models for Multiple Folds

In [None]:
# Train models for the first fold (adjust as needed for all folds)
fold_to_train = 0  # Change this to train different folds

model, best_auc = train_and_validate(fold=fold_to_train)
print(f"Fold {fold_to_train} Best ROC-AUC: {best_auc:.4f}")

# Save the final model
final_model_path = os.path.join(CONFIG['model_save_dir'], f"ast_final_fold_{fold_to_train}.pt")
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")

## Inference and Prediction

In [None]:
def load_audio_for_inference(audio_path, target_sr=16000, max_length_seconds=10):
    # Load and preprocess audio file for inference
    try:
        audio, sr = librosa.load(audio_path, sr=target_sr, mono=True)
        max_len = int(target_sr * max_length_seconds)
        
        # If audio is longer than max_length_seconds, take multiple segments
        if len(audio) > max_len:
            # Divide the audio into overlapping segments
            segments = []
            hop_length = int(max_len * 0.5)  # 50% overlap
            
            for start in range(0, len(audio) - max_len + 1, hop_length):
                segment = audio[start:start + max_len]
                segments.append(segment)
            
            # Add the last segment if it's not already included
            if len(audio) > max_len and (len(audio) - max_len) % hop_length != 0:
                segments.append(audio[-max_len:])
                
            return segments, sr
        else:
            # If audio is shorter than max_length_seconds, pad it
            if len(audio) < max_len:
                padding = max_len - len(audio)
                audio = np.pad(audio, (0, padding), mode='constant')
            
            return [audio], sr
            
    except Exception as e:
        print(f"Error loading audio {audio_path}: {e}")
        return [np.zeros(target_sr * max_length_seconds)], target_sr

def predict_audio(model, audio_path):
    model.eval()
    segments, sr = load_audio_for_inference(audio_path)
    
    all_preds = []
    
    with torch.no_grad():
        for segment in segments:
            # Convert to tensor
            audio_tensor = torch.tensor(segment, dtype=torch.float32)
            
            # Process with feature extractor
            audio_inputs = feature_extractor(
                audio_tensor, 
                sampling_rate=sr, 
                return_tensors="pt",
                max_length=feature_extractor.max_length if hasattr(feature_extractor, 'max_length') else None,
                padding="max_length" if hasattr(feature_extractor, 'max_length') else "do_not_pad",
                truncation=True
            )
            
            # Get model predictions
            input_values = audio_inputs.input_values.to(DEVICE)
            attention_mask = audio_inputs.attention_mask.to(DEVICE) if hasattr(audio_inputs, 'attention_mask') else None
            
            outputs = model(input_values, attention_mask)
            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.append(preds)
    
    # Average the predictions from all segments
    avg_preds = np.mean(all_preds, axis=0)[0]
    
    return avg_preds

# Function to create predictions for Kaggle submission
def create_predictions(model, test_dir, output_file='submission.csv'):
    # Get list of test files
    test_files = glob.glob(os.path.join(test_dir, '*.ogg'))
    
    print(f"Found {len(test_files)} test files")
    
    predictions = []
    
    for audio_file in tqdm(test_files, desc="Predicting"):
        file_id = os.path.basename(audio_file).split('.')[0]
        preds = predict_audio(model, audio_file)
        
        # Create a row for each 5-second segment
        for end_time in range(5, 65, 5):  # 5, 10, 15, ..., 60
            row_id = f"{file_id}_{end_time}"
            predictions.append([row_id] + list(preds))
    
    # Create submission DataFrame
    columns = ['row_id'] + list(idx_to_label.values())
    submission_df = pd.DataFrame(predictions, columns=columns)
    
    # Save to CSV
    submission_path = os.path.join(CONFIG['model_save_dir'], output_file)
    submission_df.to_csv(submission_path, index=False)
    print(f"Predictions saved to {submission_path}")
    
    return submission_df

## Analyze Model Performance

In [None]:
# Let's evaluate our model performance per species
def analyze_performance_per_species(fold=0):
    # Create validation dataset
    valid_df = df_train[df_train['fold'] == fold].reset_index(drop=True)
    valid_dataset = BirdCLEFDataset(valid_df, label_to_idx, TRAIN_AUDIO_DIR, is_train=False)
    valid_loader = DataLoader(
        valid_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=False, 
        num_workers=2,
        collate_fn=collate_fn
    )
    
    # Load trained model
    model = ASTClassifier(CONFIG['model_name'], num_classes).to(DEVICE)
    model_path = os.path.join(CONFIG['model_save_dir'], f"ast_fold_{fold}.pt")
    model.load_state_dict(torch.load(model_path))
    
    # Get predictions
    _, val_preds, val_targets = evaluate(model, valid_loader, DEVICE)
    
    # Calculate AUC per species
    species_aucs = {}
    for i, species_name in idx_to_label.items():
        if np.sum(val_targets[:, i]) > 0:  # Only include species with positive samples
            auc = roc_auc_score(val_targets[:, i], val_preds[:, i])
            species_aucs[species_name] = auc
    
    # Convert to DataFrame for analysis
    species_perf_df = pd.DataFrame({
        'species': list(species_aucs.keys()),
        'auc': list(species_aucs.values())
    }).sort_values('auc')
    
    # Show best and worst performing species
    print("Best performing species:")
    print(species_perf_df.tail(10))
    
    print("\nWorst performing species:")
    print(species_perf_df.head(10))
    
    # Plot AUC distribution
    plt.figure(figsize=(12, 6))
    plt.hist(species_perf_df['auc'], bins=20)
    plt.xlabel('ROC AUC')
    plt.ylabel('Number of Species')
    plt.title('Distribution of ROC AUC scores across species')
    plt.savefig(os.path.join(CONFIG['model_save_dir'], f'species_auc_distribution_fold_{fold}.png'))
    plt.show()
    
    # Save species performance to CSV
    species_perf_path = os.path.join(CONFIG['model_save_dir'], f'species_performance_fold_{fold}.csv')
    species_perf_df.to_csv(species_perf_path, index=False)
    print(f"Species performance saved to {species_perf_path}")
    
    return species_perf_df

In [None]:
# After training, run this to analyze model performance
# species_performance = analyze_performance_per_species(fold=0)

## Prepare Model for Ensemble

We'll save some metadata about the model to help with ensembling later.

In [None]:
def save_model_info(fold=0):
    """Save model metadata for later ensemble use"""
    model_info = {
        'model_type': 'AST',
        'base_model': CONFIG['model_name'],
        'fold': fold,
        'num_classes': num_classes,
        'label_to_idx': label_to_idx,
        'idx_to_label': idx_to_label,
        'input_sample_rate': CONFIG['target_sample_rate'],
        'date_trained': pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S"),
    }
    
    # Save to JSON
    import json
    with open(os.path.join(CONFIG['model_save_dir'], f'ast_model_info_fold_{fold}.json'), 'w') as f:
        # Convert non-serializable items to strings
        model_info_serializable = {k: str(v) if not isinstance(v, (dict, list, str, int, float, bool, type(None))) else v 
                                for k, v in model_info.items()}
        json.dump(model_info_serializable, f, indent=2)
    
    print(f"Model info saved for fold {fold}")
    return model_info

# Save model info after training
# model_info = save_model_info(fold=fold_to_train)

## Conclusion

This notebook implements the Audio Transformer component (using a pretrained AST model) for the BirdCLEF+ 2025 ensemble approach. Key features include:

1. Leveraging the pretrained Audio Spectrogram Transformer (AST) for audio classification
2. Comprehensive audio augmentation techniques including mixup, time shift, and noise addition
3. Multi-label classification capability for detecting multiple species in a single recording
4. Proper cross-validation strategy with stratified folds
5. Specialized inference handling for long audio files using overlapping windows
6. Integration with Google Drive for persistent storage of models and results

Next steps:
- Train models for all folds
- Combine with other models in the ensemble (CNN, CRNN, etc.)
- Further explore semi-supervised learning with the unlabeled soundscape data