# Fine-Tuning the Audio Spectrogram Transformer (AST) for Audio Classification

This Jupyter Notebook provides a comprehensive guide for fine-tuning the Audio Spectrogram Transformer (AST) model on your own audio classification dataset using tools from the HuggingFace ecosystem and PyTorch. The notebook covers the entire workflow, including data loading, preprocessing, applying audio augmentations, configuring the model, and setting up the training process.

**Published:** 30.07.2024  
**Author:** Marius Steger  
**Email:** [marius.steger@renumics.com](mailto:marius.steger@renumics.com)  
**Organization:** [Renumics](https://renumics.com/)  

## Step 1: Install Required Packages
Before we start, install all the required packages.

In [16]:
# !pip install transformers[torch] datasets[audio]


SyntaxError: invalid syntax (2312878907.py, line 2)

In [17]:
pip install audiomentations

Collecting audiomentations
  Downloading audiomentations-0.40.0-py3-none-any.whl.metadata (11 kB)
Collecting numpy<2,>=1.22.0 (from audiomentations)
  Using cached numpy-1.26.4-cp311-cp311-win_amd64.whl.metadata (61 kB)
Collecting numpy-minmax<1,>=0.3.0 (from audiomentations)
  Downloading numpy_minmax-0.4.0-cp311-cp311-win_amd64.whl.metadata (4.3 kB)
Collecting numpy-rms<1,>=0.4.2 (from audiomentations)
  Downloading numpy_rms-0.5.0-cp311-cp311-win_amd64.whl.metadata (3.6 kB)
Collecting librosa!=0.10.0,<0.11.0,>=0.8.0 (from audiomentations)
  Downloading librosa-0.10.2.post1-py3-none-any.whl.metadata (8.6 kB)
Collecting python-stretch<1,>=0.3.1 (from audiomentations)
  Downloading python_stretch-0.3.1-cp311-cp311-win_amd64.whl.metadata (3.7 kB)
INFO: pip is looking at multiple versions of numpy-minmax to determine which version is compatible with other requirements. This could take a while.
Collecting numpy-minmax<1,>=0.3.0 (from audiomentations)
  Downloading numpy_minmax-0.3.1-cp311

  You can safely remove it manually.
  You can safely remove it manually.


## Step 2: Load Your Data in the Correct Format

In [3]:
from datasets import Dataset, Audio, ClassLabel, Features, load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Define class labels
#class_labels = ClassLabel(names=["bang", "dog_bark"])

# Define features with audio and label columns
#features = Features({
#    "audio": Audio(),
#    "labels": class_labels
#})

# Load data (example with a dictionary)
#dataset = Dataset.from_dict({
#    "audio": ["/audio/fold1/7061-6-0-0.wav", "/audio/fold1/7383-3-0-0.wav"],
#    "labels": [0, 1],
#}, features=features)

In [45]:
# Load a pre-existing dataset from the HuggingFace Hub
esc50 = load_dataset("ashraq/esc50", split="train")

Repo card metadata block was not found. Setting CardData to empty.


## Step 3: Preprocess the Audio Data

In [46]:
import numpy as np
from datasets import Audio, ClassLabel
from transformers import ASTFeatureExtractor

In [47]:
# get target value - class name mappings
df = esc50.select_columns(["target", "category"]).to_pandas()
class_names = df.iloc[np.unique(df["target"], return_index=True)[1]]["category"].to_list()

# cast target and audio column
esc50 = esc50.cast_column("target", ClassLabel(names=class_names))
esc50 = esc50.cast_column("audio", Audio(sampling_rate=16000))

# rename the target feature
esc50 = esc50.rename_column("target", "labels")
num_labels = len(np.unique(esc50["labels"]))

In [48]:
# Define the pretrained model and instantiate the feature extractor
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)
model_input_name = feature_extractor.model_input_names[0]
SAMPLING_RATE = feature_extractor.sampling_rate

In [49]:
# Preprocessing function
def preprocess_audio(batch):
    print(batch)
    wavs = [audio["array"] for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}

In [50]:
# we use the esc50 train split for this tutorial on how to fine-tune the AST Model
dataset = esc50
label2id = dataset.features["labels"]._str2int  # we add the mapping from INTs to STRINGs

