This notebook uses some code from a wonderful last-year BirdCLEF competition notebook by kkiller: 
https://www.kaggle.com/kneroma/clean-fast-simple-bird-identifier-training-colab

**Please upvote it as well!**

However, I'm using PyTorch Lightning here to simplify the code.

Another note: this notebook trains model on all 152 available classes. However, only 21 classes will be scored, so maybe I'll change the model later.

## Install and import packages

In [None]:
!pip install -q resnest

In [None]:
import numpy as np
import pandas as pd
import json
import joblib
from pathlib import Path
from ast import literal_eval
from tqdm.notebook import tqdm

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torchmetrics import F1

import resnest.torch as resnest_torch

In [None]:
SEED = 42

MEL_PATHS = sorted(Path("../input").glob("birdclef-2022-melspectrogram-compute/rich_train_metadata.csv"))
TRAIN_LABEL_PATHS = sorted(Path("../input").glob("birdclef-2022-melspectrogram-compute/LABEL_IDS.json"))

N_CLASSES = 152
SR = 32_000
DURATION = 7

MAX_READ_SAMPLES = 5 

USE_FOLD = 0

TRAIN_BATCH_SIZE = 100
TRAIN_NUM_WORKERS = 2

VAL_BATCH_SIZE = 128
VAL_NUM_WORKERS = 2

EPOCHS = 15

## Process dataset

In [None]:
def get_df(mel_paths=MEL_PATHS, train_label_paths=TRAIN_LABEL_PATHS):
    df = None
    LABEL_IDS = {}

    for file_path in mel_paths:
        temp = pd.read_csv(str(file_path), index_col=0)
        temp["impath"] = temp.apply(lambda row: file_path.parent/"audio_images/{}.npy".format(row.filename), axis=1) 
        df = temp if df is None else df.append(temp)

        df["secondary_labels"] = df["secondary_labels"].apply(literal_eval)

    for file_path in train_label_paths:
        with open(str(file_path)) as f:
            LABEL_IDS.update(json.load(f))

    return LABEL_IDS, df

In [None]:
def get_model(name, num_classes=N_CLASSES):
    """
    Loads a pretrained model. 
    Supports ResNest, ResNext-wsl, EfficientNet, ResNext and ResNet.

    Arguments:
        name {str} -- Name of the model to load

    Keyword Arguments:
        num_classes {int} -- Number of classes to use (default: {1})

    Returns:
        torch model -- Pretrained model
    """
    if "resnest" in name:
        model = getattr(resnest_torch, name)(pretrained=False) # getting 403 error when trying to download weights
        model.load_state_dict(torch.load('../input/resnest50/resnest50-528c19ca.pth')) # so let's load them manually
    elif "wsl" in name:
        model = torch.hub.load("facebookresearch/WSL-Images", name)
    elif name.startswith("resnext") or  name.startswith("resnet"):
        model = torch.hub.load("pytorch/vision:v0.6.0", name, pretrained=True)
    elif name.startswith("tf_efficientnet_b"):
        model = getattr(timm.models.efficientnet, name)(pretrained=True)
    elif "efficientnet-b" in name:
        model = EfficientNet.from_pretrained(name)
    else:
        model = pretrainedmodels.__dict__[name](pretrained='imagenet')

    if hasattr(model, "fc"):
        nb_ft = model.fc.in_features
        model.fc = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "_fc"):
        nb_ft = model._fc.in_features
        model._fc = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "classifier"):
        nb_ft = model.classifier.in_features
        model.classifier = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "last_linear"):
        nb_ft = model.last_linear.in_features
        model.last_linear = nn.Linear(nb_ft, num_classes)

    return model

In [None]:
def load_data(df):
    def load_row(row):
        # impath = TRAIN_IMAGES_ROOT/f"{row.primary_label}/{row.filename}.npy"
        return row.filename, np.load(str(row.impath))[:MAX_READ_SAMPLES]
    pool = joblib.Parallel(4)
    mapper = joblib.delayed(load_row)
    tasks = [mapper(row) for row in df.itertuples(False)]
    res = pool(tqdm(tasks))
    res = dict(res)
    return res

In [None]:
 pl.seed_everything(SEED)
    
LABEL_IDS, df = get_df()
    
# We cache the train set to reduce training time
audio_image_store = load_data(df)
len(audio_image_store)

## Define and train model

In [None]:
class BirdClefDataset(Dataset):
    def __init__(self, audio_image_store, meta, sr=SR, is_train=True, num_classes=N_CLASSES, duration=DURATION):
        self.audio_image_store = audio_image_store
        self.meta = meta.copy().reset_index(drop=True)
        self.sr = sr
        self.is_train = is_train
        self.num_classes = num_classes
        self.duration = duration
        self.audio_length = self.duration*self.sr
    
    @staticmethod
    def normalize(image):
        image = image.astype("float32", copy=False) / 255.0
        image = np.stack([image, image, image])
        return image

    def __len__(self):
        return len(self.meta)
    
    def __getitem__(self, idx):
        row = self.meta.iloc[idx]
        image = self.audio_image_store[row.filename]

        image = image[np.random.choice(len(image))]
        image = self.normalize(image)
        
        
        #t = np.zeros(self.num_classes, dtype=np.float32) + 0.0025 # Label smoothing
        #t[row.label_id] = 0.995
        t = row.label_id
        
        return image, t

In [None]:
class BirdClefModel(pl.LightningModule):
    def __init__(self, name, n_classes):
        super().__init__()
        self.model = get_model(name, n_classes)
        self.f1 = F1(num_classes=n_classes, average='macro')

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = F.cross_entropy(logits, y)
        f1_score = self.f1(logits, y)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_f1_score", f1_score, on_epoch=True, prog_bar=True)
        return {'val_loss': loss, 'val_f1_score': f1_score}

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, eta_min=1e-5, T_max=EPOCHS)
        return {
            "optimizer":optimizer,
            "lr_scheduler" : {
                "scheduler" : scheduler,
                "monitor" : "val_loss",
                
            }
        }

In [None]:
fold_bar = tqdm(df.reset_index().groupby("fold").index.apply(list).items(), total=df.fold.max()+1)

for fold, val_set in fold_bar:
    if fold != USE_FOLD:
        continue

    print(f"\n############################### [FOLD {fold}]")
    fold_bar.set_description(f"[FOLD {fold}]")
    train_set = np.setdiff1d(df.index, val_set)

In [None]:
model = BirdClefModel("resnest50", N_CLASSES)

train_data = BirdClefDataset(audio_image_store, 
                             meta=df.iloc[train_set].reset_index(drop=True),
                             sr=SR, duration=DURATION, is_train=True)
train_dataloader = DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, 
                          num_workers=TRAIN_NUM_WORKERS, 
                          shuffle=True, pin_memory=True)

val_data = BirdClefDataset(audio_image_store, 
                           meta=df.iloc[val_set].reset_index(drop=True), 
                           sr=SR, duration=DURATION, is_train=False)
val_dataloader = DataLoader(val_data, batch_size=VAL_BATCH_SIZE,
                        num_workers=VAL_NUM_WORKERS, shuffle=False)

In [None]:
chk_callback = ModelCheckpoint(filename='best', monitor='val_f1_score', 
                               save_last=False, save_top_k=1, mode='max')
trainer = pl.Trainer(gpus=-1,
                     max_epochs=EPOCHS,
                     callbacks=[chk_callback]
                    )
trainer.fit(model, train_dataloader, val_dataloader)