In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!nvidia-smi

In [None]:
!pip install -r /path/to/requirements.txt

In [None]:
%%time
!cp /path/to/train_audio_sep_5sec_ogg.zip /content
!unzip -q train_audio_sep_5sec_ogg.zip -d /content/BirdCLEF2024
!rm train_audio_sep_5sec_ogg.zip

In [None]:
!cp /path/to/unlabeld_data_back.zip /content
!unzip -q unlabeld_data_back.zip -d /content
!rm unlabeld_data_back.zip

In [None]:
!cp /path/to/birdclef2024-additional-cleaned.zip /content
!unzip -q birdclef2024-additional-cleaned.zip -d /content/birdclef2024-additional-cleaned
!rm birdclef2024-additional-cleaned.zip

In [None]:
import cv2
import audioread
import logging
import gc
import os
import json
import sys
import glob
import time

import random
import shutil
import warnings
warnings.simplefilter('ignore')

import librosa
import numpy as np
import pandas as pd
import soundfile as sf
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as torchdata
from torchvision import transforms as TT
from contextlib import contextmanager
from joblib import Parallel, delayed
from pathlib import Path
from typing import Optional
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn import metrics
from sklearn.metrics import mean_squared_error, roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold

from albumentations.core.transforms_interface import ImageOnlyTransform
import torchlibrosa
from torchlibrosa.stft import LogmelFilterBank, Spectrogram
from torchlibrosa.augmentation import SpecAugmentation
import colorednoise
from tqdm import tqdm

import albumentations as A
import albumentations.pytorch.transforms as T
import audiomentations as AA

import matplotlib.pyplot as plt
import wandb

import transformers
from torch.cuda.amp import autocast, GradScaler

sys.path.append('/path/to/modules') # Need to change
from utils import filter_data, upsample_data, downsample_data, set_seed
from augmentations import OneOf, Compose, NoiseInjection, GaussianNoise, PinkNoise, RandomVolume, Normalize, cutmix, mixup
from losses import BCEFocalLoss, BCELoss, loss_fn, mixup_criterion, cutmix_criterion
from models import CustomModel
from metrics import AverageMeter, MetricMeter
from dataset import WaveformDataset

## Config


