In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q kaggle
!mkdir -p .kaggle
!cp "./drive/My Drive/Study/config/kaggle.json" .kaggle/
!chmod 600 .kaggle/kaggle.json
!mv .kaggle /root

In [None]:
# 6分くらい
%%time
!mkdir -p birdclef-2021

!kaggle competitions download -c birdclef-2021 -f train_metadata.csv
!kaggle datasets download takamichitoda/birdclef-split-audio-frequency-500400

!unzip birdclef-split-audio-frequency-500400.zip -d birdclef-2021 > /dev/null
!unzip train_metadata.csv.zip -d birdclef-2021 > /dev/null

!rm birdclef-split-audio-frequency-500400.zip train_metadata.csv.zip

In [None]:
!pip install timm torchaudio evaluations wandb

In [None]:
with open("./drive/My Drive/Study/config/wandb.txt", "r") as f:
    for line in f:
        wandb_key = line.replace("\n", "")

!wandb login {wandb_key}

In [None]:
import os
import librosa
import psutil
import torch.nn as nn
import random

import numpy as np
import pandas as pd
import soundfile as sf

import matplotlib.pyplot as plt

import albumentations as A
from torchvision import transforms

from sklearn.model_selection import StratifiedKFold

import torch
from torch.nn import functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

import wandb
import timm
from tqdm.notebook import tqdm as tqdm_notebook

import torchaudio
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import f1_score
from sklearn.metrics import average_precision_score
from evaluations.kaggle_2020 import row_wise_micro_averaged_f1_score

device = torch.device("cuda")

In [None]:
class config:
    EXP_NUM = "0000"
    EXP_NAME = "baseline"
    # data ssetting
    INPUT_ROOT = "/content/birdclef-2021"
    WORK_ROOT = "/content"
    OUTPUT_ROOT = "/content/drive/MyDrive/Study/BirdCLEF/output"
    LABEL_FREQ = "500-400"
    # audio setting
    SAMPLE_RATE = 32000
    FMIN = 20
    FMAX = 16000
    N_FFT = 2048
    SPEC_HEIGHT = 128
    PERIOD = 20
    HOP_LENGTH = 512
    # ML setting
    SEED = 416
    BATCH_SIZE = 64
    MODEL_NAME = "resnet18"
    LEAENING_RATE = 1e-3
    T_MAX = 5
    NUM_EPOCHS = 30
    N_ACCUMULATE = 1
    DATA_N_LIMIT = 100
    # infer setting
    THRESHOLD = 0.5

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def arrange_wave_length(waveform):
    effective_length = config.PERIOD * config.SAMPLE_RATE
    input_length = waveform.shape[1]
    if input_length > effective_length:
        head_idx = np.random.randint(input_length - effective_length)
        _waveform = waveform[:, head_idx:head_idx+effective_length]
    elif input_length < effective_length:
        pad = torch.zeros((1, effective_length - input_length))
        _waveform = torch.hstack([waveform, pad])
    else:
        _waveform = waveform
    return _waveform


class BirdCLEFTrainDataset(torch.utils.data.Dataset):
    def __init__(self, fnames, labels, mode):
        self.fnames = fnames
        self.labels = labels
        self.mode = mode

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        label = self.labels[idx]

        audio_path = f"{config.INPUT_ROOT}/{label}/{fname}"

        #waveform, sample_rate = torchaudio.load(audio_path, normalization=True)
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = arrange_wave_length(waveform)
        
        label_ohe = torch.eye(n_labels)[label_dic[label]]
        
        return waveform, label_ohe

In [None]:
def birdclef_criterion(outputs, targets, device):
    clipwise_output = outputs["clipwise_output"]
    loss = nn.BCEWithLogitsLoss(reduction="mean")(clipwise_output, targets)
    return loss

In [None]:
MODEL_HEADER_INFO = {
    "resnet18": (-2, 512)
}

