# BirdCLEF 2023 data preprocessing

## Importing libraries

In [1]:
import gc
import os
import time
import copy
import math
from collections import defaultdict

import numpy as np
import pandas as pd

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
from torch import optim

from tqdm import tqdm
import timm
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
from torchvision import transforms, utils

## Config

In [2]:
CONFIG = {
    "DATA_DIR": "./birdclef-2023/",
    "TRAIN_DIR": os.path.join("./birdclef-2023/train_audio"),
    "TEST_DIR": os.path.join("./birdclef-2023/test_soundscapes"),
    "SAMPLE_RATE": 32_000,
    "NUM_CLASSES": 264,
    "N_FOLDS": 5,
    "DURATION": 5,
    "MODEL_NAME": "tf_efficientnet_b4_ns",
    "EMBEDDING_SIZE": 768,
    "DEVICE": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "TRAIN_BATCH_SIZE": 16,
    "VALID_BATCH_SIZE": 16,
    "LEARNING_RATE": 1e-3,
    "WEIGHT_DECAY": 1e-6,
    "NUM_EPOCHS": 10,
    "N_MELS": 224,
    "N_FFT": 1024,
    "HOP_LENGTH": 512,
}

## Loading data

In [3]:
train_metadata = pd.read_csv(os.path.join(CONFIG["DATA_DIR"], "train_metadata.csv"))

In [4]:
train_metadata.head()

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,license,rating,url,filename
0,abethr1,[],['song'],4.3906,38.2788,Turdus tephronotus,African Bare-eyed Thrush,Rolf A. de By,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://www.xeno-canto.org/128013,abethr1/XC128013.ogg
1,abethr1,[],['call'],-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,3.5,https://www.xeno-canto.org/363501,abethr1/XC363501.ogg
2,abethr1,[],['song'],-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,3.5,https://www.xeno-canto.org/363502,abethr1/XC363502.ogg
3,abethr1,[],['song'],-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,5.0,https://www.xeno-canto.org/363503,abethr1/XC363503.ogg
4,abethr1,[],"['call', 'song']",-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,4.5,https://www.xeno-canto.org/363504,abethr1/XC363504.ogg


In [5]:
encoder = LabelEncoder()
all_labels = sorted(train_metadata['primary_label'].values)
encoder.fit(all_labels)

train_metadata["primary_label"] = encoder.fit_transform(train_metadata["primary_label"])

In [6]:
stratified_kfold = StratifiedKFold(n_splits=CONFIG["N_FOLDS"])

for fold, (train, valid) in enumerate(stratified_kfold.split(X=train_metadata, y=train_metadata["primary_label"])):
    train_metadata.loc[valid, "kfold"] = fold



In [7]:
train_metadata

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,license,rating,url,filename,kfold
0,0,[],['song'],4.3906,38.2788,Turdus tephronotus,African Bare-eyed Thrush,Rolf A. de By,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://www.xeno-canto.org/128013,abethr1/XC128013.ogg,0.0
1,0,[],['call'],-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,3.5,https://www.xeno-canto.org/363501,abethr1/XC363501.ogg,0.0
2,0,[],['song'],-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,3.5,https://www.xeno-canto.org/363502,abethr1/XC363502.ogg,0.0
3,0,[],['song'],-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,5.0,https://www.xeno-canto.org/363503,abethr1/XC363503.ogg,1.0
4,0,[],"['call', 'song']",-2.9524,38.2921,Turdus tephronotus,African Bare-eyed Thrush,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,4.5,https://www.xeno-canto.org/363504,abethr1/XC363504.ogg,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
16936,263,[],[''],-1.2502,29.7971,Eurillas latirostris,Yellow-whiskered Greenbul,András Schmidt,Creative Commons Attribution-NonCommercial-Sha...,3.0,https://xeno-canto.org/703472,yewgre1/XC703472.ogg,4.0
16937,263,[],[''],-1.2489,29.7923,Eurillas latirostris,Yellow-whiskered Greenbul,András Schmidt,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://xeno-canto.org/703485,yewgre1/XC703485.ogg,4.0
16938,263,[],[''],-1.2433,29.7844,Eurillas latirostris,Yellow-whiskered Greenbul,András Schmidt,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://xeno-canto.org/704433,yewgre1/XC704433.ogg,4.0
16939,263,[],[''],0.0452,36.3699,Eurillas latirostris,Yellow-whiskered Greenbul,Lars Lachmann,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://xeno-canto.org/752974,yewgre1/XC752974.ogg,4.0


## Dataset

