In [46]:
import os
import torchaudio
from torch.utils.data import Dataset, DataLoader
import random
import torch.nn as nn
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import Wav2Vec2Model, BertConfig, BertModel
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm


class SpeechCommandsDataset(Dataset):
    def __init__(self, dataset_path, background_noise_path, segment_length=1, unknown_label='unknown', target_background_samples=2375, selected_words=None):
        self.selected_words = selected_words if selected_words is not None else []
        self.audio_labels = []
        self.audio_paths = []
        self.label_to_index = {unknown_label: 0}
        self.total_size_bytes = 0
        self.target_background_samples = target_background_samples
        self.background_samples_added = 0
        all_labels = sorted(os.listdir(dataset_path))

        for label in all_labels:
            if label.startswith('_') or (self.selected_words and label not in self.selected_words):
                continue  
            label_path = os.path.join(dataset_path, label)
            if os.path.isdir(label_path):
                label_index = len(self.label_to_index)
                self.label_to_index[label] = label_index

                for audio_file in os.listdir(label_path):
                    if audio_file.endswith('.wav'):
                        file_path = os.path.join(label_path, audio_file)
                        self.audio_paths.append(file_path)
                        self.audio_labels.append(label_index)
                        self.total_size_bytes += os.path.getsize(file_path)

        # noise_files = os.listdir(background_noise_path)
        # random.shuffle(noise_files)  
        # for noise_file in noise_files:
        #     if noise_file.endswith('.wav'):
        #         noise_path = os.path.join(background_noise_path, noise_file)
        #         if self.background_samples_added < self.target_background_samples:
        #             self.process_background_noise(noise_path, segment_length)

    # def process_background_noise(self, noise_path, segment_length):
    #     waveform, sample_rate = torchaudio.load(noise_path)
    #     total_samples = waveform.size(1)
    #     samples_per_segment = int(sample_rate * segment_length)
    #     remaining_segments = self.target_background_samples - self.background_samples_added
        
    #     while remaining_segments > 0:
    #         max_start = total_samples - samples_per_segment
    #         start = random.randint(0, max_start)
    #         end = start + samples_per_segment
            
    #         self.audio_paths.append((noise_path, start, end))
    #         self.audio_labels.append(self.label_to_index['unknown'])
    #         self.total_size_bytes += end - start
    #         self.background_samples_added += 1
    #         remaining_segments -= 1

    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_info = self.audio_paths[idx]
        label = self.audio_labels[idx]
        if isinstance(audio_info, tuple):
            waveform, sample_rate = torchaudio.load(audio_info[0], frame_offset=audio_info[1], num_frames=audio_info[2]-audio_info[1])
        else:
            waveform, sample_rate = torchaudio.load(audio_info)
        waveform = waveform.squeeze()
        if waveform.numel() == 0:  # Check if the waveform is empty
            print(f"Empty waveform at idx: {idx}, path: {audio_info}")
        return waveform, label

