## About this Code

+ Multi-label classified images with CNN model
+ Try using TorchLibrosa's SpecAugmentation function

There are some issues
+ image size
  + need resize?
+ need mono_to_color?

### References

+ [Training a winning model](https://www.kaggle.com/theoviel/training-a-winning-model/notebook?scriptVersionId=42814701)
+ [[PyTorch, Training] BirdCLEF2021 Starter](https://www.kaggle.com/hidehisaarai1213/pytorch-training-birdclef2021-starter)
+ [Cassava / resnext50_32x4d starter [training]](https://www.kaggle.com/yasufuminakama/cassava-resnext50-32x4d-starter-training)

Thank you for publishing a great notebook :)

## Libraries

In [None]:
!pip install ../input/torchlibrosa/torchlibrosa-0.0.5-py3-none-any.whl

In [None]:
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')

import os
import math
import time
import random

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
import soundfile as sf
from pathlib import Path
from IPython.display import Audio, IFrame, display 

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from torchlibrosa.stft import LogmelFilterBank, Spectrogram
from torchlibrosa.augmentation import SpecAugmentation

import timm
import warnings 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

train = pd.read_csv('../input/birdclef-2021/train_metadata.csv')
test = pd.read_csv('../input/birdclef-2021/test.csv')
train = train[['primary_label', 'filename']]
train

## CFG

In [None]:
TARGETS = [
        'acafly', 'acowoo', 'aldfly', 'ameavo', 'amecro',
        'amegfi', 'amekes', 'amepip', 'amered', 'amerob',
        'amewig', 'amtspa', 'andsol1', 'annhum', 'astfly',
        'azaspi1', 'babwar', 'baleag', 'balori', 'banana',
        'banswa', 'banwre1', 'barant1', 'barswa', 'batpig1',
        'bawswa1', 'bawwar', 'baywre1', 'bbwduc', 'bcnher',
        'belkin1', 'belvir', 'bewwre', 'bkbmag1', 'bkbplo',
        'bkbwar', 'bkcchi', 'bkhgro', 'bkmtou1', 'bknsti', 'blbgra1',
        'blbthr1', 'blcjay1', 'blctan1', 'blhpar1', 'blkpho',
        'blsspa1', 'blugrb1', 'blujay', 'bncfly', 'bnhcow', 'bobfly1',
        'bongul', 'botgra', 'brbmot1', 'brbsol1', 'brcvir1', 'brebla',
        'brncre', 'brnjay', 'brnthr', 'brratt1', 'brwhaw', 'brwpar1',
        'btbwar', 'btnwar', 'btywar', 'bucmot2', 'buggna', 'bugtan',
        'buhvir', 'bulori', 'burwar1', 'bushti', 'butsal1', 'buwtea',
        'cacgoo1', 'cacwre', 'calqua', 'caltow', 'cangoo', 'canwar',
        'carchi', 'carwre', 'casfin', 'caskin', 'caster1', 'casvir',
        'categr', 'ccbfin', 'cedwax', 'chbant1', 'chbchi', 'chbwre1',
        'chcant2', 'chispa', 'chswar', 'cinfly2', 'clanut', 'clcrob',
        'cliswa', 'cobtan1', 'cocwoo1', 'cogdov', 'colcha1', 'coltro1',
        'comgol', 'comgra', 'comloo', 'commer', 'compau', 'compot1',
        'comrav', 'comyel', 'coohaw', 'cotfly1', 'cowscj1', 'cregua1',
        'creoro1', 'crfpar', 'cubthr', 'daejun', 'dowwoo', 'ducfly', 'dusfly',
        'easblu', 'easkin', 'easmea', 'easpho', 'eastow', 'eawpew', 'eletro',
        'eucdov', 'eursta', 'fepowl', 'fiespa', 'flrtan1', 'foxspa', 'gadwal',
        'gamqua', 'gartro1', 'gbbgul', 'gbwwre1', 'gcrwar', 'gilwoo',
        'gnttow', 'gnwtea', 'gocfly1', 'gockin', 'gocspa', 'goftyr1',
        'gohque1', 'goowoo1', 'grasal1', 'grbani', 'grbher3', 'grcfly',
        'greegr', 'grekis', 'grepew', 'grethr1', 'gretin1', 'greyel',
        'grhcha1', 'grhowl', 'grnher', 'grnjay', 'grtgra', 'grycat',
        'gryhaw2', 'gwfgoo', 'haiwoo', 'heptan', 'hergul', 'herthr',
        'herwar', 'higmot1', 'hofwoo1', 'houfin', 'houspa', 'houwre',
        'hutvir', 'incdov', 'indbun', 'kebtou1', 'killde', 'labwoo', 'larspa',
        'laufal1', 'laugul', 'lazbun', 'leafly', 'leasan', 'lesgol', 'lesgre1',
        'lesvio1', 'linspa', 'linwoo1', 'littin1', 'lobdow', 'lobgna5', 'logshr',
        'lotduc', 'lotman1', 'lucwar', 'macwar', 'magwar', 'mallar3', 'marwre',
        'mastro1', 'meapar', 'melbla1', 'monoro1', 'mouchi', 'moudov', 'mouela1',
        'mouqua', 'mouwar', 'mutswa', 'naswar', 'norcar', 'norfli', 'normoc', 'norpar',
        'norsho', 'norwat', 'nrwswa', 'nutwoo', 'oaktit', 'obnthr1', 'ocbfly1',
        'oliwoo1', 'olsfly', 'orbeup1', 'orbspa1', 'orcpar', 'orcwar', 'orfpar',
        'osprey', 'ovenbi1', 'pabspi1', 'paltan1', 'palwar', 'pasfly', 'pavpig2',
        'phivir', 'pibgre', 'pilwoo', 'pinsis', 'pirfly1', 'plawre1', 'plaxen1',
        'plsvir', 'plupig2', 'prowar', 'purfin', 'purgal2', 'putfru1', 'pygnut',
        'rawwre1', 'rcatan1', 'rebnut', 'rebsap', 'rebwoo', 'redcro', 'reevir1',
        'rehbar1', 'relpar', 'reshaw', 'rethaw', 'rewbla', 'ribgul', 'rinkin1',
        'roahaw', 'robgro', 'rocpig', 'rotbec', 'royter1', 'rthhum', 'rtlhum',
        'ruboro1', 'rubpep1', 'rubrob', 'rubwre1', 'ruckin', 'rucspa1', 'rucwar',
        'rucwar1', 'rudpig', 'rudtur', 'rufhum', 'rugdov', 'rumfly1', 'runwre1',
        'rutjac1', 'saffin', 'sancra', 'sander', 'savspa', 'saypho', 'scamac1',
        'scatan', 'scbwre1', 'scptyr1', 'scrtan1', 'semplo', 'shicow', 'sibtan2',
        'sinwre1', 'sltred', 'smbani', 'snogoo', 'sobtyr1', 'socfly1', 'solsan',
        'sonspa', 'soulap1', 'sposan', 'spotow', 'spvear1', 'squcuc1', 'stbori',
        'stejay', 'sthant1', 'sthwoo1', 'strcuc1', 'strfly1', 'strsal1', 'stvhum2',
        'subfly', 'sumtan', 'swaspa', 'swathr', 'tenwar', 'thbeup1', 'thbkin',
        'thswar1', 'towsol', 'treswa', 'trogna1', 'trokin', 'tromoc', 'tropar',
        'tropew1', 'tuftit', 'tunswa', 'veery', 'verdin', 'vigswa', 'warvir',
        'wbwwre1', 'webwoo1', 'wegspa1', 'wesant1', 'wesblu', 'weskin', 'wesmea',
        'westan', 'wewpew', 'whbman1', 'whbnut', 'whcpar', 'whcsee1', 'whcspa',
        'whevir', 'whfpar1', 'whimbr', 'whiwre1', 'whtdov', 'whtspa', 'whwbec1',
        'whwdov', 'wilfly', 'willet1', 'wilsni1', 'wiltur', 'wlswar', 'wooduc',
        'woothr', 'wrenti', 'y00475', 'yebcha', 'yebela1', 'yebfly', 'yebori1',
        'yebsap', 'yebsee1', 'yefgra1', 'yegvir', 'yehbla', 'yehcar1', 'yelgro',
        'yelwar', 'yeofly1', 'yerwar', 'yeteup1', 'yetvir']

In [None]:
class CFG:
    seed = 29
    n_fold = 5
    trn_fold = [0]
    target_col = 'primary_label'
    train_datadir = Path("../input/birdclef-2021/train_short_audio")
    period = 5
    img_size = 224
    criterion ='BCEWithLogitsLoss'
    model_name = 'tf_efficientnet_b3'
    target_size = len(TARGETS)
    # Audio cfg
    n_mels = 128
    fmin = 20
    fmax = 16000
    n_fft = 2048
    hop_length = 512
    sample_rate = 32000
    epochs = 10
    # scheduler/optimizer
    scheduler = 'CosineAnnealingWarmRestarts' 
    T_0=10 
    lr=1e-4
    min_lr=1e-6
    weight_decay=1e-6  
    # train
    gradient_accumulation_steps=1
    apex = False
    max_grad_norm = 1000
    print_freq = 100
    # model
    pretrained = True
    in_channels = 1
    # Split
    split = "StratifiedKFold"
    split_params = {
        "n_splits": 5,
        "shuffle": True,
        "random_state": 29
    }
    # DataLoader
    loader = {
        "train": {
            "batch_size": 64,
            "num_workers": 4,
            "shuffle": True,
            "pin_memory": True,
            "drop_last": True
        },
        "valid": {
            "batch_size": 64,
            "num_workers": 4,
            "shuffle": False,
            "pin_memory": True,
            "drop_last": False
        }
    }
    debug = True
    
if CFG.debug:
    CFG.epochs = 5
    train = train.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)

## Utils

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def init_logger(log_file='train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()
set_seed(seed=CFG.seed)

## CV split

In [None]:
folds = train.copy()
Fold = StratifiedKFold(**CFG.split_params)
for n, (tr_idx, val_idx) in enumerate(Fold.split(folds, folds[CFG.target_col])):
    folds.loc[val_idx, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
# # check the propotion
fold_proportion = pd.pivot_table(folds, index=CFG.target_col, columns="fold", aggfunc=len)
print(fold_proportion.shape)

In [None]:
fold_proportion

## Dataset

In [None]:
class WaveformDataset(Dataset):
    def __init__(self,
                 df: pd.DataFrame,
                 datadir: Path,
                 img_size=224,
                 waveform_transforms=None,
                 period=20,
                 validation=False):
        self.df = df
        self.datadir = datadir
        self.img_size = img_size
        self.waveform_transforms = waveform_transforms
        self.period = period
        self.validation = validation
        self.y = np.array([TARGETS.index(c) for c in df[CFG.target_col]])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        sample = self.df.loc[idx, :]
        wav_name = sample['filename']
        ebird_code = sample[CFG.target_col]
        
        y, sr = sf.read(self.datadir / ebird_code / wav_name)
        
        len_y = len(y)
        effective_length = sr * self.period
        if len_y < effective_length:
            new_y = np.zeros(effective_length, dtype=y.dtype)
            if not self.validation:
                start = np.random.randint(effective_length - len_y)
            else:
                start = 0
            new_y[start:start + len_y] = y
            y = new_y.astype(np.float32)
        elif len_y > effective_length:
            if not self.validation:
                start = np.random.randint(len_y - effective_length)
            else:
                start = 0
            y = y[start:start + effective_length].astype(np.float32)
        else:
            y = y.astype(np.float32)

        y = np.nan_to_num(y)

        if self.waveform_transforms:
            y = self.waveform_transforms(y)

        y = np.nan_to_num(y)
        
        labels = np.zeros(len(TARGETS), dtype=float)
        labels[TARGETS.index(ebird_code)] = 1.0
        
        return{
            'waveforms': y,
            'targets': labels
        }
        

In [None]:
train_dataset = WaveformDataset(train,
                                CFG.train_datadir,
                                img_size=CFG.img_size,
                                waveform_transforms=None,
                                period=CFG.period,
                                validation=True)

data = train_dataset[0]
print(data['waveforms'].shape, data['targets'].shape)
plt.plot(data['waveforms'])
plt.show()
Audio(data=data['waveforms'], rate=32000)

## WaveformTransforms

In [None]:
# feature works...

## Criterion

In [None]:
def get_criterion():
    if CFG.criterion=='BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss(reduction="mean").to(device)
    else:
        raise NotImplementedError
    return criterion

## Scheduler

In [None]:
def get_scheduler(optimizer):
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
    return scheduler

In [None]:
# check scheduler
model = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
scheduler = get_scheduler(optimizer)

from pylab import rcParams
lrs = []
for epoch in range(1, CFG.epochs+1):
    scheduler.step(epoch-1)
    lrs.append(optimizer.param_groups[0]["lr"])
rcParams['figure.figsize'] = 20,3
plt.plot(lrs)

## Model

In [None]:
class CustomEfficientNet(nn.Module):
    def __init__(self, model_name=CFG.model_name, pretrained=False, in_channels=1):
        super().__init__()
        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=CFG.n_fft, hop_length=CFG.hop_length,
                                                 win_length=CFG.n_fft, window="hann", center=True, pad_mode="reflect",
                                                 freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=CFG.sample_rate, n_fft=CFG.n_fft,
                                                 n_mels=CFG.n_mels, fmin=CFG.fmin, fmax=CFG.fmax, ref=1.0, amin=1e-10, top_db=None,
                                                 freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
                                               freq_drop_width=8, freq_stripes_num=2)
        self.bn0 = nn.BatchNorm2d(CFG.n_mels)

        self.model = timm.create_model(CFG.model_name, pretrained=pretrained,in_chans=in_channels)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, CFG.target_size)

    def forward(self, input):
        """
        Input: (batch_size, data_length)
        """
        x = self.spectrogram_extractor(input)# (batch_size, 1(channel), time_steps, freq_bins)
        x = self.logmel_extractor(x)# (batch_size, 1(channel), time_steps, mel_bins)
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        if self.training:
            x = self.spec_augmenter(x)
        
        x = x.transpose(2, 3)
        # (batch_size, channels, freq, frames)     
        x = self.model(x)
        return x

In [None]:
model = CustomEfficientNet(model_name=CFG.model_name, pretrained=False, in_channels=1)
train_dataset = WaveformDataset(train,
                                CFG.train_datadir,
                                img_size=CFG.img_size,
                                waveform_transforms=None,
                                period=CFG.period,
                                validation=True)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True,
                          num_workers=4, pin_memory=True, drop_last=True)
