In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.backends.cudnn.benchmark = True
torch.multiprocessing.set_sharing_strategy('file_system')

import os
import cv2
import math 
import shutil
# import neptune
import argparse 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image
from time import time
from tqdm.autonotebook import tqdm
from warnings import filterwarnings
# from sync_batchnorm import convert_model
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset

from pathlib import Path
import librosa
import audioread
import soundfile as sf

pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

In [None]:
args = {
    'exp': 'starter_development',
    'batch_size': 32,
    'fold': 0,
    'img_size': 224,
    'lr': 6e-4,
    'n_cpus': 4,
    'verbose': True,
    'nowarnings': True,
    
    'test_sample_rate': 32000,
    # Whether to override the optimal thresholds found via CV
    'override_threshold': None,
}

model = torch.jit.load('../input/modularized-resnest-training-with-mixed-precision/experiments/'+args['exp']+'/'+args['exp']+'_4.pt').cuda()

In [None]:
### Warnings and preliminary settings

# Logging, caching, and importing
if args['nowarnings']:
    filterwarnings('ignore')

In [None]:
root = '/kaggle/input'
raw_root = os.path.join(root, "birdsong-recognition")
# TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
train_raw_dir = os.path.join(raw_root, 'train_audio')
train_resampled_dirs = [
  os.path.join(root, "birdsong-resampled-train-audio-{:0>2}".format(i))  for i in range(5)
]
test_raw_dir = os.path.join(raw_root, 'test_audio')

# set(train_df.columns) - set(train_orig_df.columns)
# Unique columns: {'resampled_channels', 'resampled_filename', 'resampled_sampling_rate'}
# Resampling rate is invariably 32000, while resample channels is invariably 1 (mono)

In [None]:
train_df = pd.read_csv(os.path.join(train_resampled_dirs[0], "train_mod.csv"))
train_orig_df = pd.read_csv(os.path.join(raw_root, 'train.csv'))

if not os.path.exists(test_raw_dir):
    print('No test found (should be the case)')
    test_raw_dir = os.path.join(root, 'birdcall-check', 'test_audio')
    test_df = pd.read_csv(os.path.join(root, 'birdcall-check', 'test.csv'))
else:
    test_df = pd.read_csv(os.path.join(test_raw_dir, 'test.csv'))

In [None]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}
# Beautiful inversion
INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [None]:
# Number of seconds for each crop
PERIOD = 5

# Might be worth some looking into
def mono_to_color(
    X, mean=None, std=None,
    norm_max=None, norm_min=None, eps=1e-6
):
    # Stack to three channels?
    X = np.stack([X, X, X], axis=-1)

    # Standardize EACH SAMPLE by their minimum, maximum, and renormalize back to 255?
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V

class SpectrogramDataset(Dataset):
    def __init__(
        self, file_list, img_size=224,
        waveform_transforms=None, spectrogram_transforms=None, melspectrogram_parameters={}
    ):
        self.file_list = file_list  # list of list: [file_path, ebird_code]
        self.img_size = img_size
        self.waveform_transforms = waveform_transforms
        self.spectrogram_transforms = spectrogram_transforms
        self.melspectrogram_parameters = melspectrogram_parameters

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx: int):
        wav_path, ebird_code = self.file_list[idx]
        # Read what..? y is signal; sr is sample_rate
        signal, sample_rate = sf.read(wav_path)

        if self.waveform_transforms:
            signal = self.waveform_transforms(signal)
        else:
            len_signal = len(signal)
            effective_length = sample_rate * PERIOD
            # If less than [PERIOD] seconds: pad...zeros??? 
            # Reasonable, because centered@zero
            if len_signal < effective_length:
                padded_signal = np.zeros(effective_length, dtype=signal.dtype)
                start = np.random.randint(effective_length - len_signal)
                padded_signal[start:start + len_signal] = signal
                signal = padded_signal.astype(np.float32)
            # Else: Random crop from the whole file???
            elif len_signal > effective_length:
                start = np.random.randint(len_signal - effective_length)
                signal = signal[start:start + effective_length].astype(np.float32)
            else:
                signal = signal.astype(np.float32)

        melspec = librosa.feature.melspectrogram(signal, sr=sample_rate,\
                                                 **self.melspectrogram_parameters)
        # Huh? Take log of Amplitude squared (power)?
        melspec = librosa.power_to_db(melspec).astype(np.float32)

        if self.spectrogram_transforms:
            melspec = self.spectrogram_transforms(melspec)
        
        # Hmmm. This might be worth some looking into
        image = mono_to_color(melspec)
        height, width, _ = image.shape
        # Simple resize of image
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        # Transpose the axis? 
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)

        # labels = np.zeros(len(BIRD_CODE), dtype="i")
        labels = np.zeros(len(BIRD_CODE), dtype="f")
        labels[BIRD_CODE[ebird_code]] = 1

        return {'input':image, 'label':labels}

