### **Fine-tuning for Audio Classification with AST 🤗 Transformers**

In [1]:
model_checkpoint = "MIT/ast-finetuned-audioset-10-10-0.4593"
batch_size = 2

### Loading the dataset

In [2]:
from datasets import load_dataset, load_metric
from transformers import AutoProcessor, ASTConfig, ASTModel, AutoFeatureExtractor, ASTForAudioClassification, TrainingArguments, Trainer

import torch
import torch.nn as nn
import yaml
import os
import torch
import torchaudio
import numpy as np
import datasets

from datasets import ClassLabel

In [3]:
with open("./confs/default.yaml", "r") as f:
        configs = yaml.safe_load(f)

In [4]:
data_files = {"train": configs["data"]["synth_tsv"], "test": configs["data"]["synth_val_tsv"]}
raw_dataset = load_dataset("csv", data_files=data_files, sep = "\t")
metric = load_metric("accuracy")

Found cached dataset csv (/home/unegi/.cache/huggingface/datasets/csv/default-03475b778f293dce/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/2 [00:00<?, ?it/s]

  metric = load_metric("accuracy")


In [5]:
def process_data(example, audio_file_dir, num_samples, transformation, labels2id, target_sample_rate=16000):
    # This function will do what your __getitem__ method does
    
    # Load audio
    filepath = os.path.join(audio_file_dir, example['filename'])
    signal, sr = torchaudio.load(filepath)
    signal = signal.squeeze().numpy()
    #signal = signal.view(-1)
    label = example["event_label"]
    label_int = float(labels2id[label])
    example["label"] = label_int
    #example['label'] = torch.tensor(label_int)
    
    example["audio"] = {"array": np.array(signal), "path": filepath, "sampling_rate": sr}
    return example

In [6]:
from collections import OrderedDict


labels2id = OrderedDict(
    {
        "Alarm_bell_ringing": 0,
        "Blender": 1,
        "Cat": 2,
        "Dishes": 3,
        "Dog": 4,
        "Electric_shaver_toothbrush": 5,
        "Frying": 6,
        "Running_water": 7,
        "Speech": 8,
        "Vacuum_cleaner": 9,
    }
)

id2labels = {value: key for key, value in labels2id.items()}

In [8]:
SAMPLE_RATE = configs["data"]["fs"]
N_FFT = configs["feats"]["n_window"]
WIN_LENGTH = configs["feats"]["n_window"]
HOP_LENGTH = configs["feats"]["hop_length"]
F_MIN = configs["feats"]["f_min"]
F_MAX = configs["feats"]["f_max"]
N_MELS = configs["feats"]["n_mels"]
WINDOW_FN = torch.hamming_window
WKWARGS = {"periodic": False}
POWER = 1
NUM_SAMPLES = SAMPLE_RATE

In [8]:
partial_raw_train_dataset = raw_dataset["train"].select(range(10)).map(process_data, 
                                           fn_kwargs={'audio_file_dir': configs["data"]["synth_folder"], 
                                                      'num_samples': NUM_SAMPLES,
                                                       'transformation': None,
                                                        'labels2id': labels2id,
                                                        'target_sample_rate':16000, 
                                                        })

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [9]:
partial_raw_test_dataset = raw_dataset["test"].select(range(5)).map(process_data, 
                                           fn_kwargs={'audio_file_dir': configs["data"]["synth_folder"], 
                                                      'num_samples': NUM_SAMPLES,
                                                       'transformation': None,
                                                        'labels2id': labels2id,
                                                        'target_sample_rate':16000, 
                                                        })

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [9]:
raw_train_dataset = raw_dataset["train"].map(process_data, 
                                           fn_kwargs={'audio_file_dir': configs["data"]["synth_folder"], 
                                                      'num_samples': NUM_SAMPLES,
                                                       'transformation': None,
                                                        'labels2id': labels2id,
                                                        'target_sample_rate':16000, 
                                                        })

Loading cached processed dataset at /home/unegi/.cache/huggingface/datasets/csv/default-03475b778f293dce/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-324f9b638ea6090b.arrow


In [10]:
raw_test_dataset = raw_dataset["test"].map(process_data, 
                                           fn_kwargs={'audio_file_dir': configs["data"]["synth_folder"], 
                                                      'num_samples': NUM_SAMPLES,
                                                       'transformation': None,
                                                        'labels2id': labels2id,
                                                        'target_sample_rate':16000, 
                                                        })

Loading cached processed dataset at /home/unegi/.cache/huggingface/datasets/csv/default-03475b778f293dce/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-9dccc9695f7cf439.arrow


In [10]:
partial_raw_train_dataset.features

{'filename': Value(dtype='string', id=None),
 'onset': Value(dtype='float64', id=None),
 'offset': Value(dtype='float64', id=None),
 'event_label': Value(dtype='string', id=None),
 'label': Value(dtype='float64', id=None),
 'audio': {'array': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
  'path': Value(dtype='string', id=None),
  'sampling_rate': Value(dtype='int64', id=None)}}

In [11]:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint, device_map='cpu')

