In [None]:
import os
import logging
import random
import gc
import time
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses, metrics, callbacks
from tensorflow.keras.utils import to_categorical

from tqdm.auto import tqdm

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

# Check the TensorFlow version and GPU availability
print(f"TensorFlow version: {tf.__version__}")
print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")

TensorFlow version: 2.19.0
Num GPUs Available: 1


In [None]:
class CFG:
    seed = 42
    debug = False
    print_freq = 100
    num_workers = 1

    # Project root
    PROJECT_ROOT_DIR = Path.cwd().parents[2]

    OUTPUT_DIR = PROJECT_ROOT_DIR / 'data/working/'
    SPECTROGRAMS_DIR = OUTPUT_DIR / 'birdclef25-mel-spectrograms/'
    
    train_datadir = PROJECT_ROOT_DIR / 'data/raw/train_audio'
    train_csv = PROJECT_ROOT_DIR / 'data/raw/train.csv'
    taxonomy_csv = PROJECT_ROOT_DIR / 'data/raw/taxonomy.csv'
    working_df_path = PROJECT_ROOT_DIR / 'configs/work_df_w_split_info.csv'

    # Parametric input dim (32, 64, or 256)
    input_dim = 32  # Default to 32 for low memory
    
    # Derived paths and shapes
    @property
    def spectrogram_npy(self):
        return self.SPECTROGRAMS_DIR / f'birdclef2025_melspec_5sec_{self.input_dim}_{self.input_dim}.npy'
    
    TARGET_SHAPE = None  # Will be set in init
    in_channels = 1  # Single channel
    LOAD_DATA = True
    num_classes = 4  # 4 high-level classes

    epochs = 10
    batch_size = 16
    
    n_fold = 5
    # selected_folds = [0, 1, 2, 3, 4]
    selected_folds = [0]

    lr = 5e-4 
    weight_decay = 1e-5
    min_lr = 1e-6
    
    aug_prob = 0.0  # No augmentation
    mixup_alpha = 0.0  # No MixUp
    
    def __init__(self):
        self.TARGET_SHAPE = (self.input_dim, self.input_dim)
    
    def update_debug_settings(self):
        if self.debug:
            self.epochs = 2
            self.selected_folds = [0]

cfg = CFG()

In [None]:
# Enable Memory Growth and Set Seed and memory growth
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Memory growth enabled for {len(gpus)} GPU(s).")
    except RuntimeError as e:
        print(f"Error setting memory growth: {e}")

# Verify setup
print("TensorFlow version:", tf.__version__)
print("Num GPUs Available: ", len(gpus))

def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

set_seed(cfg.seed)

Memory growth enabled for 1 GPU(s).
TensorFlow version: 2.19.0
Num GPUs Available:  1


In [33]:
def data_generator(df, cfg, spectrograms, is_train=True):
    for _, row in df.iterrows():
        samplename = row['samplename']
        spec = None

        if spectrograms and samplename in spectrograms:
            spec = spectrograms[samplename]
            # Add channel dim (single channel)
            spec = np.expand_dims(spec, axis=-1)
        else:
            # Fallback to zero array if missing
            spec = np.zeros((*cfg.TARGET_SHAPE, cfg.in_channels), dtype=np.float32)

        # One-hot encode the class (4 classes)
        target = to_categorical(row['y_species_encoded'], num_classes=cfg.num_classes)
        
        yield spec, target

In [34]:
def get_vit_model(cfg):
    input_shape = (*cfg.TARGET_SHAPE, cfg.in_channels)
    inputs = layers.Input(shape=input_shape)
    
    # Patch embedding
    patch_size = 4  # Example patch size; adjust based on dim (e.g., for 32: 8 patches per dim)
    num_patches = (cfg.input_dim // patch_size) ** 2
    projection_dim = 64  # Embedding dim
    
    patches = layers.Reshape((-1, patch_size * patch_size * cfg.in_channels))(inputs)
    patch_embeddings = layers.Dense(projection_dim)(patches)
    
    # Positional embeddings
    positions = tf.range(start=0, limit=num_patches, delta=1)
    position_embeddings = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)(positions)
    x = patch_embeddings + position_embeddings
    
    # Transformer encoder blocks (e.g., 2 blocks)
    num_heads = 4
    transformer_units = [projection_dim * 2, projection_dim]  # MLP sizes
    for _ in range(2):  # Number of blocks
        # Multi-head attention
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim // num_heads)(x1, x1)
        x2 = layers.Add()([attention_output, x])
        
        # MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = layers.Dense(transformer_units[0], activation='gelu')(x3)
        x3 = layers.Dense(transformer_units[1])(x3)
        x = layers.Add()([x3, x2])
    
    # Head
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(cfg.num_classes, activation='softmax')(x)
    
    model = models.Model(inputs, outputs)
    return model

