# Whisper training for background sounds


## Project sounds
```
sounds_folder/
├── class1/
│   ├── *.wav (audio file)
├── class2/
│   ├── *.wav (audio file)
└── class3/
    ├── *.wav (audio file)
```


## 1. Environment Configuration

Please install the requirements.txt

## 2. Util

### 2.1 Imports

In [None]:
import os
import time
import torch
import librosa
import numpy as np
import pickle as pkl
from torch.utils.data import Dataset, DataLoader
from transformers import WhisperProcessor, WhisperForConditionalGeneration, logging
from sklearn.model_selection import train_test_split

logging.set_verbosity_error()

### 2.2 Constants

In [15]:
WHISPER_BASE_MODEL = "openai/whisper-small"
AUDIO_TYPE = ".wav"
SAMPLING_RATE = 16000
AUDIO_TIME_DURATION = 30
TOKEN_MAX_LENGTH = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 2.3 Load audio and label list

In [3]:
def load_audio_data(audio_dir):
    audio_paths = []
    labels = []
    for root, _, files in os.walk(audio_dir):
        for file in files:
            if file.endswith(AUDIO_TYPE):
                audio_paths.append(os.path.join(root, file))
                label = os.path.basename(root)
                labels.append(label)
    return audio_paths, labels

### 2.4 Librosa audio
Returns the first 30 seg of the sound to be processed

In [4]:
@staticmethod
def get_audio(audio_path):
    audio = None
    if 'pkl' in audio_path:
        with open(audio_path, "rb") as f:
            audio = pkl.load(f)['Audio']
    else:
        audio, sr = librosa.load(audio_path, sr=SAMPLING_RATE)
        
    target_length = SAMPLING_RATE * AUDIO_TIME_DURATION
    if len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)))
    else:
        audio = audio[:target_length]
    return audio


### 2.5 Save Model

In [5]:
def save_model(model, processor, dir_name):
    model.save_pretrained(dir_name)
    processor.save_pretrained(dir_name)

## 3. Class Definition

In [6]:
class DatasetClass(Dataset):
    def __init__(self, audio_paths, labels, processor):
        self.audio_paths = audio_paths
        self.labels = labels
        self.processor = processor

    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]
            
        inputs = self.processor(
            get_audio(audio_path),
            sampling_rate=SAMPLING_RATE,
            return_tensors="pt",
            padding=True
        )
        
        label_tokens = self.processor.tokenizer(
            label,
            padding="max_length",
            max_length=TOKEN_MAX_LENGTH,
            truncation=True,
            return_tensors="pt"
        )
            
        return {
            "input_features": inputs.input_features[0],
            "labels": label_tokens.input_ids[0]
        }
    
    def get_audio_paths(self):
        return self.audio_paths

    def get_labels(self):
        return self.labels


## 4. Audio Classifier

In [7]:
def transcribe_audio(audio_path, model, processor):
    inputs = processor(
            get_audio(audio_path),
            sampling_rate=SAMPLING_RATE,
            return_tensors="pt"
        ).to(DEVICE)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_features,
            language="en",
            task="transcribe",
            use_cache=False
        )
    
    transcription = processor.decode(outputs[0])
    return transcription.replace("<|startoftranscript|>", "")\
                              .replace("<|en|>", "")\
                              .replace("<|transcribe|>", "")\
                              .replace("<|notimestamps|>", "")\
                              .strip()

In [8]:
def check_transcription_for_label(audio_path, label, model, processor):
    result = transcribe_audio(audio_path, model, processor)
    labeled_correctly = label in result
    #print(f"File: {os.path.basename(audio_path)} \t Correct({labeled_correctly}) -> Label({label}) == Transcription({result})")
    return labeled_correctly

In [9]:
def classify_audios(audio_files, labels, model, processor):
    correct_classification = 0
    number_of_audio_files = len(audio_files)

    if number_of_audio_files > 0:
        for i, file_path in enumerate(audio_files):
            correct_classification += check_transcription_for_label(file_path, labels[i], model, processor)
            
        accuracy_rate = (correct_classification / number_of_audio_files) * 100
        print(f"Accuracy rate: {accuracy_rate:.2f}% ({correct_classification}/{number_of_audio_files}) correct:{correct_classification} numFiles:{number_of_audio_files}")
    else:
        print("No files .wav found.")

In [10]:
def classify_audio_folder(folter_path, model, processor):
    audio_files, labels = load_audio_data(folter_path)
    classify_audios(audio_files, labels, model, processor)

In [11]:
def classify_audio_dataset(dataSet, model, processor):
    audio_files = dataSet.get_audio_paths()
    labels = dataSet.get_labels()
    classify_audios(audio_files, labels, model, processor)

## 5. Model training

