In [None]:
import time
start = time.time()

import os
import math
import glob
import random
import numpy as np 
import pandas as pd 
from collections import Counter, deque
import matplotlib.pyplot as plt

import warnings
warnings.simplefilter(action='ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %2ds' % (m, s)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

def audio_padding(waveform, sample_rate, clip_seconds = 5):
    audio_length = waveform.size(1)
    clip_length = sample_rate * clip_seconds
    
    max_offset = clip_length - audio_length
    offset = np.random.randint(max_offset)
    clip = F.pad(waveform, (offset, clip_length - audio_length - offset), "constant")
    return clip

def clip_cut(waveform, sample_rate = 44100, clip_seconds = 5):
    ## cut audio into 5 seconds clips 
    audio_length = waveform.size(1)
    clip_length = sample_rate * clip_seconds
    num_clips = audio_length//clip_length
    
    if num_clips == 0:
        clips = [audio_padding(waveform, sample_rate, clip_seconds)]
    else:
        onset = 0
        clips = []
        for i in range(1, num_clips + 1):
            offset = i * clip_length
            clips.append(waveform[:,onset:offset])
            onset = offset
        
        #final_clip = audio_padding(waveform[:,onset:], sample_rate, clip_seconds)
        #clips.append(final_clip)
            
    return torch.stack(clips)

sample_rate = 44100

In [None]:
BASE_TEST_DIR = '../input/birdsong-recognition' if os.path.exists('../input/birdsong-recognition/test_audio') else '../input/birdcall-check'
TEST_FOLDER = f'{BASE_TEST_DIR}/test_audio'
print(TEST_FOLDER)

test_df = pd.read_csv(f'{BASE_TEST_DIR}/test.csv')
file_df = test_df[['site','audio_id']].drop_duplicates()

In [None]:
## class dict
classes = ['aldfly', 'ameavo', 'amebit', 'amecro', 'amegfi', 'amekes',
       'amepip', 'amered', 'amerob', 'amewig', 'amewoo', 'amtspa',
       'annhum', 'astfly', 'baisan', 'baleag', 'balori', 'banswa',
       'barswa', 'bawwar', 'belkin1', 'belspa2', 'bewwre', 'bkbcuc',
       'bkbmag1', 'bkbwar', 'bkcchi', 'bkchum', 'bkhgro', 'bkpwar',
       'bktspa', 'blkpho', 'blugrb1', 'blujay', 'bnhcow', 'boboli',
       'bongul', 'brdowl', 'brebla', 'brespa', 'brncre', 'brnthr',
       'brthum', 'brwhaw', 'btbwar', 'btnwar', 'btywar', 'buffle',
       'buggna', 'buhvir', 'bulori', 'bushti', 'buwtea', 'buwwar',
       'cacwre', 'calgul', 'calqua', 'camwar', 'cangoo', 'canwar',
       'canwre', 'carwre', 'casfin', 'caster1', 'casvir', 'cedwax',
       'chispa', 'chiswi', 'chswar', 'chukar', 'clanut', 'cliswa',
       'comgol', 'comgra', 'comloo', 'commer', 'comnig', 'comrav',
       'comred', 'comter', 'comyel', 'coohaw', 'coshum', 'cowscj1',
       'daejun', 'doccor', 'dowwoo', 'dusfly', 'eargre', 'easblu',
       'easkin', 'easmea', 'easpho', 'eastow', 'eawpew', 'eucdov',
       'eursta', 'evegro', 'fiespa', 'fiscro', 'foxspa', 'gadwal',
       'gcrfin', 'gnttow', 'gnwtea', 'gockin', 'gocspa', 'goleag',
       'grbher3', 'grcfly', 'greegr', 'greroa', 'greyel', 'grhowl',
       'grnher', 'grtgra', 'grycat', 'gryfly', 'haiwoo', 'hamfly',
       'hergul', 'herthr', 'hoomer', 'hoowar', 'horgre', 'horlar',
       'houfin', 'houspa', 'houwre', 'indbun', 'juntit1', 'killde',
       'labwoo', 'larspa', 'lazbun', 'leabit', 'leafly', 'leasan',
       'lecthr', 'lesgol', 'lesnig', 'lesyel', 'lewwoo', 'linspa',
       'lobcur', 'lobdow', 'logshr', 'lotduc', 'louwat', 'macwar',
       'magwar', 'mallar3', 'marwre', 'merlin', 'moublu', 'mouchi',
       'moudov', 'norcar', 'norfli', 'norhar2', 'normoc', 'norpar',
       'norpin', 'norsho', 'norwat', 'nrwswa', 'nutwoo', 'olsfly',
       'orcwar', 'osprey', 'ovenbi1', 'palwar', 'pasfly', 'pecsan',
       'perfal', 'phaino', 'pibgre', 'pilwoo', 'pingro', 'pinjay',
       'pinsis', 'pinwar', 'plsvir', 'prawar', 'purfin', 'pygnut',
       'rebmer', 'rebnut', 'rebsap', 'rebwoo', 'redcro', 'redhea',
       'reevir1', 'renpha', 'reshaw', 'rethaw', 'rewbla', 'ribgul',
       'rinduc', 'robgro', 'rocpig', 'rocwre', 'rthhum', 'ruckin',
       'rudduc', 'rufgro', 'rufhum', 'rusbla', 'sagspa1', 'sagthr',
       'savspa', 'saypho', 'scatan', 'scoori', 'semplo', 'semsan',
       'sheowl', 'shshaw', 'snobun', 'snogoo', 'solsan', 'sonspa', 'sora',
       'sposan', 'spotow', 'stejay', 'swahaw', 'swaspa', 'swathr',
       'treswa', 'truswa', 'tuftit', 'tunswa', 'veery', 'vesspa',
       'vigswa', 'warvir', 'wesblu', 'wesgre', 'weskin', 'wesmea',
       'wessan', 'westan', 'wewpew', 'whbnut', 'whcspa', 'whfibi',
       'whtspa', 'whtswi', 'wilfly', 'wilsni1', 'wiltur', 'winwre3',
       'wlswar', 'wooduc', 'wooscj2', 'woothr', 'y00475', 'yebfly',
       'yebsap', 'yehbla', 'yelwar', 'yerwar', 'yetvir']

label2class = { l:c for l, c in enumerate(classes)}
class2label = { c:l for l, c in enumerate(classes)}

len(classes)

In [None]:
def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)

