In [1]:
%load_ext autoreload
%autoreload 2
# from IPython.core.interactiveshell import InteractiveShell
# InteractiveShell.ast_node_interactivity='all'

In [65]:
import numpy as np
import pandas as pd
from pathlib import Path
import warnings

# Librosa Libraries
import librosa
import librosa.display
import IPython.display as ipd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import sys
sys.path.append('../easy_gold')

import utils
import model_utils
import datasets

In [16]:
TEST_DIR = Path('../data/dummy_test_for_check/')

In [11]:
classes = utils.load_pickle('../data/classes.pkl')

In [103]:
# classes

In [8]:

model_path = Path('../results/debug-20200811094825/model.pth')

In [12]:
model = model_utils.load_pytorch_model(model_name='base_resnet50', path=model_path, n_class=len(classes))

In [17]:
test_df = pd.read_csv(TEST_DIR / 'test.csv')

In [18]:
test_df

Unnamed: 0,site,row_id,seconds,audio_id
0,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_5,5.0,41e6fe6504a34bf6846938ba78d13df1
1,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_10,10.0,41e6fe6504a34bf6846938ba78d13df1
2,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_15,15.0,41e6fe6504a34bf6846938ba78d13df1
3,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_20,20.0,41e6fe6504a34bf6846938ba78d13df1
4,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_25,25.0,41e6fe6504a34bf6846938ba78d13df1
...,...,...,...,...
71,site_3,site_3_9cc5d9646f344f1bbb52640a988fe902,,9cc5d9646f344f1bbb52640a988fe902
72,site_3,site_3_a56e20a518684688a9952add8a9d5213,,a56e20a518684688a9952add8a9d5213
73,site_3,site_3_96779836288745728306903d54e264dd,,96779836288745728306903d54e264dd
74,site_3,site_3_f77783ba4c6641bc918b034a18c23e53,,f77783ba4c6641bc918b034a18c23e53


In [19]:
test_audio_dir = TEST_DIR / 'test_audio'

In [20]:
list(test_audio_dir.glob('*'))

[PosixPath('../data/dummy_test_for_check/test_audio/07ab324c602e4afab65ddbcc746c31b5.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/41e6fe6504a34bf6846938ba78d13df1.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/6ab74e177aa149468a39ca10beed6222.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/856b194b097441958697c2bcd1f63982.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/8680a8dd845d40f296246dbed0d37394.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/899616723a32409c996f6f3441646c2a.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/940d546e5eb745c9a74bce3f35efa1f9.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/96779836288745728306903d54e264dd.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/99af324c881246949408c0b1ae54271f.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/9cc5d9646f344f1bbb52640a988fe902.mp3'),
 PosixPath('../data/dummy_test_for_check/test_audio/a56e20a518684688a9952add8a9d

In [108]:
class TestDataset(Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray,
                 sample_rate, spec_min, spec_max):
        self.df = df
        self.clip = clip
        self.sample_rate = sample_rate
        self.do_norm = spec_min and spec_max
        self.spec_min = spec_min
        self.spec_max = spec_max
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
#         SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end = self.sample_rate * 5
            images = []
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != (self.sample_rate * 5):
                    break
                start = end
                end = end + self.sample_rate * 5
                
                y = datasets.audio_to_spec(y, self.sample_rate)
                if self.do_norm:
                    y = self.normalize(y)
                image = np.repeat(y[None, :, :], 3, 0)
                images.append(image)
            images = np.asarray(images)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = self.sample_rate * start_seconds
            end_index = self.sample_rate * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)
            y = datasets.audio_to_spec(y, self.sample_rate)
            if self.do_norm:
                y = self.normalize(y)
            image = np.repeat(y[None, :, :], 3, 0)


            return image, row_id, site
                            
    def normalize(self, x):
        return ((x - x.min()) / (x.max() - x.min() + 1e-8) - 0.11754986) / 0.16654329

In [113]:
def prediction_for_clip(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model, 
                        sample_rate,
                        threshold=0.5):

    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          sample_rate=sample_rate,
                          spec_min=-100,
                          spec_max=80
                         )
    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in tqdm(loader):
