## Huge thanks to the starter notebook here! (https://www.kaggle.com/ttahara/training-birdsong-baseline-resnest50-fast). 
### This notebook draws extensively upon the kernel for data IO & hyperparamters.
### In using it for developing baseline I found it necessary to use mixed precision, allow for different models, and customize my own training loop. 
### So here is this kernel for anyone who wants to start with more custom control over training

In [None]:
!pip install ../input/pytorch-pfn-extras/pytorch-pfn-extras-0.2.1/ > /dev/null
!pip install ../input/resnest50-fast-package/resnest-0.0.6b20200701/resnest/ > /dev/null
!pip install efficientnet_pytorch > /dev/null
!pip install timm > /dev/null
!pip install pretrainedmodels > /dev/null

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
torch.multiprocessing.set_sharing_strategy('file_system')

import os
import cv2
import math 
import timm 
import shutil
# import neptune
import argparse 
import numpy as np
import pandas as pd
import pretrainedmodels
import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image
from time import time
from tqdm.autonotebook import tqdm
from warnings import filterwarnings
# from sync_batchnorm import convert_model
from efficientnet_pytorch import EfficientNet
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset

from pathlib import Path
import librosa
import audioread
import soundfile as sf
import resnest.torch as resnest_torch

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

In [None]:
%%bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

In [None]:
args = {
    'exp': 'starter_development',
    'model': 'efficientnet-b5', # 'resnest101',
    'batch_size': 32,
    'fold': 0,
    'epochs': 40,
    'cycle_epochs': 10,
    'img_size': 224,
    'margin': 0.2,
    's': 16.,
    'fp16': 1,
    'accum_steps': 1,
    'pretrained': 'none',
    'cache_folder': None,
    'clear_cache': 0,
    'lr': 6e-4,
    'n_cpus': 4,
    'verbose': True,
    'nowarnings': True
}

In [None]:
### Warnings and preliminary settings
# Initialize the neptune experiment

# Logging, caching, and importing
args['fp16'] = (args['fp16'] == 1)
if args['nowarnings']:
    filterwarnings('ignore')
if args['fp16']:
    from apex import amp
if not os.path.isdir('experiments'):
    os.mkdir('experiments')
# if not os.path.isdir('data/cache'):
#     os.mkdir('data/cache')
logging_folder = os.path.join('experiments', args['exp'])
if os.path.isdir(logging_folder):
    shutil.rmtree(logging_folder)
os.mkdir(logging_folder)
'''if args['cache_folder'] is None:
    cache_folder = os.path.join('data/cache', str(args['img_size']))
else:
    cache_folder = args['cache_folder']
if args['clear_cache'] != 0 and os.path.exists(cache_folder):
    shutil.rmtree(cache_folder)
    print('Cache folder emptied')
if not os.path.exists(cache_folder):
    os.mkdir(cache_folder)
    print('First time caching')
print('Caching preprocessed data@', cache_folder)'''
writer = SummaryWriter(logging_folder)
print('For visualization, run:')
print('tensorboard --logdir='+logging_folder)

In [None]:
root = '/kaggle/input'
raw_root = os.path.join(root, "birdsong-recognition")
# TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
train_raw_dir = os.path.join(raw_root, 'train_audio')
train_resampled_dirs = [
  os.path.join(root, "birdsong-resampled-train-audio-{:0>2}".format(i))  for i in range(5)
]
test_raw_dir = os.path.join(raw_root, 'test_audio')

# set(train_df.columns) - set(train_orig_df.columns)
# Unique columns: {'resampled_channels', 'resampled_filename', 'resampled_sampling_rate'}
# Resampling rate is invariably 32000, while resample channels is invariably 1 (mono)

In [None]:
train_df = pd.read_csv(os.path.join(train_resampled_dirs[0], "train_mod.csv"))
train_orig_df = pd.read_csv(os.path.join(raw_root, 'train.csv'))

if not os.path.exists(test_raw_dir):
    print('No test found (should be the case)')
    test_raw_dir = os.path.join(root, 'birdcall-check', 'test_audio')
    test_df = pd.read_csv(os.path.join(root, 'birdcall-check', 'test.csv'))
