# Fine-Tuning the Audio Spectrogram Transformer (AST) for Audio Classification

This Jupyter Notebook provides a comprehensive guide for fine-tuning the Audio Spectrogram Transformer (AST) model on your own audio classification dataset using tools from the HuggingFace ecosystem and PyTorch. The notebook covers the entire workflow, including data loading, preprocessing, applying audio augmentations, configuring the model, and setting up the training process.

**Published:** 30.07.2024  
**Author:** Marius Steger  
**Email:** [marius.steger@renumics.com](mailto:marius.steger@renumics.com)  
**Organization:** [Renumics](https://renumics.com/)  

## Step 1: Install Required Packages
Before we start, install all the required packages.

In [1]:
# !pip install transformers[torch] datasets[audio] audiomentations

## Step 2: Load Your Data in the Correct Format

In [15]:
from datasets import Dataset, Audio, ClassLabel, Features, load_dataset
import pandas as pd
import os 
import torch

In [3]:
# # Define class labels
# class_labels = ClassLabel(names=["bang", "dog_bark"])

# # Define features with audio and label columns
# features = Features({
#    "audio": Audio(),
#    "labels": class_labels
# })

# # Load data (example with a dictionary)
# dataset = Dataset.from_dict({
#    "audio": [r"C:\Users\pepij\OneDrive - Delft University of Technology\THESIS\data\WAV_Groningen_1\WAV_Groningen_1\Noorderplantsoen\NP101.wav",
#              r"C:\Users\pepij\OneDrive - Delft University of Technology\THESIS\data\WAV_Groningen_1\WAV_Groningen_1\Noorderplantsoen\NP102.wav"],
#    "labels": [0, 1],
# }, features=features)

In [4]:
# Load a pre-existing dataset from the HuggingFace Hub
esc50 = load_dataset("ashraq/esc50", split="train")

Repo card metadata block was not found. Setting CardData to empty.


In [16]:
# Load metadata CSV (assuming dataset includes a CSV file linking audio to perceptual attributes)
metadata = pd.read_excel(r"C:\Users\pepij\Downloads\noorderplantsoen.xlsx")

# Preview dataset
display(metadata.head())

metadata["audio_path"] = metadata["GroupID"].apply(lambda x: r"C:\Users\pepij\OneDrive - Delft University of Technology\THESIS\data\WAV_Groningen_1\WAV_Groningen_1\Noorderplantsoen\NP" + x[2:] + ".wav")

# Keep only rows where the file exists
metadata = metadata[metadata["audio_path"].apply(os.path.exists)]

# Reset index after filtering
metadata.reset_index(drop=True, inplace=True)

metadata = metadata[['GroupID', 'pleasant', 'chaotic', 'vibrant', 'uneventful', 'calm', 'annoying', 'eventful', 'monotonous', 'audio_path']]

columns_to_convert = [
    "pleasant", "chaotic", "vibrant", "uneventful", 
    "calm", "annoying", "eventful", "monotonous"
]

metadata[columns_to_convert] = metadata[columns_to_convert].astype(float).values

Unnamed: 0,LocationID,SessionID,GroupID,RecordID,start_time,end_time,latitude,longitude,Language,Survey_Version,...,RA_cp90,RA_cp95,THD_THD,THD_Min,THD_Max,THD_L5,THD_L10,THD_L50,THD_L90,THD_L95
0,Noorderplantsoen,Noorderplantsoen1,NP101,2,2020-03-11 08:54:00,2020-03-11 09:04:00,,,nld,nldSSIDv1,...,198.0,152.0,-6.0,-1312.0,5543.0,2294.0,1909.0,533.0,-993.0,-1104.0
1,Noorderplantsoen,Noorderplantsoen1,NP101,73,2020-03-13 00:49:00,2020-03-13 00:51:00,,,nld,nldSSIDv1,...,198.0,152.0,-6.0,-1312.0,5543.0,2294.0,1909.0,533.0,-993.0,-1104.0
2,Noorderplantsoen,Noorderplantsoen1,NP102,88,2020-03-13 12:04:00,2020-03-13 12:08:00,,,nld,nldSSIDv1,...,295.0,23.0,-275.0,-1402.0,6462.0,3921.0,323.0,1115.0,-1188.0,-126.0
3,Noorderplantsoen,Noorderplantsoen1,NP103,89,2020-03-13 12:12:00,2020-03-13 12:14:00,,,nld,nldSSIDv1,...,,,,,,,,,,
4,Noorderplantsoen,Noorderplantsoen1,NP106,98,2020-03-13 13:25:00,2020-03-13 13:32:00,,,nld,nldSSIDv1,...,257.0,203.0,-624.0,-737.0,6889.0,2914.0,2397.0,969.0,-352.0,-447.0


In [17]:
class SoundscapeDataset(Dataset):
    def __init__(self, metadata):
        self.metadata = metadata

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        features = self.metadata.iloc[idx]["input_features"]
        labels = torch.tensor(self.metadata.iloc[idx][["pleasant", "vibrant", "eventful", "chaotic", 
                                                       "annoying", "monotonous", "uneventful", "calm"]].values, dtype=torch.float32)
        return features, labels


# Create dataset
dataset = SoundscapeDataset(metadata)

In [19]:
dataset

AttributeError: 'SoundscapeDataset' object has no attribute '_info'

## Step 3: Preprocess the Audio Data

In [26]:
import numpy as np
from datasets import Audio, ClassLabel
from transformers import ASTFeatureExtractor

In [7]:
esc50.features

{'filename': Value(dtype='string', id=None),
 'fold': Value(dtype='int64', id=None),
 'target': Value(dtype='int64', id=None),
 'category': Value(dtype='string', id=None),
 'esc10': Value(dtype='bool', id=None),
 'src_file': Value(dtype='int64', id=None),
 'take': Value(dtype='string', id=None),
 'audio': Audio(sampling_rate=None, mono=True, decode=True, id=None)}

In [94]:
# get target value - class name mappings
df = esc50.select_columns(["target", "category"]).to_pandas()
class_names = df.iloc[np.unique(df["target"], return_index=True)[1]]["category"].to_list()
print(class_names)

# cast target and audio column
esc50 = esc50.cast_column("target", ClassLabel(names=class_names))
esc50 = esc50.cast_column("audio", Audio(sampling_rate=16000))

# rename the target feature
esc50 = esc50.rename_column("target", "labels")
num_labels = len(np.unique(esc50["labels"]))

ValueError: Column name ['target'] not in the dataset. Current columns in the dataset: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'audio'].

