## RegNetY_800MF CNN transfer learning, audio files as spectrograms

https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_800mf.html#torchvision.models.RegNet_Y_800MF_Weights

**Results**: 48%, still learning after epoch 5 (tuning potential).

**Conclusion**: Due to the low amount of trainable parameters and yet great results, this could be the ideal CNN for this problem.

**Next**: Feature preparation because it looks like there's still lots of noise in the data.

In [None]:
import pandas as pd
import numpy as np
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.models import regnet_y_800mf, RegNet_Y_800MF_Weights
from torchvision import transforms
from IPython.display import Audio
import librosa
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
import torch.nn.functional as F
from tqdm import tqdm

import random
import glob
import os
import time

import sys
sys.path.append("..")
import utils

In [None]:
RANDOM_SEED = 21

# Set seed for experiment reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
is_in_kaggle_env = utils.get_is_in_kaggle_env()

data_path = '/kaggle/input/birdclef-2023' if is_in_kaggle_env else '../data'

device = 'cpu' if is_in_kaggle_env else utils.determine_device()

if not is_in_kaggle_env and not os.path.exists('../data'):
    print("Downloading data...")
    !kaggle competitions download -c 'birdclef-2023'
    !mkdir ../data
    !unzip -q birdclef-2023.zip -d ../data
    !rm birdclef-2023.zip

df_metadata_csv = pd.read_csv(f"{data_path}/train_metadata.csv")

audio_data_dir = f"{data_path}/train_audio/"

In [None]:
class_counts = df_metadata_csv["primary_label"].value_counts()

two_or_less_samples_rows = df_metadata_csv[df_metadata_csv["primary_label"].isin(class_counts[class_counts < 3].index)]

print(f"Number of unique classes with less than 2 samples: {len(two_or_less_samples_rows['primary_label'].unique())}")
print(f"Number of rows with less than 2 samples: {len(two_or_less_samples_rows)}")
print(f"Primary labels with less than 2 samples: {two_or_less_samples_rows['primary_label'].unique()}")

In [None]:
# Drop rows with primary_label that have two or less samples
print(f"Number of rows before dropping: {len(df_metadata_csv)}")
df_metadata_csv = df_metadata_csv[~df_metadata_csv["primary_label"].isin(class_counts[class_counts < 3].index)]
print(f"Number of rows after dropping: {len(df_metadata_csv)}")

In [None]:
unique_classes = df_metadata_csv.primary_label.unique()
print(f"Number of classes: {len(unique_classes)}")

In [None]:
class BirdClef23Dataset(Dataset):
    def __init__(self, df, audio_data_dir, label_encoder, seconds, n_mels, device, pad_method = 'wrap'):
        self.df = df
        self.audio_data_dir = audio_data_dir
        self.label_encoder = label_encoder
        self.seconds = seconds
        self.n_mels = n_mels
        self.device = device
        self.pad_method = pad_method

    def __getitem__(self, index):
        audio_path = os.path.join(self.audio_data_dir, self.df.iloc[index, 11])
        audio_numpy, audio_sr = librosa.load(audio_path, sr=32000)

        if audio_sr != 32000:
            raise ValueError(f"Sample rate is not 32000, it is {audio_sr} for {audio_path}")

        # Increase audio length if below {seconds} by padding
        if audio_numpy.shape[0] < 32000 * self.seconds:
            padding_needed = int(32000 * self.seconds - audio_numpy.shape[0])
            
            pad_width = (0, padding_needed)
            
            # wrap means copy the audio until the length is reached
            if self.pad_method in ['wrap', 'wrap_double_reflect']:
                audio_numpy = np.pad(audio_numpy, pad_width, 'wrap')

            # constant means pad with constant value, here 0 --> 1-2% less accuracy
            if self.pad_method == 'constant_zero':
                audio_numpy = np.pad(audio_numpy, pad_width, 'constant', constant_values=0)

            # reflect means the vector mirrored
            if self.pad_method == 'reflect':
                audio_numpy = np.pad(audio_numpy, pad_width, 'reflect')

        # Truncate audio length if above {seconds}
        if audio_numpy.shape[0] > 32000 * self.seconds:
            audio_numpy = audio_numpy[:32000 * self.seconds]
            
            # max_start_idx = audio_numpy.shape[0] - (32000 * self.seconds)
            # start_idx = np.random.randint(0, max_start_idx)
            # audio_numpy = audio_numpy[start_idx:start_idx + (32000 * self.seconds)]

        # Create a mirrored version of the audio_numpy array and concatenate it
        if self.pad_method == 'wrap_double_reflect':
            mirrored_audio_numpy = audio_numpy[::-1]
            audio_numpy = np.concatenate((audio_numpy, mirrored_audio_numpy))

        # What is a mel-scaled spectrogram? https://www.youtube.com/watch?v=PYlr8ayHb4g
        # Compute mel-scaled spectrogram and convert to log scale (dB) https://librosa.org/doc/latest/generated/librosa.feature.melspectrogram.html
        mel_spectrogram = librosa.feature.melspectrogram(y=audio_numpy, sr=audio_sr, n_mels=self.n_mels)
        log_mel_spectrogram = librosa.amplitude_to_db(mel_spectrogram)
        log_mel_spectrogram_norm = utils.normalize_spectrogram(log_mel_spectrogram)

        audio_numpy = log_mel_spectrogram_norm.reshape((1, log_mel_spectrogram_norm.shape[0], log_mel_spectrogram_norm.shape[1]))
        
        audio_tensor = torch.from_numpy(audio_numpy).float().to(self.device)

        primary_label_raw = self.df.iloc[index, 0]
        primary_label = self.label_encoder.transform([primary_label_raw])[0]

        row_id = audio_path.split('/')[-1].split('.')[0]

        return row_id, audio_tensor, primary_label
    
    def __len__(self):
        return len(self.df)


def get_data_loader(dataset, batch_size=32, data_percentage=None, shuffle=False, pin_memory=False):
    if data_percentage is not None:
        data_len = int(len(dataset) * data_percentage)
        dataset, _ = random_split(dataset, [data_len, len(dataset) - data_len])

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)

    return data_loader


def split_df(df, primary_label='primary_label', percentages=[60, 20, 20]):
    """
    - Percentages: [train, valid, test]
    - Splits a dataframe into three dataframes (train, valid, test), stratified by primary_label
    - Also returns the class weights (based on the training set)
    """
    print(f"Splitting dataframe into train {percentages[0]}%, valid {percentages[1]}%, test {percentages[2]}%, stratified by {primary_label}")
    
    train_perc, valid_perc, test_perc = [perc / 100 for perc in percentages]
    train_valid_split = round(train_perc / (train_perc + valid_perc), 2)
    
    temp_df, test_df = train_test_split(df, test_size=test_perc, stratify=df[primary_label], random_state=RANDOM_SEED)
    
    train_df, valid_df = train_test_split(temp_df, test_size=1-train_valid_split, stratify=temp_df[primary_label], random_state=RANDOM_SEED)

    classes = np.unique(train_df[primary_label])
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_df[primary_label])

    return train_df, valid_df, test_df, class_weights


class RegnetCNN(torch.nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        # https://pytorch.org/vision/stable/models.html
        self.regnet = regnet_y_800mf(weights=RegNet_Y_800MF_Weights.DEFAULT)

        """
        Replace the stem to take 1 channel instead of 3. The original stem:
        RegnetCNN(
        (regnet): RegNet(
            (stem): SimpleStemIN(
            (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
        )"""
        self.regnet.stem = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
        )
        
        # Fine-tune the regnet classifier
        self.regnet.fc = nn.Sequential(
            nn.Linear(self.regnet.fc.in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.PReLU(),
            nn.Linear(512, num_classes),
        )

        self.softmax = nn.Softmax(dim=1)
 
    def forward(self, x):
        logits = self.regnet(x)
        probas = self.softmax(logits)

        return logits, probas


def train(model, train_loader, valid_loader, loss_func, optimizer, num_epochs, validate_on_train, scheduler, device):
    minibatch_loss, train_acc_lst, valid_acc_lst, train_loss_lst, valid_loss_lst = [], [], [], [], []
    
    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}/{num_epochs}")
        model.train()
        
        # use tqdm to show progress bar
        for row_id, audio_tensor, primary_label in tqdm(train_loader, total=len(train_loader), desc="Training batches"):

            features = audio_tensor.to(device)
            targets = primary_label.to(device)

            logits, probas = model(features)

            loss = loss_func(logits, targets)

            optimizer.zero_grad()
            
            loss.backward()

            minibatch_loss.append(loss.item())
            
            optimizer.step()
            
        train_acc, train_loss = validate(model, train_loader, loss_func) if validate_on_train else (torch.tensor(0.0), torch.tensor(0.0))
        train_acc_lst.append(train_acc)
        train_loss_lst.append(train_loss)

        valid_acc, valid_loss = validate(model, valid_loader, loss_func)
        valid_acc_lst.append(valid_acc)
        valid_loss_lst.append(valid_loss)

        if scheduler is not None:
            scheduler.step(valid_loss)

        print(f"Finsished epoch {epoch+1}/{num_epochs}. Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Train Accuracy: {train_acc:.2f}%, Valid Accuracy: {valid_acc:.2f}%")
              
    return minibatch_loss, train_loss_lst, valid_loss_lst, train_acc_lst, valid_acc_lst