In [None]:
def run_training(df, cfg):
    
    if cfg.debug:
        cfg.update_debug_settings()
        df = df.sample(min(2000, len(df)), random_state=cfg.seed).reset_index(drop=True)

    # Load precomputed spectrograms
    print("Loading pre-computed mel spectrograms from NPY file...")
    try:
        spectrograms = np.load(cfg.spectrogram_npy, allow_pickle=True).item()
        print(f"Loaded {len(spectrograms)} pre-computed mel spectrograms")
    except Exception as e:
        print(f"Error loading pre-computed spectrograms: {e}")
        spectrograms = {}
    
    # Prepare DF
    df['samplename'] = df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])
    
    # Stratified K-Fold on y_species_encoded
    skf = StratifiedKFold(n_splits=cfg.n_fold, shuffle=True, random_state=cfg.seed)
    
    best_scores = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['y_species_encoded'])):
        if fold not in cfg.selected_folds:
            continue
            
        print(f'\n{"="*30} Fold {fold} {"="*30}')
        
        train_df = df.iloc[train_idx].reset_index(drop=True)
        val_df = df.iloc[val_idx].reset_index(drop=True)
        
        # Create tf.data datasets
        train_ds = tf.data.Dataset.from_generator(
            lambda: data_generator(train_df, cfg, spectrograms, is_train=True),
            output_signature=(
                tf.TensorSpec(shape=(*cfg.TARGET_SHAPE, cfg.in_channels), dtype=tf.float32),
                tf.TensorSpec(shape=(cfg.num_classes,), dtype=tf.float32)
            )
        )
        
        val_ds = tf.data.Dataset.from_generator(
            lambda: data_generator(val_df, cfg, spectrograms, is_train=False),
            output_signature=(
                tf.TensorSpec(shape=(*cfg.TARGET_SHAPE, cfg.in_channels), dtype=tf.float32),
                tf.TensorSpec(shape=(cfg.num_classes,), dtype=tf.float32)
            )
        )
        
        # Shuffle train (no MixUp, simple shuffle)
        train_ds = train_ds.shuffle(buffer_size=cfg.batch_size * 128).batch(cfg.batch_size).prefetch(tf.data.AUTOTUNE)
        val_ds = val_ds.batch(cfg.batch_size).prefetch(tf.data.AUTOTUNE)

        # Model, Optimizer, Loss, Metrics
        model = get_vit_model(cfg)
        
        steps_per_epoch = len(train_df) // cfg.batch_size
        
        lr_schedule = optimizers.schedules.CosineDecay(
            cfg.lr,
            decay_steps=steps_per_epoch * cfg.epochs,
            alpha=cfg.min_lr / cfg.lr
        )
        
        optimizer = optimizers.AdamW(learning_rate=lr_schedule, weight_decay=cfg.weight_decay)
        
        loss = losses.CategoricalCrossentropy()
        
        auc_metric = metrics.AUC(multi_label=False, name='auc')
        
        model.compile(optimizer=optimizer, loss=loss, metrics=[auc_metric, 'accuracy'])

        # Callbacks
        model_checkpoint = callbacks.ModelCheckpoint(
            f"full_model_fold{fold}.keras", 
            monitor='val_auc', 
            mode='max', 
            save_best_only=True,
            verbose=1
        )
        
        history = model.fit(
            train_ds,
            epochs=cfg.epochs,
            validation_data=val_ds,
            callbacks=[model_checkpoint]
        )
        
        best_fold_auc = max(history.history['val_auc'])
        best_scores.append(best_fold_auc)
        print(f"\nBest AUC for fold {fold}: {best_fold_auc:.4f}")
        
        # Clear memory
        del model, train_ds, val_ds
        tf.keras.backend.clear_session()
        gc.collect()
    
    print("\n" + "="*60)
    print("Cross-Validation Results:")
    for i, score in enumerate(best_scores):
        print(f"Fold {cfg.selected_folds[i]}: {score:.4f}")
    print(f"Mean AUC: {np.mean(best_scores):.4f}")
    print("="*60)

