## RegNet Training

In [None]:
import pandas as pd
import numpy as np
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.models import regnet_y_800mf, RegNet_Y_800MF_Weights
from torchvision import transforms
from IPython.display import Audio
import librosa
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
import torch.nn.functional as F
from torchaudio import functional as F_audio
from tqdm import tqdm
import timm
import torchaudio
import torchaudio.transforms as T_audio

import copy
import random
import glob
import os
import time
import sys
import re

In [None]:
RANDOM_SEED = 21

# Set seed for experiment reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
## REUSE IN INFERENCE NOTEBOOK

custom_dataset_path = '/kaggle/input/birdclef2023-inference'
if os.path.exists(os.path.join(custom_dataset_path, 'utils.py')):
    sys.path.append(custom_dataset_path)
else:
    sys.path.append('..')
import utils

IS_IN_KAGGLE_ENV = utils.get_is_in_kaggle_env()

DATA_PATH = '/kaggle/input/birdclef-2023' if IS_IN_KAGGLE_ENV else '../data'
JOBLIB_PATH = custom_dataset_path if IS_IN_KAGGLE_ENV else './'

DEVICE = 'cpu' if IS_IN_KAGGLE_ENV else utils.determine_device()

AUDIO_LENGTH_S = 5
SAMPLE_RATE = 32_000

In [None]:
if not IS_IN_KAGGLE_ENV and not os.path.exists(DATA_PATH):
    print("Downloading data ...")
    !kaggle competitions download -c 'birdclef-2023'
    !mkdir ../data
    !unzip -q birdclef-2023.zip -d ../data
    !rm birdclef-2023.zip

DF_METADATA_CSV = pd.read_csv(f"{DATA_PATH}/train_metadata.csv")

AUDIO_DATA_DIR = f"{DATA_PATH}/train_audio/"

In [None]:
# Rows with 1 sample: copy twice, rows with 2 samples: copy once
# This is to ensure that stratified sampling for train/valid/test splits works correctly
def ensure_min_two_samples(DF_METADATA_CSV):
    class_counts = DF_METADATA_CSV["primary_label"].value_counts()

    one_sample_rows = DF_METADATA_CSV[DF_METADATA_CSV["primary_label"].isin(class_counts[class_counts == 1].index)]
    print(f"Number of rows with 1 sample: {len(one_sample_rows)}")
    if len(one_sample_rows) > 0:
        DF_METADATA_CSV = pd.concat([DF_METADATA_CSV, one_sample_rows, one_sample_rows], ignore_index=True)
    
    two_sample_rows = DF_METADATA_CSV[DF_METADATA_CSV["primary_label"].isin(class_counts[class_counts == 2].index)]
    print(f"Number of rows with 2 samples: {len(two_sample_rows)}")
    if len(two_sample_rows) > 0:
        DF_METADATA_CSV = pd.concat([DF_METADATA_CSV, two_sample_rows], ignore_index=True)

    return DF_METADATA_CSV

DF_METADATA_CSV = ensure_min_two_samples(DF_METADATA_CSV)

# Run again to verify
DF_METADATA_CSV = ensure_min_two_samples(DF_METADATA_CSV)

In [None]:
## REUSE IN INFERENCE NOTEBOOK

class BirdMelspecClf(nn.Module):
    def __init__(self, out_features, pretrained):
        super().__init__()
        
        # https://pytorch.org/vision/stable/models.html

        self.regnet = regnet_y_800mf(weights=RegNet_Y_800MF_Weights.DEFAULT) if pretrained else regnet_y_800mf()

        """
        Replace the stem to take 1 channel instead of 3. The original stem:
        RegnetCNN(
        (regnet): RegNet(
            (stem): SimpleStemIN(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
        )"""
        self.regnet.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
        )
        
        # Replace original classifier: (fc): Linear(in_features=784, out_features=1000, bias=True)
        self.regnet.fc = nn.Linear(self.regnet.fc.in_features, out_features)

        self.softmax = nn.Softmax(dim=1)
 
    def forward(self, x):
        logits = self.regnet(x)
        probas = self.softmax(logits)

        return logits, probas


