In [None]:
!nvidia-smi

In [None]:
#!pip install perceiver-pytorch

# Notes

Days back, I've shared this [infernece kernel](https://www.kaggle.com/kneroma/clean-fast-simple-bird-identifier-inference). But its weights are static as you can't retrain the model. In this work, I'm gonna release the training notebook which is almost my internal training pipeline. I removed some experimentation ideas to make things clearer and straightforward. Don't mind adding new ideas at your side as well.

To make the training faster, we cached the training set into RAM. The whole training records are already [converted into handy  melspecs images](https://www.kaggle.com/kneroma/kkiller-birdclef-2021). These images are from 7 seconds extracts (training on 7 seconds seems to be more effective than 5 seconds). Longer records are truncated into random 7x10 seconds.

**If one is interessted in to the whole records' melspecs** (no truncation):
* https://www.kaggle.com/kneroma/kkiller-birdclef-mels-computer-d7-part1
* https://www.kaggle.com/kneroma/kkiller-birdclef-mels-computer-d7-part2
* https://www.kaggle.com/kneroma/kkiller-birdclef-mels-computer-d7-part3
* https://www.kaggle.com/kneroma/kkiller-birdclef-mels-computer-d7-part4

### Tips & suggestions
* You can choose a wide set of models from the **get_model** interface : ["resnest*", "resnet*", "resnext*", "efficientnet*" ...]
* You can change the learning rate scheduler: OneCycle ? ReduceOnPlateau ?
* Adds secondary labels
* Use train & test metadata (dates, positions (longitude, latitude), ...)
* Add melspecs augmentation

**For Colab training, you just have to uncomment the first cells**

# Versions

* **v1** : initial version
* **v3** : enable training on whole (no truncation) record melspecs

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# ! pip install --upgrade --force-reinstall --no-deps  kaggle > /dev/null
# ! mkdir ~/.kaggle
# ! cp "/content/drive/My Drive/Kaggle/kaggle.json" ~/.kaggle/
# ! chmod 600 ~/.kaggle/kaggle.json

In [None]:
# %%time

# import os
# if not os.path.exists("/content/datasets/audio_images"):
#   !mkdir datasets
#   !kaggle datasets download -d kneroma/kkiller-birdclef-2021
#   !unzip /content//kkiller-birdclef-2021.zip -d datasets
import sys
sys.path.append('../input/timm-latest')

In [None]:
!pip install -q pysndfx SoundFile audiomentations pretrainedmodels efficientnet_pytorch resnest

In [None]:
import numpy as np
import librosa as lb
import librosa.display as lbd
import soundfile as sf
from  soundfile import SoundFile
import pandas as pd
from  IPython.display import Audio
from pathlib import Path

import torch
from torch import nn, optim
from  torch.utils.data import Dataset, DataLoader
import timm

from resnest.torch import resnest50

from matplotlib import pyplot as plt

import os, random, gc
import re, time, json
from  ast import literal_eval


from IPython.display import Audio
from sklearn.metrics import label_ranking_average_precision_score

from tqdm.notebook import tqdm
import joblib

In [None]:
from efficientnet_pytorch import EfficientNet
#from perceiver_pytorch import Perceiver
import pretrainedmodels
import resnest.torch as resnest_torch

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
NUM_CLASSES = 397
SR = 32_000
DURATION = 5
N_MELS = 256
EPSILON_FP16 = 0.005
USE_FOCAL_LOSS = False
SECONDARY_FACTOR = 1
SECONDARY_FACTORS = [1, 1]
MODEL_NAMES = [
      "rexnet_200",
      "rexnet_150",
] 
MAX_READ_SAMPLES = 9999999999 # Each record will have 10 melspecs at most, you can increase this on Colab with High Memory Enabled

# # For colab
# DATA_ROOT = Path("/content/datasets/")
# TRAIN_IMAGES_ROOT = Path("/content/datasets/audio_images")
# TRAIN_LABELS_FILE = Path("/content/datasets/rich_train_metadata.csv")
# MODEL_ROOT = Path("/content/drive/My Drive/Kaggle/BirdClef2021/models")
AUG_REPLACEMENT = ['birdclef-mels-computer-', 
                   ['birdclef-mels-audiomentation-', 'birdclef-mels-audiomentation2-']]
AUG_RATE = 0.25

DATA_ROOT = Path("../input/birdclef-2021")
# TRAIN_IMAGES_ROOT = Path("../input/kkiller-birdclef-2021/audio_images")
# TRAIN_LABELS_FILE = Path("../input/kkiller-birdclef-2021/rich_train_metadata.csv")

MEL_PATHS = sorted(Path("../input").glob("birdclef-mels-computer-*/rich_train_metadata.csv"))
TRAIN_LABEL_PATHS = sorted(Path("../input").glob("birdclef-mels-computer-*/LABEL_IDS.json"))

MODEL_ROOT = Path(".")

In [None]:
TRAIN_BATCH_SIZE = 32
TRAIN_NUM_WORKERS = 4

VAL_BATCH_SIZE = 32
VAL_NUM_WORKERS = 4

EPOCHS = 20
FOLDS = [1]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Device:", DEVICE)

In [None]:
def get_df(mel_paths=MEL_PATHS, train_label_paths=TRAIN_LABEL_PATHS):
  df = None
  LABEL_IDS = {}
    
  for file_path in mel_paths:
    temp = pd.read_csv(str(file_path), index_col=0)
    temp["impath"] = temp.apply(lambda row: file_path.parent/"audio_images/{}/{}.npy".format(row.primary_label, row.filename), axis=1) 
    df = temp if df is None else df.append(temp)
    
  df["secondary_labels"] = df["secondary_labels"].apply(literal_eval)

  for file_path in train_label_paths:
    with open(str(file_path)) as f:
      LABEL_IDS.update(json.load(f))

  return LABEL_IDS, df

In [None]:
# df = pd.read_csv(TRAIN_LABELS_FILE, nrows=None)
# df["secondary_labels"] = df["secondary_labels"].apply(literal_eval)
# LABEL_IDS = {label: label_id for label_id,label in enumerate(sorted(df["primary_label"].unique()))}

# print(df.shape)
# df.head()

In [None]:
LABEL_IDS, df = get_df()

print(df.shape)
df.head()

In [None]:
df["primary_label"].value_counts()

In [None]:
df["label_id"].min(), df["label_id"].max()

In [None]:
class BCEFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, preds, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(preds, targets)
        probas = torch.sigmoid(preds)
        loss = targets * self.alpha * \
            (1. - probas)**self.gamma * bce_loss + \
            (1. - targets) * probas**self.gamma * bce_loss
        loss = loss.mean()
        return loss
    
class BCEFocal2WayLoss(nn.Module):
    def __init__(self, weights=[1, 1], class_weights=None):
        super().__init__()

        self.focal = BCEFocalLoss()

        self.weights = weights

    def forward(self, input, target):
        input_ = input["logit"]
        target = target.float()

        framewise_output = input["framewise_logit"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        loss = self.focal(input_, target)
        aux_loss = self.focal(clipwise_output_with_max, target)

        return self.weights[0] * loss + self.weights[1] * aux_loss

In [None]:
def get_model(name, num_classes=NUM_CLASSES):
    """
    Loads a pretrained model. 
    Supports ResNest, ResNext-wsl, EfficientNet, ResNext and ResNet.

    Arguments:
        name {str} -- Name of the model to load

    Keyword Arguments:
        num_classes {int} -- Number of classes to use (default: {1})

    Returns:
        torch model -- Pretrained model
    """
    if "resnest" in name:
        #model = getattr(resnest_torch, name)(pretrained=True)
        model = getattr(timm.models.resnest, name)(pretrained=True)
    elif "resnet" in name or "resnext" in name:
        model = getattr(timm.models.resnet, name)(pretrained=True)
    elif "densenet" in name:
        model = getattr(timm.models.densenet, name)(pretrained=True) 
    elif "rexnet" in name:
        model = getattr(timm.models.rexnet, name)(pretrained=True, num_classes=NUM_CLASSES)
    elif "nf_" in name:
        model = getattr(timm.models.nfnet, name)(pretrained=True, num_classes=NUM_CLASSES)
    elif "vovnet" in name:
        model = getattr(timm.models.vovnet, name)(pretrained=True, num_classes=NUM_CLASSES)
    elif name.startswith("dla"):
        model = getattr(timm.models.dla, name)(pretrained=True)
    elif "res2next" in name:
        model = getattr(timm.models.res2net, name)(pretrained=True)
    elif "regnet" in name: 
        model = getattr(timm.models.regnet, name)(pretrained=True)
    elif "coat" in name:
        model = getattr(timm.models.coat, name)(pretrained=True)
    elif "wsl-image" in name:
        model = torch.hub.load("facebookresearch/WSL-Images", name)
    elif name.startswith("resnext") or  name.startswith("resnet"):
        model = torch.hub.load("pytorch/vision:v0.6.0", name, pretrained=True)
    elif "efficientnet_" in name or "mixnet" in name:
        model = getattr(timm.models.efficientnet, name)(pretrained=True)
    elif "efficientnet-b" in name:
        model = EfficientNet.from_pretrained(name)
    else:
        model = pretrainedmodels.__dict__[name](pretrained='imagenet')

    if hasattr(model, "fc"):
        nb_ft = model.fc.in_features
        model.fc = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "_fc"):
        nb_ft = model._fc.in_features
        model._fc = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "classifier"):
        nb_ft = model.classifier.in_features
        model.classifier = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "last_linear"):
        nb_ft = model.last_linear.in_features
        model.last_linear = nn.Linear(nb_ft, num_classes)

    return model

