In [1]:
import os
from deep_utils import warmup_cosine
from datasets import load_dataset, Audio
from transformers import AutoFeatureExtractor
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

In [2]:
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
feature_extractor



Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000
}

In [5]:
train_path = "../data/train_gender.csv"
test_path = '../data/test_gender.csv'
dataset = load_dataset('csv', data_files={'train': train_path,
                                          'test': test_path})
dataset = dataset.cast_column("audio_path", Audio(sampling_rate=16_000))
dataset["train"][0]

Using custom data configuration default-72d9a721b9c4e5f6
Reusing dataset csv (/home/ai/.cache/huggingface/datasets/csv/default-72d9a721b9c4e5f6/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/2 [00:00<?, ?it/s]

{'audio_path': {'path': '/home/ai/projects/speech/dataset/irancel-voice-dataset/new-raw-dataset/samples_02/samples_02_01/wav_files/0845283_003_00_S00_F.wav',
  'array': array([  -0.060793,   -0.062068,   -0.062516, ...,   -0.077135,   -0.077294,    -0.07786], dtype=float32),
  'sampling_rate': 16000},
 'label': 'female'}

In [7]:
import random
import IPython.display as ipd
import librosa
index = random.randint(0, len(dataset['train']))

path = dataset['train'][index]['audio_path']['path']
waveform, sr = librosa.load(path)
text = dataset['train'][index]['label']
print(text)
ipd.Audio(waveform, rate=sr, autoplay=True)

female


In [8]:
labels = set(dataset["train"]['label'])
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label
label2id

{'female': '0', 'male': '1'}

In [9]:
def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio_path"]]
    inputs = feature_extractor(
        audio_arrays, sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True
    )
    label = [int(label2id[x]) for x in examples["label"]]
    inputs["label"] = label
    return inputs

In [10]:
encoded_dataset = dataset.map(preprocess_function, remove_columns="audio_path", batched=True)
encoded_dataset['train'][0]

  0%|          | 0/3 [00:00<?, ?ba/s]

  tensor = as_tensor(value)


  0%|          | 0/1 [00:00<?, ?ba/s]

{'label': 0,
 'input_values': [-0.301716685295105,
  -0.34659481048583984,
  -0.36235311627388,
  -0.336408406496048,
  -0.27883854508399963,
  -0.22755521535873413,
  -0.20778848230838776,
  -0.22426968812942505,
  -0.2682577669620514,
  -0.31038370728492737,
  -0.33207207918167114,
  -0.3285040259361267,
  -0.30304965376853943,
  -0.27730026841163635,
  -0.2620263993740082,
  -0.2584722638130188,
  -0.268401563167572,
  -0.2824321687221527,
  -0.2997622787952423,
  -0.31838512420654297,
  -0.32690373063087463,
  -0.32586851716041565,
  -0.3136153519153595,
  -0.2931865155696869,
  -0.27628380060195923,
  -0.2629943788051605,
  -0.2589307427406311,
  -0.26843133568763733,
  -0.28585192561149597,
  -0.31021708250045776,
  -0.3278941214084625,
  -0.32716110348701477,
  -0.31459105014801025,
  -0.2982427775859833,
  -0.2947500944137573,
  -0.3089294731616974,
  -0.3239307701587677,
  -0.3295852541923523,
  -0.3178650140762329,
  -0.29343417286872864,
  -0.2734445631504059,
  -0.260920017

In [11]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average="weighted")
    recall = recall_score(labels, predictions, average="weighted")
    precision = precision_score(labels, predictions, average="weighted")

    return {"accuracy": acc, "f1-score": f1, "recall-score": recall, "precision-score": precision}

In [13]:
import math
import torch
from transformers import EarlyStoppingCallback
early_stopping = EarlyStoppingCallback(early_stopping_patience=5)

train_bs = 32 
epochs = 25
lr = 5e-5
lrf = lr
output_dir = "./results"
total_steps = int((np.ceil(encoded_dataset["train"].num_rows / train_bs) * epochs))

num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(
    "facebook/wav2vec2-base", num_labels=num_labels, label2id=label2id, id2label=id2label
)

training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=epochs,
    report_to="tensorboard",
    load_best_model_at_end=True,
    save_total_limit=1,
    metric_for_best_model='loss',
    per_device_train_batch_size = train_bs,
    per_device_eval_batch_size = 64,
    logging_steps=1,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_cosine(total_steps//10,
                                                                       max_lr=lr,
                                                                       total_steps=total_steps,
                                                                       optimizer_lr=lr,
                                                                       min_lr=1e-6))
# reduce lr with a cosine annealing if total_steps is set to total_steps
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, scheduler)
)