else:
    test_df = pd.read_csv(os.path.join(test_raw_dir, 'test.csv'))

In [None]:
# Model portion

# Swish utilities
from efficientnet_pytorch.utils import Swish, MemoryEfficientSwish
from timm.models.layers.activations import Mish
from timm.models.layers.activations import Swish as Swish_timm 

### Transform model to (memory efficient) swish ###
def to_swish(model):
    for child_name, child in model.named_children():
        if isinstance(child, (nn.ReLU, Swish, MemoryEfficientSwish, Swish_timm)):
            setattr(model, child_name, MemoryEfficientSwish())
        else:
            to_swish(child)
            
def to_traceswish(model):
    for child_name, child in model.named_children():
        if isinstance(child, MemoryEfficientSwish):
            setattr(model, child_name, Swish())
        else:
            to_traceswish(child)
            
class normalized_Linear(nn.Module):
    def __init__(self, in_features, out_features, init_s=args['s']):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.s = init_s
        # Important! Initialize it!
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input):
        input = F.normalize(input)
        cos = F.linear(input, F.normalize(self.weight))
        return self.s*cos

class add_tail(nn.Module):
    def __init__(self, backbone, num_features):
        super().__init__()
        self.backbone = backbone
        self.fc = normalized_Linear(num_features, 264)
    
    def forward(self, x):
        return self.fc(self.backbone(x))

# Plain model loading
def get_model(args):
    total_names = ['efficientnet-b'+str(i) for i in range(8)] + pretrainedmodels.model_names\
                    + timm.list_models() + ['resnest'+str(depth) for depth in [50, 101, 200, 269]]
    if not args['model'] in total_names:
        print('Nope! Available models are:', total_names)
    # Load from pretrained first
    if 'efficientnet' in args['model']:
        try:
            backbone = EfficientNet.from_pretrained(args['model'], 10)
        except:
            print('efficientnet-bx x~[0-7] please')
            raise NotImplementedError
        num_features = backbone._fc.weight.shape[1]
        backbone._fc = nn.Sequential()
    elif args['model'] in pretrainedmodels.model_names:
        backbone = pretrainedmodels.__dict__[args['model']](pretrained='imagenet')
        num_features = backbone.last_linear.weight.shape[1]
        backbone.last_linear = nn.Sequential()
    elif 'resnest' in args['model']:
        backbone = getattr(resnest_torch, args["model"])(pretrained=True)
        num_features = backbone.fc.weight.shape[1]
        backbone.fc = nn.Sequential()
    else:
        backbone = timm.create_model(args['model'], pretrained=True)
        for child_name, child in list(backbone.named_children())[::-1]:
            if isinstance(child, nn.Linear):
                num_features = child.weight.shape[1]
                setattr(backbone, child_name, nn.Sequential())
                break
    model = add_tail(backbone, num_features)
    to_swish(model)
    return model

model = get_model(args).cuda()

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}
# Beautiful inversion
INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
# Number of seconds for each crop
PERIOD = 5

# Might be worth some looking into
def mono_to_color(
    X, mean=None, std=None,
    norm_max=None, norm_min=None, eps=1e-6
):
    # Stack to three channels?
    X = np.stack([X, X, X], axis=-1)

    # Standardize EACH SAMPLE by their minimum, maximum, and renormalize back to 255?
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V