In [None]:
class CFG:
    ######################
    # Globals #
    ######################
    EXP_ID = 'v0'
    save_dir = '/path/to/save/' + EXP_ID # Need to change
    csv_path = '/path/to/train_metadata.csv' # Need to change
    wandb_key = ''# Insert your wandb key
    seed = 71

    epochs = 20
    last_epochs = 0
    cutmix_and_mixup_epochs = 17 - last_epochs
    folds = [0, 1, 2] # [0, 1, 2, 3, 4]
    secondary_ratio = 0
    upsample_thr = 15
    downsample_thr = 500
    N_FOLDS = 5
    LR = 1e-3
    ETA_MIN = 1e-7
    WEIGHT_DECAY = 1e-6
    train_bs = 32 # 32
    valid_bs = 32 # 64
    base_model_name = "tf_efficientnet_b0_ns"
    EARLY_STOPPING = True
    DEBUG = False # True
    if DEBUG:
        epochs = 1
    EVALUATION = 'AUC'
    apex = True

    pretrained = True
    in_channels = 1
    target_columns = ['asbfly', 'ashdro1', 'ashpri1', 'ashwoo2', 'asikoe2',
                      'asiope1', 'aspfly1', 'aspswi1', 'barfly1', 'barswa',
                      'bcnher', 'bkcbul1', 'bkrfla1', 'bkskit1', 'bkwsti',
                      'bladro1', 'blaeag1', 'blakit1', 'blhori1', 'blnmon1',
                      'blrwar1', 'bncwoo3', 'brakit1', 'brasta1', 'brcful1',
                      'brfowl1', 'brnhao1', 'brnshr', 'brodro1', 'brwjac1',
                      'brwowl1', 'btbeat1', 'bwfshr1', 'categr', 'chbeat1',
                      'cohcuc1', 'comfla1', 'comgre', 'comior1', 'comkin1',
                      'commoo3', 'commyn', 'compea', 'comros', 'comsan',
                      'comtai1', 'copbar1', 'crbsun2', 'cregos1', 'crfbar1',
                      'crseag1', 'dafbab1', 'darter2', 'eaywag1', 'emedov2',
                      'eucdov', 'eurbla2', 'eurcoo', 'forwag1', 'gargan',
                      'gloibi', 'goflea1', 'graher1', 'grbeat1', 'grecou1',
                      'greegr', 'grefla1', 'grehor1', 'grejun2', 'grenig1',
                      'grewar3', 'grnsan', 'grnwar1', 'grtdro1', 'gryfra',
                      'grynig2', 'grywag', 'gybpri1', 'gyhcaf1', 'heswoo1',
                      'hoopoe', 'houcro1', 'houspa', 'inbrob1', 'indpit1',
                      'indrob1', 'indrol2', 'indtit1', 'ingori1', 'inpher1',
                      'insbab1', 'insowl1', 'integr', 'isbduc1', 'jerbus2',
                      'junbab2', 'junmyn1', 'junowl1', 'kenplo1', 'kerlau2',
                      'labcro1', 'laudov1', 'lblwar1', 'lesyel1', 'lewduc1',
                      'lirplo', 'litegr', 'litgre1', 'litspi1', 'litswi1',
                      'lobsun2', 'maghor2', 'malpar1', 'maltro1', 'malwoo1',
                      'marsan', 'mawthr1', 'moipig1', 'nilfly2', 'niwpig1',
                      'nutman', 'orihob2', 'oripip1', 'pabflo1', 'paisto1',
                      'piebus1', 'piekin1', 'placuc3', 'plaflo1', 'plapri1',
                      'plhpar1', 'pomgrp2', 'purher1', 'pursun3', 'pursun4',
                      'purswa3', 'putbab1', 'redspu1', 'rerswa1', 'revbul',
                      'rewbul', 'rewlap1', 'rocpig', 'rorpar', 'rossta2',
                      'rufbab3', 'ruftre2', 'rufwoo2', 'rutfly6', 'sbeowl1',
                      'scamin3', 'shikra1', 'smamin1', 'sohmyn1', 'spepic1',
                      'spodov', 'spoowl1', 'sqtbul1', 'stbkin1', 'sttwoo1',
                      'thbwar1', 'tibfly3', 'tilwar1', 'vefnut1', 'vehpar1',
                      'wbbfly1', 'wemhar1', 'whbbul2', 'whbsho3', 'whbtre1',
                      'whbwag1', 'whbwat1', 'whbwoo2', 'whcbar1', 'whiter2',
                      'whrmun', 'whtkin2', 'woosan', 'wynlau1', 'yebbab1',
                      'yebbul3','zitcis1'
                     ]
    num_classes = len(target_columns) #182

    n_accumulate  = max(1, 32//train_bs)
    period = 5
    downsample = 2
    use_first = True
    reshape_factor = 200
    sample_rate = 32000


class AudioParams:
    """
    Parameters used for the audio data
    """
    sr = CFG.sample_rate
    duration = CFG.period

In [None]:
wandb.login(key=CFG.wandb_key)

## Make All files dataframe

In [None]:
files_sep_audio = glob.glob('/content/BirdCLEF2024/train_audio_sep_5sec_ogg/*/*.ogg') # Need to change
df_sep_audio = pd.DataFrame({'clip_file_path': files_sep_audio},)
df_sep_audio['filename'] = df_sep_audio['clip_file_path'].map(lambda x: '/'.join(x.split('/')[-2:]))
df_sep_audio['filename'] = df_sep_audio['filename'].map(lambda x: x.split('_')[0]+'.ogg')
df_sep_audio

In [None]:
import pickle

paths_clean = glob.glob('/content/birdclef2024-additional-cleaned/*/*.ogg')
df_clean = pd.DataFrame({'clip_file_path': paths_clean})
df_clean['filename'] = df_clean['clip_file_path'].map(lambda x: '/'.join(x.split('/')[-2:]))
df_clean

In [None]:
df_sep_audio = pd.concat([df_sep_audio, df_clean], axis=0).reset_index(drop=True)

## Make Sample dataframe

In [None]:
import ast
train_2024 = pd.read_csv(CFG.csv_path)
train_2024['new_target'] = train_2024['primary_label'] + ' ' + train_2024['secondary_labels'].map(lambda x: ' '.join(ast.literal_eval(x)))
train_2024['len_new_target'] = train_2024['new_target'].map(lambda x: len(x.split()))
train_2024['filename_base'] = train_2024['filename'].map(lambda x: x.split('/')[-1])
train_2024['birdclef'] = '24'
train_2024['xc_id'] = train_2024['filename_base'].map(lambda x: x.split('.')[0])
train_2024 = filter_data(train_2024)
# train['len_new_target'].value_counts()
train_2024[train_2024['cv']==False]
train_2024 = train_2024.drop_duplicates(subset=['filename_base']).reset_index(drop=True)
id_unique = df_sep_audio['filename'].unique().tolist()
train_2024 = train_2024[train_2024['filename'].isin(id_unique)].reset_index(drop=True)
train_2024 = train_2024[['filename', 'author','primary_label','new_target',
                       'xc_id', 'birdclef', 'cv']]
train_2024["rating_0_1"] = 1

* Remove duplicated files [(reference)](https://www.kaggle.com/code/robbynevels/bc24-duplicate-audio-files/)

In [None]:
dupes = [
    ('asbfly/XC724266.ogg', 'asbfly/XC724148.ogg'),
    ('barswa/XC575749.ogg', 'barswa/XC575747.ogg'),
    ('bcnher/XC669544.ogg', 'bcnher/XC669542.ogg'),
    ('bkskit1/XC350251.ogg', 'bkskit1/XC350249.ogg'),
    ('blhori1/XC417215.ogg', 'blhori1/XC417133.ogg'),
    ('blhori1/XC743616.ogg', 'blhori1/XC537503.ogg'),
    ('blrwar1/XC662286.ogg', 'blrwar1/XC662285.ogg'),
    ('brakit1/XC743675.ogg', 'brakit1/XC537471.ogg'),
    ('brcful1/XC197746.ogg', 'brcful1/XC157971.ogg'),
    ('brnshr/XC510751.ogg', 'brnshr/XC510750.ogg'),
    ('btbeat1/XC665307.ogg', 'btbeat1/XC513403.ogg'),
    ('btbeat1/XC743618.ogg', 'btbeat1/XC683300.ogg'),
    ('btbeat1/XC743619.ogg', 'btbeat1/XC683300.ogg'),
    ('btbeat1/XC743619.ogg', 'btbeat1/XC743618.ogg'),
    ('categr/XC787914.ogg', 'categr/XC438523.ogg'),
    ('cohcuc1/XC253418.ogg', 'cohcuc1/XC241127.ogg'),
    ('cohcuc1/XC423422.ogg', 'cohcuc1/XC423419.ogg'),
    ('comgre/XC202776.ogg', 'comgre/XC192404.ogg'),
    ('comgre/XC602468.ogg', 'comgre/XC175341.ogg'),
    ('comgre/XC64628.ogg', 'comgre/XC58586.ogg'),
    ('comior1/XC305930.ogg', 'comior1/XC303819.ogg'),
    ('comkin1/XC207123.ogg', 'comior1/XC207062.ogg'),
    ('comkin1/XC691421.ogg', 'comkin1/XC690633.ogg'),
    ('commyn/XC577887.ogg', 'commyn/XC577886.ogg'),
    ('commyn/XC652903.ogg', 'commyn/XC652901.ogg'),
    ('compea/XC665320.ogg', 'compea/XC644022.ogg'),
    ('comsan/XC385909.ogg', 'comsan/XC385908.ogg'),
    ('comsan/XC643721.ogg', 'comsan/XC642698.ogg'),
    ('comsan/XC667807.ogg', 'comsan/XC667806.ogg'),
    ('comtai1/XC126749.ogg', 'comtai1/XC122978.ogg'),
    ('comtai1/XC305210.ogg', 'comtai1/XC304811.ogg'),
    ('comtai1/XC542375.ogg', 'comtai1/XC540351.ogg'),
    ('comtai1/XC542379.ogg', 'comtai1/XC540352.ogg'),
    ('crfbar1/XC615780.ogg', 'crfbar1/XC615778.ogg'),
    ('dafbab1/XC188307.ogg', 'dafbab1/XC187059.ogg'),
    ('dafbab1/XC188308.ogg', 'dafbab1/XC187068.ogg'),
    ('dafbab1/XC188309.ogg', 'dafbab1/XC187069.ogg'),
    ('dafbab1/XC197745.ogg', 'dafbab1/XC157972.ogg'),
    ('eaywag1/XC527600.ogg', 'eaywag1/XC527598.ogg'),
    ('eucdov/XC355153.ogg', 'eucdov/XC355152.ogg'),
    ('eucdov/XC360303.ogg', 'eucdov/XC347428.ogg'),
    ('eucdov/XC365606.ogg', 'eucdov/XC124694.ogg'),
    ('eucdov/XC371039.ogg', 'eucdov/XC368596.ogg'),
    ('eucdov/XC747422.ogg', 'eucdov/XC747408.ogg'),
    ('eucdov/XC789608.ogg', 'eucdov/XC788267.ogg'),
    ('goflea1/XC163901.ogg', 'bladro1/XC163901.ogg'),
    ('goflea1/XC208794.ogg', 'bladro1/XC208794.ogg'),
    ('goflea1/XC208795.ogg', 'bladro1/XC208795.ogg'),
    ('goflea1/XC209203.ogg', 'bladro1/XC209203.ogg'),
    ('goflea1/XC209549.ogg', 'bladro1/XC209549.ogg'),
    ('goflea1/XC209564.ogg', 'bladro1/XC209564.ogg'),
    ('graher1/XC357552.ogg', 'graher1/XC357551.ogg'),
    ('graher1/XC590235.ogg', 'graher1/XC590144.ogg'),
    ('grbeat1/XC304004.ogg', 'grbeat1/XC303999.ogg'),
    ('grecou1/XC365426.ogg', 'grecou1/XC365425.ogg'),
    ('greegr/XC247286.ogg', 'categr/XC197438.ogg'),
    ('grewar3/XC743681.ogg', 'grewar3/XC537475.ogg'),
    ('grnwar1/XC197744.ogg', 'grnwar1/XC157973.ogg'),
    ('grtdro1/XC651708.ogg', 'grtdro1/XC613192.ogg'),
    ('grywag/XC459760.ogg', 'grywag/XC457124.ogg'),
    ('grywag/XC575903.ogg', 'grywag/XC575901.ogg'),
    ('grywag/XC650696.ogg', 'grywag/XC592019.ogg'),
    ('grywag/XC690448.ogg', 'grywag/XC655063.ogg'),
    ('grywag/XC745653.ogg', 'grywag/XC745650.ogg'),
    ('grywag/XC812496.ogg', 'grywag/XC812495.ogg'),
    ('heswoo1/XC357155.ogg', 'heswoo1/XC357149.ogg'),
    ('heswoo1/XC744698.ogg', 'heswoo1/XC665715.ogg'),
    ('hoopoe/XC631301.ogg', 'hoopoe/XC365530.ogg'),
    ('hoopoe/XC631304.ogg', 'hoopoe/XC252584.ogg'),
    ('houcro1/XC744704.ogg', 'houcro1/XC683047.ogg'),
    ('houspa/XC326675.ogg', 'houspa/XC326674.ogg'),
    ('inbrob1/XC744708.ogg', 'inbrob1/XC744706.ogg'),
    ('insowl1/XC305214.ogg', 'insowl1/XC301142.ogg'),
    ('junbab2/XC282587.ogg', 'junbab2/XC282586.ogg'),
    ('labcro1/XC267645.ogg', 'labcro1/XC265731.ogg'),
    ('labcro1/XC345836.ogg', 'labcro1/XC312582.ogg'),
    ('labcro1/XC37773.ogg', 'labcro1/XC19736.ogg'),
    ('labcro1/XC447036.ogg', 'houcro1/XC447036.ogg'),
    ('labcro1/XC823514.ogg', 'gybpri1/XC823527.ogg'),
    ('laudov1/XC185511.ogg', 'grewar3/XC185505.ogg'),
    ('laudov1/XC405375.ogg', 'laudov1/XC405374.ogg'),
    ('laudov1/XC514027.ogg', 'eucdov/XC514027.ogg'),
    ('lblwar1/XC197743.ogg', 'lblwar1/XC157974.ogg'),
    ('lewduc1/XC261506.ogg', 'lewduc1/XC254813.ogg'),
    ('litegr/XC403621.ogg', 'bcnher/XC403621.ogg'),
    ('litegr/XC535540.ogg', 'litegr/XC448898.ogg'),
    ('litegr/XC535552.ogg', 'litegr/XC447850.ogg'),
    ('litgre1/XC630775.ogg', 'litgre1/XC630560.ogg'),
    ('litgre1/XC776082.ogg', 'litgre1/XC663244.ogg'),
    ('litspi1/XC674522.ogg', 'comtai1/XC674522.ogg'),
    ('litspi1/XC722435.ogg', 'litspi1/XC721636.ogg'),
    ('litspi1/XC722436.ogg', 'litspi1/XC721637.ogg'),
    ('litswi1/XC443070.ogg', 'litswi1/XC440301.ogg'),
    ('lobsun2/XC197742.ogg', 'lobsun2/XC157975.ogg'),
    ('maghor2/XC197740.ogg', 'maghor2/XC157978.ogg'),
    ('maghor2/XC786588.ogg', 'maghor2/XC786587.ogg'),
    ('malpar1/XC197770.ogg', 'malpar1/XC157976.ogg'),
    ('marsan/XC383290.ogg', 'marsan/XC383288.ogg'),
    ('marsan/XC733175.ogg', 'marsan/XC716673.ogg'),
    ('mawthr1/XC455222.ogg', 'mawthr1/XC455211.ogg'),
    ('orihob2/XC557991.ogg', 'orihob2/XC557293.ogg'),
    ('piebus1/XC165050.ogg', 'piebus1/XC122395.ogg'),
    ('piebus1/XC814459.ogg', 'piebus1/XC792272.ogg'),
    ('placuc3/XC490344.ogg', 'placuc3/XC486683.ogg'),
    ('placuc3/XC572952.ogg', 'placuc3/XC572950.ogg'),
    ('plaflo1/XC615781.ogg', 'plaflo1/XC614946.ogg'),
    ('purher1/XC467373.ogg', 'graher1/XC467373.ogg'),
    ('purher1/XC827209.ogg', 'purher1/XC827207.ogg'),
    ('pursun3/XC268375.ogg', 'comtai1/XC241382.ogg'),
    ('pursun4/XC514853.ogg', 'pursun4/XC514852.ogg'),
    ('putbab1/XC574864.ogg', 'brcful1/XC574864.ogg'),
    ('rewbul/XC306398.ogg', 'bkcbul1/XC306398.ogg'),
    ('rewbul/XC713308.ogg', 'asbfly/XC713467.ogg'),
    ('rewlap1/XC733007.ogg', 'rewlap1/XC732874.ogg'),
    ('rorpar/XC199488.ogg', 'rorpar/XC199339.ogg'),
    ('rorpar/XC402325.ogg', 'comior1/XC402326.ogg'),
    ('rorpar/XC516404.ogg', 'rorpar/XC516402.ogg'),
    ('sbeowl1/XC522123.ogg', 'brfowl1/XC522123.ogg'),
    ('sohmyn1/XC744700.ogg', 'sohmyn1/XC743682.ogg'),
    ('spepic1/XC804432.ogg', 'spepic1/XC804431.ogg'),
    ('spodov/XC163930.ogg', 'bladro1/XC163901.ogg'),
    ('spodov/XC163930.ogg', 'goflea1/XC163901.ogg'),
    ('spoowl1/XC591485.ogg', 'spoowl1/XC591177.ogg'),
    ('stbkin1/XC266782.ogg', 'stbkin1/XC266682.ogg'),
    ('stbkin1/XC360661.ogg', 'stbkin1/XC199815.ogg'),
    ('stbkin1/XC406140.ogg', 'stbkin1/XC406138.ogg'),
    ('vefnut1/XC197738.ogg', 'vefnut1/XC157979.ogg'),
    ('vefnut1/XC293526.ogg', 'vefnut1/XC289785.ogg'),
    ('wemhar1/XC581045.ogg', 'comsan/XC581045.ogg'),
    ('wemhar1/XC590355.ogg', 'wemhar1/XC590354.ogg'),
    ('whbbul2/XC335671.ogg', 'whbbul2/XC335670.ogg'),
    ('whbsho3/XC856465.ogg', 'whbsho3/XC856463.ogg'),
    ('whbsho3/XC856468.ogg', 'whbsho3/XC856463.ogg'),
    ('whbsho3/XC856468.ogg', 'whbsho3/XC856465.ogg'),
    ('whbwat1/XC840073.ogg', 'whbwat1/XC840071.ogg'),
    ('whbwoo2/XC239509.ogg', 'rufwoo2/XC239509.ogg'),
    ('whcbar1/XC659329.ogg', 'insowl1/XC659329.ogg'),
    ('whiter2/XC265271.ogg', 'whiter2/XC265267.ogg'),
    ('whtkin2/XC197737.ogg', 'whtkin2/XC157981.ogg'),
    ('whtkin2/XC430267.ogg', 'whtkin2/XC430256.ogg'),
    ('whtkin2/XC503389.ogg', 'comior1/XC503389.ogg'),
    ('whtkin2/XC540094.ogg', 'whtkin2/XC540087.ogg'),
    ('woosan/XC184466.ogg', 'marsan/XC184466.ogg'),
    ('woosan/XC545316.ogg', 'woosan/XC476064.ogg'),
    ('woosan/XC587076.ogg', 'woosan/XC578599.ogg'),
    ('woosan/XC742927.ogg', 'woosan/XC740798.ogg'),
    ('woosan/XC825766.ogg', 'grnsan/XC825765.ogg'),
    ('zitcis1/XC303866.ogg', 'zitcis1/XC302781.ogg'),
]

id_drop = [dup[1] for dup in dupes]
pattern = '|'.join(id_drop)
train_2024 = train_2024[~train_2024['filename'].str.contains(pattern)].reset_index(drop=True)
train_2024

In [None]:
gkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0)
train_2024['fold'] = -1
for fold, (train_idx, val_idx) in enumerate(gkf.split(train_2024, y=train_2024.primary_label.tolist(), groups=train_2024.author.tolist())):
    train_2024.loc[val_idx, 'fold'] = fold

In [None]:
bird_names = set(train_2024.primary_label)
unseen_birds = []
for fold in range(CFG.N_FOLDS):
    bird_names_fold = set(train_2024[train_2024['fold']==fold].primary_label)
    unseen_bird = list(bird_names - bird_names_fold)
    print(f'fold{fold}:', unseen_bird)
    unseen_birds.extend(unseen_bird)
unseen_birds = np.unique(unseen_birds)
train_2024.loc[train_2024['primary_label'].isin(unseen_birds), 'fold'] = -1
train_2024['len_new_target'] = train_2024['new_target'].map(lambda x: len(x.split()))
train_2024['filename_base'] = train_2024['filename'].map(lambda x: x.split('/')[-1])
train_2024['birdclef'] = '24'
train_2024['xc_id'] = train_2024['filename_base'].map(lambda x: x.split('.')[0])

In [None]:
import pickle

with open('/content/birdclef2024-additional-cleaned/additional_cleaned.pkl', 'rb') as file:
    train_audio_meta_cleaned = pickle.load(file)
train_audio_meta_cleaned['new_target'] = train_audio_meta_cleaned['primary_label']
train_audio_meta_cleaned['cv'] = True
train_audio_meta_cleaned['rating_0_1'] = 1
train_audio_meta_cleaned['fold'] = -1

In [None]:
train = pd.concat([train_2024, train_audio_meta_cleaned], axis=0, ignore_index=True).reset_index(drop=True)

In [None]:
train

In [None]:
unseen_index = []
for name in unseen_birds:
    unseen_index.append(CFG.target_columns.index(name))

## Train

In [None]:
set_seed(CFG.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bg_noise_files = glob.glob("/content/unlabeld_data_back/*")
transforms = {'train': Compose([AA.AddBackgroundNoise(bg_noise_files, min_snr_in_db=3.0,max_snr_in_db=30.0,p=0.5),
                                AA.Gain(min_gain_in_db=-12, max_gain_in_db=12, p=0.2),
                                OneOf([NoiseInjection(p=1, max_noise_level=0.04), #(p=1, max_noise_level=0.04)
                                       GaussianNoise(p=1, min_snr=5, max_snr=20), #(p=1, min_snr=5, max_snr=20)
                                       PinkNoise(p=1, min_snr=5, max_snr=20), #(p=1, min_snr=5, max_snr=20)
                                      ],p=0.3,
                                      ),
                                RandomVolume(p=0.3, limit=4),
                                Normalize(p=1),]
                               ),
              'valid': Compose([Normalize(p=1),])
              }

In [None]:
def train_fn(model, data_loader, device, optimizer):
    model.train()
    scaler = GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    scores = MetricMeter(indices_ignore=unseen_index)
    tk0 = tqdm(enumerate(data_loader), total=len(data_loader))

    for step, data in tk0:
        inputs = data['image'].to(device)
        targets = data['targets'].to(device)
        rating = data['rating'].to(device)
        with autocast(enabled=CFG.apex):
            outputs = model(inputs)
            loss = loss_fn(outputs, targets, rating)
            loss /= CFG.n_accumulate

        scaler.scale(loss).backward()
        if (step + 1) % CFG.n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        #scheduler.step()
        losses.update(loss.item(), inputs.size(0))
        scores.update(targets, outputs)
        tk0.set_postfix(loss=losses.avg)
    return scores.avg, losses.avg


def train_mixup_cutmix_fn(model, data_loader, device, optimizer):
    model.train()
    scaler = GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    scores = MetricMeter(indices_ignore=unseen_index)
    tk0 = tqdm(enumerate(data_loader), total=len(data_loader))

    for step, data in tk0:
        inputs = data['image'].to(device)
        targets = data['targets'].to(device)
        rating = data['rating'].to(device)

        if np.random.rand()<0.5:#0.5
            inputs, new_targets, new_rating = mixup(inputs, targets, rating, 0.4)
            with autocast(enabled=CFG.apex):
                outputs = model(inputs)
                loss = mixup_criterion(outputs, new_targets, new_rating)
                loss /= CFG.n_accumulate
        else:
            inputs, new_targets, new_rating = cutmix(inputs, targets, rating, 0.4)
            with autocast(enabled=CFG.apex):
                outputs = model(inputs)
                loss = cutmix_criterion(outputs, new_targets, new_rating)
                loss /=  CFG.n_accumulate


        scaler.scale(loss).backward()
        if (step + 1) % CFG.n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        #scheduler.step()
        losses.update(loss.item(), inputs.size(0))
        scores.update(new_targets[0], outputs)
        tk0.set_postfix(loss=losses.avg)
    return scores.avg, losses.avg


def valid_fn(model, data_loader, device):
    model.eval()
    losses = AverageMeter()
    scores = MetricMeter(indices_ignore=unseen_index)
    tk0 = tqdm(data_loader, total=len(data_loader))
    paths = []
    with torch.no_grad():
        for data in tk0:
            path = data['path']
            inputs = data['image'].to(device)
            targets = data['targets'].to(device)
            rating = data['rating'].to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets, rating)
            losses.update(loss.item(), inputs.size(0))
            scores.update(targets, outputs)
            paths.extend(path)
            tk0.set_postfix(loss=losses.avg)
    return scores.avg, losses.avg, scores.y_true, scores.y_pred, paths

def inference_fn(model, data_loader, device):
    model.eval()
    tk0 = tqdm(data_loader, total=len(data_loader))
    final_output = []
    final_target = []
    with torch.no_grad():
        for b_idx, data in enumerate(tk0):
            inputs = data['image'].to(device)
            targets = data['targets'].to(device).detach().cpu().numpy().tolist()
            output = model(inputs)
            output = output.cpu().detach().cpu().numpy().tolist()
            final_output.extend(output)
            final_target.extend(targets)
    return final_output, final_target

In [None]:
from pandas._libs.lib import fast_unique_multiple_list_gen
# main loop
load_model = False
DEBUG = False
run_train = True
save_dir = CFG.save_dir
os.makedirs(save_dir, exist_ok=True)
if run_train:
    for fold in range(CFG.N_FOLDS):
        run = wandb.init(project='BirdCLEF2024',
                    config={k:v for k, v in dict(vars(CFG)).items() if '__' not in k},
                    anonymous=None,
                    #name=f"fold-{fold}|dim-{CFG.img_size[0]}x{CFG.img_size[1]}|model-{CFG.model_name}",
                    group=f'fold{fold}',
                    name=CFG.EXP_ID
                )
        if fold not in CFG.folds:
            continue
        print("=" * 100)
        print(f"Fold {fold} Training")
        print("=" * 100)
        trn_df = train.query("fold!=@fold | ~cv").reset_index(drop=True)
        val_df = train.query("fold==@fold & cv").reset_index(drop=True)
        val_df = val_df.drop_duplicates(subset=['primary_label', 'author']).reset_index(drop=True)
        print('num classes:', trn_df['primary_label'].nunique())
        trn_df = downsample_data(trn_df, thr=CFG.downsample_thr, seed=CFG.seed)
        trn_df = upsample_data(trn_df, thr=CFG.upsample_thr, seed=CFG.seed)
        if DEBUG:
            trn_df = trn_df[:100]
            val_df = val_df[:100]

        train_dataset = WaveformDataset(df=trn_df,
                                        df_sep=df_sep_audio,
                                        target_columns=CFG.target_columns,
                                        transforms=transforms,
                                        duration=CFG.period,
                                        secondary_ratio=CFG.secondary_ratio,
                                        use_first=CFG.use_first,
                                        downsample=CFG.downsample,
                                        mode='train'
                                        )
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=CFG.train_bs, num_workers=8, pin_memory=True, shuffle=True
        )

        valid_dataset = WaveformDataset(df=val_df,
                                        df_sep=df_sep_audio,
                                        target_columns=CFG.target_columns,
                                        transforms=transforms,
                                        duration=CFG.period,
                                        secondary_ratio=CFG.secondary_ratio,
                                        use_first=CFG.use_first,
                                        downsample=CFG.downsample,
                                        mode='valid'
                                        )
        valid_dataloader = torch.utils.data.DataLoader(
            valid_dataset, batch_size=CFG.valid_bs, num_workers=8, pin_memory=True, shuffle=False
        )

        model = CustomModel(model_name=CFG.base_model_name,
                            pretrained=True,
                            num_classes=CFG.num_classes,
                            in_chans=3,
                            reshape_factor=CFG.reshape_factor
                            )

        optimizer = transformers.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG.epochs,
                                                                            T_mult=1,eta_min=CFG.ETA_MIN)
        #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)

        min_loss = 999
        best_score = -np.inf
        resume_epoch = 0
        if load_model:
            checkpoint = torch.load(os.path.join(save_dir, f"checkpoint_{fold:02d}.bin"))
            if hasattr(model, "module"):  # DataParallelを使用した場合
                model.module.load_state_dict(checkpoint["model"])
            else:
                model.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            scheduler.load_state_dict(checkpoint["scheduler"])
            random.setstate(checkpoint["random"])
            np.random.set_state(checkpoint["np_random"])
            torch.set_rng_state(checkpoint["torch"])
            torch.random.set_rng_state(checkpoint["torch_random"])
            torch.cuda.set_rng_state(checkpoint["cuda_random"]) # gpuを使用する場合は必要
            torch.cuda.torch.cuda.set_rng_state_all(checkpoint["cuda_random_all"])
            resume_epoch = checkpoint['epoch'] + 1
            best_score = checkpoint['best_score']

        model = model.to(device)

        for epoch in range(resume_epoch, CFG.epochs, 1):
            print("Starting {} epoch...".format(epoch+1))
            print('learning rate:', optimizer.param_groups[0]['lr'])
            start_time = time.time()

            if epoch < CFG.cutmix_and_mixup_epochs:
                train_avg, train_loss = train_mixup_cutmix_fn(model, train_dataloader, device, optimizer)
            else:
                train_avg, train_loss = train_fn(model, train_dataloader, device, optimizer)

            valid_avg, valid_loss, y_true, y_pred, valid_path = valid_fn(model, valid_dataloader, device)
            scheduler.step()
            elapsed = time.time() - start_time

            print(f'Epoch {epoch+1} - avg_train_loss: {train_loss:.5f}  avg_val_loss: {valid_loss:.5f}  time: {elapsed:.0f}s')
            print(f"Epoch {epoch+1} - train_score:{train_avg['score']:0.5f}  valid_score:{valid_avg['score']:0.5f}")
            wandb.log({"Train Loss": train_loss,
                        "Train score": train_avg['score'],
                        "Valid Loss": valid_loss,
                        "Valid score": valid_avg['score'],
                        })
            model_to_save = model.module if hasattr(model, "module") else model # DataParallelを使用している場合はmodel.moduleを取り出す。
            checkpoint = {
                "model": model_to_save.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                #"amp": amp.state_dict(), # apex混合精度を使用する場合は必要
                "random": random.getstate(),
                "np_random": np.random.get_state(), # numpy.randomを使用する場合は必要
                "torch": torch.get_rng_state(),
                "torch_random": torch.random.get_rng_state(),
                "cuda_random": torch.cuda.get_rng_state(), # gpuを使用する場合は必要
                "cuda_random_all": torch.cuda.get_rng_state_all(), # 複数gpuを使用する場合は必要
                "epoch": epoch,
                "best_score": best_score
            }
            torch.save(checkpoint, os.path.join(save_dir, f"checkpoint_{fold:02d}.bin"))
            torch.save(model.state_dict(), os.path.join(save_dir, f'last_fold-{fold}.bin'))
            if valid_avg['score'] > best_score:
                print(f">>>>>>>> Model Improved From {best_score} ----> {valid_avg['score']}")
                torch.save(model.state_dict(), os.path.join(save_dir, f'best_fold-{fold}.bin'))
                best_score = valid_avg['score']
                val_df = pd.DataFrame({'path': valid_path,
                                        'true': y_true,
                                        'pred': y_pred
                                        })
                val_df.to_csv(os.path.join(save_dir,'best.csv'))