In [51]:
# split training data
if "test" not in dataset:
    dataset = dataset.train_test_split(
        test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels")

In [52]:
dataset

DatasetDict({
    train: Dataset({
        features: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'audio'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'audio'],
        num_rows: 400
    })
})

## Step 4: Add Audio Augmentations

In [18]:
import torch
from audiomentations import Compose, AddGaussianSNR, GainTransition, Gain, ClippingDistortion, TimeStretch, PitchShift

In [19]:
# Define audio augmentations
audio_augmentations = Compose([
    AddGaussianSNR(min_snr_db=10, max_snr_db=20),
    Gain(min_gain_db=-6, max_gain_db=6),
    GainTransition(min_gain_db=-6, max_gain_db=6, min_duration=0.01, max_duration=0.3, duration_unit="fraction"),
    ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.2),
    PitchShift(min_semitones=-4, max_semitones=4),
], p=0.8, shuffle=True)

In [20]:
# Preprocessing with augmentations
def preprocess_audio_with_transforms(batch):
    wavs = [audio_augmentations(audio["array"], sample_rate=SAMPLING_RATE) for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}

In [21]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
dataset = dataset.rename_column("audio", "input_values")

In [22]:
# calculate values for normalization
feature_extractor.do_normalize = False  # we set normalization to False in order to calculate the mean + std of the dataset
mean = []
std = []

# we use the transformation w/o augmentation on the training dataset to calculate the mean + std
dataset["train"].set_transform(preprocess_audio, output_all_columns=False)
for i, (audio_input, labels) in enumerate(dataset["train"]):
    cur_mean = torch.mean(dataset["train"][i][audio_input])
    cur_std = torch.std(dataset["train"][i][audio_input])
    mean.append(cur_mean)
    std.append(cur_std)

feature_extractor.mean = np.mean(mean)
feature_extractor.std = np.mean(std)
feature_extractor.do_normalize = True

print("Calculated mean and std:", feature_extractor.mean, feature_extractor.std)

Calculated mean and std: -3.3504603 4.387065


In [43]:
# Apply transforms
dataset["train"].set_transform(preprocess_audio_with_transforms, output_all_columns=False)
dataset["test"].set_transform(preprocess_audio, output_all_columns=False)

In [44]:
dataset

DatasetDict({
    train: Dataset({
        features: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'input_values'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['filename', 'fold', 'labels', 'category', 'esc10', 'src_file', 'take', 'input_values'],
        num_rows: 400
    })
})

## Step 5: Configure and Initialize the AST for Fine-Tuning

In [24]:
import evaluate
from transformers import ASTConfig, ASTForAudioClassification, TrainingArguments, Trainer

In [25]:
# Load configuration from the pretrained model
config = ASTConfig.from_pretrained(pretrained_model)
config.num_labels = num_labels
config.label2id = label2id
config.id2label = {v: k for k, v in label2id.items()}

In [26]:
# Initialize the model with the updated configuration
model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)
model.init_weights()

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Setup Metrics and Start Training

In [None]:
# FAST TRAINING SETUP
training_args = TrainingArguments(
    output_dir="./runs/ast_classifier_quick",
    logging_dir="./logs/ast_classifier_quick",
    report_to="none",  # disable TensorBoard if not needed
    learning_rate=1e-4,  # optionally increase learning rate for faster convergence
    push_to_hub=False,
    num_train_epochs=1,  # reduce epochs
    per_device_train_batch_size=2,  # reduce batch size
    eval_strategy="steps",  # more frequent eval if needed
    save_strategy="no",  # skip saving for fast testing
    eval_steps=10,
    logging_strategy="steps",
    logging_steps=5,
    load_best_model_at_end=False,  # skip loading best
    disable_tqdm=False,
)


In [28]:
# Define evaluation metrics
accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

# setup metrics function
def compute_metrics(eval_pred):
    # get predictions and scores
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)

    # compute metrics
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))

    return metrics

In [35]:
# setup trainer
trainer = Trainer(
    model=model,
    args=training_args,  # we use our configured training arguments
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,  # we the metrics function from above
)

In [36]:
# start a training
trainer.train()

Step,Training Loss,Validation Loss


KeyboardInterrupt: 