In [8]:
class BirdCLEFDataset(Dataset):
    def __init__(self, df, sample_rate, duration, image_transforms=None):
        self.audio_paths = df["filename"].values
        self.labels = df['primary_label'].values
        self.sample_rate = sample_rate
        self.duration = duration
        self.num_samples = self.sample_rate * self.duration
        self.image_transforms = image_transforms
        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=512,
            n_mels=128
        )

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = torch.tensor(self.labels[idx])

        audio, sample_rate = torchaudio.load(os.path.join(CONFIG["TRAIN_DIR"], audio_path))

        audio = self.mono_audio(audio)
        audio = self.resample(audio, sample_rate)

        if audio.shape[0] > self.num_samples:
            audio = self.crop_audio(audio)

        if audio.shape[0] < self.num_samples:
            audio = self.pad_audio(audio)

        mel_spectrograms = self.mel_spectrogram(audio)
        image = torch.stack([mel_spectrograms, mel_spectrograms, mel_spectrograms])
        image = self.normalize(image)

        return image, label

    @staticmethod
    def mono_audio(audio):
        audio = torch.mean(audio, dim=0)
        return audio

    def resample(self, audio, sample_rate):
        if sample_rate != self.sample_rate:
            resampler = Resample(sample_rate, self.sample_rate)
            audio = resampler(audio)
        return audio

    @staticmethod
    def normalize(image):
        max_val = torch.max(torch.abs(image))
        if max_val > 0:
            image = image / max_val
        return image

    def crop_audio(self, audio):
        return audio[:self.num_samples]

    def pad_audio(self, audio):
        pad_length = self.num_samples - audio.shape[0]
        last_dim_padding = (0, pad_length)
        audio = nn.functional.pad(audio, last_dim_padding)
        return audio

## GEM Pooling

In [9]:
class GeM(torch.nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)

    def gem(self, x, p=3, eps=1e-6):
        return torch.nn.functional.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p)

    def __repr__(self):
        return self.__class__.__name__ + \
            '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
            ', ' + 'eps=' + str(self.eps) + ')'

## Model

