In [1]:
import gc
import os
import sys
import math
import random
import warnings
import ast

import albumentations as A
import cv2
import librosa
import numpy as np
import pandas as pd
import soundfile as sf
import timm
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata

from pathlib import Path
from typing import List

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping

import wandb
from pytorch_lightning.loggers import WandbLogger

from albumentations.pytorch import ToTensorV2
from albumentations.core.transforms_interface import ImageOnlyTransform
from sklearn import model_selection
from sklearn import metrics
from timm.models.layers import SelectAdaptivePool2d
from torch.optim.optimizer import Optimizer
from torchlibrosa.stft import LogmelFilterBank, Spectrogram
from torchlibrosa.augmentation import SpecAugmentation
import torchaudio

# from transformers import AdamW
# from transformers import get_cosine_schedule_with_warmup

In [2]:
sys.path.append('../../')
import src.utils as utils
from config_ import CFG

## Config

In [3]:
DATA_DIR = Path("/home/knikaido/work/BirdCLEF2021/data/")
MAIN_DATA_DIR = DATA_DIR / 'birdclef-2021'
OUTPUT_DIR = Path('./output/')

In [4]:
######################
# Data #
######################
train_datadir = MAIN_DATA_DIR / 'train_short_audio'
train_csv = MAIN_DATA_DIR / 'train_metadata.csv'
train_soundscape = MAIN_DATA_DIR / 'train_soundscape_labels.csv'

In [5]:
DEBUG = False
if DEBUG:
    CFG.epochs = 1

## My func

In [6]:
class WaveformDataset(torchdata.Dataset):
    def __init__(self,
                 df: pd.DataFrame,
                 datadir: Path,
                 img_size=224,
                 period=20,
                 validation=False):
        self.df = df
        self.datadir = datadir
        self.img_size = img_size
        self.period = period
        self.validation = validation

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        sample = self.df.loc[idx, :]
        wav_name = sample["filename"]
        ebird_code = sample["primary_label"]
        primary_label = sample["primary_labels"]
        secondary_labels = sample["secondary_labels"]

        y, sr = sf.read(self.datadir / ebird_code / wav_name)

        len_y = len(y)
        effective_length = sr * self.period
        if len_y < effective_length:
            new_y = np.zeros(effective_length, dtype=y.dtype)
            if not self.validation:
                start = np.random.randint(effective_length - len_y)
            else:
                start = 0
            new_y[start:start + len_y] = y
            y = new_y.astype(np.float32)
        elif len_y > effective_length:
            if not self.validation:
                start = np.random.randint(len_y - effective_length)
            else:
                start = 0
            y = y[start:start + effective_length].astype(np.float32)
        else:
            y = y.astype(np.float32)

        y = np.nan_to_num(y)

        labels = primary_label
        labels = labels.astype(np.float32)

        return {
            "image": y,
            "targets": labels
        }

## Model

In [9]:

class BirdModel(nn.Module):
    def __init__(self, base_model_name: str, pretrained=False, num_classes=24, in_channels=1):
        super().__init__()
        self.melspectrogramer = torchaudio.transforms.MelSpectrogram(sample_rate=CFG.sample_rate, n_fft=CFG.n_fft, win_length=CFG.n_fft, 
                                                                     hop_length=CFG.hop_length, f_min=CFG.fmin, 
                                                                     f_max=CFG.fmax, n_mels=CFG.n_mels, power=1.0)
        self.amplituder = torchaudio.transforms.AmplitudeToDB()

        base_model = timm.create_model(
            base_model_name, pretrained=pretrained, in_chans=in_channels, num_classes = num_classes)