class SpectrogramDataset(Dataset):
    def __init__(
        self, file_list, img_size=224,
        waveform_transforms=None, spectrogram_transforms=None, melspectrogram_parameters={}
    ):
        self.file_list = file_list  # list of list: [file_path, ebird_code]
        self.img_size = img_size
        self.waveform_transforms = waveform_transforms
        self.spectrogram_transforms = spectrogram_transforms
        self.melspectrogram_parameters = melspectrogram_parameters

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx: int):
        wav_path, ebird_code = self.file_list[idx]
        # Read what..? y is signal; sr is sample_rate
        signal, sample_rate = sf.read(wav_path)

        if self.waveform_transforms:
            signal = self.waveform_transforms(signal)
        else:
            len_signal = len(signal)
            effective_length = sample_rate * PERIOD
            # If less than [PERIOD] seconds: pad...zeros??? 
            # Reasonable, because centered@zero
            if len_signal < effective_length:
                padded_signal = np.zeros(effective_length, dtype=signal.dtype)
                start = np.random.randint(effective_length - len_signal)
                padded_signal[start:start + len_signal] = signal
                signal = padded_signal.astype(np.float32)
            # Else: Random crop from the whole file???
            elif len_signal > effective_length:
                start = np.random.randint(len_signal - effective_length)
                signal = signal[start:start + effective_length].astype(np.float32)
            else:
                signal = signal.astype(np.float32)

        melspec = librosa.feature.melspectrogram(signal, sr=sample_rate,\
                                                 **self.melspectrogram_parameters)
        # Huh? Take log of Amplitude squared (power)?
        melspec = librosa.power_to_db(melspec).astype(np.float32)

        if self.spectrogram_transforms:
            melspec = self.spectrogram_transforms(melspec)
        
        # Hmmm. This might be worth some looking into
        image = mono_to_color(melspec)
        height, width, _ = image.shape
        # Simple resize of image
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        # Transpose the axis? 
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)

        # labels = np.zeros(len(BIRD_CODE), dtype="i")
        labels = np.zeros(len(BIRD_CODE), dtype="f")
        labels[BIRD_CODE[ebird_code]] = 1

        return {'input':image, 'label':labels}

In [None]:
tmp_list = []
for audio_dir in train_resampled_dirs:
    if not os.path.exists(audio_dir):
        continue
    for ebird_name in os.listdir(audio_dir):
        ebird_dir = os.path.join(audio_dir, ebird_name)
        if os.path.isdir(ebird_dir):
            for wav_name in os.listdir(ebird_dir):
                wav_path = os.path.join(ebird_dir, wav_name)
                tmp_list.append([ebird_name, wav_name, wav_path])
            
train_wav_path_exist = pd.DataFrame(
    tmp_list, columns=["ebird_code", "resampled_filename", "file_path"])

train_full_df = pd.merge(
    train_df, train_wav_path_exist, on=["ebird_code", "resampled_filename"], how="inner")

print(train_df.shape)
print(train_wav_path_exist.shape)
print(train_full_df.shape)
# I don't understand
train_full_df['ebird_code'].values

# Create splits
skf = StratifiedKFold(n_splits=5, random_state=2020, shuffle=True)

train_full_df["fold"] = -1
for fold_id, (train_index, val_index) in enumerate(skf.split(train_full_df, train_full_df["ebird_code"])):
    train_full_df.iloc[val_index, -1] = fold_id
    
use_fold = args['fold']
train_file_list = train_full_df.query("fold != @use_fold")[["file_path", "ebird_code"]].values.tolist()
val_file_list = train_full_df.query("fold == @use_fold")[["file_path", "ebird_code"]].values.tolist()
print("[fold {}] train: {}, val: {}".format(use_fold, len(train_file_list), len(val_file_list)))
    
# # check the propotion
fold_proportion = pd.pivot_table(train_full_df, index="ebird_code", columns="fold", values="xc_id", aggfunc=len)
# Check that stratified is working well
fold_proportion

In [None]:
# Get specific loaders
melspectrogram_parameters = {
    'n_mels': 128,
    'fmin': 20,
    'fmax': 16000
}
trainset = SpectrogramDataset(train_file_list, img_size=args['img_size'],\
                                      melspectrogram_parameters=melspectrogram_parameters)
valset = SpectrogramDataset(val_file_list, img_size=args['img_size'],\
                                      melspectrogram_parameters=melspectrogram_parameters)

trainloader = DataLoader(trainset, pin_memory=False, shuffle=True,\
                         batch_size=args['batch_size'], num_workers=args['n_cpus'])
valloader = DataLoader(valset, pin_memory=False, shuffle=False,\
                         batch_size=args['batch_size'], num_workers=args['n_cpus'])

In [None]:
# Optimizer and scheduler
op = torch.optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(op, max_lr=args['lr'],     total_steps=args['epochs']*len(trainloader))
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(op, len(trainloader)*args['cycle_epochs'])
if args['fp16']:
    model, op = amp.initialize(model, op, opt_level='O2', verbosity=0)

