In [1]:
# pip install librosa

In [2]:
# pip install soundfile

In [3]:
# pip install accelerate -U

In [4]:
# pip install wandb

In [5]:
# wandb login

In [6]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor
import evaluate
import numpy as np
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
from enum import Enum
import random
import torch
from pydub import AudioSegment
import os



In [7]:
class Model(Enum):
    FacebookWav2Vec2 = 1
    HUBERT = 2

In [8]:
SEED = 1
SPLIT_SILENCE = False
DATASET_PATH = 'train/audio'
LEARNING_RATE = 1e-4
PER_DEVICE_TRAIN_BATCH_SIZE = 32
GRADIENT_ACCUMULATION_STEPS = 4
PER_DEVICE_EVAL_BATCH_SIZE = 32
NUM_TRAIN_EPOCHS = 5
WARMUP_RATIO = 0.1
LOGGING_STEPS = 10
MODEL = Model.HUBERT
MODEL_NAMES = { Model.FacebookWav2Vec2: "Wav2Vec-LR", Model.HUBERT: "HUBERT-LR" }
MODEL_NAME = MODEL_NAMES[MODEL]

In [9]:
random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7d4e367a43f0>

In [10]:
k = 0

def split_audio(file_path, output_folder, k):
    audio = AudioSegment.from_file(file_path)
    length_ms = len(audio)

    for i in range(0, length_ms, 1000):
        end = i + 1000

        if end > length_ms:
            end = length_ms

        chunk = audio[i:end]
        chunk_name = f"{output_folder}/chunk_{k:03d}.wav"
        chunk.export(chunk_name, format="wav")

        k += 1

    print(f"Audio split into {length_ms//1000} chunks.")

    return k

if SPLIT_SILENCE:
    for file in os.listdir(f"{DATASET_PATH}/_background_noise_/"):
        if file.endswith(".wav"):
            k = split_audio(f"{DATASET_PATH}/_background_noise_/{file}", f"{DATASET_PATH}/silence", k)

In [11]:
data = load_dataset(DATASET_PATH, split='train')
data = data.train_test_split(test_size=0.2, seed=SEED)

Resolving data files:   0%|          | 0/65123 [00:00<?, ?it/s]

In [12]:
data["train"][0]

{'audio': {'path': '/home/wojtek/Studia/dlm2/train/audio/marvin/3a789a0d_nohash_1.wav',
  'array': array([-0.02600098, -0.02432251, -0.02545166, ..., -0.02835083,
         -0.0284729 , -0.02923584]),
  'sampling_rate': 16000},
 'label': 12}

In [13]:
data['test'][0]

{'audio': {'path': '/home/wojtek/Studia/dlm2/train/audio/yes/ec74a8a5_nohash_1.wav',
  'array': array([-9.15527344e-05, -9.15527344e-05, -6.10351562e-05, ...,
         -5.79833984e-04, -2.44140625e-04, -3.66210938e-04]),
  'sampling_rate': 16000},
 'label': 29}

In [14]:
labels = data["train"].features["label"].names
labels

['bed',
 'bird',
 'cat',
 'dog',
 'down',
 'eight',
 'five',
 'four',
 'go',
 'happy',
 'house',
 'left',
 'marvin',
 'nine',
 'no',
 'off',
 'on',
 'one',
 'right',
 'seven',
 'sheila',
 'silence',
 'six',
 'stop',
 'three',
 'tree',
 'two',
 'up',
 'wow',
 'yes',
 'zero']

In [15]:
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

id2label[str(2)]

'cat'

In [16]:
model_name = "facebook/wav2vec2-base" if MODEL == Model.FacebookWav2Vec2 else "facebook/hubert-base-ls960"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays, sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True
    )
    return inputs

data = data.map(preprocess_function, remove_columns="audio", batched=True)



In [17]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

In [18]:
num_labels = len(id2label)
model = AutoModelForAudioClassification.from_pretrained(
    model_name, num_labels=num_labels, label2id=label2id, id2label=id2label
)