In [12]:
partial_raw_train_dataset_2 = partial_raw_train_dataset.remove_columns(["onset", "offset", "event_label"])

In [13]:
partial_raw_test_dataset_2 = partial_raw_test_dataset.remove_columns(["onset", "offset", "event_label"])

In [12]:
raw_train_dataset_2 = raw_train_dataset.remove_columns(["onset", "offset", "event_label"])

In [13]:
raw_test_dataset_2 = raw_test_dataset.remove_columns(["onset", "offset", "event_label"])

In [14]:
raw_train_dataset_2 = raw_train_dataset_2.cast_column("label", ClassLabel(names = list(labels2id.keys()) ))

Loading cached processed dataset at /home/unegi/.cache/huggingface/datasets/csv/default-03475b778f293dce/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-264ba74a4b672df3.arrow


In [15]:
raw_test_dataset_2 = raw_test_dataset_2.cast_column("label", ClassLabel(names = list(labels2id.keys()) ))

Loading cached processed dataset at /home/unegi/.cache/huggingface/datasets/csv/default-03475b778f293dce/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-79248f4d9aa294c1.arrow


In [14]:
partial_raw_train_dataset_2 = partial_raw_train_dataset_2.cast_column(
        "audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
    )
partial_raw_train_dataset_2 = partial_raw_train_dataset_2.cast_column("label", ClassLabel(names = list(labels2id.keys()) ))

Casting the dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

In [15]:
partial_raw_test_dataset_2 = partial_raw_test_dataset_2.cast_column(
        "audio", datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
    )
partial_raw_test_dataset_2 = partial_raw_test_dataset_2.cast_column("label", ClassLabel(names = list(labels2id.keys()) ))

Casting the dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

In [16]:
partial_raw_test_dataset_2.features

{'filename': Value(dtype='string', id=None),
 'label': ClassLabel(names=['Alarm_bell_ringing', 'Blender', 'Cat', 'Dishes', 'Dog', 'Electric_shaver_toothbrush', 'Frying', 'Running_water', 'Speech', 'Vacuum_cleaner'], id=None),
 'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None)}

In [17]:
partial_raw_train_dataset_2.features["label"]

ClassLabel(names=['Alarm_bell_ringing', 'Blender', 'Cat', 'Dishes', 'Dog', 'Electric_shaver_toothbrush', 'Frying', 'Running_water', 'Speech', 'Vacuum_cleaner'], id=None)

We can see that the audio file has automatically been loaded. This is thanks to the new [`"Audio"` feature](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=audio#datasets.Audio) introduced in `datasets == 1.13.3`, which loads and resamples audio files on-the-fly upon calling.

The sampling rate is set to 16kHz which is what `Wav2Vec2` expects as an input.

In [19]:
import random
from IPython.display import Audio, display

for _ in range(1):
    rand_idx = random.randint(0, len(partial_raw_train_dataset_2)-1)
    example = partial_raw_train_dataset_2[rand_idx]
    audio = example["audio"]
    print(f'Label: {id2labels[(example["label"])]}')
    print(f'Shape: {audio["array"].shape}, sampling rate: {audio["sampling_rate"]}')
    display(Audio(audio["array"], rate=audio["sampling_rate"]))
    print()

Label: Speech
Shape: (160000,), sampling rate: 16000





### Preprocessing the data

In [16]:
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
feature_extractor

ASTFeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "ASTFeatureExtractor",
  "feature_size": 1,
  "max_length": 1024,
  "mean": -4.2677393,
  "num_mel_bins": 128,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000,
  "std": 4.5689974
}

In [17]:
max_duration = 3.0  # seconds

In [18]:
def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=int(feature_extractor.sampling_rate * max_duration),
        truncation=True,
    )
    return inputs

In [None]:
preprocess_function(raw_train_dataset_2[:5])

In [20]:
#encoded_dataset_train = partial_raw_train_dataset_2.map(preprocess_function, remove_columns=["audio", "filename"], batched=True)
encoded_dataset_train = raw_train_dataset_2.map(preprocess_function, remove_columns=["audio", "filename"], batched=True)
encoded_dataset_train

Map:   0%|          | 0/32096 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_values'],
    num_rows: 32096
})

In [21]:
#encoded_dataset_test = partial_raw_test_dataset_2.map(preprocess_function, remove_columns=["audio", "filename"], batched=True)
encoded_dataset_test = raw_test_dataset_2.map(preprocess_function, remove_columns=["audio", "filename"], batched=True)
encoded_dataset_test

Map:   0%|          | 0/8132 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_values'],
    num_rows: 8132
})

### Training the model

In [22]:
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