#         layers = list(base_model.children())[:-2]
        self.encoder = base_model

        if hasattr(base_model, "fc"):
            in_features = base_model.fc.in_features
        else:
            in_features = base_model.classifier.in_features

    def forward(self, input):
        # (batch_size, 1, time_steps, freq_bins)
        x = self.melspectrogramer(input)
        x = self.amplituder(x)
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x_mean = x.mean(keepdim=True, dim=(1, 2, 3))
        x_std = x.std(keepdim=True, dim=(1, 2, 3))
        x = (x - x_mean) / (x_std * 2 + 1e-6)

        # (batch_size, channels, freq, frames)
        x = self.encoder(x)

        return x

## Training Utils

In [10]:
# Custom optimizer
__OPTIMIZERS__ = {}


def get_optimizer(model: nn.Module):
    optimizer_name = CFG.optimizer_name
    if optimizer_name == "SAM":
        base_optimizer_name = CFG.base_optimizer
        if __OPTIMIZERS__.get(base_optimizer_name) is not None:
            base_optimizer = __OPTIMIZERS__[base_optimizer_name]
        else:
            base_optimizer = optim.__getattribute__(base_optimizer_name)
        return SAM(model.parameters(), base_optimizer, **CFG.optimizer_params)

    if __OPTIMIZERS__.get(optimizer_name) is not None:
        return __OPTIMIZERS__[optimizer_name](model.parameters(),
                                              **CFG.optimizer_params)
    else:
        return optim.__getattribute__(optimizer_name)(model.parameters(),
                                                      **CFG.optimizer_params)


def get_scheduler(optimizer):
    scheduler_name = CFG.scheduler_name

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **CFG.scheduler_params)

## Train

In [11]:
# environment
utils.set_seed(CFG.seed)

# validation
splitter = getattr(model_selection, CFG.split)(**CFG.split_params)

In [12]:
train = pd.read_csv(MAIN_DATA_DIR / 'train_metadata.csv')
train['secondary_labels'] = [ast.literal_eval(d) for d in train['secondary_labels']]
train

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,date,filename,license,rating,time,url
0,acafly,[amegfi],"['begging call', 'call', 'juvenile']",35.3860,-84.1250,Empidonax virescens,Acadian Flycatcher,Mike Nelson,2012-08-12,XC109605.ogg,Creative Commons Attribution-NonCommercial-Sha...,2.5,09:30,https://www.xeno-canto.org/109605
1,acafly,[],['call'],9.1334,-79.6501,Empidonax virescens,Acadian Flycatcher,Allen T. Chartier,2000-12-26,XC11209.ogg,Creative Commons Attribution-NonCommercial-Sha...,3.0,?,https://www.xeno-canto.org/11209
2,acafly,[],['call'],5.7813,-75.7452,Empidonax virescens,Acadian Flycatcher,Sergio Chaparro-Herrera,2012-01-10,XC127032.ogg,Creative Commons Attribution-NonCommercial-Sha...,3.0,15:20,https://www.xeno-canto.org/127032
3,acafly,[whwbec1],['call'],4.6717,-75.6283,Empidonax virescens,Acadian Flycatcher,Oscar Humberto Marin-Gomez,2009-06-19,XC129974.ogg,Creative Commons Attribution-NonCommercial-Sha...,3.5,07:50,https://www.xeno-canto.org/129974
4,acafly,[whwbec1],['call'],4.6717,-75.6283,Empidonax virescens,Acadian Flycatcher,Oscar Humberto Marin-Gomez,2009-06-19,XC129981.ogg,Creative Commons Attribution-NonCommercial-Sha...,3.5,07:50,https://www.xeno-canto.org/129981
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
62869,yetvir,[],"['adult', 'male', 'song']",30.2150,-97.6505,Vireo flavifrons,Yellow-throated Vireo,Caleb Helsel,2020-07-10,XC591680.ogg,Creative Commons Attribution-NonCommercial-Sha...,1.0,08:30,https://www.xeno-canto.org/591680
62870,yetvir,[],"['life stage uncertain', 'sex uncertain', 'song']",42.3005,-72.5877,Vireo flavifrons,Yellow-throated Vireo,Christopher McPherson,2019-05-31,XC600085.ogg,Creative Commons Attribution-NonCommercial-Sha...,5.0,09:30,https://www.xeno-canto.org/600085
62871,yetvir,"[amered, eawpew, norcar, reevir1]","['adult', 'male', 'song']",42.3005,-72.5877,Vireo flavifrons,Yellow-throated Vireo,Christopher McPherson,2020-06-02,XC602701.ogg,Creative Commons Attribution-NonCommercial-Sha...,4.5,08:30,https://www.xeno-canto.org/602701
62872,yetvir,[],['uncertain'],32.2357,-99.8811,Vireo flavifrons,Yellow-throated Vireo,Brad Banner,2019-04-27,XC614733.ogg,Creative Commons Attribution-NonCommercial-Sha...,4.0,17:30,https://www.xeno-canto.org/614733