def get_model(out_features, device, pretrained=False, load_state_dict=True, state_dict_starts_with=f"{AUDIO_LENGTH_S}s_regnetY800MF"):
    model = BirdMelspecClf(out_features=out_features, pretrained=pretrained)
    print(f"Loaded model {model.__class__.__name__} with {sum(p.numel() for p in model.parameters())} parameters, pretained={pretrained}")
    model.to(device)

    if not load_state_dict:
        return model

    model_files = [f for f in os.listdir(JOBLIB_PATH) if f.startswith(state_dict_starts_with) and f.endswith('.pt')]
    if len(model_files) == 0:
        print(f"No model starting with {state_dict_starts_with} found in {JOBLIB_PATH}")
        return model
    
    # Extract timestamp from the filenames and sort based on it
    model_files.sort(key=lambda x: int(re.findall(r'\d+', x)[-1]) if re.findall(r'\d+', x) else -1)

    # The latest model file is the last one in the sorted list
    latest_model_file = model_files[-1]
    model_path = os.path.join(JOBLIB_PATH, latest_model_file)
    model.load_state_dict(torch.load(model_path))
    print(f"Loaded model weights from {model_path}")
    model.to(device)

    return model

In [None]:
UNIQUE_CLASSES = DF_METADATA_CSV.primary_label.unique()
N_CLASSES = len(UNIQUE_CLASSES)
print(f"Number of classes: {N_CLASSES}")

In [None]:
train_classes_paths = glob.glob(f"{DATA_PATH}/train_audio/*")
len(train_classes_paths), train_classes_paths[0]

In [None]:
print(f"Total samples in dataset: {len(DF_METADATA_CSV)}")

In [None]:
class_with_fewest_samples = DF_METADATA_CSV.primary_label.value_counts(sort=True).index[-1]
print(f"Class with fewest samples: {class_with_fewest_samples}: {DF_METADATA_CSV.primary_label.value_counts().min()} samples")

In [None]:
## REUSE IN INFERENCE NOTEBOOK

def resample(audio, current_sample_rate, desired_sample_rate=SAMPLE_RATE):
    resampler = torchaudio.transforms.Resample(orig_freq=current_sample_rate, new_freq=desired_sample_rate)
    resampled_audio = resampler(audio)
    return resampled_audio

def load_audio(audio_path, sample_rate=SAMPLE_RATE):
    audio, sr = torchaudio.load(audio_path)
    if sr != sample_rate:
        audio = resample(audio, sr, sample_rate)
    return audio

# Using librosa defaults for n_fft and hop_length
def get_melspec_transform(sample_rate=SAMPLE_RATE, n_fft=2048, hop_length=512, n_mels=128):
    return torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
    )

# Using librosa defaults for top_db
def get_melspec_db_transform(stype='power', top_db=80):
    return torchaudio.transforms.AmplitudeToDB(
        stype=stype,
        top_db=top_db
    )

# Copied from torchaudio/transforms/_transforms.py (to avoid converting to melspec twice)
dct_mat = F_audio.create_dct(40, 128, "ortho")
def get_mfcc_from_melspec(melspec):
    return torch.matmul(melspec.transpose(-1, -2), dct_mat).transpose(-1, -2)

def normalize_tensor(tensor):
    min_val = torch.min(tensor)
    max_val = torch.max(tensor)
    if max_val - min_val == 0:
        return tensor
    else:
        return (tensor - min_val) / (max_val - min_val)

In [None]:
class BirdClef23Dataset(Dataset):
    def __init__(self, df, audio_data_dir, label_encoder, n_mels, device, pad_method='wrap', transforms=None):
        self.df = df
        self.audio_data_dir = audio_data_dir
        self.label_encoder = label_encoder
        self.device = device
        self.pad_method = pad_method
        self.melspec_transform = get_melspec_transform(n_mels=n_mels)
        self.melspec_db_transform = get_melspec_db_transform()
        self.transforms = transforms

    def __getitem__(self, index):
        audio_path = os.path.join(self.audio_data_dir, self.df.iloc[index, 11])
        audio = load_audio(audio_path, SAMPLE_RATE)

        # Increase audio length if below {AUDIO_LENGTH_S} by padding
        if audio.shape[1] < AUDIO_LENGTH_S * SAMPLE_RATE:
            padding_needed = AUDIO_LENGTH_S * SAMPLE_RATE - audio.shape[1]

            if self.pad_method == 'wrap':
                audio = F.pad(audio, (0, padding_needed), mode='replicate')

            elif self.pad_method == 'zeros':
                audio = F.pad(audio, (0, padding_needed), mode='constant', value=0)


        # Truncate audio length if above {AUDIO_LENGTH_S} by random cropping
        if audio.shape[1] > AUDIO_LENGTH_S * SAMPLE_RATE:
            max_start_idx = audio.shape[1] - (SAMPLE_RATE * AUDIO_LENGTH_S)
            start_idx = torch.randint(0, max_start_idx, (1,)).item()
            audio = audio[:, start_idx:start_idx + (SAMPLE_RATE * AUDIO_LENGTH_S)]

        melspec = self.melspec_db_transform(self.melspec_transform(audio))
        norm_melspec = normalize_tensor(melspec)

        # mfcc = get_mfcc_from_melspec(melspec)
        # norm_mfcc = normalize_tensor(mfcc)

        # features = torch.cat((norm_melspec, norm_mfcc), dim=1)

        features = norm_melspec

        if self.transforms:
            features = self.transforms(features)
        
        primary_label_raw = self.df.iloc[index, 0]
        primary_label = self.label_encoder.transform([primary_label_raw])[0]

        return features, primary_label

    def __len__(self):
        return len(self.df)