for data in train_loader:
    print(data['waveforms'].shape)
    output = model(data['waveforms'])
    target = data['targets']
    break
    
criterion = get_criterion()
loss = criterion(output, target).item()
loss

## Helper functions

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

# https://www.kaggle.com/theoviel/training-a-winning-model/notebook?scriptVersionId=42814701
ONE_HOT = np.eye(CFG.target_size)
def f1(truth, pred, threshold=0.5, avg="samples"):

    if len(truth.shape) == 1:
        truth = ONE_HOT[truth]
    pred = (pred > threshold).astype(int)
    return f1_score(truth, pred, average=avg)

In [None]:
def train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    
    for step, data in enumerate(train_loader):
        waveforms = data['waveforms']
        labels = data['targets']
        # measure data loading time
        data_time.update(time.time() - end)
        
        waveforms = waveforms.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        y_preds = model(waveforms)
        loss = criterion(y_preds, labels)
        # record loss
        losses.update(loss.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        if CFG.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  #'LR: {lr:.6f}  '
                  .format(
                   epoch+1, step, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(train_loader)),
                   grad_norm=grad_norm,
                   #lr=scheduler.get_lr()[0],
                   ))
    return losses.avg

def valid_fn(valid_loader, model, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    preds = np.empty((0, CFG.target_size))
    start = end = time.time()
    for step, data in enumerate(valid_loader):
        waveforms = data['waveforms']
        labels = data['targets']
        # measure data loading time
        data_time.update(time.time() - end)
        waveforms = waveforms.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            y_preds = model(waveforms)
            preds = np.concatenate([preds, torch.sigmoid(y_preds).cpu().numpy()])
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
    return losses.avg, preds

## Train loop

In [None]:
def train_loop(folds: pd.DataFrame, fold_num: int = 0):
    LOGGER.info(f"========== fold: {fold_num} training ==========")
    ### dataset
    tr_index = folds[folds["fold"] != fold_num].index
    vl_index = folds[folds["fold"] == fold_num].index
    
    train_folds = folds.loc[tr_index].reset_index(drop=True)
    valid_folds = folds.loc[vl_index].reset_index(drop=True)
    
    train_dataset = WaveformDataset(train_folds,
                                    CFG.train_datadir,
                                    img_size=CFG.img_size,
                                    waveform_transforms=None,
                                    period=CFG.period,
                                    validation=False)
    valid_dataset = WaveformDataset(valid_folds,
                                    CFG.train_datadir,
                                    img_size=CFG.img_size,
                                    waveform_transforms=None,
                                    period=CFG.period,
                                    validation=True)
    ### dataloader
    train_loader = DataLoader(train_dataset, **CFG.loader['train'])
    valid_loader = DataLoader(valid_dataset, **CFG.loader['valid'])
    
    ### model
    model = CustomEfficientNet(model_name=CFG.model_name, pretrained=CFG.pretrained, in_channels=CFG.in_channels)
    model.to(device)
    ### optimizer
    optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
    ### get scheduler
    scheduler = get_scheduler(optimizer)
    if CFG.apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
    ### criterion
    criterion = get_criterion()
    
    # ====================================================
    # loop
    # ====================================================
    best_score = 0.
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        start_time = time.time()
        # train
        avg_loss = train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        # eval
        avg_val_loss, preds = valid_fn(valid_loader, model, criterion, device)
        # scoring
        print(f'pred max{np.amax(preds)}')
        micro_f1 = f1(valid_dataset.y, preds, avg="micro")
        samples_f1 = f1(valid_dataset.y, preds)
        LOGGER.info(f'micro_f1{micro_f1},samples_f1{samples_f1}')
        
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_loss:.4f} Model')
            torch.save({'model': model.state_dict(), 
                        'preds': preds},
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold_num}_best.pth')
    
    check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold_num}_best.pth')

    return valid_folds

In [None]:
def main():
    for fold in range(CFG.n_fold):
        if fold in CFG.trn_fold:
            _oof_df = train_loop(folds, fold)

if __name__ == '__main__':
    main()