#         print(row_id, site)
#         print(image.shape)
        site = site[0]
        row_id = row_id[0]
        if site in {"site_1", "site_2"}:
            image = image.to(device)

            with torch.no_grad():
                prediction = model(image)
#                 proba = prediction["multilabel_proba"].detach().cpu().numpy().reshape(-1)
                proba = prediction.detach().cpu().numpy().reshape(-1)

            events = proba >= threshold
            labels = np.argwhere(events).reshape(-1).tolist()

        else:
            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = 16
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
                if batch.ndim == 3:
                    batch = batch.unsqueeze(0)

                batch = batch.to(device)
                with torch.no_grad():
                    prediction = model(batch)
#                     proba = prediction["multilabel_proba"].detach().cpu().numpy()
                    proba = prediction.detach().cpu().numpy()
                    
                events = proba >= threshold
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)
                        
            labels = list(all_events)
#         print(labels)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
#             labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            labels_str_list = list(map(lambda x: classes[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

In [114]:
def prediction(test_df: pd.DataFrame,
               test_audio: Path,
               model,
               sample_rate=32000,
               threshold=0.5):
    unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
    for audio_id in unique_audio_id:
        clip, _ = librosa.load(test_audio / (audio_id + ".mp3"),
                               sr=sample_rate,
                               mono=True,
                               res_type="kaiser_fast")
        
        test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        prediction_dict = prediction_for_clip(test_df_for_audio_id,
                                              clip=clip,
                                              model=model,
                                              sample_rate=sample_rate,
                                              threshold=threshold)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

In [115]:
submission = prediction(test_df=test_df,
                        test_audio=test_audio_dir,
                        model=model,
                        sample_rate=32000,
                        threshold=0.8)

100%|██████████| 5/5 [00:00<00:00, 63.78it/s]
100%|██████████| 7/7 [00:00<00:00, 65.03it/s]
100%|██████████| 7/7 [00:00<00:00, 65.00it/s]
100%|██████████| 6/6 [00:00<00:00, 66.22it/s]
100%|██████████| 7/7 [00:00<00:00, 65.56it/s]
100%|██████████| 1/1 [00:00<00:00, 58.82it/s]
100%|██████████| 9/9 [00:00<00:00, 63.47it/s]
100%|██████████| 14/14 [00:00<00:00, 65.49it/s]
100%|██████████| 5/5 [00:00<00:00, 46.91it/s]
100%|██████████| 10/10 [00:00<00:00, 64.78it/s]
100%|██████████| 1/1 [00:00<00:00,  2.43it/s]
100%|██████████| 1/1 [00:00<00:00, 17.36it/s]
100%|██████████| 1/1 [00:00<00:00, 34.08it/s]
100%|██████████| 1/1 [00:00<00:00, 46.77it/s]
100%|██████████| 1/1 [00:00<00:00, 25.10it/s]


In [116]:
submission

Unnamed: 0,row_id,birds
0,site_1_41e6fe6504a34bf6846938ba78d13df1_5,nocall
1,site_1_41e6fe6504a34bf6846938ba78d13df1_10,nocall
2,site_1_41e6fe6504a34bf6846938ba78d13df1_15,nocall
3,site_1_41e6fe6504a34bf6846938ba78d13df1_20,nocall
4,site_1_41e6fe6504a34bf6846938ba78d13df1_25,bktspa wesmea
...,...,...
71,site_3_9cc5d9646f344f1bbb52640a988fe902,nocall
72,site_3_a56e20a518684688a9952add8a9d5213,nocall
73,site_3_96779836288745728306903d54e264dd,nocall
74,site_3_f77783ba4c6641bc918b034a18c23e53,nocall