In [None]:
tmp_list = []
for audio_dir in train_resampled_dirs:
    if not os.path.exists(audio_dir):
        continue
    for ebird_name in os.listdir(audio_dir):
        ebird_dir = os.path.join(audio_dir, ebird_name)
        if os.path.isdir(ebird_dir):
            for wav_name in os.listdir(ebird_dir):
                wav_path = os.path.join(ebird_dir, wav_name)
                tmp_list.append([ebird_name, wav_name, wav_path])
            
train_wav_path_exist = pd.DataFrame(
    tmp_list, columns=["ebird_code", "resampled_filename", "file_path"])

train_full_df = pd.merge(
    train_df, train_wav_path_exist, on=["ebird_code", "resampled_filename"], how="inner")

print(train_df.shape)
print(train_wav_path_exist.shape)
print(train_full_df.shape)
# I don't understand
train_full_df['ebird_code'].values

# Create splits
skf = StratifiedKFold(n_splits=5, random_state=2020, shuffle=True)

train_full_df["fold"] = -1
for fold_id, (train_index, val_index) in enumerate(skf.split(train_full_df, train_full_df["ebird_code"])):
    train_full_df.iloc[val_index, -1] = fold_id
    
use_fold = args['fold']
train_file_list = train_full_df.query("fold != @use_fold")[["file_path", "ebird_code"]].values.tolist()
val_file_list = train_full_df.query("fold == @use_fold")[["file_path", "ebird_code"]].values.tolist()
print("[fold {}] train: {}, val: {}".format(use_fold, len(train_file_list), len(val_file_list)))
    
# # check the propotion
fold_proportion = pd.pivot_table(train_full_df, index="ebird_code", columns="fold", values="xc_id", aggfunc=len)
# Check that stratified is working well
fold_proportion

In [None]:
# Get specific loaders
melspectrogram_parameters = {
    'n_mels': 128,
    'fmin': 20,
    'fmax': 16000
}
trainset = SpectrogramDataset(train_file_list, img_size=args['img_size'],\
                                      melspectrogram_parameters=melspectrogram_parameters)
valset = SpectrogramDataset(val_file_list, img_size=args['img_size'],\
                                      melspectrogram_parameters=melspectrogram_parameters)

trainloader = DataLoader(trainset, pin_memory=False, shuffle=True,\
                         batch_size=args['batch_size'], num_workers=args['n_cpus'])
valloader = DataLoader(valset, pin_memory=False, shuffle=False,\
                         batch_size=args['batch_size'], num_workers=args['n_cpus'])

In [None]:
'''Criteria'''
import math
from sklearn.metrics import roc_auc_score, roc_curve, auc
    
# Assumes both pytorch tensors
def metric(preds, labels):
    # Just a crude mock at the moment
    return -criterion(preds, labels).item()

# Here! With effective weighting
def criterion(preds, labels, reduction='mean'):
    labels = labels.type_as(preds)
    return F.binary_cross_entropy_with_logits(preds, labels, reduction=reduction)

In [None]:
# Local validation   
model.eval()
with torch.no_grad():
    l, p = [], []
    model.eval()
    i1 = valloader
    if args['verbose']:
        i1 = tqdm(i1)
    count = 0
    for batch in i1:
        count += 1
        x, labels = batch['input'].cuda(), batch['label']
        preds = model(x)
        l.append(labels)
        p.append(preds.cpu())

    labels, preds = torch.cat(l), torch.cat(p)
    val_loss_value, val_metric_value = criterion(preds, labels).item(), metric(preds, labels)
labels_, preds_ = labels.numpy(), preds.sigmoid().numpy()

In [None]:
labels_, preds_ = labels.numpy(), preds.sigmoid().numpy()

In [None]:
from sklearn.metrics import *

def binarize(preds, threshold):
    preds_ = preds.copy()
    preds_[preds_ < threshold] = 0
    preds_[preds >= threshold] = 1
    return preds_

def find_optimal_f1_thresh(preds, labels):
    best_thresh, best_f1 = 0, 0
    # Find coarse best first. Heuristics
    for thresh in [.15, .2, .225, .25, .275, .3, .325, .35, .375, .4, .425, .45, .475, .5, .55, .6, .65, .7, .75]:
        f1 = f1_score(labels, binarize(preds, thresh))
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = thresh
    return best_thresh, best_f1