In [28]:
# Define the pretrained model and instantiate the feature extractor
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)
model_input_name = feature_extractor.model_input_names[0]
SAMPLING_RATE = feature_extractor.sampling_rate

print(model_input_name, SAMPLING_RATE)

input_values 16000


In [29]:
# Preprocessing function
def preprocess_audio(batch):
    wavs = [audio["array"] for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}

In [30]:
esc50

Dataset({
    features: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'audio'],
    num_rows: 2000
})

In [31]:
# we use the esc50 train split for this tutorial on how to fine-tune the AST Model
dataset = esc50
label2id = dataset.features["labels"]._str2int  # we add the mapping from INTs to STRINGs

In [32]:
print(dataset[0])

{'filename': '1-100032-A-0.wav', 'fold': 1, 'labels': 0, 'category': 'dog', 'esc10': True, 'src_file': 100032, 'take': 'A', 'audio': {'path': None, 'array': array([0., 0., 0., ..., 0., 0., 0.]), 'sampling_rate': 16000}}


In [33]:
# split training data
if "test" not in dataset:
    dataset = dataset.train_test_split(
        test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels")

## Step 4: Add Audio Augmentations

In [34]:
import torch
from audiomentations import Compose, AddGaussianSNR, GainTransition, Gain, ClippingDistortion, TimeStretch, PitchShift

ModuleNotFoundError: No module named 'audiomentations'

In [12]:
# Define audio augmentations
audio_augmentations = Compose([
    AddGaussianSNR(min_snr_db=10, max_snr_db=20),
    Gain(min_gain_db=-6, max_gain_db=6),
    GainTransition(min_gain_db=-6, max_gain_db=6, min_duration=0.01, max_duration=0.3, duration_unit="fraction"),
    ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.2),
    PitchShift(min_semitones=-4, max_semitones=4),
], p=0.8, shuffle=True)