In [44]:
class AudioClassifier(nn.Module):
    def __init__(self, wav2vec_model_name, num_labels, learning_rate=0.001, weight_decay=0.01):
        super(AudioClassifier, self).__init__()
        self.wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_name)
        
        transformer_config = BertConfig(
            hidden_size=self.wav2vec.config.hidden_size,
            num_attention_heads=self.wav2vec.config.num_attention_heads,
            num_hidden_layers=1, 
        )
        self.transformer = BertModel(transformer_config)
    
        self.classifier = nn.Linear(self.wav2vec.config.hidden_size, num_labels)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
        self.criterion = nn.CrossEntropyLoss()

        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.7 ** (epoch // 5))

    def forward(self, input_values):
        with torch.no_grad():
            extracted_features = self.wav2vec(input_values).last_hidden_state
        transformer_output = self.transformer(inputs_embeds=extracted_features)
        cls_output = transformer_output.last_hidden_state[:, 0, :]
        logits = self.classifier(cls_output)
        
        return logits
    
    def fit(self, train_loader, epochs, device):
        self.to(device)
        self.train()

        for epoch in range(epochs):
            # tqdm is used here for the progress bar in the epochs loop
            with tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as tepoch:
                for input_values, labels in tepoch:
                    input_values = input_values.to(device)
                    labels = labels.to(device)
                    outputs = self(input_values)

                    loss = self.criterion(outputs, labels)
                    
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    tepoch.set_postfix(loss=loss.item())
                    self.scheduler.step()

In [47]:
DATASET_PATH = './train/train/audio'
BACKGROUND_NOISE_PATH = './train/train/audio/_background_noise_'
dataset = SpeechCommandsDataset(DATASET_PATH, BACKGROUND_NOISE_PATH)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

print(dataset.audio_labels.count(dataset.label_to_index['unknown']))
print(dataset.audio_labels.count(dataset.label_to_index['yes']))

0
2377


In [48]:
import os
import re
import hashlib

MAX_NUM_WAVS_PER_CLASS = 2**27 - 1  

def which_set(filename, validation_percentage, testing_percentage):
    base_name = os.path.basename(filename)
    hash_name = re.sub(r'_nohash_.*$', '', base_name)
    hash_name_hashed = hashlib.sha1(hash_name.encode('utf-8')).hexdigest()
    percentage_hash = ((int(hash_name_hashed, 16) % (MAX_NUM_WAVS_PER_CLASS + 1)) *
                       (100.0 / MAX_NUM_WAVS_PER_CLASS))
    if percentage_hash < validation_percentage:
        result = 'validation'
    elif percentage_hash < (testing_percentage + validation_percentage):
        result = 'testing'
    else:
        result = 'training'
    return result

validation_percentage = 10.0 
testing_percentage = 10.0 

train_files = []
validation_files = []
test_files = []

for file_path in dataset.audio_paths:
    set_category = which_set(file_path, validation_percentage, testing_percentage)
    if set_category == 'training':
        train_files.append(file_path)
    elif set_category == 'validation':
        validation_files.append(file_path)
    else:  # 'testing'
        test_files.append(file_path)

train_files[:5], validation_files[:5], test_files[:5] 

print(len(train_files), len(validation_files) , len(test_files))

51088 6798 6835


In [49]:
print(len(train_files), len(validation_files) , len(test_files))

51088 6798 6835


In [50]:
import random

# Get label indices for 'yes' and 'no'
yes_index = dataset.label_to_index['yes']
no_index = dataset.label_to_index['no']

# Initialize the lists for storing file paths
train_yes = []
train_no = []
train_unknown = []
validation_yes = []
validation_no = []
validation_unknown = []
test_yes = []
test_no = []
test_unknown = []

# Distribute files into the respective sets based on the provided function
for file_path, label in zip(dataset.audio_paths, dataset.audio_labels):
    set_type = which_set(file_path, validation_percentage, testing_percentage)
    if label == yes_index:
        if set_type == 'training':
            train_yes.append(file_path)
        elif set_type == 'validation':
            validation_yes.append(file_path)
        else:
            test_yes.append(file_path)
    elif label == no_index:
        if set_type == 'training':
            train_no.append(file_path)
        elif set_type == 'validation':
            validation_no.append(file_path)
        else:
            test_no.append(file_path)
    else:
        if set_type == 'training':
            train_unknown.append(file_path)
        elif set_type == 'validation':
            validation_unknown.append(file_path)
        else:
            test_unknown.append(file_path)

num_yes_train = len(train_yes)
num_yes_test = len(test_yes)
num_yes_validation = len(validation_yes)
random.shuffle(train_unknown)
random.shuffle(train_unknown)
random.shuffle(train_unknown)
train_unknown = train_unknown[:num_yes_train]
test_unknown = test_unknown[:num_yes_test]
validation_unknown = validation_unknown[:num_yes_validation]

train_files = train_yes + train_no + train_unknown
validation_files = validation_yes + validation_no + validation_unknown
test_files = test_yes + test_no + test_unknown

random.shuffle(train_files)
random.shuffle(validation_files)
random.shuffle(test_files)

len(train_files), len(validation_files), len(test_files)

(5573, 792, 764)

In [55]:
import os
import random
import torchaudio

import os
import random
import torchaudio


def process_and_distribute_background_noise(background_noise_path, segment_length, target_background_samples, num_train_samples, num_val_samples, num_test_samples):
    noise_files = [f for f in os.listdir(background_noise_path) if f.endswith('.wav')]
    random.shuffle(noise_files)
    
    audio_segments = []
    label_to_index = {'unknown': 0}  # Assuming 'unknown' is the label for background noise
    
    total_required_samples = num_train_samples + num_val_samples + num_test_samples

    background_samples_added = 0
    while background_samples_added < target_background_samples:
        noise_file = random.choice(noise_files)
        noise_path = os.path.join(background_noise_path, noise_file)
        waveform, sample_rate = torchaudio.load(noise_path)
        total_samples = waveform.size(1)
        samples_per_segment = int(sample_rate * segment_length)

        max_start = total_samples - samples_per_segment
        if max_start > 0: 
            start = random.randint(0, max_start)
            end = start + samples_per_segment

            audio_segments.append((noise_path, start, end, label_to_index['unknown']))
            background_samples_added += 1

    random.shuffle(audio_segments)

    total_collected_samples = len(audio_segments)

    if total_required_samples > total_collected_samples:
        raise ValueError("Requested number of samples exceeds the total number of collected samples.")

    train_segments = audio_segments[:num_train_samples]
    val_segments = audio_segments[num_train_samples:num_train_samples + num_val_samples]
    test_segments = audio_segments[num_train_samples + num_val_samples:num_train_samples + num_val_samples + num_test_samples]

    return train_segments, val_segments, test_segments

# Example usage
background_noise_path = BACKGROUND_NOISE_PATH
segment_length = 1  
target_background_samples = num_yes_train+num_yes_test+num_yes_validation 
num_train_samples = num_yes_train
num_val_samples = num_yes_test
num_test_samples = num_yes_validation

train_unknow, val_unknown, test_unknown = process_and_distribute_background_noise(
    background_noise_path, segment_length, target_background_samples, 
    num_train_samples, num_val_samples, num_test_samples
)

In [42]:
def collate_fn(batch):
    audios, labels = zip(*batch)
    audios_padded = pad_sequence([audio for audio in audios], batch_first=True, padding_value=0.0)
    labels = torch.tensor(labels)
    return audios_padded, labels

train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [35]:
model = AudioClassifier('facebook/wav2vec2-base-960h', num_labels=31)
model.fit(train_loader, epochs=10, device='cpu')

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/10:   0%|          | 1/2023 [00:01<48:41,  1.44s/batch, loss=3.63]


KeyboardInterrupt: 