In [36]:
print("\nLoading combined train data...")
full_df = pd.read_csv(cfg.working_df_path)

# Combine train and val as "combined train"
combined_train_df = full_df[full_df['split'].isin(['train', 'val'])].reset_index(drop=True)
combined_train_df.head()


Loading combined train data...


Unnamed: 0,primary_label,rating,filename,target,filepath,samplename,class,y_species_encoded,split
0,1139490,0.0,1139490/CSA36389.ogg,0,/pub/ddlin/projects/mids/DATASCI207_Bird_Sound...,1139490-CSA36389,Insecta,2,train
1,1192948,0.0,1192948/CSA36358.ogg,1,/pub/ddlin/projects/mids/DATASCI207_Bird_Sound...,1192948-CSA36358,Insecta,2,train
2,1192948,0.0,1192948/CSA36366.ogg,1,/pub/ddlin/projects/mids/DATASCI207_Bird_Sound...,1192948-CSA36366,Insecta,2,train
3,1192948,0.0,1192948/CSA36373.ogg,1,/pub/ddlin/projects/mids/DATASCI207_Bird_Sound...,1192948-CSA36373,Insecta,2,val
4,1194042,0.0,1194042/CSA18783.ogg,2,/pub/ddlin/projects/mids/DATASCI207_Bird_Sound...,1194042-CSA18783,Insecta,2,val


In [37]:
print("\nStarting training...")
run_training(combined_train_df, cfg)

print("\nTraining complete!")


Starting training...
Loading pre-computed mel spectrograms from NPY file...
Loaded 28564 pre-computed mel spectrograms