In [None]:
# Preprocessing with augmentations
def preprocess_audio_with_transforms(batch):
    wavs = [audio_augmentations(audio["array"], sample_rate=SAMPLING_RATE) for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}

In [35]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
dataset = dataset.rename_column("audio", "input_values")

In [36]:
# calculate values for normalization
feature_extractor.do_normalize = False  # we set normalization to False in order to calculate the mean + std of the dataset
mean = []
std = []

# we use the transformation w/o augmentation on the training dataset to calculate the mean + std
dataset["train"].set_transform(preprocess_audio, output_all_columns=False)
for i, (audio_input, labels) in enumerate(dataset["train"]):
    cur_mean = torch.mean(dataset["train"][i][audio_input])
    cur_std = torch.std(dataset["train"][i][audio_input])
    mean.append(cur_mean)
    std.append(cur_std)

feature_extractor.mean = np.mean(mean)
feature_extractor.std = np.mean(std)
feature_extractor.do_normalize = True

print("Calculated mean and std:", feature_extractor.mean, feature_extractor.std)

Calculated mean and std: -3.3504603 4.387065


In [63]:
len(dataset['test'].features['labels'].names)

50

In [52]:
dataset['test'][0]['input_values']   #[0,0]

tensor([[-0.2134, -0.6156, -0.2232,  ..., -0.1845, -0.2843, -0.6573],
        [-0.0349, -0.3366,  0.0558,  ..., -0.2055, -0.3688, -0.6004],
        [-0.1835, -0.3470,  0.0454,  ..., -0.2269, -0.2812, -0.5483],
        ...,
        [ 0.3819,  0.3819,  0.3819,  ...,  0.3819,  0.3819,  0.3819],
        [ 0.3819,  0.3819,  0.3819,  ...,  0.3819,  0.3819,  0.3819],
        [ 0.3819,  0.3819,  0.3819,  ...,  0.3819,  0.3819,  0.3819]])

In [42]:
# Apply transforms
# dataset["train"].set_transform(preprocess_audio_with_transforms, output_all_columns=False)
dataset["test"].set_transform(preprocess_audio, output_all_columns=False)

## Step 5: Configure and Initialize the AST for Fine-Tuning

In [64]:
import evaluate
from transformers import ASTConfig, ASTForAudioClassification, TrainingArguments, Trainer

In [65]:
# Load configuration from the pretrained model
config = ASTConfig.from_pretrained(pretrained_model)
config.num_labels = num_labels
config.label2id = label2id
config.id2label = {v: k for k, v in label2id.items()}

In [66]:
# Initialize the model with the updated configuration
model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)
model.init_weights()

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Setup Metrics and Start Training

In [67]:
# Configure training arguments
training_args = TrainingArguments(
    output_dir=f"./runs/ast_classifier",
    logging_dir=f"./logs/ast_classifier",
    report_to="tensorboard",
    learning_rate=5e-5,  # LEARNING RATE
    push_to_hub=False,
    num_train_epochs=3,  # EPOCHS
    per_device_train_batch_size=8,  # BATCH SIZE
    eval_strategy="epoch",
    save_strategy="epoch",
    eval_steps=1,
    save_steps=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",  # eval_+metric ist utilized
    logging_strategy="steps",
    logging_steps=20,
)

In [68]:
# Define evaluation metrics
accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

# setup metrics function
def compute_metrics(eval_pred):
    # get predictions and scores
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)

    # compute metrics
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))

    return metrics

Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 2.14MB/s]
Downloading builder script: 100%|██████████| 7.38k/7.38k [00:00<?, ?B/s]
Downloading builder script: 100%|██████████| 7.56k/7.56k [00:00<?, ?B/s]
Downloading builder script: 100%|██████████| 6.79k/6.79k [00:00<?, ?B/s]


In [69]:
# setup trainer
trainer = Trainer(
    model=model,
    args=training_args,  # we use our configured training arguments
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,  # we the metrics function from above
)

In [70]:
# start a training
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 