In [13]:
labels_total = []
for ebird_code in train['primary_label']:
    labels = np.zeros(len(CFG.target_columns), dtype=float)
    labels[CFG.target_columns.index(ebird_code)] = 1.0
    labels_total.append(labels)

train['primary_labels'] = labels_total

In [14]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, num_train_steps, num_warmup_steps):
        super().__init__()
        self.model = model
        self.criterion = nn.BCEWithLogitsLoss()
        self.num_train_steps = num_train_steps
        self.num_warmup_steps = num_warmup_steps
    
    def training_step(self, batch, batch_idx):
        b_data = batch
        output = self.model(b_data['image'])
        loss = self.criterion(output, b_data["targets"])
        
        self.log(f'Loss/train', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
#         sch = self.lr_schedulers()
#         sch.step()
        return loss
    
    def validation_step(self, batch, batch_idx):
        b_data = batch
        output = self.model(b_data['image'])
        loss = self.criterion(output, b_data["targets"])
        
        clipwise_ousput_np = to_np(output)
        targets_np = to_np(b_data["targets"])
        f1_score_3 = metrics.f1_score(targets_np > 0.5, clipwise_ousput_np > 0.3, average="samples")
        f1_score_5 = metrics.f1_score(targets_np > 0.5, clipwise_ousput_np > 0.5, average="samples")

        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'F1_03/val', f1_score_3, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'F1_05/val', f1_score_5, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return loss
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x for x in outputs]).mean()
        print(f'epoch = {self.current_epoch}, loss = {avg_loss}')

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model)
        scheduler = get_scheduler(optimizer)
#         optimizer = AdamW(self.model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
#         scheduler = get_cosine_schedule_with_warmup(optimizer, 
#                                                     num_warmup_steps=self.num_warmup_steps, 
#                                                     num_training_steps=self.num_train_steps)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [15]:
def to_np(input):
    return input.detach().cpu().numpy()

In [None]:
for i, (trn_idx, val_idx) in enumerate(splitter.split(train, y=train["primary_label"])):
    if i not in CFG.folds:
        continue

    trn_df = train.loc[trn_idx, :].reset_index(drop=True)
    val_df = train.loc[val_idx, :].reset_index(drop=True)
    
    
    loaders = {
        phase: torchdata.DataLoader(
            WaveformDataset(
                df_,
                train_datadir,
                img_size=CFG.img_size,
                period=CFG.period,
                validation=(phase == "valid")
            ),
            **CFG.loader_params[phase])  # type: ignore
        for phase, df_ in zip(["train", "valid"], [trn_df, val_df])
    }
    
    num_train_steps = int(len(loaders['train']) * CFG.epochs)
    num_warmup_steps = int(num_train_steps / 10)    
    model = BirdModel(
        base_model_name=CFG.base_model_name,
        pretrained=CFG.pretrained,
        num_classes=CFG.num_classes,
        in_channels=CFG.in_channels)
    model_name = model.__class__.__name__
    
    learner = Learner(model, num_train_steps, num_warmup_steps)
    
    # loggers
    RUN_NAME = f'exp{str(CFG.exp_num)}'
    wandb.init(project='BirdCLEF-Experiment', entity='sqrt4kaido', group=RUN_NAME, job_type=RUN_NAME + f'-fold-{i}')
    wandb.run.name = RUN_NAME + f'-fold-{i}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb.watch(model)
    
    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=OUTPUT_DIR,
        verbose=False,
        filename=f'{model_name}-{learner.current_epoch}-{i}')
    callbacks.append(checkpoint_callback)

