## Multi-label dataset

This notebook provides a multi-label dataset in pytorch, with extra labels extracted from `secondary_labels` or `background` meta-data in `train.csv`. Also, there are some simple statistics about multi-labels.  

Since the test dataset is multi-labeled, maybe training with multi-labels can help. But the main problem for this competition is noise and the exact time range corresponding to each label, so this may not be that helpful.

### References:  
Multi-label idea: https://www.kaggle.com/maxwell110/mluticlass-to-muiltilabel  
Dataset scripts: https://www.kaggle.com/ttahara/training-birdsong-baseline-resnest50-fast  

In [None]:
import os
import gc
import time
import shutil
import random
import warnings
import typing as tp
from pathlib import Path
from glob import glob
from datetime import datetime

import cv2
import librosa
import audioread
import soundfile as sf
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import StratifiedKFold

import torch

In [None]:
VIS = True

N_SPLITS = 5
FOLD = 0
IMSIZE = 224
WORKSPACE = '.'
SEED = 42

In [None]:
ROOT = Path.cwd().parent
INPUT_ROOT = ROOT / "input"
RAW_DATA = INPUT_ROOT / "birdsong-recognition"
TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
TRAIN_RESAMPLED_AUDIO_DIRS = [INPUT_ROOT / "birdsong-resampled-train-audio-{:0>2}".format(i)  for i in range(5)]
TEST_AUDIO_DIR = RAW_DATA / "test_audio"

In [None]:
train = pd.read_csv(TRAIN_RESAMPLED_AUDIO_DIRS[0] / "train_mod.csv")

In [None]:
# use sci_name to extract secondary labels  
SPS_CODE = dict(zip(train.sci_name, train.ebird_code))
INV_SPS_CODE = {v: k for k, v in SPS_CODE.items()}

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
import re

# background
def get_sps_from_bg(s):
    if type(s) != str: s = str(s)
    return re.findall('\((.*?)\)', s)

# secondary
def get_sps_from_sec(l):
    return [re.split('_', s[1:-1])[0] for s in re.split(',\s*', l[1:-1])]

def map_sps_to_code(l):
    res = []
    for s in l:
        if s == '':
            continue
        elif s in SPS_CODE:
            res.append(SPS_CODE[s])
        else:
            # there are species not included in the classification targets (thus no ebird_code)
            print(f"{s} doesn't exist.")
    return res

train['labels_bg'] = train['background'].apply(get_sps_from_bg).apply(map_sps_to_code)
train['labels_sec'] = train['secondary_labels'].apply(get_sps_from_sec).apply(map_sps_to_code)

In [None]:
train[['ebird_code', 'labels_sec', 'labels_bg']]

In [None]:
# Simple statistics
print(f"Total Entries                  : {len(train)}")
print(f"Entries with secondary_labels  : {len(train[train['labels_sec'].str.len() > 0])}")
print(f"Entries with background        : {len(train[train['labels_bg'].str.len() > 0])}")
print(f"Average secondary_labels length: {train['labels_sec'].apply(len).mean()}")

In [None]:
# Simple check shows some bg has more labels than sec. (most are the same)
for i in range(len(train)):
    if set(train.labels_bg.iloc[i]) != set(train.labels_sec.iloc[i]):
        print(train.labels_bg.iloc[i], train.labels_sec.iloc[i])

In [None]:
PERIOD = 5 # 5s

def mono_to_color(X, mean=None, std=None, eps=1e-6):
    
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean)/(std + eps)
    
    # Normalization
    mx = X.max()
    mn = X.min()
    if mx - mn > eps:
        X = (X - X.min()) / (X.max() - X.min())
        X = (X * 255).astype(np.uint8)
    else:
        X = np.zeros_like(X, dtype=np.uint8)
        
    return X


class SpectrogramDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, img_size=IMSIZE, waveform_transforms=None, spectrogram_transforms=None):
        self.file_list = file_list  # list of list: [file_path, ebird_code]
        self.img_size = img_size
        self.waveform_transforms = waveform_transforms
        self.spectrogram_transforms = spectrogram_transforms

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx: int):
        wav_path, ebird_code, labels_sec = self.file_list[idx]

        y, sr = sf.read(wav_path) # sample rate

        if self.waveform_transforms:
            y = self.waveform_transforms(y)
        else:
            len_y = len(y)
            effective_length = sr * PERIOD
            # padding
            if len_y < effective_length:
                new_y = np.zeros(effective_length, dtype=y.dtype)
                start = np.random.randint(effective_length - len_y)
                new_y[start:start + len_y] = y
                y = new_y.astype(np.float32)
            # random clip
            elif len_y > effective_length:
                start = np.random.randint(len_y - effective_length)
                y = y[start:start + effective_length].astype(np.float32)
            else:
                y = y.astype(np.float32)

        melspec = librosa.feature.melspectrogram(y, sr=sr, n_mels=128, fmin=20, fmax=16000)
        melspec = librosa.power_to_db(melspec).astype(np.float32)

        if self.spectrogram_transforms:
            melspec = self.spectrogram_transforms(melspec)
        else:
            pass
        
        image = mono_to_color(melspec)  
            
        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)
        
        # one-hot label
        labels = np.zeros(len(BIRD_CODE), dtype="f")
        labels[BIRD_CODE[ebird_code]] = 1
        for sec in labels_sec:
            labels[BIRD_CODE[sec]] = 1

        return image, labels

In [None]:
tmp_list = []
for audio_d in TRAIN_RESAMPLED_AUDIO_DIRS:
    if not audio_d.exists():
        continue
    for ebird_d in audio_d.iterdir():
        if ebird_d.is_file():
            continue
        for wav_f in ebird_d.iterdir():
            tmp_list.append([ebird_d.name, wav_f.name, wav_f.as_posix()])
            
train_wav_path_exist = pd.DataFrame(tmp_list, columns=["ebird_code", "resampled_filename", "file_path"])

del tmp_list

train_all = pd.merge(train, train_wav_path_exist, on=["ebird_code", "resampled_filename"], how="inner")

print(train.shape)
print(train_wav_path_exist.shape)
print(train_all.shape)

In [None]:
skf = StratifiedKFold(n_splits=N_SPLITS, random_state=SEED, shuffle=True)

train_all["fold"] = -1
for fold_id, (train_index, val_index) in enumerate(skf.split(train_all, train_all["ebird_code"])):
    train_all.iloc[val_index, -1] = fold_id
    
# # check the propotion
fold_proportion = pd.pivot_table(train_all, index="ebird_code", columns="fold", values="xc_id", aggfunc=len)
print(fold_proportion.shape, fold_proportion)

In [None]:
use_fold = FOLD
train_file_list = train_all.query("fold != @use_fold")[["file_path", "ebird_code", "labels_sec"]].values.tolist()
val_file_list = train_all.query("fold == @use_fold")[["file_path", "ebird_code", "labels_sec"]].values.tolist()

print("[fold {}] train: {}, val: {}".format(use_fold, len(train_file_list), len(val_file_list)))

In [None]:
train_dataset = SpectrogramDataset(train_file_list)

for i in range(3):
    img, lbl = train_dataset[i]
    print(np.argwhere(lbl))
    plt.imshow(img.transpose(1,2,0))
    plt.show()