# Overvew
- stride mask inference を評価したい
- 評価するフレームワークの作成

# CONST

In [1]:
NB = '048'
DEBUG = False
PERIOD = 5
# PATH_EVENT = './../data_ignore/event/nb017_event_rms/nb017_event_rms.csv'
DATASET = '32khz'
DIR_MODEL = './../data_ignore/model'
DIR_NOCALL = '/home/knikaido/work/Cornell-Birdcall-Identification/data/nocall'
# PATH_MODEL = './../data_ignore/model/resnest50/resnest50_fast_1s1x64d-d8fbf808.pth'
PATH_SECONDALY_LAVEL_CODE = './codes.pkl'
TARGET_SR = 32000

In [2]:
settings_str = """
globals:
  seed: 1213
  device: cuda
  num_epochs: 45
  output_dir: /kaggle/training_output/
  use_fold: 0
  target_sr: 32000

dataset:
  name: SpectrogramDataset
  params:
    img_size: 224
    melspectrogram_parameters:
      n_mels: 128
      fmin: 20
      fmax: 16000
    
split:
  name: StratifiedKFold
  params:
    n_splits: 5
    random_state: 42
    shuffle: True

loader:
  train:
    batch_size: 5
    shuffle: True
    num_workers: 10
    pin_memory: True
    drop_last: True
  val:
    batch_size: 5
    shuffle: False
    num_workers: 10
    pin_memory: True
    drop_last: False

model:
  name: resnest50_fast_1s1x64d
  params:
    pretrained: True
    n_classes: 264

loss:
  name: BCEWithLogitsLoss
  params: {}

optimizer:
  name: Adam
  params:
    lr: 0.001

scheduler:
  name: CosineAnnealingLR
  params:
    T_max: 10
"""

# Import everything I need :)

In [3]:
import os
import gc
import pickle
import typing as tp
import logging
import cv2 
import warnings
import math
import time
import yaml
import random
import librosa
import numpy as np
import pandas as pd
from glob import glob
import soundfile as sf
from pathlib import Path
from scipy.io import wavfile
import matplotlib.pyplot as plt
from fastprogress import progress_bar
from contextlib import contextmanager
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
# from torchvision import models
import torch.nn.functional as F
import torch.utils.data as data
# import resnest.torch as resnest_torch
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair

import sys
sys.path.insert(0, './../src/util/')
# from const import BIRD_CODE, INV_BIRD_CODE

In [4]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}


HASBIRD_CODE = {
    'nocall': 0, 'bird': 1
}

INV_HASBIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

# My function

In [5]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore

@contextmanager
def timer(name: str, logger):
    t0 = time.time()
    msg = f"[{name}] start"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    yield

    msg = f"[{name}] done in {time.time() - t0:.2f} s"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
        
def get_logger(out_file=None):
    logger = logging.getLogger()
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    logger.handlers = []
    logger.setLevel(logging.INFO)

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)

    if out_file is not None:
        fh = logging.FileHandler(out_file)
        fh.setFormatter(formatter)
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)
    logger.info("logger set up")
    return logger

In [6]:
def plot_spec(melspec):
    plt.figure(figsize=(6,3), dpi= 80)
    melspec = librosa.power_to_db(melspec).astype(np.float32)
    plt.imshow(melspec)

In [7]:
def mono_to_color(X, mean=None, std=None, norm_max=None, norm_min=None, eps=1e-6):
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V

In [8]:
class SpectrogramDataset(data.Dataset):
    def __init__(self, file_list, img_size=224, 
                 waveform_transforms=None, spectrogram_transforms=None, melspectrogram_parameters={}):
        
        self.file_list = file_list  # list of list: [file_path, ebird_code]
        self.img_size = img_size
        self.waveform_transforms = waveform_transforms
        self.spectrogram_transforms = spectrogram_transforms
        self.melspectrogram_parameters = melspectrogram_parameters

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx: int):
        wav_path, ebird_code = self.file_list[idx]

        y, sr = sf.read(wav_path)
#         sr, y = wavfile.read(wav_path)

        if self.waveform_transforms:
            y = self.waveform_transforms(y)
        else:
            len_y = len(y)
            effective_length = sr * PERIOD
            if len_y < effective_length:
                new_y = np.zeros(effective_length, dtype=y.dtype)
                start = np.random.randint(effective_length - len_y)
                new_y[start:start + len_y] = y
                y = new_y.astype(np.float32)
            elif len_y > effective_length:
                start = np.random.randint(len_y - effective_length)
                y = y[start:start + effective_length].astype(np.float32)
            else:
                y = y.astype(np.float32)

        melspec = librosa.feature.melspectrogram(y, sr=sr, **self.melspectrogram_parameters)
        melspec = librosa.power_to_db(melspec).astype(np.float32)

        if self.spectrogram_transforms:
            melspec = self.spectrogram_transforms(melspec)
        else:
            pass

        image = mono_to_color(melspec)
        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)

#         labels = np.zeros(len(BIRD_CODE), dtype="i")
        labels = np.zeros(2, dtype="f")
        labels[BIRD_CODE[ebird_code]] = 1

        return image, labels

class SpectrogramEventRmsDataset(data.Dataset):
    def __init__(self, file_list, img_size=224, 
                 waveform_transforms=None, spectrogram_transforms=None, melspectrogram_parameters={}):
        
        self.file_list = file_list  # list of list: [file_path, ebird_code]
        self.img_size = img_size
        self.waveform_transforms = waveform_transforms
        self.spectrogram_transforms = spectrogram_transforms
        self.melspectrogram_parameters = melspectrogram_parameters
        self.df_rms = pd.read_csv(PATH_FEAT)
        self.sr_feat = 4

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx: int):
        wav_path, ebird_code = self.file_list[idx]

        y, sr = sf.read(wav_path)
        
        if self.waveform_transforms:
            y = self.waveform_transforms(y)
        else:
            len_y = len(y)
            effective_length = sr * PERIOD
            if len_y < effective_length:
                new_y = np.zeros(effective_length, dtype=y.dtype)
                start = np.random.randint(effective_length - len_y)
                new_y[start:start + len_y] = y
                y = new_y.astype(np.float32)
            elif len_y > effective_length:
                basename = os.path.basename(wav_path)
                rms = self.df_rms.query('filename == @basename').librosa_rms.values
                x_feat_sec = np.arange(0, len(rms))/self.sr_feat + 1/self.sr_feat
                event_mask = rms > 1.3*np.median(rms)
                
                silent = ~any(event_mask)
                if silent:
                    # event を検出できなかったらランダムにクロップ
                    start = np.random.randint(len_y - effective_length)
                    y = y[start:start + effective_length].astype(np.float32)
                else:
                    choice = random.choice(x_feat_sec[event_mask])
                    ed_sec = x_feat_sec[-1]
                    st_range_sec = 2.5001
                    ed_range_sec = ed_sec - 2.5001
                    range_in = (st_range_sec <= choice) & (choice <= ed_range_sec)
                    if range_in:
                        idxs = np.arange(len_y)
                        x_sec = idxs/sr
                        mask = (choice - 2.5) < x_sec
                        start = idxs[mask][0]
                        y = y[start:start + effective_length].astype(np.float32)
                    else:
                        # ランダムにクロップ
                        start = np.random.randint(len_y - effective_length)
                        y = y[start:start + effective_length].astype(np.float32)
                    
                # ----
            else:
                y = y.astype(np.float32)

        melspec = librosa.feature.melspectrogram(y, sr=sr, **self.melspectrogram_parameters)
        melspec = librosa.power_to_db(melspec).astype(np.float32)

        if self.spectrogram_transforms:
            melspec = self.spectrogram_transforms(melspec)
        else:
            pass

        image = mono_to_color(melspec)
        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)