In [12]:
def train_whisper(model, dataloader, test_dataset, num_epochs=10, processor=None):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

    model.train()
    for epoch in range(num_epochs):
        start_time = time.time()
        total_loss = 0
        for batch in dataloader:
            input_features = batch["input_features"].to(DEVICE, non_blocking=True)
            labels = batch["labels"].to(DEVICE)

            outputs = model(input_features=input_features, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            print(f"Item loss: {loss}")

        end_time = time.time()
        classify_audio_dataset(test_dataset, model, processor)
        print(f"Epochs {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader)}")
        print(f"Epoch time: {end_time - start_time:.2f} seconds\n")

    return total_loss / len(dataloader)

## Analisys

### Siren e traffic

In [None]:
AUDIO_DIR = "sounds"
MODEL_OUTPUT_DIR = "modelo_sirenes_completo"
NUM_EPOCHS = 10

processor = WhisperProcessor.from_pretrained(WHISPER_BASE_MODEL)
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_BASE_MODEL)
model.to(DEVICE)

audio_paths, labels = load_audio_data(AUDIO_DIR)
audio_paths_train, audio_paths_test, labels_train, labels_test = train_test_split(
    audio_paths, labels, test_size=0.3, random_state=42
)

print(f"Training audios: {audio_paths_train}")
print(f"Training labels: {labels_train}")
print(f"Test audios: {audio_paths_test}")
print(f"Test labels: {labels_test}")

train_dataset = DatasetClass(audio_paths_train, labels_train, processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True)

test_dataset = DatasetClass(audio_paths_test, labels_test, processor)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, pin_memory=True)

train_whisper(model, train_dataloader, test_dataset, NUM_EPOCHS, processor)

classify_audio_folder(AUDIO_DIR, model, processor)

save_model(model, processor, MODEL_OUTPUT_DIR)

### Vehicle sounds filtred 300 each class

In [None]:
AUDIO_DIR = "vehicleSoundsFiltred"
MODEL_OUTPUT_DIR = "modelo_vehicle_filtred"
NUM_EPOCHS = 10

processor = WhisperProcessor.from_pretrained(WHISPER_BASE_MODEL)
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_BASE_MODEL)
model.to(DEVICE)

audio_paths, labels = load_audio_data(AUDIO_DIR)
audio_paths_train, audio_paths_test, labels_train, labels_test = train_test_split(
    audio_paths, labels, test_size=0.3, random_state=42
)

print(f"Training audios: {audio_paths_train}")
print(f"Training labels: {labels_train}")
print(f"Test audios: {audio_paths_test}")
print(f"Test labels: {labels_test}")

train_dataset = DatasetClass(audio_paths_train, labels_train, processor)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=False)

test_dataset = DatasetClass(audio_paths_test, labels_test, processor)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, pin_memory=True)

train_whisper(model, train_dataloader, test_dataset, NUM_EPOCHS, processor)

classify_audio_folder(AUDIO_DIR, model, processor)

save_model(model, processor, MODEL_OUTPUT_DIR)

### Samosa

In [None]:
AUDIO_DIR = "TrainingDataset"
MODEL_OUTPUT_DIR = "modelo_samosa"
NUM_EPOCHS = 10
AUDIO_TYPE = ".pkl"

processor = WhisperProcessor.from_pretrained(WHISPER_BASE_MODEL)
model = WhisperForConditionalGeneration.from_pretrained(WHISPER_BASE_MODEL)
model.to(DEVICE)

audio_paths, labels = load_audio_data(AUDIO_DIR)
for i, path in enumerate(audio_paths):
    labels[i] = path.rstrip(".pkl").split("---")[2].replace('_', ' ')

audio_paths_train, audio_paths_test, labels_train, labels_test = train_test_split(
    audio_paths, labels, test_size=0.3, random_state=42
)

print(f"Training audios: {audio_paths_train}")
print(f"Training labels: {labels_train}")
print(f"Test audios: {audio_paths_test}")
print(f"Test labels: {labels_test}")

train_dataset = DatasetClass(audio_paths_train, labels_train, processor)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=False)

test_dataset = DatasetClass(audio_paths_test, labels_test, processor)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, pin_memory=True)

train_whisper(model, train_dataloader, test_dataset, NUM_EPOCHS, processor)

classify_audio_folder(AUDIO_DIR, model, processor)

save_model(model, processor, MODEL_OUTPUT_DIR)

## 5. Spectrogram visualization

To view the audio spectrogram:

In [None]:
import pylab
import wave

def visualizar_espectrograma(wav_file):
    wav = wave.open(wav_file, 'r')
    frames = wav.readframes(-1)
    sound_info = pylab.fromstring(frames, 'int16')
    frame_rate = wav.getframerate()
    wav.close()
    
    pylab.figure(figsize=(10, 4))
    pylab.specgram(sound_info, Fs=frame_rate)
    pylab.xlabel('Tempo (s)')
    pylab.ylabel('Frequência (Hz)')
    pylab.colorbar(label='Intensidade (dB)')
    pylab.title('Espectrograma do Áudio')
    pylab.show()
