In [2]:
from datasets import load_dataset, Audio

gtzan = load_dataset("marsyas/gtzan", "all", trust_remote_code=True)

model_name = "MIT/ast-finetuned-audioset-16-16-0.442"

gtzan = gtzan["train"].train_test_split(seed=42, shuffle=True, test_size=0.1)
gtzan


DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'genre'],
        num_rows: 899
    })
    test: Dataset({
        features: ['file', 'audio', 'genre'],
        num_rows: 100
    })
})

In [4]:
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained(
    model_name, do_normalize=True, return_attention_mask=False
)

sampling_rate = feature_extractor.sampling_rate

gtzan = gtzan.cast_column("audio", Audio(sampling_rate=sampling_rate))

In [5]:
max_duration = 30.0


def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=int(feature_extractor.sampling_rate * max_duration),
        truncation=True,
        return_attention_mask=False,
    )
    return inputs

gtzan_encoded = gtzan.map(
    preprocess_function,
    remove_columns=["audio", "file"],
    batched=True,
    batch_size=100,
    num_proc=1,
)

id2label_fn = gtzan["train"].features["genre"].int2str

gtzan_encoded = gtzan_encoded.rename_column("genre", "label")

id2label = {
    str(i): id2label_fn(i)
    for i in range(len(gtzan_encoded["train"].features["label"].names))
}
label2id = {v: k for k, v in id2label.items()}

Map:   0%|          | 0/899 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [7]:
from transformers import AutoModelForAudioClassification

num_labels = len(id2label)

model = AutoModelForAudioClassification.from_pretrained(
    model_name,
    ignore_mismatched_sizes=True,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
)

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-16-16-0.442 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
from transformers import TrainingArguments
import torch

model_name = model_name.split("/")[-1]
batch_size = 20
gradient_accumulation_steps = 1
num_train_epochs = 10

training_args = TrainingArguments(
    f"{model_name}-finetuned-gtzan",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=False,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_train_epochs,
    warmup_ratio=0.1,
    logging_steps=5,
    optim="adamw_8bit",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    bf16_full_eval=torch.cuda.is_bf16_supported(),
    fp16_full_eval=not torch.cuda.is_bf16_supported(),
    push_to_hub=False,
)

In [10]:
import evaluate
import numpy as np
from transformers import Trainer

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)


trainer = Trainer(
    model,
    training_args,
    train_dataset=gtzan_encoded["train"],
    eval_dataset=gtzan_encoded["test"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

In [11]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88
5,0.0291,0.369513,0.9


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88
5,0.0291,0.369513,0.9
6,0.0409,0.3071,0.91


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88
5,0.0291,0.369513,0.9
6,0.0409,0.3071,0.91
7,0.0152,0.348184,0.92


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88
5,0.0291,0.369513,0.9
6,0.0409,0.3071,0.91
7,0.0152,0.348184,0.92
8,0.0003,0.318657,0.94


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88
5,0.0291,0.369513,0.9
6,0.0409,0.3071,0.91
7,0.0152,0.348184,0.92
8,0.0003,0.318657,0.94
9,0.0003,0.325778,0.93


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88
5,0.0291,0.369513,0.9
6,0.0409,0.3071,0.91
7,0.0152,0.348184,0.92
8,0.0003,0.318657,0.94
9,0.0003,0.325778,0.93
10,0.0004,0.331511,0.93


Epoch,Training Loss,Validation Loss,Accuracy
1,0.8802,0.526695,0.85
2,0.3183,0.589314,0.81
3,0.1094,0.442113,0.89
4,0.0259,0.410036,0.88
5,0.0291,0.369513,0.9
6,0.0409,0.3071,0.91
7,0.0152,0.348184,0.92
8,0.0003,0.318657,0.94
9,0.0003,0.325778,0.93
10,0.0004,0.331511,0.93


TrainOutput(global_step=450, training_loss=0.21270850697066634, metrics={'train_runtime': 683.8016, 'train_samples_per_second': 13.147, 'train_steps_per_second': 0.658, 'total_flos': 6.056103807025152e+17, 'train_loss': 0.21270850697066634, 'epoch': 10.0})

In [12]:
kwargs = {
    "dataset_tags": "marsyas/gtzan",
    "dataset": "GTZAN",
    "model_name": f"{model_name}-finetuned-gtzan",
    "finetuned_from": model_name,
    "tasks": "audio-classification",
}

trainer.push_to_hub(**kwargs)

Non-default generation parameters: {'max_length': 1024}


model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.05k [00:00<?, ?B/s]

events.out.tfevents.1713982047.LAPTOP-IV5HBHI1.55916.0:   0%|          | 0.00/27.6k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Ostixe360/ast-finetuned-audioset-16-16-0.442-finetuned-gtzan/commit/995dee18745a15568d1776b19efc2e9226bbca3c', commit_message='End of training', commit_description='', oid='995dee18745a15568d1776b19efc2e9226bbca3c', pr_url=None, pr_revision=None, pr_num=None)