def get_data_loader(dataset, batch_size=32, data_percentage=None, shuffle=False, pin_memory=False):
    if data_percentage is not None:
        data_len = int(len(dataset) * data_percentage)
        dataset, _ = random_split(dataset, [data_len, len(dataset) - data_len])

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)

    return data_loader


def split_df(df, primary_label='primary_label', percentages=[60, 20, 20]):
    """
    - Percentages: [train, valid, test]
    - Splits a dataframe into three dataframes (train, valid, test), stratified by primary_label
    - Also returns the class weights (based on the training set)
    """
    print(f"Splitting dataframe into train {percentages[0]}%, valid {percentages[1]}%, test {percentages[2]}%, stratified by {primary_label}")
    
    train_perc, valid_perc, test_perc = [perc / 100 for perc in percentages]
    train_valid_split = round(train_perc / (train_perc + valid_perc), 2)
    
    if test_perc == 0:
        temp_df = df
        test_df = None
    else:
        temp_df, test_df = train_test_split(df, test_size=test_perc, stratify=df[primary_label], random_state=RANDOM_SEED)
    
    train_df, valid_df = train_test_split(temp_df, test_size=1-train_valid_split, stratify=temp_df[primary_label], random_state=RANDOM_SEED)

    classes = np.unique(train_df[primary_label])
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_df[primary_label])

    return train_df, valid_df, test_df, class_weights


def convert_to_three_digits(num):
    if num < 100:
        return str(num).zfill(3)
    else:
        return str(num)
    

def train(model, train_loader, valid_loader, loss_func, optimizer, num_epochs, validate_on_train, validate_train_pct, scheduler, device):
    minibatch_loss, train_acc_lst, valid_acc_lst, train_loss_lst, valid_loss_lst, best_valid_acc = [], [], [], [], [], 0.0
    train_start_ts = str(time.time()).split('.')[0]

    # Initial validation step for models with loaded weights, so new weights are only saved if they improve on the loaded ones
    valid_acc, valid_loss = validate(model, device, valid_loader, loss_func)
    best_valid_acc = valid_acc
    print(f"Initial validation accuracy: {valid_acc:.4f}%")
    
    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}/{num_epochs}")
        model.train()
        
        for melspec, primary_label in tqdm(train_loader, total=len(train_loader), desc="Training batches"):
            features = melspec.to(device)
            targets = primary_label.to(device)

            logits, probas = model(features)

            loss = loss_func(logits, targets)

            optimizer.zero_grad()
            
            loss.backward()

            minibatch_loss.append(loss.item())
            
            optimizer.step()
        
        train_acc, train_loss = validate(model, device, train_loader, loss_func, validate_train_pct) if validate_on_train else (torch.tensor(0.0), torch.tensor(0.0))
        valid_acc, valid_loss = validate(model, device, valid_loader, loss_func)

        train_acc_lst.append(train_acc)
        train_loss_lst.append(train_loss)
        valid_acc_lst.append(valid_acc)
        valid_loss_lst.append(valid_loss)

        if valid_acc > best_valid_acc:
            best_valid_acc = valid_acc
            model_name = f"{AUDIO_LENGTH_S}s_regnetY800MF_e{epoch}_valacc{valid_acc:.0f}_traacc{train_acc:.0f}_{train_start_ts}{convert_to_three_digits(epoch)}.pt"
            print(f"Moving model to CPU and saving it as {model_name} ...")
            model_cpu = copy.deepcopy(model)
            model_cpu.to('cpu')
            model_path = os.path.join(JOBLIB_PATH, model_name)
            torch.save(model_cpu.state_dict(), model_path)

        if scheduler is not None:
            lr_before = optimizer.param_groups[0]['lr']
            scheduler.step(valid_acc)
            lr_after = optimizer.param_groups[0]['lr']
            if lr_before != lr_after:
                print(f"Learning rate changed from {lr_before} to {lr_after}")

        print(f"Finsished epoch {epoch+1}/{num_epochs}. Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Train Accuracy: {train_acc:.2f}%, Valid Accuracy: {valid_acc:.2f}%")
              
    return minibatch_loss, train_loss_lst, valid_loss_lst, train_acc_lst, valid_acc_lst