Epoch 1/10
   1143/Unknown [1m19s[0m 9ms/step - accuracy: 0.9028 - auc: 0.9602 - loss: 0.3703
Epoch 1: val_auc improved from -inf to 0.97856, saving model to model_fold0.keras
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 12ms/step - accuracy: 0.9029 - auc: 0.9603 - loss: 0.3701 - val_accuracy: 0.9678 - val_auc: 0.9786 - val_loss: 0.3041
Epoch 2/10


2025-07-27 16:16:52.396708: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:16:52.396733: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1134/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.8990 - auc: 0.9557 - loss: 0.4280

2025-07-27 16:16:58.784828: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:16:58.784860: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Epoch 2: val_auc did not improve from 0.97856
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 6ms/step - accuracy: 0.8996 - auc: 0.9560 - loss: 0.4256 - val_accuracy: 0.9678 - val_auc: 0.9786 - val_loss: 0.2849
Epoch 3/10


2025-07-27 16:17:00.292067: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:00.292090: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1133/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 6ms/step - accuracy: 0.9069 - auc: 0.9600 - loss: 0.4053
Epoch 3: val_auc did not improve from 0.97856
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 7ms/step - accuracy: 0.9075 - auc: 0.9603 - loss: 0.4028 - val_accuracy: 0.9678 - val_auc: 0.9786 - val_loss: 0.2714
Epoch 4/10


2025-07-27 16:17:08.999432: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:08.999473: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1137/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9100 - auc: 0.9588 - loss: 0.4075

2025-07-27 16:17:15.383039: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:15.383067: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Epoch 4: val_auc did not improve from 0.97856
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 6ms/step - accuracy: 0.9104 - auc: 0.9590 - loss: 0.4059 - val_accuracy: 0.9678 - val_auc: 0.9786 - val_loss: 0.2727
Epoch 5/10


2025-07-27 16:17:16.874977: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:16.875001: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.9115 - auc: 0.9581 - loss: 0.4092

2025-07-27 16:17:23.266986: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:23.267013: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Epoch 5: val_auc did not improve from 0.97856
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 6ms/step - accuracy: 0.9115 - auc: 0.9581 - loss: 0.4090 - val_accuracy: 0.9678 - val_auc: 0.9786 - val_loss: 0.2481
Epoch 6/10


2025-07-27 16:17:24.694804: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:24.694830: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1127/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9090 - auc: 0.9563 - loss: 0.4233
Epoch 6: val_auc did not improve from 0.97856
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 6ms/step - accuracy: 0.9099 - auc: 0.9568 - loss: 0.4192 - val_accuracy: 0.9678 - val_auc: 0.9786 - val_loss: 0.2375
Epoch 7/10


2025-07-27 16:17:32.527267: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:32.527304: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1129/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9101 - auc: 0.9535 - loss: 0.4337

2025-07-27 16:17:38.935332: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:38.935358: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Epoch 7: val_auc improved from 0.97856 to 0.97909, saving model to model_fold0.keras
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 7ms/step - accuracy: 0.9109 - auc: 0.9540 - loss: 0.4301 - val_accuracy: 0.9678 - val_auc: 0.9791 - val_loss: 0.2174
Epoch 8/10


2025-07-27 16:17:40.402958: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:40.402985: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1140/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9133 - auc: 0.9549 - loss: 0.4170

2025-07-27 16:17:46.784951: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:46.784988: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Epoch 8: val_auc improved from 0.97909 to 0.98031, saving model to model_fold0.keras
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 6ms/step - accuracy: 0.9135 - auc: 0.9550 - loss: 0.4161 - val_accuracy: 0.9678 - val_auc: 0.9803 - val_loss: 0.2032
Epoch 9/10


2025-07-27 16:17:48.246995: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:48.247026: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1130/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9112 - auc: 0.9478 - loss: 0.4545

2025-07-27 16:17:54.860972: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:54.861004: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Epoch 9: val_auc improved from 0.98031 to 0.98433, saving model to model_fold0.keras
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 7ms/step - accuracy: 0.9119 - auc: 0.9483 - loss: 0.4510 - val_accuracy: 0.9678 - val_auc: 0.9843 - val_loss: 0.1779
Epoch 10/10


2025-07-27 16:17:56.311175: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:17:56.311202: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044


[1m1135/1143[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9113 - auc: 0.9536 - loss: 0.4419

2025-07-27 16:18:02.685543: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:18:02.685580: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Epoch 10: val_auc improved from 0.98433 to 0.98462, saving model to model_fold0.keras
[1m1143/1143[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 6ms/step - accuracy: 0.9117 - auc: 0.9539 - loss: 0.4398 - val_accuracy: 0.9681 - val_auc: 0.9846 - val_loss: 0.1683


2025-07-27 16:18:04.166727: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 3478428190026981576
2025-07-27 16:18:04.166750: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7984707892478842044



Best AUC for fold 0: 0.9846

Cross-Validation Results:
Fold 0: 0.9846
Mean AUC: 0.9846

Training complete!


In [None]:
def evaluate_on_test(cfg, spectrograms, test_df):
    # Create test dataset
    test_ds = tf.data.Dataset.from_generator(
        lambda: data_generator(test_df, cfg, spectrograms, is_train=False),
        output_signature=(
            tf.TensorSpec(shape=(*cfg.TARGET_SHAPE, cfg.in_channels), dtype=tf.float32),
            tf.TensorSpec(shape=(cfg.num_classes,), dtype=tf.float32)
        )
    )
    test_ds = test_ds.batch(cfg.batch_size).prefetch(tf.data.AUTOTUNE)
    
    # Ground truth labels
    y_true = np.array([to_categorical(row['y_species_encoded'], num_classes=cfg.num_classes) for _, row in test_df.iterrows()])
    
    # Ensemble predictions from all folds
    y_pred_avg = np.zeros((len(test_df), cfg.num_classes))
    
    for fold in cfg.selected_folds:
        model_path = f"full_model_fold{fold}.keras"
        if os.path.exists(model_path):
            model = models.load_model(model_path)
            y_pred = model.predict(test_ds, verbose=1)
            y_pred_avg += y_pred / len(cfg.selected_folds)
            del model
            tf.keras.backend.clear_session()
            gc.collect()
        else:
            print(f"Model for fold {fold} not found at {model_path}")
    
    # Compute metrics
    test_loss = losses.CategoricalCrossentropy()(y_true, y_pred_avg).numpy()
    test_auc = metrics.AUC(multi_label=False)(y_true, y_pred_avg).numpy()
    test_acc = metrics.CategoricalAccuracy()(y_true, y_pred_avg).numpy()
    
    print(f"\nTest Loss: {test_loss:.4f}")
    print(f"Test AUC: {test_auc:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")

In [39]:
test_df = full_df[full_df['split'] == 'test'].reset_index(drop=True)

# Reload spectrograms if needed
spectrograms = np.load(cfg.spectrogram_npy, allow_pickle=True).item()

evaluate_on_test(cfg, spectrograms, test_df)

[1m358/358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 8ms/step

Test Loss: 0.1690
Test AUC: 0.9857
Test Accuracy: 0.9680