#         labels = np.zeros(len(BIRD_CODE), dtype="i")
        labels = np.zeros(len(BIRD_CODE), dtype="f")
        labels[BIRD_CODE[ebird_code]] = 1

        return image, labels

In [125]:
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])
def normalize(image, mean, std):
    image = (image / 255.0).astype(np.float32)
    image = (image - mean) / std
    return np.moveaxis(image, 2, 0).astype(np.float32)

class TestDataset(data.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray,
                 img_size=224, melspectrogram_parameters={}):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end = SR * 5
            images = []
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != (SR * 5):
                    break
                start = end
                end = end + SR * 5
                
                melspec = librosa.feature.melspectrogram(y_batch,
                                                         sr=SR,
                                                         **self.melspectrogram_parameters)
                melspec = librosa.power_to_db(melspec).astype(np.float32)
                image = mono_to_color(melspec)
#                 height, width, _ = image.shape
                
                height, width, _ = image.shape
                image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                image = normalize(image, mean=MEAN, std=STD)
                
                
#                 image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
#                 image = np.moveaxis(image, 2, 0)
#                 image = (image / 255.0).astype(np.float32)
                images.append(image)
            images = np.asarray(images)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = SR * start_seconds
            end_index = SR * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)

            melspec = librosa.feature.melspectrogram(y, sr=SR, **self.melspectrogram_parameters)
#             plot_spec(melspec)
            melspec = librosa.power_to_db(melspec).astype(np.float32)

            image = mono_to_color(melspec)
        
            height, width, _ = image.shape
            image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
            image = normalize(image, mean=MEAN, std=STD)
#             height, width, _ = image.shape
#             image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
#             image = np.moveaxis(image, 2, 0)
#             image = (image / 255.0).astype(np.float32)

            return image, row_id, site
        
def prediction_for_clip_(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model, 
                        mel_params: dict, 
                        threshold=0.5):

    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in loader:
        site = site[0]
        row_id = row_id[0]
        if site in {"site_1", "site_2"}:
            image = image.to(device)

            with torch.no_grad():
                prediction = model(image)
                proba = F.sigmoid(prediction).detach().cpu().numpy().reshape(-1)
#             plt.figure(figsize=(6,3), dpi= 80)
#             plt.plot(proba)

            events = proba >= threshold
            labels = np.argwhere(events).reshape(-1).tolist()
#             print(proba)
            

        else:
            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = 16
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
#                 if batch.ndim == 3:
#                     batch = batch.unsqueeze(0)

                batch = batch.to(device)
                with torch.no_grad():
                    prediction = model(batch)
                    proba = F.sigmoid(prediction).detach().cpu().numpy()              
                events = proba >= threshold
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)
                        
            labels = list(all_events)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
#             labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
#             label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = 'bird'
    return prediction_dict


def prediction_for_clip(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model, 
                        mel_params: dict, 
                        threshold=0.5):

    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in loader:
        site = site[0]
        row_id = row_id[0]
        if site in {"site_1", "site_2"}:
            image = image.to(device)

            with torch.no_grad():
                prediction = model(image)
                proba = F.sigmoid(prediction).detach().cpu().numpy().reshape(-1)
            print(prediction)

#             plt.figure(figsize=(6,3), dpi= 80)
#             plt.plot(proba)

            events = proba >= threshold
            labels = np.argwhere(events).reshape(-1).tolist()
#             print(labels)
            print(proba)
            

        else:
            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = 16
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
#                 if batch.ndim == 3:
#                     batch = batch.unsqueeze(0)

                batch = batch.to(device)
                with torch.no_grad():
                    prediction = model(batch)
                    proba = F.sigmoid(prediction).detach().cpu().numpy()              
                events = proba >= threshold
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)
                        
            labels = list(all_events)
        if len(labels) != 0 and 1 in labels:
            prediction_dict[row_id] = 'bird'
#         if len(labels) == 0 or 0 in labels:
#             prediction_dict[row_id] = "nocall"
        else:
            prediction_dict[row_id] = "nocall"

#             labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
#             label_string = " ".join(labels_str_list)
            
    return prediction_dict

def prediction_mod(test_df: pd.DataFrame,
               mel_params: dict,
               model,
               threshold=0.5,
                  mask_th=0):
    
    unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
    for i in progress_bar(range(len(test_df))):
        audio_id = test_df['audio_id'].values[i]
        extension = '' if '.wav' in audio_id else '.mp3'
        clip, _ = librosa.load(test_df['file_path'].values[i],
                               sr=TARGET_SR,
                               mono=True,
                               res_type="kaiser_fast")
        
        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        prediction_dict = prediction_for_clip(test_df_for_audio_id,
                                              clip=clip,
                                              model=model,
                                              mel_params=mel_params,
                                              threshold=threshold)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

In [88]:
def make_mask_images(melspec):

    len_mel = melspec.shape[1]
    hop_len = int(len_mel / int(MASK_NUM*2))
    mels = []
    
    start = 0
    end = start + int(len_mel / MASK_NUM)
    mels.append(melspec)

    while end < len_mel:
        mask_mel = np.zeros((melspec.shape[0], melspec.shape[1]), dtype=melspec.dtype)
        mask_mel[:,start:end] = melspec[:,start:end]
        start += hop_len
        end = start + int(len_mel / MASK_NUM)
        mels.append(mask_mel)
#         plot_spec(mask_mel)
    return mels

def signal_to_mel_with_mask(y, img_size):
    melspec = librosa.feature.melspectrogram(y,
                             sr=TARGET_SR,
                             fmin=20,
                             fmax=16000)
    mask_mels = make_mask_images(melspec)
    images = []
    for melspec in mask_mels:
        melspec = librosa.power_to_db(melspec).astype(np.float32)
        image = mono_to_color(melspec)
        height, width, _ = image.shape
        image = cv2.resize(image, (int(width * img_size / height), img_size))
        image = np.moveaxis(image, 2, 0)
        image = (image / 255.0).astype(np.float32)
        images.append(image)
    return images

class TestDatasetStrideMask(data.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray,
                 img_size=224, melspectrogram_parameters={}):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
#             start = 0
#             end = SR * 5
            images = []
            effect_len = SR * PERIOD

            # 短い
            if(len_y < effect_len):
                for i in range(3):
                    new_y = np.zeros(effect_len, dtype=y.dtype)
                    start = np.random.randint(effect_len - len_y)
                    new_y[0:len(y)] = y
                    y = new_y.astype(np.float32)
                    image = signal_to_mel_with_mask(y, self.img_size)
                    images.extend(image)
            # 長い
            elif(len_y > effect_len):
                start = 0
                end = SR * PERIOD
                while len_y > start:
                    y_batch = y[start:end].astype(np.float32)
                    if len(y_batch) != (SR * PERIOD):
                        break
                    start = end
                    end = end + SR * PERIOD
                    image = signal_to_mel_with_mask(y_batch, self.img_size)
                    images.extend(image)
                    
                start = len(y) - SR * PERIOD
                end = len(y) 
                while 0 < start:
                    y_batch = y[start:end].astype(np.float32)
                    if len(y_batch) != (SR * PERIOD):
                        break
                    end = start
                    start = end - SR * PERIOD
                    image = signal_to_mel_with_mask(y_batch, self.img_size)
                    images.extend(image)
            # 同じ
            else:
                y = y.astype(np.float32)
                image = signal_to_mel_with_mask(y, self.img_size)
                images.extend(image)
            images = np.asarray(images)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = SR * start_seconds
            end_index = SR * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)
            image = signal_to_mel_with_mask(y, self.img_size)
            images = np.asarray(image)
            return images, row_id, site