def interpolate_and_padding(x, frames_num):  # x: (batch, class_num, time)
    ratio = frames_num // x.shape[2]
    x = x.transpose(1, 2)  # (batch, time, class_num)
    
    # interpolate
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)

    # padding
    output = F.interpolate(
        upsampled.unsqueeze(1),
        size=(frames_num, upsampled.size(2)),
        align_corners=True,
        mode="bilinear").squeeze(1)
    
    output = output.transpose(1, 2) # (batch, class_num, time)
    
    return output

class BirdCLEFNet(nn.Module):
    def __init__(self, model_name):
        super(BirdCLEFNet, self).__init__()
        self.model_name = model_name
        self.n_label = (n_labels)
        
        self.mel_spectrogram_extractor = MelSpectrogram(
            sample_rate=config.SAMPLE_RATE,
            n_fft=config.N_FFT,
            f_min=config.FMIN, 
            f_max=config.FMAX,
            n_mels=config.SPEC_HEIGHT,
            hop_length=config.HOP_LENGTH,
        )
        self.amplitude_to_db = AmplitudeToDB()

        base_model = timm.create_model(model_name, pretrained=True, in_chans=1)
        h_idx, n_dense = MODEL_HEADER_INFO[model_name]        
        self.model_head = nn.Sequential(*list(base_model.children())[:h_idx])
                
        self.fc_a = nn.Conv1d(n_dense, self.n_label, 1)
        self.fc_b = nn.Conv1d(n_dense, self.n_label, 1)

    def forward(self, x):  # input x: (batch, channel, Hz, time)
        h = x  # (batch, channel, time)
        h = self.mel_spectrogram_extractor(h)  # (batch, channel, Hz, time)
        h = self.amplitude_to_db(h)

        frames_num = h.shape[3]
        h = self.model_head(h)  # (batch, unit, Hz, time)        
        h = F.relu(h)
        time_pool = torch.mean(h, dim=2)  # (batch, unit, time)

        xa = self.fc_a(time_pool)  # (batch, n_class, time)
        xb = self.fc_b(time_pool)  # (batch, n_class, time)
        xb = torch.softmax(xb, dim=2)

        # time pool
        clipwise_output = torch.sum(xa * xb, dim=2)
        segmentwise_output = interpolate_and_padding(xa, frames_num)

        return {
            "clipwise_output": clipwise_output,
            "segmentwise_output": segmentwise_output,
        }

In [None]:
train_metadata_df = pd.read_csv(f"{config.INPUT_ROOT}/train_metadata.csv")

exist_labels = os.listdir(f"{config.INPUT_ROOT}")
print("original data:", len(train_metadata_df))
train_metadata_df = train_metadata_df.query(f"primary_label in {exist_labels}").reset_index(drop=True)
print("use data:", len(train_metadata_df))

filenames = train_metadata_df["filename"]
primary_labels = train_metadata_df["primary_label"]
label_dic = {v:i for i, v in enumerate(primary_labels.unique())}
label_dic_inv = {i:v for i, v in enumerate(primary_labels.unique())}
n_labels = len(label_dic)

print("### labels ###")
print(label_dic)
print(label_dic_inv)

In [None]:
def train_loop(train_data_loader, model, optimizer, scheduler):
    losses, lrs = [], []
    model.train()
    optimizer.zero_grad()
    for n_iter, (X, y) in tqdm_notebook(enumerate(train_data_loader), total=len(train_data_loader)):
        X, y = X.to(device), y.to(device)
        outputs = model(X)
        loss = birdclef_criterion(outputs, y)
        loss.backward()
        
        if n_iter % config.N_ACCUMULATE == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        if scheduler is not None:
            scheduler.step()

        lrs.append(np.array([param_group["lr"] for param_group in optimizer.param_groups]).mean())
        losses.append(loss.item())
        
    return losses, lrs

