In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os 
import librosa
import librosa.display
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
import torch.optim as optim

import pytorch_lightning as pl
from torchmetrics import Accuracy, Precision, Recall, F1Score
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor,
    EarlyStopping,
)

from birdclef.io import dump_json_file
from birdclef.data_modeling.loader import BirdclefDataset

In [2]:
KAGGLE_PATH = "/Users/taa/Documents/kaggle/kaggle_birdclef25"
SPECTOGRAM_STYLE = "mel"  # "mel" or "stft"

INPUT_DATA_MODEL = {
    "stft" : (1025, 626),
    "mel" : (128, 626)
}

In [3]:
DATA_PATH = os.path.join(KAGGLE_PATH, "data", "processed", SPECTOGRAM_STYLE)
exisiting_files = os.listdir(DATA_PATH)
print("Number of FILES/FOLDERS", len(exisiting_files))

Number of FILES/FOLDERS 3


In [4]:
train_ds = BirdclefDataset(
    data_path=os.path.join(DATA_PATH, "train"),
    set_type="train",
)
val_ds = BirdclefDataset(
    data_path=os.path.join(DATA_PATH, "val"),
    set_type="val",
    label2id=train_ds.label2id,
    id2label=train_ds.id2label,
)

train_dl = DataLoader(
    train_ds,
    batch_size=128,
    shuffle=True,
    num_workers=14,
    persistent_workers=True,
)
val_dl = DataLoader(
    val_ds,
    batch_size=128,
    shuffle=False,
    num_workers=14,
    persistent_workers=True,
)

In [5]:
ex_sample, ex_label = next(iter(train_dl))

In [6]:
ex_sample.shape, ex_label.shape

(torch.Size([128, 1, 128, 626]), torch.Size([128]))

# Model

In [7]:
import torch
import torch.nn as nn

In [8]:
class SpecAugment(nn.Module):
    def __init__(self, time_mask=30, freq_mask=13):
        super().__init__()
        self.time_mask = time_mask
        self.freq_mask = freq_mask

    def forward(self, x):
        if not self.training:
            return x  

        for i in range(x.size(0)):
            t = torch.randint(0, self.time_mask, (1,)).item()
            f = torch.randint(0, self.freq_mask, (1,)).item()

            t0 = torch.randint(0, max(1, x.size(3) - t), (1,)).item()
            f0 = torch.randint(0, max(1, x.size(2) - f), (1,)).item()

            x[i, 0, f0:f0+f, :] = 0  
            x[i, 0, :, t0:t0+t] = 0 

        return x

In [9]:
class BirdCLEFCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.SiLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.SiLU(),
        )

        self.dropout = nn.Dropout2d(0.3)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv_block(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [10]:
model = BirdCLEFCNN(num_classes=len(train_ds.label2id))

# Loss

In [11]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, inputs, targets):
        ce_loss = self.ce(inputs, targets)
        pt = torch.exp(-ce_loss)
        return (self.alpha * (1 - pt) ** self.gamma * ce_loss).mean()

In [12]:
class LightModel(pl.LightningModule):
    def __init__(self, model, n_classes: int, learning_rate=1e-3):
        super(LightModel, self).__init__()
        self.spec_augment = SpecAugment()
        self.model = model
        self.learning_rate = learning_rate
        self.criterion = FocalLoss()
        self.n_classes = n_classes

    def forward(self, x):
        x = self.spec_augment(x)
        return self.model(x)

    def configure_optimizers(self):
        opt = optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=1e-4,
        )
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=5, min_lr=1e-6)
        return [opt], [dict(scheduler=lr_scheduler, interval="epoch", monitor="train_loss")]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y.long())
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, val_batch: dict, batch_idx: int) -> torch.Tensor:
        x, y = val_batch
        y_hat = self(x)
        val_loss = self.criterion(y_hat, y.long())
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)

        final_preds = torch.nn.functional.softmax(y_hat, dim=-1)
        self.update_metrics(final_preds, y.long())

        return val_loss

    def on_validation_start(self) -> None:
        self.eval_metrics = dict(
            accuracy=Accuracy(
                task="multiclass",
                num_classes=self.n_classes,
                average="weighted",
            ).to(self.device),
            f1_score=F1Score(
                task="multiclass",
                num_classes=self.n_classes,
                average="weighted",
            ).to(self.device),
            precision=Precision(
                task="multiclass",
                num_classes=self.n_classes,
                average="weighted",
            ).to(self.device),
            recall=Recall(
                task="multiclass",
                num_classes=self.n_classes,
                average="weighted",
            ).to(self.device),
        )

    def on_validation_end(self) -> None:
        save_path = os.path.join(self.logger.log_dir, "val_metrics.json")
        val_metrics = self.compute_metrics()
        dump_json_file(val_metrics, save_path)
        self.reset_metrics()

    def update_metrics(
        self,
        pred: torch.Tensor,
        gt: torch.Tensor,
    ):
        self.eval_metrics["accuracy"].update(preds=pred, target=gt)
        self.eval_metrics["f1_score"].update(preds=pred, target=gt)
        self.eval_metrics["precision"].update(preds=pred, target=gt)
        self.eval_metrics["recall"].update(preds=pred, target=gt)

    def compute_metrics(self) -> dict:
        final_metrics = {
            met_name: met.compute().item() for met_name, met in self.eval_metrics.items()
        }
        return final_metrics

    def reset_metrics(self) -> None:
        for _, metric in self.eval_metrics.items():
            metric.reset()

In [13]:
logger = TensorBoardLogger(
    f"{KAGGLE_PATH}/lightning_logs/", name="birdclef_specaugment", default_hp_metric=False
)
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", save_top_k=1, filename="{epoch}-{val_loss:.2f}"
)
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.01,
    patience=5,
    verbose=False,
    mode="min",
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [early_stop_callback, checkpoint_callback, lr_monitor]

In [14]:
model = LightModel(model=model, n_classes=len(train_ds.id2label), learning_rate=1e-3)
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True,
    check_val_every_n_epoch=2,
    callbacks=callbacks,
    logger=logger,
)
trainer.fit(model, train_dl, val_dl)
# Save the model
# torch.save(model.state_dict(), "birdclef_cnn.pth")
# # Load the model
# model = LightModel(model=BirdCLEFCNN(num_classes=len(train_ds.labels)))
# model.load_state_dict(torch.load("birdclef_cnn.pth"))
# Set the model to evaluation mode

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name         | Type        | Params | Mode 
-----------------------------------------------------
0 | spec_augment | SpecAugment | 0      | train
1 | model        | BirdCLEFCNN | 119 K  | train
2 | criterion    | FocalLoss   | 0      | train
-----------------------------------------------------
119 K     Trainable params
0         Non-trainable params
119 K     Total params
0.478     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  4.29it/s]



Epoch 6:  47%|████▋     | 318/674 [02:49<03:09,  1.87it/s, v_num=0, train_loss_step=3.660, train_loss_epoch=3.550, val_loss=3.470]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined