## Fine-tuning an audio spectrogram transformer (AST) classifier on audio files (converted to spectrograms)

AST is one of the best audio classification techniques today: https://paperswithcode.com/sota/audio-classification-on-audioset

**Paper abstract:**

In the past decade, convolutional neural networks (CNNs) have been widely adopted as the main building block for end-to-end audio classification models, which aim to learn a direct mapping from audio spectrograms to corresponding labels. To better capture long-range global context, a recent trend is to add a self-attention mechanism on top of the CNN, forming a CNN-attention hybrid model. However, it is unclear whether the reliance on a CNN is necessary, and if neural networks purely based on attention are sufficient to obtain good performance in audio classification. In this paper, we answer the question by introducing the Audio Spectrogram Transformer (AST), the first convolution-free, purely attention-based model for audio classification. We evaluate AST on various audio classification benchmarks, where it achieves new state-of-the-art results of 0.485 mAP on AudioSet, 95.6% accuracy on ESC-50, and 98.1% accuracy on Speech Commands V2.

https://arxiv.org/abs/2104.01778

**Results**: ~0.3% accuracy

**Notes**:
- Had to use 16000 Hz sample rate for the audio files to be compatible with the AST feature extractor
- 41min training time even with much reduced data: max. 3 seconds of audio per file (truncated), 5% of total samples only, train on the GPU (the model is quite large: 86'594'063 parameters)
- More data and much more training time would be required for better results

**Conclusion**: Not the model for this competition (requirement: maximum 2h training on CPU)

**Next**: Research ...

In [None]:
import pandas as pd
import numpy as np
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
from IPython.display import Audio
import librosa
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
import torch.nn.functional as F
from transformers import ASTFeatureExtractor, AutoModelForAudioClassification, TrainingArguments, Trainer
import torchaudio 

import random
import glob
import os
import time

import sys
sys.path.append("..")
import utils

In [None]:
RANDOM_SEED = 21

# Set seed for experiment reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
is_in_kaggle_env = utils.get_is_in_kaggle_env()

data_path = '/kaggle/input/birdclef-2023' if is_in_kaggle_env else '../data'

device = 'cpu' if is_in_kaggle_env else utils.determine_device()

if not is_in_kaggle_env and not os.path.exists('../data'):
    print("Downloading data...")
    !kaggle competitions download -c 'birdclef-2023'
    !mkdir ../data
    !unzip -q birdclef-2023.zip -d ../data
    !rm birdclef-2023.zip

df_metadata_csv = pd.read_csv(f"{data_path}/train_metadata.csv")

audio_data_dir = f"{data_path}/train_audio/"

In [None]:
class_counts = df_metadata_csv["primary_label"].value_counts()

two_or_less_samples_rows = df_metadata_csv[df_metadata_csv["primary_label"].isin(class_counts[class_counts < 3].index)]

print(f"Number of unique classes with less than 2 samples: {len(two_or_less_samples_rows['primary_label'].unique())}")
print(f"Number of rows with less than 2 samples: {len(two_or_less_samples_rows)}")
print(f"Primary labels with less than 2 samples: {two_or_less_samples_rows['primary_label'].unique()}")

In [None]:
# Drop rows with primary_label that have two or less samples
print(f"Number of rows before dropping: {len(df_metadata_csv)}")
df_metadata_csv = df_metadata_csv[~df_metadata_csv["primary_label"].isin(class_counts[class_counts < 3].index)]
print(f"Number of rows after dropping: {len(df_metadata_csv)}")

In [None]:
unique_classes = df_metadata_csv.primary_label.unique()
print(f"Number of classes: {len(unique_classes)}")

In [None]:
class BirdClef23Dataset(Dataset):
    def __init__(self, df, audio_data_dir, label_encoder, feature_extractor, seconds=10):
        self.df = df
        self.audio_data_dir = audio_data_dir
        self.label_encoder = label_encoder
        self.feature_extractor = feature_extractor
        self.seconds = seconds

    def __getitem__(self, index):
        audio_path = os.path.join(self.audio_data_dir, self.df.iloc[index, 11])
        audio_numpy, audio_sr = librosa.load(audio_path, sr=16000)

        # Truncate the audio file to {seconds}
        if len(audio_numpy) > audio_sr * self.seconds:
            audio_numpy = audio_numpy[:audio_sr * self.seconds]

        # Use the feature extractor to process the audio_numpy
        inputs = self.feature_extractor(
            audio_numpy,
            sampling_rate=audio_sr,
            max_length=audio_sr * self.seconds,
            padding="max-length",
            return_tensors="pt"
        )
        input_values = inputs.input_values[0]

        primary_label_raw = self.df.iloc[index, 0]
        primary_label = self.label_encoder.transform([primary_label_raw])[0]

        row_id = audio_path.split('/')[-1].split('.')[0]

        return {"row_id": row_id, "input_values": input_values, "labels": primary_label}

    def __len__(self):
        return len(self.df)


def split_df(df, primary_label='primary_label', percentages=[60, 20, 20]):
    """
    - Percentages: [train, valid, test]
    - Splits a dataframe into three dataframes (train, valid, test), stratified by primary_label
    - Also returns the class weights (based on the training set)
    """
    print(f"Splitting dataframe into train {percentages[0]}%, valid {percentages[1]}%, test {percentages[2]}%, stratified by {primary_label}")
    
    train_perc, valid_perc, test_perc = [perc / 100 for perc in percentages]
    train_valid_split = round(train_perc / (train_perc + valid_perc), 2)
    
    temp_df, test_df = train_test_split(df, test_size=test_perc, stratify=df[primary_label], random_state=RANDOM_SEED)
    
    train_df, valid_df = train_test_split(temp_df, test_size=1-train_valid_split, stratify=temp_df[primary_label], random_state=RANDOM_SEED)

    classes = np.unique(train_df[primary_label])
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_df[primary_label])

    return train_df, valid_df, test_df, class_weights


# --- training
ignore_existing_label_encoder = True
if ignore_existing_label_encoder or not os.path.exists('label_encoder.joblib'):
    print('Creating label encoder...')
    label_encoder = LabelEncoder()
    label_encoder.fit(list(unique_classes))
    joblib.dump(label_encoder, 'label_encoder.joblib')
else:
    print('Loading label encoder...')
    label_encoder = joblib.load('label_encoder.joblib')

data_percentage = 5
seconds = 3
batch_size = 16
num_epochs = 1
learning_rate = 0.0005

feature_extractor = ASTFeatureExtractor(
    sampling_rate=16000,
)

train_df, valid_df, test_df, class_weights = split_df(df_metadata_csv)

print(f"Using {data_percentage}% of the data")
train_df = train_df.sample(frac=data_percentage/100, random_state=RANDOM_SEED)
valid_df = valid_df.sample(frac=data_percentage/100, random_state=RANDOM_SEED)
test_df = test_df.sample(frac=data_percentage/100, random_state=RANDOM_SEED)

train_dataset = BirdClef23Dataset(train_df, audio_data_dir, label_encoder, feature_extractor, seconds)
valid_dataset = BirdClef23Dataset(valid_df, audio_data_dir, label_encoder, feature_extractor, seconds)
test_dataset = BirdClef23Dataset(test_df, audio_data_dir, label_encoder, feature_extractor, seconds)

model = AutoModelForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593"
)
print(f"Initialized model {model._get_name()}")

# MPS Training class https://github.com/huggingface/transformers/issues/17971
class TrainingArgumentsWithMPSSupport(TrainingArguments):
    @property
    def device(self) -> torch.device:
        return torch.device(device)

training_args = TrainingArgumentsWithMPSSupport(
    output_dir="birdclef-2023-ast",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

def custom_metrics_fn(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {
        "accuracy": (predictions == labels).mean().item()
    }

trainer = Trainer(
    model=model.to(device),
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=feature_extractor,
    compute_metrics=custom_metrics_fn,
)

trainer.train()

In [None]:
# print the amount of parameters in the model
print(f"Model has {trainer.model.num_parameters()} parameters")