In [None]:
import os
import json
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
import torchaudio
from torchaudio import transforms
from torchvision.models.resnet import ResNet, BasicBlock

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

In [None]:
out_path = "./train/"
try:
    os.mkdir(out_path)
except FileExistsError:
    pass

with open("../input/birdclef-2022/scored_birds.json") as f:
    scored_birds = json.load(f)

test_df = pd.read_csv("../input/birdclef-2022/test.csv")
train_metadata_df = pd.read_csv("../input/birdclef-2022/train_metadata.csv")

In [None]:
filter_df = train_metadata_df[train_metadata_df["primary_label"].isin(scored_birds) | train_metadata_df["secondary_labels"].apply(lambda x: set(eval(x)).intersection(set(scored_birds))).apply(bool)]
bird_label = train_metadata_df["primary_label"].unique()
print(bird_label)

train_rest_df = train_metadata_df[~train_metadata_df.index.isin(filter_df.index.to_list())]
train_metadata_df = filter_df.append(train_rest_df.sample(len(train_rest_df) // 5))

In [None]:
sample_rate = 32000
n_fft = 1024
hop_length = 512
n_mels = 256
min_sec_proc = sample_rate * 5
f_min = 250

mel_spectrogram = transforms.MelSpectrogram(
    sample_rate = sample_rate,
    n_fft = n_fft,
    hop_length = hop_length,
    center = True,
    f_min = f_min,
    pad_mode = "reflect",
    power = 2.0,
    norm = "slaney",
    onesided = True,
    n_mels = n_mels,
    mel_scale = "htk"
)

In [None]:
# !pip download noisereduce==2.0.0 -d /kaggle/working/noisereduce/

In [None]:
!pip install ../input/noisereduce-2-0-0/noisereduce-2.0.0-py3-none-any.whl

In [None]:
import noisereduce as nr
def normalize_std(spec):
    return (spec - torch.mean(spec)) / torch.std(spec)

def audio_to_mel_label_train(filepath, min_sec_proc, reduce_noise = False, data_index = 0, label_list = [], bird_label = [], label_file = []):
    label_file_all = np.zeros(bird_label.shape)
    for label_file_temp in label_file:
        label_file_all += (label_file_temp == bird_label)
    label_file_all = np.clip(label_file_all, 0, 1)

    waveform, _ = torchaudio.load(filepath = filepath)
    if reduce_noise:
        waveform = torch.tensor(nr.reduce_noise(y = waveform, sr = sample_rate, win_length = mel_spectrogram.win_length, use_tqdm = False, n_fft = mel_spectrogram.n_fft, n_jobs = -1))
    len_wave = waveform.shape[1]
    waveform = waveform[0, :].reshape(1, len_wave)

    if len_wave < min_sec_proc: # Fill with recurrent sound samples until sample length 
        for i in range(int(min_sec_proc / len_wave)):
            waveform = torch.cat((waveform, waveform[:, 0:len_wave]), dim = 1)
        len_wave = min_sec_proc
        waveform = waveform[:, 0:len_wave]

    for i in range(int(len_wave / min_sec_proc)):
        log_melspec = torch.log10(mel_spectrogram(waveform[0, i * min_sec_proc:(i + 1) * min_sec_proc]).unsqueeze(0) + 1e-10)
        log_melspec = normalize_std(log_melspec)

        torch.save(log_melspec, out_path + str(data_index) + ".pt")
        label_list.append(label_file_all)
        data_index += 1

    return data_index
    
def audio_to_mel_label_test(filepath, min_sec_proc, reduce_noise = False, mel_list = []):
    waveform, _ = torchaudio.load(filepath = filepath)
    if reduce_noise:
        waveform = torch.tensor(nr.reduce_noise(y = waveform, sr = sample_rate, win_length = mel_spectrogram.win_length, use_tqdm = False, n_fft = mel_spectrogram.n_fft, n_jobs = -1))
    len_wave = waveform.shape[1]
    waveform = waveform[0, :].reshape(1, len_wave)
    
    if len_wave >= min_sec_proc * 12: # Curtail sound samples that are too long (> 12 * min_sec_proc (1 min))
        waveform = torch.cat((waveform, waveform[:, 0:len_wave]), dim = 1)
        len_wave = min_sec_proc * 12
        waveform = waveform[:, 0:len_wave]

    for i in range(int(len_wave / min_sec_proc)):
        log_melspec = torch.log10(mel_spectrogram(waveform[0, i * min_sec_proc:(i + 1) * min_sec_proc]).unsqueeze(0) + 1e-10)
        log_melspec = normalize_std(log_melspec)
        mel_list.append(log_melspec)

        return mel_list

def load_tensor(path, filename):
    return torch.load(path + str(filename) + ".pt")

def get_X_y(path, idx, label_list):
    batch_X = torch.stack([load_tensor(path, i.item()) for i in idx])
    batch_y = torch.stack([label_list[i.item()] for i in idx])
    return batch_X, batch_y

def plot_history(history):
    plt.figure(figsize = (10, 10))
    plt.plot(history[:, 0], history[:, 1], label = "loss")
    plt.plot(history[:, 0], history[:, 2], label = "val_loss")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.legend()


In [None]:
# Save spectrogram
data_index = 0
label_list = []
for primary_label, secondary_label, filename in zip(train_metadata_df["primary_label"], train_metadata_df["secondary_labels"], train_metadata_df["filename"]):
    data_index = audio_to_mel_label_train("../input/birdclef-2022/train_audio/" + filename, min_sec_proc, False, data_index, label_list, bird_label, [primary_label] + eval(secondary_label))

torch.save(np.stack(label_list), out_path + "label_list.pt")
label_list = torch.from_numpy(np.stack(label_list)).clone()
# label_list = torch.from_numpy(torch.load(out_path + "label_list.pt"))

In [None]:
# Build model
n_output = len(bird_label)
out_sigmoid = nn.Sigmoid()

"""class BirdResNet(ResNet):
    def __init__(self):
        super().__init__(BasicBlock, [3, 4, 6, 3], num_classes = n_output)
        self.conv1 = nn.Conv2d(1, 64, kernel_size = 7, stride = 1, padding = 3, bias = False)"""

import torchvision.models as models
net = models.resnet18()
net.conv1 = torch.nn.Conv2d(1, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
net.fc = torch.nn.Linear(512, n_output)
net = net.to(device)

In [None]:
# Data split
data_len = label_list.shape[0]
train_idx, val_idx = torch.utils.data.random_split(np.arange(0, data_len), [int(0.8 * data_len), data_len - int(0.8 * data_len)])

In [None]:
# Train loop
num_epochs = 20
lr = 0.0005
batch_size = 48

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr = lr)
history = np.zeros((0, 3))

train_loader = DataLoader(train_idx, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(val_idx, batch_size = batch_size, shuffle = True)

for epoch in range(num_epochs):
    train_loss, val_loss = 0, 0
    train_acc, val_acc = 0, 0
    n_train, n_val = 0, 0

    net.train()
    for idx in train_loader:
        inputs, labels = get_X_y(out_path, idx, label_list)
        n_train += len(labels)
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)

        out_labels = out_sigmoid(outputs) > 0.1
        train_acc += len(labels) - (torch.eq(out_labels, 1 - labels).sum(axis = 1) > 0).sum().item()

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    net.eval()
    with torch.no_grad():
        for idx in val_loader:
            inputs_val, labels_val = get_X_y(out_path, idx, label_list)
            n_val += len(labels_val)
            inputs_val = inputs_val.to(device)
            labels_val = labels_val.to(device)

            outputs_val = net(inputs_val)

            out_labels_val = out_sigmoid(outputs_val) > 0.1
            val_acc += len(labels_val) - (torch.eq(out_labels_val, 1 - labels_val).sum(axis = 1) > 0).sum().item()

            loss_val = criterion(outputs_val, labels_val)
            val_loss += loss_val.item()

    train_loss = train_loss * batch_size / n_train
    val_loss = val_loss * batch_size / n_val
    train_acc /= n_train
    val_acc /= n_val
    print(f"Epoch [{epoch + 1} / {num_epochs}], Train loss: {train_loss:.5f}, Train accuracy: {train_acc:.5f}, Val loss: {val_loss:.5f}, Val accuracy: {val_acc:.5f}")
    item = np.array([epoch + 1, train_loss, val_loss])
    history = np.vstack((history, item))
    
    if (epoch + 1) % 5 == 0:
        lr = lr * 0.7

torch.save(net.state_dict(), "model.pt")
plot_history(history)

In [None]:
test_audio_dir = "../input/birdclef-2022/test_soundscapes/"
test_list = [f.split(".")[0] for f in sorted(os.listdir(test_audio_dir))]
print(f"Number of test soundscapes: {len(test_list)}")

In [None]:
# Evaluate and submission
pred = {"row_id": [], "target": []}
binary_th = 0.1
net.eval()

test_df = pd.read_csv("../input/birdclef-2022/test.csv")

for testfile in test_list:
    path = "../input/birdclef-2022/test_soundscapes/" + testfile + ".ogg"
    
    chunks = [[] for i in range(12)]

    mel_list_test = []
    mel_list_test = audio_to_mel_label_test(path, min_sec_proc, mel_list = mel_list_test)
    # n_chunks = len(mel_list_test)
    mel_list_test = torch.stack(mel_list_test).to(device)

    outputs = net(mel_list_test)
    outputs_test = out_sigmoid(outputs)

    for i in range(len(chunks)):
        chunk_end_time = (i + 1) * 5
        for bird in scored_birds:
            try:
                score = outputs_test[i][np.where(bird_label == bird)]
            except IndexError:
                score = 0
            
            row_id = testfile + "_" + bird + "_" + str(chunk_end_time)

            pred["row_id"].append(row_id)
            pred["target"].append(bool(score > binary_th))
            
results = pd.DataFrame(pred, columns = ["row_id", "target"])
print(results["target"])
results.to_csv("submission.csv", index = False)