num_labels = len(id2labels)
model = AutoModelForAudioClassification.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    label2id=labels2id,
    id2label=id2labels,
    ignore_mismatched_sizes=True
)


Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([10]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
# Set the device to CPU
device = torch.device("cpu")

In [24]:
model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    f"{model_name}-finetuned-ks",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

In [26]:
import numpy as np

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [113]:
partial_raw_train_dataset_2

Dataset({
    features: ['filename', 'label', 'audio'],
    num_rows: 10
})

In [27]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset_train,
    eval_dataset=encoded_dataset_test,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

[codecarbon INFO @ 20:15:31] [setup] RAM Tracking...
[codecarbon INFO @ 20:15:31] [setup] GPU Tracking...
[codecarbon INFO @ 20:15:31] No GPU found.
[codecarbon INFO @ 20:15:31] [setup] CPU Tracking...
[codecarbon INFO @ 20:15:33] CPU Model on constant consumption mode: Intel(R) Core(TM) i7-6700 CPU @ 3.40GHz
[codecarbon INFO @ 20:15:33] >>> Tracker's metadata:
[codecarbon INFO @ 20:15:33]   Platform system: Linux-5.19.0-43-generic-x86_64-with-glibc2.10
[codecarbon INFO @ 20:15:33]   Python version: 3.8.5
[codecarbon INFO @ 20:15:33]   CodeCarbon version: 2.2.4
[codecarbon INFO @ 20:15:33]   Available RAM : 7.623 GB
[codecarbon INFO @ 20:15:33]   CPU count: 8
[codecarbon INFO @ 20:15:33]   CPU model: Intel(R) Core(TM) i7-6700 CPU @ 3.40GHz
[codecarbon INFO @ 20:15:33]   GPU count: None
[codecarbon INFO @ 20:15:33]   GPU model: None


In [28]:
trainer.train()

[codecarbon INFO @ 20:15:52] Energy consumed for RAM : 0.000012 kWh. RAM Power : 2.858745574951172 W
[codecarbon INFO @ 20:15:52] Energy consumed for all CPUs : 0.000135 kWh. Total CPU Power : 32.5 W
[codecarbon INFO @ 20:15:52] 0.000147 kWh of electricity used since the beginning.
[codecarbon INFO @ 20:16:07] Energy consumed for RAM : 0.000024 kWh. RAM Power : 2.858745574951172 W
[codecarbon INFO @ 20:16:07] Energy consumed for all CPUs : 0.000271 kWh. Total CPU Power : 32.5 W
[codecarbon INFO @ 20:16:07] 0.000295 kWh of electricity used since the beginning.


Epoch,Training Loss,Validation Loss


[codecarbon INFO @ 20:16:22] Energy consumed for RAM : 0.000036 kWh. RAM Power : 2.858745574951172 W
[codecarbon INFO @ 20:16:22] Energy consumed for all CPUs : 0.000406 kWh. Total CPU Power : 32.5 W
[codecarbon INFO @ 20:16:22] 0.000442 kWh of electricity used since the beginning.
[codecarbon INFO @ 20:16:37] Energy consumed for RAM : 0.000048 kWh. RAM Power : 2.858745574951172 W
[codecarbon INFO @ 20:16:37] Energy consumed for all CPUs : 0.000542 kWh. Total CPU Power : 32.5 W
[codecarbon INFO @ 20:16:37] 0.000589 kWh of electricity used since the beginning.
[codecarbon INFO @ 20:16:52] Energy consumed for RAM : 0.000060 kWh. RAM Power : 2.858745574951172 W
[codecarbon INFO @ 20:16:52] Energy consumed for all CPUs : 0.000677 kWh. Total CPU Power : 32.5 W
[codecarbon INFO @ 20:16:52] 0.000737 kWh of electricity used since the beginning.
[codecarbon INFO @ 20:17:07] Energy consumed for RAM : 0.000071 kWh. RAM Power : 2.858745574951172 W
[codecarbon INFO @ 20:17:07] Energy consumed for a

In [33]:
trainer.evaluate()

{'eval_loss': 1.7789722681045532,
 'eval_accuracy': 0.4,
 'eval_runtime': 11.8693,
 'eval_samples_per_second': 0.421,
 'eval_steps_per_second': 0.253,
 'epoch': 4.0}

#### Prediction

In [52]:
idx = random.randint(0, len(partial_raw_train_dataset_2)-1)
os.listdir(configs["data"]["synth_folder"])[idx]

'183.wav'

In [53]:
test_file = os.path.join(configs["data"]["synth_folder"], os.listdir(configs["data"]["synth_folder"])[idx])
test_file

'../../data/dcase/dataset/dcase_synth/audio/train/synthetic21_train/soundscapes_16k/183.wav'

In [54]:
waveform, sampling_rate = torchaudio.load(test_file)
waveform = waveform.squeeze().numpy()

In [55]:
inputs = feature_extractor(waveform, sampling_rate=sampling_rate, padding="max_length", return_tensors="pt")
input_values = inputs.input_values

In [56]:
with torch.no_grad():
  outputs = model(input_values)

In [57]:
from IPython.display import Audio

Audio(test_file)

In [58]:
predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: Frying