In [None]:
def load_data(df):
    def load_row(row):
        # impath = TRAIN_IMAGES_ROOT/f"{row.primary_label}/{row.filename}.npy"
        return row.filename, str(row.impath)
    pool = joblib.Parallel(4)
    mapper = joblib.delayed(load_row)
    tasks = [mapper(row) for row in df.itertuples(False)]
    res = pool(tqdm(tasks))
    res = dict(res)
    return res

In [None]:
# We cache the train set to reduce training time

audio_image_store = load_data(df)
len(audio_image_store)

In [None]:
#print("shape:", next(iter(audio_image_store.values())).shape)
#lbd.specshow(next(iter(audio_image_store.values()))[0])

In [None]:
pd.Series([len(x) for x in audio_image_store.values()]).value_counts()

In [None]:
class BirdClefDataset(Dataset):

    def __init__(self, audio_image_store, meta, sr=SR, is_train=True, num_classes=NUM_CLASSES, duration=DURATION, 
                 aug_replacement=AUG_REPLACEMENT, aug_rate=AUG_RATE):
        
        self.audio_image_store = audio_image_store
        self.meta = meta.copy().reset_index(drop=True)
        self.sr = sr
        self.is_train = is_train
        self.num_classes = num_classes
        self.duration = duration
        self.audio_length = self.duration*self.sr
        self.aug_replacement = aug_replacement
        self.aug_rate = aug_rate
    
    @staticmethod
    def normalize(image):
        image = image.astype("float32", copy=False) / 255.0
        image = np.stack([image, image, image])
        return image

    def __len__(self):
        return len(self.meta)
    
    def __getitem__(self, idx):
        row = self.meta.iloc[idx]
        impath = self.audio_image_store[row.filename]
        if random.random() > self.aug_rate:
            impath = impath.replace(self.aug_replacement[0], random.choice(self.aug_replacement[1]))
        image = np.load(impath)[:MAX_READ_SAMPLES]

        image = image[np.random.choice(len(image))]
        image = self.normalize(image)
        
        #if USE_1ST_LOSS:
        #    secondary_labels = np.zeros(self.num_classes, dtype=np.float32)
        #    for label in row.secondary_labels:
        #        secondary_labels[LABEL_IDS[label]] = 1.0
        #    all_labels = np.zeros(self.num_classes, dtype=np.float32)
        #    all_labels[row.label_id] = 1.0
        #    all_labels += secondary_labels
        #    t = np.array([all_labels, secondary_labels], dtype=np.float32)
        #else:
        t = np.zeros(self.num_classes, dtype=np.float32) + 0.0025 # Label smoothing
        if SECONDARY_FACTOR > 0:
            for label in row.secondary_labels:
                t[LABEL_IDS[label]] = SECONDARY_FACTOR
        t[row.label_id] = 0.995
        
        return image, t
    