def validate(model, data_loader, loss_fn=F.cross_entropy):
    model.eval()
    
    num_examples, correct_pred, cross_entropy = 0.0, 0.0, 0.0

    with torch.no_grad():
        for row_id, audio_tensor, primary_label in tqdm(data_loader, total=len(data_loader), desc='Validation batches'):
            features = audio_tensor.to(device)
            targets = primary_label.to(device)

            logits, probas = model(features)
            cross_entropy += loss_fn(logits, targets)

            _, predicted_labels = torch.max(probas, 1)
            num_examples += targets.size(0)

            correct_pred += (predicted_labels == targets).sum()

    accuracy = correct_pred / num_examples * 100
    loss = cross_entropy / num_examples
    return accuracy, loss


# --- training
ignore_existing_label_encoder = True
if ignore_existing_label_encoder or not os.path.exists('label_encoder.joblib'):
    print('Creating label encoder...')
    label_encoder = LabelEncoder()
    label_encoder.fit(list(unique_classes))
    joblib.dump(label_encoder, 'label_encoder.joblib')
else:
    print('Loading label encoder...')
    label_encoder = joblib.load('label_encoder.joblib')

train_df, valid_df, test_df, class_weights = split_df(df_metadata_csv)

seconds = 20 # 24 is the median - but 20 has better results
batch_size = 8
data_percentage = .5 # 1 means 100% of the data
num_epochs = 2
n_mels = 128 # 128 is the default value in librosa
learning_rate = 0.00008
# device = 'cpu'
pad_method = 'reflect' # wrap, constant_zero, reflect, wrap_double_reflect

pin_memory = True # https://discuss.pytorch.org/t/when-to-set-pin-memory-to-true/19723/6
validate_on_train = False

train_dataset = BirdClef23Dataset(train_df, audio_data_dir, label_encoder, seconds, n_mels, device, pad_method=pad_method)
valid_dataset = BirdClef23Dataset(valid_df, audio_data_dir, label_encoder, seconds, n_mels, device, pad_method=pad_method)
# test_dataset = BirdClef23Dataset(test_df, audio_data_dir, label_encoder, seconds, n_mels, device, pad_method=pad_method)

train_loader = get_data_loader(train_dataset, batch_size, data_percentage, shuffle=True, pin_memory=pin_memory)
valid_loader = get_data_loader(valid_dataset, batch_size, data_percentage, shuffle=False, pin_memory=pin_memory)
# test_loader = get_data_loader(test_dataset, batch_size, data_percentage, shuffle=False, pin_memory=pin_memory)

model = RegnetCNN(num_classes=len(unique_classes)).to(device)
print(f"Initialized model {model._get_name()}, trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# print(model)

loss_function = torch.nn.CrossEntropyLoss(weight=torch.from_numpy(class_weights).float().to(device))
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.0001) # worse accuracy
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)
scheduler = None

print(f"Seconds: {seconds}, batch_size: {batch_size}, data_percentage: {data_percentage}, num_epochs: {num_epochs}, n_mels: {n_mels}, learning_rate: {learning_rate}, pin_memory: {pin_memory}, validate_on_train: {validate_on_train}, device: {device}, pad_method: {pad_method}")

minibatch_loss, train_loss_lst, valid_loss_lst, train_acc_lst, valid_acc_lst = train(model, train_loader, valid_loader, loss_function, optimizer, num_epochs, validate_on_train, scheduler, device)

In [None]:
utils.plot_minibatch_loss(minibatch_loss)

In [None]:
utils.plot_train_and_valid_loss_and_accuracy(train_loss_lst, valid_loss_lst, train_acc_lst, valid_acc_lst)