#     early_stop_callback = EarlyStopping(
#         monitor='Loss/val',
#         min_delta=0.00,
#         patience=20,
#         verbose=True,
#         mode='min')
#     callbacks.append(early_stop_callback)
    
    loggers = []
    loggers.append(WandbLogger())
    
    trainer = pl.Trainer(
        logger=loggers,
        callbacks=callbacks,
        max_epochs=CFG.epochs,
        default_root_dir=OUTPUT_DIR,
        gpus=1,
#         fast_dev_run=DEBUG,
        deterministic=True,
        benchmark=False,
        )
    
    trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])
    
    trainer.save_checkpoint(OUTPUT_DIR / "last.ckpt")


[34m[1mwandb[0m: Currently logged in as: [33msqrt4kaido[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | BirdModel         | 4.5 M 
1 | criterion | BCEWithLogitsLoss | 0     
------------------------------------------------
4.5 M     Trainable params
0         Non-trainable params
4.5 M     Total params
18.062    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

epoch = 0, loss = 0.6993155479431152


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

epoch = 0, loss = 0.009357216767966747


Validating: 0it [00:00, ?it/s]

epoch = 1, loss = 0.006860276218503714


Validating: 0it [00:00, ?it/s]

epoch = 2, loss = 0.005821107886731625


Validating: 0it [00:00, ?it/s]

epoch = 3, loss = 0.005257656332105398


Validating: 0it [00:00, ?it/s]

epoch = 4, loss = 0.004868531133979559


Validating: 0it [00:00, ?it/s]

epoch = 5, loss = 0.004503840580582619


Validating: 0it [00:00, ?it/s]

epoch = 6, loss = 0.004305958282202482


Validating: 0it [00:00, ?it/s]

epoch = 7, loss = 0.0042017120867967606


Validating: 0it [00:00, ?it/s]

epoch = 8, loss = 0.004145312123000622


Validating: 0it [00:00, ?it/s]

epoch = 9, loss = 0.004056253004819155


Validating: 0it [00:00, ?it/s]

epoch = 10, loss = 0.004073841497302055


Validating: 0it [00:00, ?it/s]

epoch = 11, loss = 0.004077691584825516


Validating: 0it [00:00, ?it/s]

epoch = 12, loss = 0.004174037370830774


Validating: 0it [00:00, ?it/s]

epoch = 13, loss = 0.004239286761730909


Validating: 0it [00:00, ?it/s]

epoch = 14, loss = 0.004560326691716909


Validating: 0it [00:00, ?it/s]

epoch = 15, loss = 0.004778638016432524


Validating: 0it [00:00, ?it/s]

epoch = 16, loss = 0.004836465697735548


Validating: 0it [00:00, ?it/s]

epoch = 17, loss = 0.004931189119815826


Validating: 0it [00:00, ?it/s]

epoch = 18, loss = 0.004958209581673145


Validating: 0it [00:00, ?it/s]

epoch = 19, loss = 0.005080194212496281


Validating: 0it [00:00, ?it/s]

epoch = 20, loss = 0.005131194833666086


In [None]:
wandb.init(project='BirdCLEF-2021', entity='sqrt4kaido', group=RUN_NAME, job_type='summary')
wandb.run.name = 'summary'
# wandb.log({'CV_score': oofs_score})
wandb.save('./config_.py')
wandb.finish()