def prediction_for_clip_stride_mask(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model, 
                        mel_params: dict, 
                        threshold=0.5):

    dataset = TestDatasetStrideMask(df=test_df, 
                          clip=clip,
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in loader:
        site = site[0]
        row_id = row_id[0]
        
        image = image.squeeze(0)
        batch_size = 16
        whole_size = image.size(0)
        if whole_size % batch_size == 0:
            n_iter = whole_size // batch_size
        else:
            n_iter = whole_size // batch_size + 1

        all_events = set()
        for batch_i in range(n_iter):
            batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
            if batch.ndim == 3:
                batch = batch.unsqueeze(0)

            batch = batch.to(device)
            with torch.no_grad():
                prediction = model(batch)
                proba = F.sigmoid(prediction).detach().cpu().numpy()              
            events = proba >= threshold
            for i in range(len(events)):
                event = events[i, :]
                labels = np.argwhere(event).reshape(-1).tolist()
                for label in labels:
                    all_events.add(label)

        labels = list(all_events)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict


def prediction_stride_mask_mod(test_df: pd.DataFrame, 
                               mel_params: dict, 
                               model, 
                               threshold=0.5,
                              mask_th=0):
    
#     unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
#     for audio_id in unique_audio_id:
    for i in progress_bar(range(len(test_df))):
        audio_id = test_df['audio_id'].values[i]
        extension = '' if '.wav' in audio_id else '.mp3'
        clip, _ = librosa.load(test_df['file_path'].values[i],
                               sr=TARGET_SR,
                               mono=True,
                               res_type="kaiser_fast")
        
        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        prediction_dict = prediction_for_clip_stride_mask(test_df_for_audio_id,
                                              clip=clip,
                                              model=model,
                                              mel_params=mel_params,
                                              threshold=threshold)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

In [89]:
def prediction_for_clip_joint_stride_mask(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model, 
                        mel_params: dict, 
                        threshold=0.5,
                        mask_th=0.5):

    dataset = TestDatasetStrideMask(df=test_df, 
                          clip=clip,
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in loader:
        site = site[0]
        row_id = row_id[0]
        
        image = image.squeeze(0)
        batch_size = 16
        whole_size = image.size(0)
        if whole_size % batch_size == 0:
            n_iter = whole_size // batch_size
        else:
            n_iter = whole_size // batch_size + 1

        all_events = set()
        nocall_flag = False
        
        if site in {"site_1", "site_2"}:
            #一枚目(全体)チェック
            batch = image[:1]
            if batch.to('cpu').detach().numpy().copy().ndim == 3:
                batch = batch.unsqueeze(0)
            batch = batch.to(device)
            with torch.no_grad():
                prediction = model(batch)
                prediction = F.sigmoid(prediction)
                proba = prediction.detach().cpu().numpy()
            nocall_proba = proba[0]
            nocall_event = nocall_proba >= threshold
            nocall_label = np.argwhere(nocall_event).reshape(-1).tolist()
            if len(nocall_label) == 0:
                nocall_flag = True
    #         print(nocall_label)

        if nocall_flag == False:

            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
    #             print(batch.shape)
                if batch.to('cpu').detach().numpy().copy().ndim == 3:
                    batch = batch.unsqueeze(0)
                batch = batch.to(device)
                with torch.no_grad():
                    prediction = model(batch)
                    prediction = F.sigmoid(prediction)
                    proba = prediction.detach().cpu().numpy()
                    events = proba >= mask_th
                    for i in range(len(events)):
                        event = events[i, :]
                        labels = np.argwhere(event).reshape(-1).tolist()
                        for label in labels:
                            all_events.add(label)
        

        labels = list(all_events)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

        


def prediction_joint_stride_mask(test_df: pd.DataFrame, 
                               mel_params: dict, 
                               model, 
                               threshold=0.5,
                                mask_th=0.6):
    
#     unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
#     for audio_id in unique_audio_id:
    for i in progress_bar(range(len(test_df))):
        audio_id = test_df['audio_id'].values[i]
        extension = '' if '.wav' in audio_id else '.mp3'
        clip, _ = librosa.load(test_df['file_path'].values[i],
                               sr=TARGET_SR,
                               mono=True,
                               res_type="kaiser_fast")
        
        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        prediction_dict = prediction_for_clip_joint_stride_mask(test_df_for_audio_id,
                                              clip=clip,
                                              model=model,
                                              mel_params=mel_params,
                                              threshold=threshold,
                                            mask_th=mask_th)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

In [90]:
def prediction_for_clip_stride_mask_thre9(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model, 
                        mel_params: dict, 
                        threshold=0.5,
                        mask_th=0):

    dataset = TestDatasetStrideMask(df=test_df, 
                          clip=clip,
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in loader:
        site = site[0]
        row_id = row_id[0]
        
        image = image.squeeze(0)
        batch_size = 16
        whole_size = image.size(0)
        if whole_size % batch_size == 0:
            n_iter = whole_size // batch_size
        else:
            n_iter = whole_size // batch_size + 1

        all_events = set()
        all_events_loose = set()
        for batch_i in range(n_iter):
            batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
            if batch.to('cpu').detach().numpy().copy().ndim == 3:
                batch = batch.unsqueeze(0)

            batch = batch.to(device)
            with torch.no_grad():
                prediction = model(batch)
                prediction = F.sigmoid(prediction)

                proba = prediction.detach().cpu().numpy()

            events = proba >= threshold
#             events_loose = proba >= threshold_loose
            events_loose = proba >= mask_th
            for i in range(len(events)):
                event = events[i, :]
                labels = np.argwhere(event).reshape(-1).tolist()
                for label in labels:
                    all_events.add(label)
            
            for i in range(len(events_loose)):
                event_loose = events_loose[i, :]
                labels_loose = np.argwhere(event_loose).reshape(-1).tolist()
                for label in labels_loose:
                    all_events_loose.add(label)

        labels = list(all_events)
        labels_loose = list(all_events_loose)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels_loose))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

def prediction_stride_mask_thre9(test_df: pd.DataFrame, 
                               mel_params: dict, 
                               model, 
                               threshold=0.5,
                                mask_th=1):
    
#     unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
#     for audio_id in unique_audio_id:
    for i in progress_bar(range(len(test_df))):
        audio_id = test_df['audio_id'].values[i]
        extension = '' if '.wav' in audio_id else '.mp3'
        clip, _ = librosa.load(test_df['file_path'].values[i],
                               sr=TARGET_SR,
                               mono=True,
                               res_type="kaiser_fast")
        
        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        prediction_dict = prediction_for_clip_stride_mask_thre9(test_df_for_audio_id,
                                              clip=clip,
                                              model=model,
                                              mel_params=mel_params,
                                              threshold=threshold,
                                            mask_th=mask_th)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

In [91]:
def row_wise_f1_score_micro(y_true, y_pred):
    """ author @shonenkov """
    F1 = []
    for preds, trues in zip(y_pred, y_true):
        TP, FN, FP = 0, 0, 0
        preds = preds.split()
        trues = trues.split()
        for true in trues:
            if true in preds:
                TP += 1
            else:
                FN += 1
        for pred in preds:
            if pred not in trues:
                FP += 1
        F1.append(2*TP / (2*TP + FN + FP))
    return np.mean(F1)

In [92]:
def get_loaders_for_training(dataset_class, args_dataset, args_loader, train_file_list, valid_file_list):
    # # make dataset
    train_dataset = dataset_class(train_file_list, **args_dataset)
    val_dataset = dataset_class(valid_file_list, **args_dataset)
    # # make dataloader
    train_loader = data.DataLoader(train_dataset, **args_loader["train"])
    valid_loader = data.DataLoader(val_dataset, **args_loader["val"])
    
    return train_loader, valid_loader

# Preparation

settings

In [93]:
pd.set_option('display.max_columns', 100)
pd.options.display.max_rows = 500

In [94]:
settings = yaml.safe_load(settings_str)
if DEBUG:
    settings['split']['params']['n_splits'] = 2
    settings['globals']['num_epochs'] = 3

for k, v in settings.items():
    print("[{}]".format(k))
    print(v)
set_seed(settings["globals"]["seed"])

[globals]
{'seed': 1213, 'device': 'cuda', 'num_epochs': 45, 'output_dir': '/kaggle/training_output/', 'use_fold': 0, 'target_sr': 32000}
[dataset]
{'name': 'SpectrogramDataset', 'params': {'img_size': 224, 'melspectrogram_parameters': {'n_mels': 128, 'fmin': 20, 'fmax': 16000}}}
[split]
{'name': 'StratifiedKFold', 'params': {'n_splits': 5, 'random_state': 42, 'shuffle': True}}
[loader]
{'train': {'batch_size': 5, 'shuffle': True, 'num_workers': 10, 'pin_memory': True, 'drop_last': True}, 'val': {'batch_size': 5, 'shuffle': False, 'num_workers': 10, 'pin_memory': True, 'drop_last': False}}
[model]
{'name': 'resnest50_fast_1s1x64d', 'params': {'pretrained': True, 'n_classes': 264}}
[loss]
{'name': 'BCEWithLogitsLoss', 'params': {}}
[optimizer]
{'name': 'Adam', 'params': {'lr': 0.001}}
[scheduler]
{'name': 'CosineAnnealingLR', 'params': {'T_max': 10}}


<br>

data path

In [95]:
INPUT_ROOT = Path("/home/knikaido/work/Cornell-Birdcall-Identification/data")
RAW_DATA = INPUT_ROOT / "birdsong-recognition"
TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
TRAIN_RESAMPLED_AUDIO_DIRS = [
  INPUT_ROOT / "birdsong-resampled-train-audio-{:0>2}".format(i)  for i in range(5)
]
TEST_AUDIO_DIR = RAW_DATA / "test_audio"
input_ex_root = INPUT_ROOT

In [96]:
root = Path.cwd().parent
# input_ex_root = root / 'data_ignore/external_dataset' / DATASET
train_resampled_audio_dirs = [input_ex_root / "birdsong-resampled-train-audio-{:0>2}"
                              .format(i)  for i in range(5)]

In [97]:
train_resampled_audio_dirs

[PosixPath('/home/knikaido/work/Cornell-Birdcall-Identification/data/birdsong-resampled-train-audio-00'),
 PosixPath('/home/knikaido/work/Cornell-Birdcall-Identification/data/birdsong-resampled-train-audio-01'),
 PosixPath('/home/knikaido/work/Cornell-Birdcall-Identification/data/birdsong-resampled-train-audio-02'),
 PosixPath('/home/knikaido/work/Cornell-Birdcall-Identification/data/birdsong-resampled-train-audio-03'),
 PosixPath('/home/knikaido/work/Cornell-Birdcall-Identification/data/birdsong-resampled-train-audio-04')]

<br>

load data

In [99]:
train = pd.read_csv(train_resampled_audio_dirs[0] / "train_mod.csv")
with open('./codes.pkl', mode='rb') as f:
    scd_labels = pickle.load(f)

# Execution

<br>

train_all を作成  
train_all: リサンプル後のファイルパスなどを追加

In [100]:
tmp_list = []
for audio_d in train_resampled_audio_dirs:
    if not audio_d.exists():
        continue
    for ebird_d in audio_d.iterdir():
        if ebird_d.is_file():
            continue
        for i, wav_f in enumerate(ebird_d.iterdir()):
            bool_n_splits = i==settings['split']['params']['n_splits']
            if bool_n_splits and DEBUG: break  # if DEBUG=True: 1bird/n_splits file
            tmp_list.append([ebird_d.name, wav_f.name, wav_f.as_posix()])
            
train_wav_path_exist = pd.DataFrame(
    tmp_list, columns=["ebird_code", "resampled_filename", "file_path"])

del tmp_list

train_all = pd.merge(
    train, train_wav_path_exist, on=["ebird_code", "resampled_filename"], how="inner")

if DEBUG: print('----- debug mode -----')
print(train.shape)
print(train_wav_path_exist.shape)
print(train_all.shape)

(21375, 38)
(21375, 3)
(21375, 39)


In [101]:
train_all['secondary_labels_code'] = scd_labels

<br>

fold

In [102]:
skf = StratifiedKFold(**settings["split"]["params"])

train_all["fold"] = -1
for fold_id, (train_index, val_index) in enumerate(skf.split(train_all, train_all["ebird_code"])):
    train_all.iloc[val_index, -1] = fold_id
    
train_all['hasbird'] = 1

# skf = StratifiedKFold(n_splits=4, random_state=42,shuffle=True)

# train_all["fold"] = -1
# for fold_id, (train_index, val_index) in enumerate(skf.split(train_all, train_all["ebird_code"])):
#     train_all.iloc[val_index, -1] = fold_id
    
# use_fold = settings["globals"]["use_fold"]
# train_all_2class = train_all.query("fold == @use_fold")

# print("[fold {}] train: {}".format(use_fold, len(train_all_2class)))  
# train_all_2class.reset_index(inplace=True, drop=True)

# skf = StratifiedKFold(settings["split"]["params"]["n_splits"])

# train_all_2class["fold"] = -1
# for fold_id, (train_index, val_index) in enumerate(skf.split(train_all_2class, train_all_2class["ebird_code"])):
#     train_all_2class.iloc[val_index, -1] = fold_id
    
# # check the propotion
fold_proportion = pd.pivot_table(train_all, index="ebird_code", columns="fold", values="xc_id", aggfunc=len)
print(fold_proportion.shape)
fold_proportion.T

(264, 5)


ebird_code,aldfly,ameavo,amebit,amecro,amegfi,amekes,amepip,amered,amerob,amewig,amewoo,amtspa,annhum,astfly,baisan,baleag,balori,banswa,barswa,bawwar,belkin1,belspa2,bewwre,bkbcuc,bkbmag1,bkbwar,bkcchi,bkchum,bkhgro,bkpwar,bktspa,blkpho,blugrb1,blujay,bnhcow,boboli,bongul,brdowl,brebla,brespa,brncre,brnthr,brthum,brwhaw,btbwar,btnwar,btywar,buffle,buggna,buhvir,...,semplo,semsan,sheowl,shshaw,snobun,snogoo,solsan,sonspa,sora,sposan,spotow,stejay,swahaw,swaspa,swathr,treswa,truswa,tuftit,tunswa,veery,vesspa,vigswa,warvir,wesblu,wesgre,weskin,wesmea,wessan,westan,wewpew,whbnut,whcspa,whfibi,whtspa,whtswi,wilfly,wilsni1,wiltur,winwre3,wlswar,wooduc,wooscj2,woothr,y00475,yebfly,yebsap,yehbla,yelwar,yerwar,yetvir
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1
0,20,8,9,20,20,15,20,20,20,7,17,15,20,20,11,7,20,20,20,15,14,16,20,10,12,20,20,14,20,20,19,20,20,20,20,20,8,20,10,20,20,20,20,15,15,20,20,3,20,20,...,13,14,12,5,20,16,19,20,20,19,20,20,6,20,20,17,8,20,13,20,20,11,20,7,8,17,20,6,20,20,20,20,7,20,8,20,20,10,20,20,8,18,20,20,20,13,12,18,20,20
1,20,8,9,20,20,15,20,20,20,7,17,15,20,20,11,7,20,20,20,15,14,15,20,10,12,20,20,14,20,20,19,20,20,20,20,20,8,20,10,20,20,20,20,15,15,20,19,3,20,20,...,13,14,12,5,20,16,19,20,20,19,20,20,5,20,20,17,8,20,13,20,20,11,20,7,8,17,20,6,20,20,20,20,7,20,7,20,20,9,20,20,8,18,20,20,20,13,12,18,20,20
2,20,8,9,20,20,15,20,20,20,7,17,15,20,20,11,7,20,20,20,15,14,15,20,10,12,20,20,14,20,20,19,20,20,20,20,20,8,20,10,20,20,20,20,15,15,20,19,3,20,20,...,13,14,11,4,20,15,19,20,20,19,20,20,5,20,20,17,8,20,13,20,20,10,20,7,8,17,20,6,20,20,20,20,7,20,7,20,20,9,20,20,8,18,20,20,20,13,11,18,20,20
3,20,7,9,20,20,15,20,20,20,6,16,15,20,20,11,7,20,20,20,15,14,15,20,9,12,20,20,13,20,20,19,20,20,20,20,20,8,20,10,20,20,20,20,15,15,20,19,3,20,20,...,13,14,11,4,20,15,19,20,20,18,20,20,5,20,20,17,7,20,13,20,20,10,20,7,7,16,20,6,20,20,20,20,6,20,7,20,20,9,20,20,8,18,20,20,20,12,11,18,20,19
4,20,7,8,20,20,15,20,20,20,6,16,14,20,20,11,7,20,20,20,14,13,15,20,9,12,20,20,13,20,20,18,20,20,20,20,20,8,20,9,20,20,20,19,14,15,20,19,3,20,20,...,12,14,11,4,20,15,18,20,20,18,20,20,5,20,20,17,7,20,12,20,20,10,20,7,7,16,20,6,20,20,20,20,6,20,7,20,20,9,20,20,8,17,20,20,20,12,11,17,20,19


In [103]:
use_fold = settings["globals"]["use_fold"]
idx_train = train_all['fold']!=use_fold
idx_valid = train_all['fold']==use_fold
df_valid = train_all[idx_valid]

# use_fold = settings["globals"]["use_fold"]
# idx_train = train_all_2class['fold']!=use_fold
# idx_valid = train_all_2class['fold']==use_fold
# df_valid = train_all_2class[idx_valid]
train_file_list = train_all.query("fold != @use_fold")[["file_path", "hasbird"]].values.tolist()
valid_file_list = train_all.query("fold == @use_fold")[["file_path", "hasbird"]].values.tolist()

print("[fold {}] train: {}, val: {}".format(use_fold, len(train_file_list), len(val_file_list)))

print("[fold {}] train: {}, val: {}".format(use_fold, len(train_file_list), len(valid_file_list)))

[fold 0] train: 17057, val: 4318
[fold 0] train: 17057, val: 4318


In [104]:
df_valid

Unnamed: 0,rating,playback_used,ebird_code,channels,date,pitch,duration,filename,speed,species,number_of_notes,title,secondary_labels,bird_seen,sci_name,location,latitude,sampling_rate,type,elevation,description,bitrate_of_mp3,file_type,volume,background,xc_id,url,country,author,primary_label,longitude,length,time,recordist,license,resampled_sampling_rate,resampled_filename,resampled_channels,file_path,secondary_labels_code,fold,hasbird
0,3.5,no,aldfly,1 (mono),2013-05-25,Not specified,25,XC134874.mp3,Not specified,Alder Flycatcher,Not specified,XC134874 Alder Flycatcher (Empidonax alnorum),"['Empidonax minimus_Least Flycatcher', 'Leioth...",yes,Empidonax alnorum,"Grey Cloud Dunes SNA, Washington, Minnesota",44.793,48000 (Hz),call,220 m,Two short segments of fairly quiet *pip* calls...,192000 (bps),mp3,Not specified,American Yellow Warbler (Setophaga aestiva); L...,134874,https://www.xeno-canto.org/134874,United States,Jonathon Jongsma,Empidonax alnorum_Alder Flycatcher,-92.962,Not specified,8:00,Jonathon Jongsma,Creative Commons Attribution-ShareAlike 3.0,32000,XC134874.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,"[136, 264, 116, 98]",0,1
4,4.0,no,aldfly,2 (stereo),2013-05-27,both,36,XC135457.mp3,level,Alder Flycatcher,1-3,XC135457 Alder Flycatcher (Empidonax alnorum),[],yes,Empidonax alnorum,"Carver's Gap Parking area, Roan Mountain Stat...",36.1065,44100 (Hz),"call, song",1700 m,img:http://www.flickr.com/photos/madbirder/888...,128000 (bps),mp3,level,,135457,https://www.xeno-canto.org/135457,United States,Mike Nelson,Empidonax alnorum_Alder Flycatcher,-82.1106,0-3(s),08:30,Mike Nelson,Creative Commons Attribution-NonCommercial-Sha...,32000,XC135457.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,[-1],0,1
10,4.5,no,aldfly,1 (mono),2013-06-23,Not specified,269,XC139577.mp3,Not specified,Alder Flycatcher,Not specified,XC139577 Alder Flycatcher (Empidonax alnorum),"['Geothlypis trichas_Common Yellowthroat', 'Co...",yes,Empidonax alnorum,"Big Bay State Park, Ashland, Wisconsin",46.8107,48000 (Hz),song,190 m,Singing from alder thicket at the edge of lago...,192000 (bps),mp3,Not specified,Common Yellowthroat (Geothlypis trichas); Amer...,139577,https://www.xeno-canto.org/139577,United States,Jonathon Jongsma,Empidonax alnorum_Alder Flycatcher,-90.6776,Not specified,06:57,Jonathon Jongsma,Creative Commons Attribution-ShareAlike 3.0,32000,XC139577.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,"[80, 3, 221, 74]",0,1
12,5.0,no,aldfly,2 (stereo),2008-07-02,Not specified,19,XC142065.mp3,Not specified,Alder Flycatcher,Not specified,XC142065 Alder Flycatcher (Empidonax alnorum),[],yes,Empidonax alnorum,"Searsport, ME",44.4584,44100 (Hz),call,10 m,,128000 (bps),mp3,Not specified,,142065,https://www.xeno-canto.org/142065,United States,Paul Driver,Empidonax alnorum_Alder Flycatcher,-68.9242,Not specified,06:00,Paul Driver,Creative Commons Attribution-NonCommercial-Sha...,32000,XC142065.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,[-1],0,1
18,4.0,no,aldfly,2 (stereo),2013-07-16,Not specified,46,XC144674.mp3,Not specified,Alder Flycatcher,Not specified,XC144674 Alder Flycatcher (Empidonax alnorum),[],yes,Empidonax alnorum,"Seawall, ME",44.228,44100 (Hz),"""weeoo"" call",0 m,,128000 (bps),mp3,Not specified,,144674,https://www.xeno-canto.org/144674,United States,Paul Driver,Empidonax alnorum_Alder Flycatcher,-68.312,Not specified,06:20,Paul Driver,Creative Commons Attribution-NonCommercial-Sha...,32000,XC144674.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,[-1],0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21352,5.0,no,yetvir,1 (mono),2015-07-13,Not specified,38,XC385850.mp3,Not specified,Yellow-throated Vireo,Not specified,XC385850 Yellow-throated Vireo (Vireo flavifrons),[],no,Vireo flavifrons,"Stony Creek Metropark, Macomb County, Michigan",42.7335,44100 (Hz),song,240 m,"Editing: High-pass filter, cutoff frequency 1k...",192000 (bps),mp3,Not specified,,385850,https://www.xeno-canto.org/385850,United States,Antonio Xeira,Vireo flavifrons_Yellow-throated Vireo,-83.0753,Not specified,11:30,Antonio Xeira,Creative Commons Attribution-NonCommercial-Sha...,32000,XC385850.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,[-1],0,1
21358,4.5,no,yetvir,1 (mono),2018-05-06,both,52,XC417989.mp3,both,Yellow-throated Vireo,1-3,XC417989 Yellow-throated Vireo (Vireo flavifrons),"['Setophaga cerulea_Cerulean Warbler', 'Baeolo...",no,Vireo flavifrons,Lawrence Woods SNP,40.5652,48000 (Hz),song,320 m,Natural vocalizations from the boardwalk.,320000 (bps),mp3,both,Cerulean Warbler (Setophaga cerulea); Tufted T...,417989,https://www.xeno-canto.org/417989,United States,Sue Riffe,Vireo flavifrons_Yellow-throated Vireo,-83.6255,0-3(s),12:50,Sue Riffe,Creative Commons Attribution-NonCommercial-Sha...,32000,XC417989.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,"[264, 231]",0,1
21360,4.5,no,yetvir,1 (mono),2018-05-06,both,7,XC417991.mp3,both,Yellow-throated Vireo,1-3,XC417991 Yellow-throated Vireo (Vireo flavifrons),['Vireo olivaceus_Red-eyed Vireo'],no,Vireo flavifrons,Lawrence Woods SNP,40.5652,48000 (Hz),song,320 m,Natural vocalizations from the boardwalk.,320000 (bps),mp3,both,Red-eyed Vireo (Vireo olivaceus),417991,https://www.xeno-canto.org/417991,United States,Sue Riffe,Vireo flavifrons_Yellow-throated Vireo,-83.6255,0-3(s),12:50,Sue Riffe,Creative Commons Attribution-NonCommercial-Sha...,32000,XC417991.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,[192],0,1
21367,3.5,no,yetvir,2 (stereo),2018-06-10,Not specified,100,XC467630.mp3,Not specified,Yellow-throated Vireo,Not specified,XC467630 Yellow-throated Vireo (Vireo flavifrons),"['Contopus virens_Eastern Wood-Pewee', 'Turdus...",yes,Vireo flavifrons,"George Montgomery Sanctuary, Brome-Missisquoi...",45.024,44100 (Hz),song,100 m,Habitat: mature deciduous forest.,128000 (bps),mp3,Not specified,Eastern Wood Pewee (Contopus virens); American...,467630,https://www.xeno-canto.org/467630,Canada,Martin St-Michel,Vireo flavifrons_Yellow-throated Vireo,-73.075,Not specified,08:21,Martin St-Michel,Creative Commons Attribution-NonCommercial-Sha...,32000,XC467630.wav,1 (mono),/home/knikaido/work/Cornell-Birdcall-Identific...,"[94, 8, 136, 45, 7]",0,1


In [61]:
import cloudpickle
# with open('./training_output0912_2/train_file_list.pkl', 'rb') as f:
#     valid_file_list = cloudpickle.load(f)

In [105]:
len(train_file_list)+len(valid_file_list)

21375

In [106]:
device = torch.device('cuda')

<br>

**<font color='orange'> -------------------- settings ------------------ </font>**

**<font color='orange'> ------------------------------------------------ </font>**

In [107]:
# # # get loader
train_loader, valid_loader = get_loaders_for_training(
    SpectrogramDataset,
    settings["dataset"]["params"], settings["loader"], train_file_list, valid_file_list)

In [108]:
def get_epoch_loss_score(model, device, valid_loder, loss_func):
    model.eval()
    epoch_valid_loss = 0
    y_pred_list = []
    y_true_list = []
    for batch_idx, (data, target) in enumerate(progress_bar(valid_loader)):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = loss_func(output, target)
        epoch_valid_loss += loss.item()*data.size(0)
        
#         out_numpy = output.detach().cpu().numpy()
        _y_pred = output.detach().cpu().numpy().argmax(axis=1)
        y_pred_list.append(_y_pred)
        _y_true = target.detach().cpu().numpy().argmax(axis=1)
        y_true_list.append(_y_true)
    
    loss = epoch_valid_loss / len(valid_loader.dataset)
    y_pred = np.concatenate(y_pred_list, axis=0)
    y_true = np.concatenate(y_true_list, axis=0)
    f_score = f1_score(y_true, y_pred, average='macro')
    del data
    return loss, f_score

In [109]:
# # # get optimizer
# optimizer = getattr(
#         torch.optim, settings["optimizer"]["name"]
#         )(model.parameters(), **settings["optimizer"]["params"])

# # # # get scheduler
# scheduler = getattr(
#     torch.optim.lr_scheduler, settings["scheduler"]["name"]
#     )(optimizer, **settings["scheduler"]["params"])

# # # get loss
loss_func = getattr(nn, settings["loss"]["name"])(**settings["loss"]["params"])

# Evaluate

In [110]:
melspectrogram_parameters = {
    "n_mels": 128,
    "fmin": 20,
    "fmax": 16000
}

logger = get_logger()
set_seed(1213)

2020-09-14 10:38:03,999 - INFO - logger set up


<br>

**nocall dataset**

In [111]:
file_path = sorted(glob(f'{DIR_NOCALL}/*'))
row_id_list = [os.path.basename(file_path[i]) for i in range(len(file_path))]

df_valid_nocall = pd.DataFrame()
df_valid_nocall['row_id'] = row_id_list
df_valid_nocall['site'] = 'site_1'
df_valid_nocall['seconds'] = 5
df_valid_nocall['audio_id'] = row_id_list
df_valid_nocall['file_path'] = file_path
df_valid_nocall['all_bird'] = 'nocall'

<br>

**site3 dataset(長い、複数羽)**

In [142]:
df_valid_site3 = pd.DataFrame()
for i in range(len(df_valid)):
    _df = df_valid['secondary_labels_code']
    lbl = _df.values[i]
    if (lbl != [-1]) & (lbl != [264]):
        if not 264 in lbl:
            df_valid_site3 = df_valid_site3.append(df_valid.iloc[[i], :])
df_valid_site3 = df_valid_site3.reset_index(drop=True)
df_valid_site3['audio_id'] = df_valid_site3['filename']
df_valid_site3['row_id'] = df_valid_site3['filename']
df_valid_site3['site'] = 'site_3'

In [143]:
all_bird_list = []
for i in range(len(df_valid_site3)):
    primary_label = df_valid_site3['ebird_code'][i]
    scd_list = df_valid_site3['secondary_labels_code'][i]
    all_bird = primary_label
    for scd in scd_list:
        bird_code = INV_BIRD_CODE[scd]
        all_bird = all_bird + ' ' + bird_code
    all_bird_list.append(all_bird)
df_valid_site3['all_bird'] =  'bird'
# df_valid_site3['all_bird'] = all_bird_list

<br>

**1bird dataset(短い、1羽)**

In [144]:
df_valid_1bird = pd.DataFrame()
for i in range(len(df_valid)):
    _df = df_valid['secondary_labels_code']
    lbl = _df.values[i]
    duration = df_valid['duration'].values[i]
    if (lbl != [-1]) & (lbl != [264]):
        pass
    elif duration < 7:
        df_valid_1bird = df_valid_1bird.append(df_valid.iloc[[i], :])

df_valid_1bird['audio_id'] = df_valid_1bird['filename']
df_valid_1bird['row_id'] = df_valid_1bird['filename']
df_valid_1bird['site'] = 'site_1'
df_valid_1bird['seconds'] = 5
df_valid_1bird['all_bird'] = 'bird'
# df_valid_1bird['all_bird'] = df_valid_1bird['ebird_code']

<br>

**some birds dataset(短い、複数羽)**

In [145]:
df_valid_some_bird = df_valid_site3[df_valid_site3['duration'] < 10]
df_valid_some_bird['site'] = 'site_1'
df_valid_some_bird['seconds'] = 5

# Model01

In [146]:
model_path = './../data_ignore/hydra_outputs/2020-09-06/19-56-52/try0/ResNeSt_fold0.pth'
model_config = {
    "base_model_name": "resnest50_fast_1s1x64d",
    "pretrained": False,
    "n_classes": 264,
    "trained_weights": model_path,
}
# model = get_model_v2(model_config)
import cloudpickle
with open('./training_output0912_2/model_ResNeSt_norm_best.pkl', 'rb') as f:
    model = cloudpickle.load(f)

model = model.to(device)
print(f'model name: {model.__class__.__name__}')

model name: ResNet


---

In [147]:
df_valid_nocall.head()

Unnamed: 0,row_id,site,seconds,audio_id,file_path,all_bird
0,BLKFR_122_128.wav,site_1,5,BLKFR_122_128.wav,/home/knikaido/work/Cornell-Birdcall-Identific...,nocall
1,BLKFR_149_157.wav,site_1,5,BLKFR_149_157.wav,/home/knikaido/work/Cornell-Birdcall-Identific...,nocall
2,BLKFR_156_163.wav,site_1,5,BLKFR_156_163.wav,/home/knikaido/work/Cornell-Birdcall-Identific...,nocall
3,BLKFR_162_168.wav,site_1,5,BLKFR_162_168.wav,/home/knikaido/work/Cornell-Birdcall-Identific...,nocall
4,BLKFR_183.2_188.2.wav,site_1,5,BLKFR_183.2_188.2.wav,/home/knikaido/work/Cornell-Birdcall-Identific...,nocall


In [148]:
display(df_valid_site3[['all_bird', 'duration', 'site']].head(3))
display(df_valid_1bird[['all_bird', 'duration', 'site']].head(3))
display(df_valid_some_bird[['all_bird', 'duration', 'site']].head(3))

Unnamed: 0,all_bird,duration,site
0,bird,269,site_3
1,bird,72,site_3
2,bird,86,site_3


Unnamed: 0,all_bird,duration,site
77,bird,6,site_1
136,bird,1,site_1
162,bird,6,site_1


Unnamed: 0,all_bird,duration,site
11,bird,7,site_1
12,bird,5,site_1
86,bird,9,site_1


In [149]:
len(df_valid_nocall), len(df_valid_1bird[::3]), len(df_valid_some_bird), len(df_valid_site3[::20]), 

(23, 122, 66, 48)

In [150]:
# df_valid_some_bird = df_valid_some_bird[df_valid_some_bird['row_id'] == 'XC417855.mp3']

In [151]:
# df_valid_nocall = df_valid_nocall[df_valid_nocall['row_id'] == 'BLKFR_122_128.wav']

In [160]:
def evaluate_model(model, prediction_method, threshold, mask_num, mask_th):
    df_dict = {
        'nocall': df_valid_nocall[:int(len(df_valid_nocall)/2)],
        '1bird': df_valid_1bird[::3],
        'some_bird': df_valid_some_bird,
#         'site3': df_valid_site3[::20],
    }
    
    
    global MASK_NUM
    MASK_NUM = mask_num 
#     if mask_num == None:
#         prediction_method = prediction_mod
#     else:
#         global MASK_NUM
#         prediction_method = prediction_joint_stride_mask
# #         prediction_method = prediction_stride_mask_mod
    
    df_score = pd.DataFrame()
    sub_dict = {}
    for key in df_dict.keys():
        df = df_dict[key][::1]
        sub = prediction_method(test_df=df, model=model,
                            mel_params=melspectrogram_parameters, threshold=threshold, mask_th=mask_th)
        sub_dict[key] = sub
        score = row_wise_f1_score_micro(df['all_bird'], sub['birds'])
        df_score[key] = [score]
    return df_score, sub_dict

In [161]:

# df_result_mask, sub_dict_mask = evaluate_model(model, prediction_joint_stride_mask, threshold=0.6, mask_num=3)
# df_result_mask3_thre9, sub_dict_mask3_thre9 = evaluate_model(model, prediction_stride_mask_mod, threshold=0.9, mask_num=3)
# df_result_mask3_thre9_thre6, sub_dict_mask3_thre9_thre6 = evaluate_model(model, prediction_stride_mask_thre9, threshold=0.6, mask_num=3)
# df_result_mask3_thre9_thre8, sub_dict_mask3_thre9_thre8 = evaluate_model(model, prediction_stride_mask_thre9, threshold=0.8, mask_num=3)
# df_result_mask3_thre85, sub_dict_mask3_thre85 = evaluate_model(model, prediction_stride_mask_mod, threshold=0.85, mask_num=3)
# df_result_mask, sub_dict_mask = evaluate_model(model, prediction_joint_stride_mask, threshold=0.6, mask_num=3)

In [162]:
df_result_original, sub_dict_original = evaluate_model(model, prediction_mod, threshold=0.9, mask_num=None, mask_th=0.6)
# sub_dict_original
# sub_dict_mask['some_bird']['all_bird'] = df_valid_some_bird['all_bird'].values
# sub_dict_mask['some_bird']['original_birds'] = sub_dict_original['some_bird']['birds'].values
# sub_dict_mask['some_bird']

tensor([[-0.1857,  0.1934]], device='cuda:0')
[0.45370814 0.5481965 ]
tensor([[-0.3212,  0.3223]], device='cuda:0')
[0.4203778  0.57989454]
tensor([[-1.9365,  1.9418]], device='cuda:0')
[0.12602738 0.8745511 ]
tensor([[-1.8314,  1.8409]], device='cuda:0')
[0.13807659 0.86305064]
tensor([[-0.9518,  0.9611]], device='cuda:0')
[0.27852282 0.72333497]
tensor([[-0.5248,  0.5254]], device='cuda:0')
[0.3717221  0.62840146]
tensor([[-1.9829,  1.9870]], device='cuda:0')
[0.12100946 0.8794201 ]
tensor([[-0.5539,  0.5631]], device='cuda:0')
[0.3649624  0.63715994]
tensor([[ 2.4820, -2.4590]], device='cuda:0')
[0.92286676 0.07878242]
tensor([[ 0.3749, -0.3514]], device='cuda:0')
[0.5926374  0.41304433]
tensor([[ 0.6871, -0.6459]], device='cuda:0')
[0.66533166 0.34390378]


tensor([[-16.4091,  16.1433]], device='cuda:0')
[7.475298e-08 9.999999e-01]
tensor([[-22.4622,  21.9928]], device='cuda:0')
[1.7570957e-10 1.0000000e+00]
tensor([[-23.0373,  22.5267]], device='cuda:0')
[9.885822e-11 1.000000e+00]
tensor([[-60.9792,  59.4817]], device='cuda:0')
[3.2891763e-27 1.0000000e+00]
tensor([[-43.7108,  42.6421]], device='cuda:0')
[1.0390601e-19 1.0000000e+00]
tensor([[-51.2971,  50.0084]], device='cuda:0')
[5.2719025e-23 1.0000000e+00]
tensor([[-42.0876,  41.0468]], device='cuda:0')
[5.2674913e-19 1.0000000e+00]
tensor([[-61.0819,  59.5857]], device='cuda:0')
[2.968034e-27 1.000000e+00]
tensor([[-42.3088,  41.2328]], device='cuda:0')
[4.2219697e-19 1.0000000e+00]
tensor([[-18.5139,  18.0575]], device='cuda:0')
[9.109687e-09 1.000000e+00]
tensor([[-23.2697,  22.7850]], device='cuda:0')
[7.8357376e-11 1.0000000e+00]
tensor([[-16.5364,  16.1786]], device='cuda:0')
[6.5819336e-08 9.9999988e-01]
tensor([[-20.3447,  19.8485]], device='cuda:0')
[1.4601166e-09 1.0000000

tensor([[-33.4170,  32.5705]], device='cuda:0')
[3.0703858e-15 1.0000000e+00]
tensor([[-17.7089,  17.3755]], device='cuda:0')
[2.0375804e-08 1.0000000e+00]
tensor([[-25.7765,  25.2225]], device='cuda:0')
[6.388731e-12 1.000000e+00]
tensor([[-39.3980,  38.3932]], device='cuda:0')
[7.7563055e-18 1.0000000e+00]
tensor([[-50.2226,  48.9989]], device='cuda:0')
[1.5438284e-22 1.0000000e+00]
tensor([[-48.9287,  47.7301]], device='cuda:0')
[5.630334e-22 1.000000e+00]
tensor([[-58.5557,  57.1886]], device='cuda:0')
[3.711605e-26 1.000000e+00]
tensor([[-36.8584,  35.9492]], device='cuda:0')
[9.830886e-17 1.000000e+00]
tensor([[-61.9960,  60.4685]], device='cuda:0')
[1.1898711e-27 1.0000000e+00]
tensor([[-64.6538,  63.0795]], device='cuda:0')
[8.340673e-29 1.000000e+00]
tensor([[-44.2121,  43.0587]], device='cuda:0')
[6.2939294e-20 1.0000000e+00]
tensor([[-55.7482,  54.3601]], device='cuda:0')
[6.149757e-25 1.000000e+00]
tensor([[-53.1258,  51.7358]], device='cuda:0')
[8.467356e-24 1.000000e+00]


In [None]:
df_result_mask, sub_dict_mask = evaluate_model(model, prediction_stride_mask_mod, threshold=0.85, mask_num=3, mask_th=1)

In [185]:
df_result_origin_mask, sub_dict_origin_mask = evaluate_model(model, prediction_joint_stride_mask, threshold=0.6, mask_num=2, mask_th=0.5)

In [575]:
df_result_mask_mask, sub_dict_mask_mask = evaluate_model(model, prediction_stride_mask_thre9, threshold=0.9, mask_num=3, mask_th=0.85)

In [156]:
_df = pd.concat([df_result_original,
#                  df_result_mask,
#                  df_result_origin_mask, 
#                  df_result_mask_mask, 
#                  df_result_mask3_thre9_thre8,
#                  df_result_mask3_thre85,
#                  df_result_thre6_mask2,
#                  df_result_thre8_mask2,
#                  df_result_thre5_mask2,
#                  df_result_thre8_mask3,
                ])
_df.index = ['original'
#              'mask',
#              'origin_mask',
#              'mask_mask'
            ]
_df

Unnamed: 0,nocall,1bird,some_bird
original,1.0,0.967213,0.984848


In [577]:
sub_dict_original.keys()

dict_keys(['nocall', '1bird', 'some_bird'])

In [139]:
sub_dict_original['nocall']['all_bird'] = df_valid_nocall['all_bird'][:int(len(df_valid_nocall)/2)].values
sub_dict_original['nocall']['original_birds'] = sub_dict_original['nocall']['birds'].values
# sub_dict_mask['some_bird']['mask_birds'] = sub_dict_mask['some_bird']['birds'].values
# sub_dict_mask['some_bird']['origin_mask_birds'] = sub_dict_origin_mask['some_bird']['birds'].values
# sub_dict_mask['some_bird']['mask_mask_birds'] = sub_dict_mask_mask['some_bird']['birds'].values
sub_dict_original['nocall']

Unnamed: 0,row_id,birds,all_bird,original_birds
0,BLKFR_122_128.wav,nocall,nocall,nocall
1,BLKFR_149_157.wav,nocall,nocall,nocall
2,BLKFR_156_163.wav,nocall,nocall,nocall
3,BLKFR_162_168.wav,nocall,nocall,nocall
4,BLKFR_183.2_188.2.wav,nocall,nocall,nocall
5,BLKFR_189.5_193.5.wav,nocall,nocall,nocall
6,BLKFR_197_203.wav,nocall,nocall,nocall
7,BLKFR_88.7_93.2.wav,nocall,nocall,nocall
8,ORANGE_10.1_13.5.wav,nocall,nocall,nocall
9,ORANGE_120.5_124.0.wav,nocall,nocall,nocall


In [140]:
sub_dict_original['some_bird']['all_bird'] = df_valid_some_bird['all_bird'].values
sub_dict_original['some_bird']['original_birds'] = sub_dict_original['some_bird']['birds'].values
# sub_dict_mask['some_bird']['mask_birds'] = sub_dict_mask['some_bird']['birds'].values
# sub_dict_mask['some_bird']['origin_mask_birds'] = sub_dict_origin_mask['some_bird']['birds'].values
# sub_dict_mask['some_bird']['mask_mask_birds'] = sub_dict_mask_mask['some_bird']['birds'].values
sub_dict_original['some_bird']

Unnamed: 0,row_id,birds,all_bird,original_birds
0,XC417855.mp3,bird,amebit marwre rewbla comyel amebit sonspa,bird
1,XC172881.mp3,bird,amecro easmea,bird
2,XC326542.mp3,bird,bkbcuc spotow sonspa indbun,bird
3,XC442868.mp3,bird,bkbmag1 bkbmag1,bird
4,XC183334.mp3,bird,bkbwar blujay,bird
5,XC112499.mp3,bird,blugrb1 amecro moudov,bird
6,XC188430.mp3,bird,bnhcow amecro,bird
7,XC31202.mp3,bird,boboli fiespa,bird
8,XC121045.mp3,bird,brncre bkcchi,bird
9,XC277930.mp3,bird,btywar bkhgro,bird


In [141]:
sub_dict_original['1bird']['all_bird'] = df_valid_1bird['all_bird'][::3].values
sub_dict_original['1bird']['original_birds'] = sub_dict_original['1bird']['birds'].values
# sub_dict_original['1bird']['mask_birds'] = sub_dict_mask['1bird']['birds'].values
# sub_dict_original['1bird']['origin_mask_birds'] = sub_dict_origin_mask['1bird']['birds'].values
# sub_dict_original['1bird']['mask_mask_birds'] = sub_dict_mask_mask['1bird']['birds'].values
sub_dict_original['1bird']

Unnamed: 0,row_id,birds,all_bird,original_birds
0,XC319017.mp3,bird,aldfly,bird
1,XC114558.mp3,bird,amecro,bird
2,XC413789.mp3,bird,amecro,bird
3,XC215135.mp3,bird,amepip,bird
4,XC390471.mp3,bird,amepip,bird
5,XC406876.mp3,bird,amepip,bird
6,XC305022.mp3,bird,amewig,bird
7,XC380017.mp3,bird,banswa,bird
8,XC120285.mp3,bird,belspa2,bird
9,XC240988.mp3,bird,bkbcuc,bird