In [None]:
# PR for each class. will be very messy!
cls_thresholds, cls_f1 = [], []
for class_idx in tqdm(range(labels.shape[-1])):
    cls_preds, cls_labels = preds_[:, class_idx], labels_[:, class_idx]
    precision, recall, thresh = precision_recall_curve(cls_labels, cls_preds)
    
    best_thresh, best_f1 = find_optimal_f1_thresh(cls_preds, cls_labels)
    cls_thresholds.append(best_thresh)
    cls_f1.append(best_f1)
    # print(best_thresh, best_f1)
    # Finding the optimal thresholds for classes
    # plt.plot(precision, recall)
    # plt.show()

# Mean area under PR curve averaged over classes
ap_scores = [average_precision_score(labels_[:, i], preds_[:, i]) for i in range(labels.shape[-1])]
plt.hist(ap_scores)
plt.show()
print(np.mean(ap_scores))

plt.hist(cls_thresholds)
plt.hist(cls_f1)
plt.show()
print(np.mean(cls_thresholds), np.mean(cls_f1))

In [None]:
if args['override_threshold'] is not None:
    cls_thresholds = np.zeros_like(cls_thresholds) + args['override_threshold']

In [None]:
### Inference ### 

# Just creating a submission first. Placeholder to be overriden
sub = pd.read_csv("../input/birdsong-recognition/sample_submission.csv")
sub.to_csv("submission.csv", index=False)  # this will be overwritten if everything goes well

In [None]:
class TestDataset(Dataset):
    def __init__(self, df, clip, img_size, melspectrogram_parameters):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = args['test_sample_rate']
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end = SR * PERIOD
            images = []
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != (SR * PERIOD):
                    break
                start = end
                end = end + SR * PERIOD
                
                melspec = librosa.feature.melspectrogram(y_batch,
                                                         sr=SR,
                                                         **self.melspectrogram_parameters)
                melspec = librosa.power_to_db(melspec).astype(np.float32)
                image = mono_to_color(melspec)
                height, width, _ = image.shape
                image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                image = np.moveaxis(image, 2, 0)
                image = (image / 255.0).astype(np.float32)
                images.append(image)
            images = np.asarray(images)
            return {'input':images, 'row_id':row_id, 'site':site}
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - PERIOD)
            
            start_index = SR * start_seconds
            end_index = SR * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)

            melspec = librosa.feature.melspectrogram(y, sr=SR, **self.melspectrogram_parameters)
            melspec = librosa.power_to_db(melspec).astype(np.float32)

            image = mono_to_color(melspec)
            height, width, _ = image.shape
            image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
            image = np.moveaxis(image, 2, 0)
            image = (image / 255.0).astype(np.float32)

            return {'input':image, 'row_id':row_id, 'site':site}

In [None]:
def prediction_for_clip(test_df, clip):
    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          img_size=args['img_size'],
                          melspectrogram_parameters=melspectrogram_parameters)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args['n_cpus'])
    
    model.eval()
    prediction_dict = {}
    for batch in loader:
        image, row_id, site = batch['input'].cuda(), batch['row_id'][0], batch['site'][0]
        if site in {"site_1", "site_2"}:
            with torch.no_grad():
                prediction = model(image).sigmoid().detach()
            proba = prediction.detach().cpu().numpy()
            events = (proba >= np.stack([cls_thresholds]*proba.shape[0])).reshape(-1)
            labels = np.argwhere(events).reshape(-1).tolist()
        else:
            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = 32
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch_img = image[batch_i * batch_size:(batch_i + 1) * batch_size]
                if batch_img.ndim == 3:
                    batch_img = batch_img.unsqueeze(0)

                batch_img = batch_img.cuda()
                with torch.no_grad():
                    prediction = model(batch_img).sigmoid()
                proba = prediction.detach().cpu().numpy()
                    
                events = (proba >= np.stack([cls_thresholds]*proba.shape[0]))
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)
                        
            labels = list(all_events)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

In [None]:
unique_audio_id = test_df.audio_id.unique()

prediction_dfs = []
for audio_id in tqdm(unique_audio_id):
    clip, sample_rate = librosa.load(os.path.join(test_raw_dir , audio_id + ".mp3"),
                           sr=args['test_sample_rate'],
                           mono=True,
                           res_type="kaiser_fast")

    test_df_for_audio_id = test_df.query(
        f"audio_id == '{audio_id}'").reset_index(drop=True)
    prediction_dict = prediction_for_clip(test_df_for_audio_id,
                                          clip=clip)
    row_id = list(prediction_dict.keys())
    birds = list(prediction_dict.values())
    prediction_df = pd.DataFrame({
        "row_id": row_id,
        "birds": birds
    })
    prediction_dfs.append(prediction_df)

prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)

In [None]:
prediction_df.to_csv("submission.csv", index=False)
prediction_df