def mono_to_color(X, eps=1e-6, mean=None, std=None):
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)
    
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = np.clip(X, _min, _max)
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(X, dtype=np.uint8)

    return V

def crop_or_pad(y, length):
    if len(y) < length:
        y = np.concatenate([y, length - np.zeros(len(y))])
    elif len(y) > length:
        y = y[:length]
    return y
    
class MelSpecComputer:
    def __init__(self, sr, n_mels, fmin, fmax, **kwargs):
        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax
        kwargs["n_fft"] = kwargs.get("n_fft", self.sr//10)
        kwargs["hop_length"] = kwargs.get("hop_length", self.sr//(10*4))
        self.kwargs = kwargs

    def __call__(self, y):

        melspec = lb.feature.melspectrogram(
            y, sr=self.sr, n_mels=self.n_mels, fmin=self.fmin, fmax=self.fmax, **self.kwargs,
        )

        melspec = lb.power_to_db(melspec).astype(np.float32)
        return melspec
    
class TestBirdCLEFDataset(Dataset):
    def __init__(self, data, sr=SR, n_mels=N_MELS, fmin=0, fmax=None, duration=DURATION, step=None, res_type="kaiser_fast", resample=True):
        
        self.data = data
        
        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax or self.sr//2

        self.duration = duration
        self.audio_length = self.duration*self.sr
        self.step = step or self.audio_length
        
        self.res_type = res_type
        self.resample = resample

        self.mel_spec_computer = MelSpecComputer(sr=self.sr, n_mels=self.n_mels, fmin=self.fmin,
                                                 fmax=self.fmax)
    def __len__(self):
        return len(self.data)
    
    @staticmethod
    def normalize(image):
        image = image.astype("float32", copy=False) / 255.0
        image = np.stack([image, image, image])
        return image
    
    def audio_to_image(self, audio):
        melspec = self.mel_spec_computer(audio) 
        image = mono_to_color(melspec)
        image = self.normalize(image)
        return image

    def read_file(self, filepath):
        audio, orig_sr = sf.read(filepath, dtype="float32")

        if self.resample and orig_sr != self.sr:
            audio = lb.resample(audio, orig_sr, self.sr, res_type=self.res_type)
          
        audios = []
        for i in range(self.audio_length, len(audio) + self.step, self.step):
            start = max(0, i - self.audio_length)
            end = start + self.audio_length
            audios.append(audio[start:end])
            
        if len(audios[-1]) < self.audio_length:
            audios = audios[:-1]
            
        images = [self.audio_to_image(audio) for audio in audios]
        images = np.stack(images)
        
        return images
    
        
    def __getitem__(self, idx):
        return self.read_file(self.data.loc[idx, "filepath"])

In [None]:
ds = BirdClefDataset(audio_image_store, meta=df, sr=SR, duration=DURATION, is_train=True)
len(df)

In [None]:
#x, y = ds[np.random.choice(len(ds))]
# x, y = ds[0]
#x.shape, y.shape, np.where(y >= 0.5)

In [None]:
lbd.specshow(x[0])

In [None]:
y[:5]

# Training the model

In [None]:
TEST_AUDIO_ROOT = Path("../input/birdclef-2021/train_soundscapes")
SAMPLE_SUB_PATH = None
# SAMPLE_SUB_PATH = "../input/birdclef-2021/sample_submission.csv"
TARGET_PATH = Path("../input/birdclef-2021/train_soundscape_labels.csv")
THRESHS = [n/100 for n in range(15, 85)]

df_train = pd.read_csv("../input/birdclef-2021/train_metadata.csv")

LABEL_IDS = {label: label_id for label_id,label in enumerate(sorted(df_train["primary_label"].unique()))}
INV_LABEL_IDS = {val: key for key,val in LABEL_IDS.items()}
LABEL_IDS['rocpig1'] = LABEL_IDS['rocpig']

In [None]:
def one_step( xb, yb, net, criterion, optimizer, scheduler=None):
  try:
      xb, yb = xb.to(DEVICE), yb.to(DEVICE)
  except:
      xb = xb.to(DEVICE)
  optimizer.zero_grad()
  o = net(xb)
  loss = criterion(o, yb)
  loss.backward()
  optimizer.step()
  
  with torch.no_grad():
      l = loss.item()
        
      o = o.sigmoid()
      #if USE_1ST_LOSS:
      #  yb = (yb[:, 0] > 0.5) * 1.0
      #else:
      yb = (yb > 0.5) * 1.0
      lrap = label_ranking_average_precision_score(yb.cpu().numpy(), o.cpu().numpy())

      o = (o > 0.5)*1.0

      prec = (o*yb).sum()/(1e-6 + o.sum())
      rec = (o*yb).sum()/(1e-6 + yb.sum())
      f1 = 2*prec*rec/(1e-6+prec+rec)

  if  scheduler is not None:
    scheduler.step()

  return l, lrap, f1.item(), rec.item(), prec.item()

In [None]:
def predict(nets, test_data, names=True):
    preds = []
    with torch.no_grad():
        for idx in  tqdm(list(range(len(test_data)))):
            xb = torch.from_numpy(test_data[idx]).to(DEVICE)
            pred = 0.
            for i in range(len(nets)):
                o = nets[i](xb)
                o = torch.sigmoid(o)

                pred += o

            pred /= len(nets)
            
            if names:
                pred = get_bird_names(get_thresh_preds(pred))

            preds.append(pred)
    return preds

def get_metrics(s_true, s_pred):
    s_true = set(s_true.split())
    s_pred = set(s_pred.split())
    n, n_true, n_pred = len(s_true.intersection(s_pred)), len(s_true), len(s_pred)
    
    prec = n/n_pred
    rec = n/n_true
    f1 = 2*prec*rec/(prec + rec) if prec + rec else 0
    
    return {"f1": f1, "prec": prec, "rec": rec, "n_true": n_true, "n_pred": n_pred, "n": n}

data = pd.DataFrame(
     [(path.stem, *path.stem.split("_"), path) for path in Path(TEST_AUDIO_ROOT).glob("*.ogg")],
    columns = ["filename", "id", "site", "date", "filepath"]
)
test_data = TestBirdCLEFDataset(data=data)

def preds_as_df(preds, data=data):
    sub = {
        "row_id": [],
        "birds": [],
    }
    
    for row, pred in zip(data.itertuples(False), preds):
        row_id = [f"{row.id}_{row.site}_{5*i}" for i in range(1, len(pred)+1)]
        sub["birds"] += pred
        sub["row_id"] += row_id
        
    sub = pd.DataFrame(sub)
    
    if SAMPLE_SUB_PATH:
        sample_sub = pd.read_csv(SAMPLE_SUB_PATH, usecols=["row_id"])
        sub = sample_sub.merge(sub, on="row_id", how="left")
        sub["birds"] = sub["birds"].fillna("nocall")
    return sub

@torch.no_grad()
def get_thresh_preds(out, thresh=None):
    thresh = thresh or THRESH
    o = (-out).argsort(1)
    npreds = (out > thresh).sum(1)
    preds = []
    for oo, npred in zip(o, npreds):
        preds.append(oo[:npred].cpu().numpy().tolist())
    return preds

def get_bird_names(preds):
    bird_names = []
    for pred in preds:
        if not pred:
            bird_names.append("nocall")
        else:
            bird_names.append(" ".join([INV_LABEL_IDS[bird_id] for bird_id in pred]))
    return bird_names

@torch.no_grad()
def evaluate(net, test_data=test_data):
    net.eval()
    pred_probas = predict([net], test_data, names=False)

    top_metrics = {"f1": 0}
    for thresh in THRESHS:
        preds = [get_bird_names(get_thresh_preds(pred, thresh=thresh)) for pred in pred_probas]
        sub = preds_as_df(preds)
        sub_target = pd.read_csv(TARGET_PATH)
        sub_target = sub_target.merge(sub, how="left", on="row_id")

        #print(sub_target["birds_x"].notnull().sum(), sub_target["birds_x"].notnull().sum())
        assert sub_target["birds_x"].notnull().all()
        assert sub_target["birds_y"].notnull().all()

        df_metrics = pd.DataFrame([get_metrics(s_true, s_pred) for s_true, s_pred in zip(sub_target.birds_x, sub_target.birds_y)])
        metrics = df_metrics.mean().to_dict()
        if top_metrics["f1"] < metrics["f1"]:
            top_metrics = metrics
            top_metrics["thresh"] = thresh
    print("top_metrics:", top_metrics)

    return top_metrics["f1"], top_metrics["prec"], top_metrics["rec"]

In [None]:
def one_epoch(net, criterion, optimizer, scheduler, train_laoder):
  net.train()
  l, lrap, prec, rec, f1, icount = 0.,0.,0.,0., 0., 0
  train_laoder = tqdm(train_laoder, leave = False)
  epoch_bar = train_laoder
  
  for (xb, yb) in  epoch_bar:
      # epoch_bar.set_description("----|----|----|----|---->")
      _l, _lrap, _f1, _rec, _prec = one_step(xb, yb, net, criterion, optimizer)
      l += _l
      lrap += _lrap
      f1 += _f1
      rec += _rec
      prec += _prec

      icount += 1
        
      if hasattr(epoch_bar, "set_postfix") and not icount%10:
          epoch_bar.set_postfix(
            loss="{:.6f}".format(l/icount),
            lrap="{:.3f}".format(lrap/icount),
            prec="{:.3f}".format(prec/icount),
            rec="{:.3f}".format(rec/icount),
            f1="{:.3f}".format(f1/icount),
          )
  
  scheduler.step()

  l /= icount
  lrap /= icount
  f1 /= icount
  rec /= icount
  prec /= icount
  
  f1_val, prec_val, rec_val = evaluate(net)
  
  return (f1, f1_val), (rec, rec_val), (prec, prec_val)

In [None]:
class AutoSave:
  def __init__(self, top_k=5, metric="f1", mode="min", root=None, name="ckpt"):
    self.top_k = top_k
    self.logs = []
    self.metric = metric
    self.mode = mode
    self.root = Path(root or MODEL_ROOT)
    assert self.root.exists()
    self.name = name

    self.top_models = []
    self.top_metrics = []

  def log(self, model, metrics):
    metric = metrics[self.metric]
    rank = self.rank(metric)

    self.top_metrics.insert(rank+1, metric)
    if len(self.top_metrics) > self.top_k:
      self.top_metrics.pop(0)

    self.logs.append(metrics)
    self.save(model, metric, rank, metrics["epoch"])


  def save(self, model, metric, rank, epoch):
    t = time.strftime("%Y%m%d%H%M%S")
    name = "{}_epoch_{:02d}_{}_{:.04f}_{}".format(self.name, epoch, self.metric, metric, t)
    name = re.sub(r"[^\w_-]", "", name) + ".pth"
    path = self.root.joinpath(name)

    old_model = None
    self.top_models.insert(rank+1, name)
    if len(self.top_models) > self.top_k:
      old_model = self.root.joinpath(self.top_models[0])
      self.top_models.pop(0)      

    torch.save(model.state_dict(), path.as_posix())

    if old_model is not None:
      old_model.unlink()

    self.to_json()


  def rank(self, val):
    r = -1
    for top_val in self.top_metrics:
      if val <= top_val:
        return r
      r += 1

    return r
  
  def to_json(self):
    # t = time.strftime("%Y%m%d%H%M%S")
    name = "{}_logs".format(self.name)
    name = re.sub(r"[^\w_-]", "", name) + ".json"
    path = self.root.joinpath(name)

    with path.open("w") as f:
      json.dump(self.logs, f, indent=2)


In [None]:
def one_fold(model_name, fold, train_set, val_set, epochs=20, save=True, save_root=None):

  save_root = Path(save_root) or MODEL_ROOT

  saver = AutoSave(root=save_root, name=f"birdclef_{model_name}_fold{fold}", metric="f1_val")

  net = get_model(model_name).to(DEVICE)

  if USE_FOCAL_LOSS:
    criterion = BCEFocalLoss()
  else:
    criterion = nn.BCEWithLogitsLoss()

  optimizer = optim.Adam(net.parameters(), lr=8e-4)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=epochs)

  train_data = BirdClefDataset(audio_image_store, meta=df.iloc[train_set].reset_index(drop=True),
                           sr=SR, duration=DURATION, is_train=True)
  train_laoder = DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, num_workers=TRAIN_NUM_WORKERS, shuffle=True, pin_memory=True)

  #val_data = BirdClefDataset(audio_image_store, meta=df.iloc[val_set].reset_index(drop=True),  sr=SR, duration=DURATION, is_train=False)
  #val_laoder = DataLoader(val_data, batch_size=VAL_BATCH_SIZE, num_workers=VAL_NUM_WORKERS, shuffle=False)

  epochs_bar = tqdm(list(range(epochs)), leave=False)
  for epoch  in epochs_bar:
    epochs_bar.set_description(f"--> [EPOCH {epoch:02d}]")
    net.train()

    (f1, f1_val), (rec, rec_val), (prec, prec_val) = one_epoch(
        net=net,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        train_laoder=train_laoder,
      )

    epochs_bar.set_postfix(
    #loss="({:.6f}, {:.6f})".format(l, l_val),
    prec="({:.3f}, {:.3f})".format(prec, prec_val),
    rec="({:.3f}, {:.3f})".format(rec, rec_val),
    f1="({:.3f}, {:.3f})".format(f1, f1_val),
    #lrap="({:.3f}, {:.3f})".format(lrap, lrap_val),
    )

    print(
        "[{epoch:02d}] f1: {f1} rec: {rec} prec: {prec}".format(
            epoch=epoch,
            #loss="({:.6f}, {:.6f})".format(l, l_val),
            prec="({:.3f}, {:.3f})".format(prec, prec_val),
            rec="({:.3f}, {:.3f})".format(rec, rec_val),
            f1="({:.3f}, {:.3f})".format(f1, f1_val),
            #lrap="({:.3f}, {:.3f})".format(lrap, lrap_val),
        )
    )

    if save:
      metrics = {
          "f1": f1, "rec": rec, "prec": prec,
          "f1_val": f1_val, "rec_val": rec_val, "prec_val": prec_val,
          "epoch": epoch,
      }

      saver.log(net, metrics)

In [None]:
def train(model_name, epochs=20, save=True, n_splits=5, seed=177, save_root=None, suffix="", folds=None):
  gc.collect()
  torch.cuda.empty_cache()

  save_root = save_root or MODEL_ROOT/f"{model_name}{suffix}"
  save_root.mkdir(exist_ok=True, parents=True)
  
  train_set = df.index

  one_fold(model_name, fold=0, train_set=train_set , val_set=None , epochs=epochs, save=save, save_root=save_root)

  gc.collect()
  torch.cuda.empty_cache()

In [None]:
for i in range(len(MODEL_NAMES)):
  model_name = MODEL_NAMES[i]
  SECONDARY_FACTOR = SECONDARY_FACTORS[i]
  print("\n\n###########################################", model_name.upper())
  try:
    train(model_name, epochs=EPOCHS, suffix=f"_sr{SR}_d{DURATION}_v1_v1", folds=FOLDS)
  except Exception as e:
    # print(f"Error {model_name} : \n{e}")
    raise ValueError() from  e