def validate(model, device, data_loader, loss_fn=F.cross_entropy, data_percentage=None):
    model.eval()
    
    num_examples, correct_pred, cross_entropy = 0.0, 0.0, 0.0

    with torch.no_grad():
        data_loader_len = len(data_loader) if data_percentage is None else int(len(data_loader) * data_percentage)
        i = 0
        for melspec, primary_label in tqdm(data_loader, total=data_loader_len, desc='Validation batches'):
            if data_percentage is not None and i >= data_loader_len:
                break
            i += 1

            features = melspec.to(device)
            targets = primary_label.to(device)

            logits, probas = model(features)
            cross_entropy += loss_fn(logits, targets)

            _, predicted_labels = torch.max(probas, 1)

            num_examples += targets.size(0)

            correct_pred += (predicted_labels == targets).sum()

    accuracy = correct_pred / num_examples * 100
    loss = cross_entropy / num_examples
    return accuracy, loss


# --- training
print('Creating label encoder...')
label_encoder = LabelEncoder()
label_encoder.fit(list(UNIQUE_CLASSES))
joblib.dump(label_encoder, 'label_encoder.joblib')

train_df, valid_df, test_df, class_weights = split_df(DF_METADATA_CSV, percentages=[0.8, 0.2, 0])

batch_size = 8
data_percentage = 1 # 1 means 100% of the data
num_epochs = 20
n_mels = 256 # 128 is the default value in librosa
learning_rate = 0.0001
# DEVICE = 'cpu'
pad_method = 'wrap' # wrap, zeros

augmentation_transforms = nn.Sequential(
    T_audio.TimeStretch(0.8, fixed_rate=True),
    T_audio.FrequencyMasking(freq_mask_param=30),
    T_audio.TimeMasking(time_mask_param=80),
)
pin_memory = True # https://discuss.pytorch.org/t/when-to-set-pin-memory-to-true/19723/6
validate_on_train = True
validate_train_pct = 0.333

train_dataset = BirdClef23Dataset(train_df, AUDIO_DATA_DIR, label_encoder, n_mels, DEVICE, pad_method=pad_method, transforms=augmentation_transforms)
valid_dataset = BirdClef23Dataset(valid_df, AUDIO_DATA_DIR, label_encoder, n_mels, DEVICE, pad_method=pad_method)

train_loader = get_data_loader(train_dataset, batch_size, data_percentage, shuffle=True, pin_memory=pin_memory)
valid_loader = get_data_loader(valid_dataset, batch_size, data_percentage, shuffle=False, pin_memory=pin_memory)

model = get_model(out_features=N_CLASSES, device=DEVICE, pretrained=False, load_state_dict=True, state_dict_starts_with='5s_regnetY800MF_e')
# model = get_model(out_features=N_CLASSES, device=DEVICE, pretrained=True, load_state_dict=False)

loss_function = torch.nn.CrossEntropyLoss(weight=torch.from_numpy(class_weights).float().to(DEVICE))
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.9, patience=2, verbose=True)
# scheduler = None

print(f"batch_size: {batch_size}, data_percentage: {data_percentage}, num_epochs: {num_epochs}, n_mels: {n_mels}, learning_rate: {learning_rate}, pin_memory: {pin_memory}, \nvalidate_on_train: {validate_on_train}, device: {DEVICE}, pad_method: {pad_method}, validate_train_pct: {validate_train_pct}")

minibatch_loss, train_loss_lst, valid_loss_lst, train_acc_lst, valid_acc_lst = train(model, train_loader, valid_loader, loss_function, optimizer, num_epochs, validate_on_train, validate_train_pct, scheduler, DEVICE)

In [None]:
utils.plot_minibatch_loss(minibatch_loss)

In [None]:
utils.plot_train_and_valid_loss_and_accuracy(train_loss_lst, valid_loss_lst, train_acc_lst, valid_acc_lst)