def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)

class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 
                               kernel_size=(3, 3),
                               stride=(1, 1),
                               padding=(1, 1),
                               bias=False)

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=(3, 3),
                               stride=(1, 1),
                               padding=(1, 1),
                               bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()

    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')

        return x


class AttBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        self.att = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.cla = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):
        # x: (batch, features, time)
        
        # attention-based pooling weights
        pooling_weights = torch.softmax(torch.clamp(self.att(x), -10, 10), dim = -1) 
        
        # frame level prediction
        framewise = torch.sigmoid(self.cla(x)) 
        
        # clip level prediction
        clipwise = torch.sum(pooling_weights * framewise, dim = -1) # (batch, feature)
        
        return framewise, clipwise

In [None]:
class TALNet(nn.Module):
    def __init__(self, classes_num = 264):
        super().__init__()

        self.preprocess = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(sample_rate = 44100, win_length = 1024, hop_length = 320, 
                                                 n_fft=2048, f_min=50, f_max=14000, n_mels=64),
            torchaudio.transforms.AmplitudeToDB(top_db = 80)
        )
        
        self.bn0 = nn.BatchNorm2d(64)
        
        self.conv_block1 = ConvBlock(in_channels=1, out_channels=32)
        self.conv_block2 = ConvBlock(in_channels=32, out_channels=64)
        self.conv_block3 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block4 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block5 = ConvBlock(in_channels=256, out_channels=512)
        
        self.biGRU = nn.GRU(1024, 512, num_layers = 2, batch_first = True, dropout = 0.2, bidirectional = True)
        
        self.att_block = AttBlock(1024, classes_num)
        
    def cnn_feature_extractor(self, x):
        x = self.conv_block1(x, pool_size=(2, 2), pool_type='max')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='max')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='max')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 1), pool_type='max')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 1), pool_type='max')
        x = F.dropout(x, p=0.2, training=self.training)
        return x


    def forward(self, input):
        """
        Input: (batch_size, data_length)"""
        x= self.preprocess(input) # (batch, 1, freq, time)
        #frames_num = x.size(-1)
        
        x = x.transpose(1, 2)
        x = self.bn0(x)
        x = x.transpose(1, 2)
        
        x = self.cnn_feature_extractor(x)
        # Flatten in channel and frequency axis
        x = torch.flatten(x, start_dim = 1, end_dim = 2) # (batch, feature, time)
        
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2) # (batch, time, feature)
        x, hn = self.biGRU(x)
        x = x.transpose(1, 2) # (batch, feature, time)
        x = F.dropout(x, p=0.5, training=self.training)
        
        (framewise, clipwise) = self.att_block(x)

        return framewise, clipwise
    
model = TALNet().to(device)
model.load_state_dict(torch.load('/kaggle/input/birdcall-identification-talnet-training3/birdcall_TALNet_model.pth'))
model.eval()

In [None]:
def make_prediction(clips):
    birds = []
    for idx, clip in enumerate(clips):
        input_tensor = clip.unsqueeze(0)
        _,output = model(input_tensor.to(device))
        
        pred = torch.where(output>=0.6, output, torch.zeros_like(output)).cpu()
        pred_label = torch.nonzero(pred.squeeze()).numpy().ravel()
        
        pred_class = [label2class[l] for l in pred_label] if len(pred_label)> 0 else ['nocall']
        
        #pred_class = label2class[pred.argmax().item()] if len(pred_label)> 0 else 'nocall'
        
        birds.append(pred_class)
        
    return birds

In [None]:
row_id, birds = [], []
    
for i, row in file_df.iterrows():

    audio_path = f'{TEST_FOLDER}/{row.audio_id}.mp3'
    waveform, orig_freq = torchaudio.load(audio_path)
    
    if orig_freq != sample_rate:
        waveform = torchaudio.transforms.Resample(orig_freq, sample_rate)(waveform)

    waveform = waveform.mean(0).unsqueeze(0) if waveform.size(0) == 2 else waveform
    clips = clip_cut(waveform)
    preds = make_prediction(clips)

    if row.site in ['site_1','site_2']:
        for s, pred in enumerate(preds):
            birds.append(' '.join(pred))
            row_id.append(f'{row.site}_{row.audio_id}_{str((s+1)*5)}')
    else:
        flat_preds = [p for pred in preds for p in pred if p != 'nocall']
        #bird = 'nocall' if len(flat_preds) == 0 else Counter(flat_preds).most_common(1)[0][0]
        bird = 'nocall' if len(flat_preds) == 0 else ' '.join(set(flat_preds))
        birds.append(bird)
        row_id.append(f'{row.site}_{row.audio_id}')
        
sub_df = pd.DataFrame(data={'row_id': row_id, 'birds': birds})

In [None]:
sub_df = test_df.merge(sub_df, on='row_id', how = 'left')
sub_df[sub_df.birds.isna()]

In [None]:
sub_df[['row_id','birds']]

In [None]:
sub_df[['row_id','birds']].to_csv("submission.csv", index=False)