In [None]:
def valid_loop(valid_data_loader, model):
    losses = []
    predicts = []
    model.eval()
    for n_iter, (X, y) in tqdm_notebook(enumerate(valid_data_loader), total=len(valid_data_loader)):
        X, y = X.to(device), y.to(device)
        with torch.no_grad():
            outputs = model(X)
        loss = birdclef_criterion(outputs, y)
        losses.append(loss.item())
        clipwise_output = outputs["clipwise_output"]
        predicts.append(clipwise_output)
    valid_predicts = torch.cat(predicts, dim=0)
    return losses, valid_predicts

In [None]:
def output_to_label(clipwise_output, thr):
    lst = []
    for pred in clipwise_output:
        pred_labs = [label_dic_inv[i] for i, v in enumerate(pred) if v > thr]
        if len(pred_labs) == 0:
            pred_labs = "nocall"
        else:
            pred_labs = " ".join(pred_labs)
        lst.append(pred_labs)
    return lst

In [None]:
skf = StratifiedKFold(n_splits=5,  shuffle=True, random_state=config.SEED)
for fold, (train_index, valid_index) in enumerate(skf.split(filenames, primary_labels)):
    set_seed(config.SEED)
    print(f"### Fold-{fold} ###")

    # データセットの準備
    train_primary_labels = primary_labels.loc[train_index].values
    valid_primary_labels = primary_labels.loc[valid_index].values
    train_filenames = filenames.loc[train_index].values 
    valid_filenames = filenames.loc[valid_index].values
    train_dset = BirdCLEFTrainDataset(train_filenames, train_primary_labels, "train")
    train_data_loader = torch.utils.data.DataLoader(train_dset, batch_size=config.BATCH_SIZE, shuffle=True)
    valid_dset = BirdCLEFTrainDataset(valid_filenames, valid_primary_labels, "valid")
    valid_data_loader = torch.utils.data.DataLoader(valid_dset, batch_size=config.BATCH_SIZE, shuffle=False)
    
    # モデル関係
    model = BirdCLEFNet(config.MODEL_NAME)
    model.to(device)
    optimizer = Adam(model.parameters(), lr=config.LEAENING_RATE)
    scheduler = CosineAnnealingLR(optimizer, T_max=len(train_data_loader)*config.T_MAX, eta_min=0.0)

    # 学習ログのwatch
    uniqe_exp_name = f"exp{config.EXP_NUM}_freq{config.LABEL_FREQ}_f{fold}_{config.EXP_NAME}"
    wandb.init(project="birdclef", entity='trtd56', name=uniqe_exp_name)
    wandb_config = wandb.config
    wandb_config.fold = fold
    for k, v in dict(vars(config)).items():
        if k[:2] == "__":
            continue
        wandb_config[k] = v
    wandb.watch(model)

    best_f1 = 0
    for epoch in range(config.NUM_EPOCHS):
        train_losses, lrs = train_loop(train_data_loader, model, optimizer, scheduler)
        valid_losses, valid_predicts = valid_loop(valid_data_loader, model)

        valid_predicts = valid_predicts.sigmoid().cpu()

        predict_labels = output_to_label(valid_predicts, config.THRESHOLD)
        epoch_f1 = row_wise_micro_averaged_f1_score(valid_primary_labels, predict_labels)

        if best_f1 < epoch_f1:
            best_f1 = epoch_f1
            torch.save(model.state_dict(), f"{config.OUTPUT_ROOT}/tmp/birdclefnet_f{fold}_thr05_best_model.bin")

        res_d = dict()
        res_d["t_loss"] = np.array(train_losses).mean()
        res_d["v_loss"] = np.array(valid_losses).mean()
        res_d["lr_avg"] = np.array(lrs).mean()
        res_d["epoch_f1"] = epoch_f1
        res_d["best_f1"] = best_f1

        wandb.log(res_d)
        torch.save(model.state_dict(), f"{config.OUTPUT_ROOT}/tmp/birdclefnet_f{fold}_last_model.bin")

    wandb.finish()
    break  # only Fold-0