training_args = TrainingArguments(
    output_dir=MODEL_NAME,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    warmup_ratio=WARMUP_RATIO,
    logging_steps=LOGGING_STEPS,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data["train"].with_format("torch"),
    eval_dataset=data["test"].with_format("torch"),
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

trainer.train()

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwojtek2288[0m ([33mwojtek2288-org[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/2035 [00:00<?, ?it/s]

  return F.conv1d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


{'loss': 3.4345, 'grad_norm': 0.9672372937202454, 'learning_rate': 4.901960784313726e-06, 'epoch': 0.02}
{'loss': 3.4269, 'grad_norm': 0.7214069962501526, 'learning_rate': 9.803921568627451e-06, 'epoch': 0.05}
{'loss': 3.4193, 'grad_norm': 0.6514737606048584, 'learning_rate': 1.4705882352941177e-05, 'epoch': 0.07}
{'loss': 3.4044, 'grad_norm': 0.6940809488296509, 'learning_rate': 1.9607843137254903e-05, 'epoch': 0.1}
{'loss': 3.3728, 'grad_norm': 1.5090889930725098, 'learning_rate': 2.4509803921568626e-05, 'epoch': 0.12}
{'loss': 3.3045, 'grad_norm': 1.686798334121704, 'learning_rate': 2.9411764705882354e-05, 'epoch': 0.15}
{'loss': 3.1435, 'grad_norm': 7.129352569580078, 'learning_rate': 3.431372549019608e-05, 'epoch': 0.17}
{'loss': 2.9468, 'grad_norm': 7.47405481338501, 'learning_rate': 3.9215686274509805e-05, 'epoch': 0.2}
{'loss': 2.7463, 'grad_norm': 4.668152809143066, 'learning_rate': 4.411764705882353e-05, 'epoch': 0.22}
{'loss': 2.5484, 'grad_norm': 25.23488998413086, 'learnin

  0%|          | 0/408 [00:00<?, ?it/s]

{'eval_loss': 0.33710142970085144, 'eval_accuracy': 0.9484836852207293, 'eval_runtime': 72.3264, 'eval_samples_per_second': 180.086, 'eval_steps_per_second': 5.641, 'epoch': 1.0}
{'loss': 0.5797, 'grad_norm': 16.127525329589844, 'learning_rate': 8.874931731294374e-05, 'epoch': 1.01}
{'loss': 0.5315, 'grad_norm': 29.789026260375977, 'learning_rate': 8.820316766794102e-05, 'epoch': 1.03}
{'loss': 0.5597, 'grad_norm': 11.37763786315918, 'learning_rate': 8.765701802293829e-05, 'epoch': 1.06}
{'loss': 0.5641, 'grad_norm': 13.575957298278809, 'learning_rate': 8.711086837793556e-05, 'epoch': 1.08}
{'loss': 0.4855, 'grad_norm': 17.259918212890625, 'learning_rate': 8.656471873293282e-05, 'epoch': 1.1}
{'loss': 0.5415, 'grad_norm': 23.361732482910156, 'learning_rate': 8.60185690879301e-05, 'epoch': 1.13}
{'loss': 0.5183, 'grad_norm': 13.002607345581055, 'learning_rate': 8.547241944292737e-05, 'epoch': 1.15}
{'loss': 0.4989, 'grad_norm': 6.121351718902588, 'learning_rate': 8.492626979792464e-05, 

  0%|          | 0/408 [00:00<?, ?it/s]

{'eval_loss': 0.19000662863254547, 'eval_accuracy': 0.9642226487523993, 'eval_runtime': 71.8153, 'eval_samples_per_second': 181.368, 'eval_steps_per_second': 5.681, 'epoch': 2.0}
{'loss': 0.463, 'grad_norm': 9.59345531463623, 'learning_rate': 6.635718186783179e-05, 'epoch': 2.01}
{'loss': 0.3763, 'grad_norm': 9.673075675964355, 'learning_rate': 6.581103222282906e-05, 'epoch': 2.04}
{'loss': 0.3515, 'grad_norm': 6.468060493469238, 'learning_rate': 6.526488257782632e-05, 'epoch': 2.06}
{'loss': 0.3383, 'grad_norm': 9.938617706298828, 'learning_rate': 6.47187329328236e-05, 'epoch': 2.09}
{'loss': 0.303, 'grad_norm': 20.537630081176758, 'learning_rate': 6.417258328782087e-05, 'epoch': 2.11}
{'loss': 0.3531, 'grad_norm': 8.246866226196289, 'learning_rate': 6.362643364281814e-05, 'epoch': 2.14}
{'loss': 0.3341, 'grad_norm': 10.045408248901367, 'learning_rate': 6.30802839978154e-05, 'epoch': 2.16}
{'loss': 0.3425, 'grad_norm': 6.7844367027282715, 'learning_rate': 6.253413435281268e-05, 'epoch

  0%|          | 0/408 [00:00<?, ?it/s]

{'eval_loss': 0.11444877833127975, 'eval_accuracy': 0.9744337811900192, 'eval_runtime': 71.4791, 'eval_samples_per_second': 182.221, 'eval_steps_per_second': 5.708, 'epoch': 3.0}
{'loss': 0.2584, 'grad_norm': 7.418062210083008, 'learning_rate': 4.396504642271983e-05, 'epoch': 3.02}
{'loss': 0.2462, 'grad_norm': 5.906749725341797, 'learning_rate': 4.341889677771709e-05, 'epoch': 3.04}
{'loss': 0.2385, 'grad_norm': 4.689990043640137, 'learning_rate': 4.2872747132714365e-05, 'epoch': 3.07}
{'loss': 0.2456, 'grad_norm': 8.792654991149902, 'learning_rate': 4.232659748771163e-05, 'epoch': 3.09}
{'loss': 0.2717, 'grad_norm': 16.692411422729492, 'learning_rate': 4.1780447842708904e-05, 'epoch': 3.12}
{'loss': 0.2517, 'grad_norm': 4.360097885131836, 'learning_rate': 4.123429819770617e-05, 'epoch': 3.14}
{'loss': 0.2508, 'grad_norm': 10.163043022155762, 'learning_rate': 4.0688148552703444e-05, 'epoch': 3.17}
{'loss': 0.2286, 'grad_norm': 10.060879707336426, 'learning_rate': 4.014199890770071e-05

  0%|          | 0/408 [00:00<?, ?it/s]

{'eval_loss': 0.09795144200325012, 'eval_accuracy': 0.9781190019193858, 'eval_runtime': 72.2277, 'eval_samples_per_second': 180.333, 'eval_steps_per_second': 5.649, 'epoch': 4.0}
{'loss': 0.2322, 'grad_norm': 11.299735069274902, 'learning_rate': 2.2119060622610596e-05, 'epoch': 4.0}
{'loss': 0.1822, 'grad_norm': 3.855872392654419, 'learning_rate': 2.1572910977607866e-05, 'epoch': 4.03}
{'loss': 0.1994, 'grad_norm': 2.8331615924835205, 'learning_rate': 2.1026761332605136e-05, 'epoch': 4.05}
{'loss': 0.2436, 'grad_norm': 11.57017993927002, 'learning_rate': 2.0480611687602405e-05, 'epoch': 4.08}
{'loss': 0.2229, 'grad_norm': 5.952452659606934, 'learning_rate': 1.9934462042599672e-05, 'epoch': 4.1}
{'loss': 0.1869, 'grad_norm': 5.162848949432373, 'learning_rate': 1.9388312397596942e-05, 'epoch': 4.13}
{'loss': 0.2021, 'grad_norm': 9.352360725402832, 'learning_rate': 1.884216275259421e-05, 'epoch': 4.15}
{'loss': 0.1975, 'grad_norm': 4.194348335266113, 'learning_rate': 1.829601310759148e-05

  0%|          | 0/408 [00:00<?, ?it/s]

{'eval_loss': 0.08980436623096466, 'eval_accuracy': 0.9776583493282149, 'eval_runtime': 72.341, 'eval_samples_per_second': 180.05, 'eval_steps_per_second': 5.64, 'epoch': 5.0}
{'train_runtime': 5735.5395, 'train_samples_per_second': 45.417, 'train_steps_per_second': 0.355, 'train_loss': 0.5633986990340512, 'epoch': 5.0}


TrainOutput(global_step=2035, training_loss=0.5633986990340512, metrics={'train_runtime': 5735.5395, 'train_samples_per_second': 45.417, 'train_steps_per_second': 0.355, 'total_flos': 2.36389907860608e+18, 'train_loss': 0.5633986990340512, 'epoch': 4.996930632289748})