In [10]:
class BirdCLEFModel(nn.Module):
    def __init__(self, embedding_size, model_name, pretrained=True):
        super(BirdCLEFModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.embedding = nn.Linear(self.in_features, embedding_size)
        self.fc = nn.Linear(embedding_size, CONFIG["NUM_CLASSES"])

    def forward(self, image):
        features = self.model(image)
        pooled_features = self.pooling(features).flatten(1)
        embedding = self.embedding(pooled_features)
        output = self.fc(embedding)
        return output

In [11]:
model = BirdCLEFModel(embedding_size=CONFIG["EMBEDDING_SIZE"], model_name=CONFIG["MODEL_NAME"])
model.to(CONFIG["DEVICE"]);

  model = create_fn(


## Training

In [12]:
def criterion(outputs, targets):
    return nn.CrossEntropyLoss()(outputs, targets)

In [13]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()

    running_loss = 0.0

    bar = tqdm(dataloader, position=0)
    for idx, (image, label) in enumerate(bar):
        image = image.to(device)
        label = label.to(device)

        outputs = model(image)
        _, preds = torch.max(outputs, 1)

        loss = criterion(outputs, label)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        scheduler.step()

        running_loss += (loss.item())

        bar.set_description(f'Epoch [{epoch+1}/{CONFIG["NUM_EPOCHS"]}]')
        bar.set_postfix(loss=loss.item())

    gc.collect()

    return running_loss / len(dataloader)

In [14]:
@torch.inference_mode()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()

    running_loss = 0

    LABELS = []
    PREDS = []

    bar = tqdm(dataloader, position=0)
    for images, labels in bar:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, preds = torch.max(outputs, 1)

        loss = criterion(outputs, labels)

        running_loss += loss.item()

        PREDS.extend(preds.view(-1).cpu().detach().numpy())
        LABELS.extend(labels.view(-1).cpu().detach().numpy())

        bar.set_description(f'Epoch [{epoch+1}/{CONFIG["NUM_EPOCHS"]}]')
        bar.set_postfix(loss=loss.item())

    valid_f1 = f1_score(LABELS, PREDS, average='macro')

    return running_loss/len(dataloader), valid_f1

In [15]:
def prepare_loaders(df, fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)

    train_dataset = BirdCLEFDataset(df_train, sample_rate=CONFIG["SAMPLE_RATE"], duration=CONFIG["DURATION"])
    valid_dataset = BirdCLEFDataset(df_valid, sample_rate=CONFIG["SAMPLE_RATE"], duration=CONFIG["DURATION"])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG["TRAIN_BATCH_SIZE"],
                              num_workers=2, shuffle=True, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG["VALID_BATCH_SIZE"],
                              num_workers=2, shuffle=False, pin_memory=True)

    return train_loader, valid_loader

In [16]:
train_loader, valid_loader = prepare_loaders(train_metadata, fold=0)



In [17]:
optimizer = optim.Adam(model.parameters(), lr=CONFIG["LEARNING_RATE"],
                       weight_decay=CONFIG["WEIGHT_DECAY"])

In [18]:
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

In [19]:
best_valid_f1 = 0
for epoch in range(CONFIG["NUM_EPOCHS"]):
    train_loss = train_one_epoch(model, optimizer, scheduler, train_loader, CONFIG["DEVICE"], epoch)
    valid_loss, valid_f1 = valid_one_epoch(model, valid_loader, CONFIG["DEVICE"], epoch)
    if valid_f1 > best_valid_f1:
        print(f"Validation F1 Improved - {best_valid_f1} ---> {valid_f1}")
        torch.save(model.state_dict(), f'.models/model_0.bin')
        print(f"Saved model checkpoint at .models/model_0.bin")
        best_valid_f1 = valid_f1

Epoch [1/10]: 100%|██████████| 847/847 [03:56<00:00,  3.58it/s, loss=2.27]
Epoch [1/10]: 100%|██████████| 212/212 [00:56<00:00,  3.72it/s, loss=4.47] 


Validation F1 Improved - 0 ---> 0.052026190043652505
Saved model checkpoint at ./model_0.bin


Epoch [2/10]: 100%|██████████| 847/847 [04:12<00:00,  3.36it/s, loss=3.84]
Epoch [2/10]: 100%|██████████| 212/212 [00:57<00:00,  3.67it/s, loss=1.93]  


Validation F1 Improved - 0.052026190043652505 ---> 0.12624179538717217
Saved model checkpoint at ./model_0.bin


Epoch [3/10]: 100%|██████████| 847/847 [04:15<00:00,  3.32it/s, loss=2.1]  
Epoch [3/10]: 100%|██████████| 212/212 [01:00<00:00,  3.51it/s, loss=2.05]  


Validation F1 Improved - 0.12624179538717217 ---> 0.20213310865110704
Saved model checkpoint at ./model_0.bin


Epoch [4/10]: 100%|██████████| 847/847 [04:13<00:00,  3.35it/s, loss=2.06] 
Epoch [4/10]: 100%|██████████| 212/212 [00:54<00:00,  3.87it/s, loss=2.1]   


Validation F1 Improved - 0.20213310865110704 ---> 0.23084842906897984
Saved model checkpoint at ./model_0.bin


Epoch [5/10]: 100%|██████████| 847/847 [04:14<00:00,  3.32it/s, loss=1.37] 
Epoch [5/10]: 100%|██████████| 212/212 [00:58<00:00,  3.64it/s, loss=1.7]    


Validation F1 Improved - 0.23084842906897984 ---> 0.272302951451076
Saved model checkpoint at ./model_0.bin


Epoch [6/10]: 100%|██████████| 847/847 [04:14<00:00,  3.33it/s, loss=1.12] 
Epoch [6/10]: 100%|██████████| 212/212 [00:58<00:00,  3.65it/s, loss=3.78]  
Epoch [7/10]: 100%|██████████| 847/847 [04:13<00:00,  3.34it/s, loss=1.61]  
Epoch [7/10]: 100%|██████████| 212/212 [00:57<00:00,  3.69it/s, loss=2.85]   


Validation F1 Improved - 0.272302951451076 ---> 0.28619918243486697
Saved model checkpoint at ./model_0.bin


Epoch [8/10]: 100%|██████████| 847/847 [04:10<00:00,  3.38it/s, loss=1.01]  
Epoch [8/10]: 100%|██████████| 212/212 [00:55<00:00,  3.84it/s, loss=2.59] 
Epoch [9/10]: 100%|██████████| 847/847 [04:16<00:00,  3.30it/s, loss=0.367] 
Epoch [9/10]: 100%|██████████| 212/212 [00:56<00:00,  3.72it/s, loss=3.71]  


Validation F1 Improved - 0.28619918243486697 ---> 0.3104315388381563
Saved model checkpoint at ./model_0.bin


Epoch [10/10]: 100%|██████████| 847/847 [04:11<00:00,  3.37it/s, loss=0.351] 
Epoch [10/10]: 100%|██████████| 212/212 [00:55<00:00,  3.80it/s, loss=2.56] 


In [20]:
print(best_valid_f1)

0.3104315388381563