In [None]:
'''Criteria'''
import math
from sklearn.metrics import *
    
# Assumes both pytorch tensors
# Use mAP for metric. Consider using row-wise AP?
def metric(preds, labels):
    preds, labels = preds.sigmoid().cpu().numpy(), labels.cpu().numpy()
    # return np.mean([average_precision_score(labels[i, :], preds[i, :]) for i in range(labels.shape[0])])
    return np.mean([average_precision_score(labels[:, i], preds[:, i]) for i in range(labels.shape[1])])

# Here! With effective weighting
def criterion(preds, labels, reduction='mean'):
    labels = labels.type_as(preds)
    return F.binary_cross_entropy_with_logits(preds, labels, reduction=reduction)

In [None]:
'''training loop'''
from collections import deque

global_steps = 0
ckpts_Q = deque(maxlen=5)
val_losses, val_metrics = [], []
for epoch in range(args['epochs']):
    init_time = time()

    i1 = trainloader
    if args['verbose'] == 1:
        i1 = tqdm(i1)

    # Training 
    model.train()
    train_losses, train_metrics = [], []
    train_preds, train_labels = [], []
    for batch in i1:
        x, labels = batch['input'].cuda(), batch['label'].cuda()
        preds = model(x)
        loss = criterion(preds, labels)
        
        if args['fp16']:
            with amp.scale_loss((loss / args['accum_steps']), op) as scaled_loss:
                scaled_loss.backward()
        else:
            (loss / args['accum_steps']).backward()

        if (global_steps + 1) % args['accum_steps'] == 0:
            op.step()
            op.zero_grad()
        
        global_steps += 1
        if global_steps % 1 == 0:
            # Logging
            train_preds.append(preds.detach().cpu())
            train_labels.append(labels.cpu())
            loss_value = loss.item()
            lr_value = np.max([group['lr'] for group in op.param_groups])
            metric_value =  metric(torch.cat(train_preds), torch.cat(train_labels))
            writer.add_scalar('lr', lr_value, global_steps)
            train_losses.append(loss_value)
            train_metrics.append(metric_value)
            if args['verbose']:
                i1.set_postfix({'train_loss':round(np.mean(train_losses), 3), 'train_metric':round(train_metrics[-1], 3), 'lr':lr_value})
        scheduler.step()
    train_loss_value, train_metric_value = np.mean(train_losses), np.mean(train_metrics)
    
    model.eval()
    with torch.no_grad():
        l, p = [], []
        model.eval()
        i1 = valloader
        if args['verbose']:
            i1 = tqdm(i1)
        count = 0
        for batch in i1:
            count += 1
            x, labels = batch['input'].cuda(), batch['label']
            preds = model(x)
            l.append(labels)
            p.append(preds.cpu())

        labels, preds = torch.cat(l), torch.cat(p)
        val_loss_value, val_metric_value = criterion(preds, labels).item(), metric(preds, labels)

        print(f"{epoch+1} loss:{round(train_loss_value, 4)} metric:{round(train_metric_value, 4)} val_loss:{round(val_loss_value, 4)} val_metric:{round(val_metric_value, 4)} time:{round(time()-init_time, 4)}")
        init_time = time()
        val_losses.append(val_loss_value)
        val_metrics.append(val_metric_value)
        
        writer.add_scalars('loss', {'train':train_loss_value, 'val':val_loss_value}, global_steps)
        writer.add_scalars('metric', {'train':train_metric_value, 'val':val_metric_value}, global_steps)
        
        if val_metric_value == max(val_metrics):
            print('New best')
            # Tracing, a little bit more complicated
            model_ = get_model(args).cuda()
            model_.load_state_dict(model.state_dict())
            to_traceswish(model_)
            model_.eval()
            ckpt = torch.jit.trace(model_, x).cpu()
            ckpts_Q.append(ckpt)
            for i, ckpt_ in enumerate(ckpts_Q):
                torch.jit.save(ckpt_, os.path.join(logging_folder, args['exp']+'_'+str(i)+'.pt'))