trainer.train()
trainer.save_model(os.path.join(output_dir, "best"))

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['quantizer.weight_proj.weight', 'project_q.bias', 'project_hid.weight', 'project_hid.bias', 'project_q.weight', 'quantizer.codevectors', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['projector.weight', 'classifier.weight', 'projecto

Epoch,Training Loss,Validation Loss,Accuracy,F1-score,Recall-score,Precision-score
1,0.4252,0.527729,0.646976,0.511197,0.646976,0.771954
2,0.2949,0.208912,0.946554,0.946251,0.946554,0.946541
3,0.3302,0.416463,0.887482,0.881644,0.887482,0.903025
4,0.0944,0.199156,0.945148,0.944322,0.945148,0.947085
5,0.1522,0.186297,0.954993,0.954644,0.954993,0.955373
6,0.3935,0.187644,0.95218,0.951956,0.95218,0.952147
7,0.1658,0.176772,0.956399,0.956039,0.956399,0.956898
8,0.0081,0.19694,0.954993,0.955032,0.954993,0.955088
9,0.1856,0.238893,0.912799,0.913863,0.912799,0.919218
10,0.2357,0.206266,0.949367,0.94913,0.949367,0.949306


***** Running Evaluation *****
  Num examples = 711
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-89
Configuration saved in ./results/checkpoint-89/config.json
Model weights saved in ./results/checkpoint-89/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-89/preprocessor_config.json
Deleting older checkpoint [results/checkpoint-1869] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 711
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-178
Configuration saved in ./results/checkpoint-178/config.json
Model weights saved in ./results/checkpoint-178/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-178/preprocessor_config.json
Deleting older checkpoint [results/checkpoint-2225] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 711
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-267
Configuration saved in ./results/checkpoint-267/config.json
Model weig

  Batch size = 64
Saving model checkpoint to ./results/checkpoint-1869
Configuration saved in ./results/checkpoint-1869/config.json
Model weights saved in ./results/checkpoint-1869/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-1869/preprocessor_config.json
Deleting older checkpoint [results/checkpoint-1780] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 711
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-1958
Configuration saved in ./results/checkpoint-1958/config.json
Model weights saved in ./results/checkpoint-1958/pytorch_model.bin
Feature extractor saved in ./results/checkpoint-1958/preprocessor_config.json
Deleting older checkpoint [results/checkpoint-1869] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 711
  Batch size = 64
Saving model checkpoint to ./results/checkpoint-2047
Configuration saved in ./results/checkpoint-2047/config.json
Model weights saved in ./results/checkpoint-2047

In [None]:
import torchaudio
import torch
import librosa
device = "cpu"
model = model.to(device)
waveform, sr = librosa.load("../audio_samples/man_02.mp4")
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = torchaudio.transforms.Resample(sr, 16_000)(waveform)
inputs = feature_extractor(waveform, sampling_rate=feature_extractor.sampling_rate,
                           max_length=16000, truncation=True)
tensor = torch.tensor(inputs['input_values'][0]).to(device)
with torch.no_grad():
    output = model(tensor)
    logits = output['logits'][0]
    label_id = torch.argmax(logits).item()
label_name = id2label[str(label_id)]
print(label_name)