<a href="https://colab.research.google.com/github/wolfram-laube/mlpc-project_team-park/blob/wl/pre-trained-v2/fastlane.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# All-in-one Pre-trained Word Tokenizer v2

In [1]:
# Install necessary libraries if not already installed
!pip install transformers librosa torch datasets noisereduce evaluate jiwer pandas accelerate



Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting noisereduce
  Downloading noisereduce-3.0.2-py3-none-any.whl (22 kB)
Collecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jiwer
  Downloading jiwer-3.0.4-py3-none-any.whl (21 kB)
Collecting accelerate
  Downloading accelerate-0.31.0-py3-none-any.whl (309 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.4/309.4 kB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12

In [1]:
data_dir = '/content/dataset'
#data_dir = '../dataset'

## Preproccess

### Load fresh data

In [3]:
import os
import sys
import shutil

# Check if the environment is Google Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    # If in Google Colab
    from google.colab import drive
    import gdown

    # Option 1: Download the file by its public link and expand it to the Colab runtime
    import urllib.request
    import zipfile

    scnwavzip_file_id = '1oI1EsH1krrEPbH9MSZRzLHu-_4p6-njR' # https://drive.google.com/file/d/1oI1EsH1krrEPbH9MSZRzLHu-_4p6-njR/view?usp=sharing
    scnnpyzip_file_id = '1oKgurvIgT93RGkxvxq8AA423VKlEVT7O' # https://drive.google.com/file/d/1oKgurvIgT93RGkxvxq8AA423VKlEVT7O/view?usp=sharing
    wrdwavzip_file_id = '1o1yBqdtqH3tjOHN4GKISJHlY2Qyu_ouX' # https://drive.google.com/file/d/1o1yBqdtqH3tjOHN4GKISJHlY2Qyu_ouX/view?usp=sharing
    wrdnpyzip_file_id = '1o2fj6QAM00zg8YMxsHwcNa2lkIXLXDYs' # https://drive.google.com/file/d/1o2fj6QAM00zg8YMxsHwcNa2lkIXLXDYs/view?usp=sharing
    annotation_file_id = '1xLxget7c5nCkwYt9Ru2RpYi5rMkk_pl0'  # https://drive.google.com/file/d/1xLxget7c5nCkwYt9Ru2RpYi5rMkk_pl0/view?usp=sharing
    scenes_file_id = '1xLgB7-cCz6nReyQbFJJcJGOUKCCbNhCG'  # https://drive.google.com/file/d/1xLgB7-cCz6nReyQbFJJcJGOUKCCbNhCG/view?usp=sharing

    scnwavzip_url = f'https://drive.google.com/uc?id={scnwavzip_file_id}'
    scnnpyzip_url = f'https://drive.google.com/uc?id={scnnpyzip_file_id}'
    wrdwavzip_url = f'https://drive.google.com/uc?id={wrdwavzip_file_id}'
    wrdnpyzip_url = f'https://drive.google.com/uc?id={wrdnpyzip_file_id}'
    annotation_url = f'https://drive.google.com/uc?id={annotation_file_id}'
    scenes_url = f'https://drive.google.com/uc?id={scenes_file_id}'

    scnwavzip_path = '/content/scenes_data.zip'
    scnnpyzip_path = '/content/scenes_feat.zip'
    wrdwavzip_path = '/content/words_data.zip'
    wrdnpyzip_path = '/content/words_feat.zip'
    data_dir = '/content/dataset'
    scenes_dir = f'{data_dir}/scenes'
    words_dir = f'{data_dir}/words'
    scenes_wav_dir = f'{scenes_dir}/wav'
    scenes_npy_dir = f'{scenes_dir}/npy'
    words_wav_dir = f'{data_dir}/words'
    words_npy_dir = f'{data_dir}/words'

    # Download the WAVZIP file
    #urllib.request.urlretrieve(wavzip_url, wavzip_path)
    gdown.download(scnwavzip_url, scnwavzip_path, quiet=False)

    # Unzip the file
    with zipfile.ZipFile(scnwavzip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

    print(f"Scenes training data extracted to {data_dir}")

     # Create the 'scenes/wav' folder structure
    os.makedirs(scenes_wav_dir, exist_ok=True)

    # Copy .wav files to 'scenes/wav'
    extracted_scenes_dir = os.path.join(data_dir, 'mlpc24_speech_commands', 'scenes')
    for root, dirs, files in os.walk(extracted_scenes_dir):
        for file in files:
            if file.endswith('.wav'):
                src_path = os.path.join(root, file)
                dst_path = os.path.join(scenes_wav_dir, file)
                shutil.copy(src_path, dst_path)

    print(f"Scenes training .wav files moved to {scenes_wav_dir}")

    # Download the SCNNPYZIP file
    gdown.download(scnnpyzip_url, scnnpyzip_path, quiet=False)

    # Unzip the file
    with zipfile.ZipFile(scnnpyzip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)

    print(f"Scenes training features extracted to {data_dir}")

     # Create the 'scenes/npy' folder structure
    os.makedirs(scenes_npy_dir, exist_ok=True)

    # Copy .npy files to 'scenes/npy'
    extracted_scenes_dir = os.path.join(data_dir, 'development_scenes')
    for root, dirs, files in os.walk(extracted_scenes_dir):
        for file in files:
            if file.endswith('.npy'):
                src_path = os.path.join(root, file)
                dst_path = os.path.join(scenes_npy_dir, file)
                shutil.copy(src_path, dst_path)

    print(f"Scenes training .npy files moved to {scenes_npy_dir}")

    # Download the WRDWAVZIP file
    #urllib.request.urlretrieve(wavzip_url, wavzip_path)
    gdown.download(wrdwavzip_url, wrdwavzip_path, quiet=False)

    # Unzip the file
    with zipfile.ZipFile(wrdwavzip_path, 'r') as zip_ref:
        zip_ref.extractall(words_wav_dir)

    print(f"Words training data extracted to {words_wav_dir}")

    # Download the WRDNPYZIP file
    gdown.download(wrdnpyzip_url, wrdnpyzip_path, quiet=False)

    # Unzip the file
    with zipfile.ZipFile(wrdnpyzip_path, 'r') as zip_ref:
        zip_ref.extractall(words_npy_dir)

    print(f"Words training ,npy files s extracted to {words_npy_dir}")


    # Download the CSV files into the data_dir
    annotation_orig_path = os.path.join(data_dir, 'development_scene_annotations.csv.orig') # Keep a backup copy because it needs fixing
    annotation_path = os.path.join(data_dir, 'development_scene_annotations.csv')
    scenes_path = os.path.join(data_dir, 'development_scenes.csv')

    gdown.download(annotation_url, annotation_orig_path, quiet=False)
    gdown.download(annotation_url, annotation_path, quiet=False)
    gdown.download(scenes_url, scenes_path, quiet=False)

    print(f"CSV files downloaded to {scenes_dir}")

    # Option 2: Mount Google Drive and use the training data
    # Note this really takes some time for preprocessing file by file
    #drive.mount('/content/drive')
    #data_dir = '/content/drive/My Drive/Dropbox/public/mlpc/dataset'

    # Use this option to read from Google Drive instead
    #print(f"Using training data from {data_dir}")
else:
    # If on local machine
    data_dir = '../dataset'
    print(f"Using local training data from {data_dir}")

# Use the data_dir variable as the path to your training data

Downloading...
From (original): https://drive.google.com/uc?id=1oI1EsH1krrEPbH9MSZRzLHu-_4p6-njR
From (redirected): https://drive.google.com/uc?id=1oI1EsH1krrEPbH9MSZRzLHu-_4p6-njR&confirm=t&uuid=c1f59364-3d29-4307-b1aa-123b1c5cd50d
To: /content/scenes_data.zip
100%|██████████| 305M/305M [00:11<00:00, 25.9MB/s]


Scenes training data extracted to /content/dataset
Scenes training .wav files moved to /content/dataset/scenes/wav


Downloading...
From (original): https://drive.google.com/uc?id=1oKgurvIgT93RGkxvxq8AA423VKlEVT7O
From (redirected): https://drive.google.com/uc?id=1oKgurvIgT93RGkxvxq8AA423VKlEVT7O&confirm=t&uuid=0666d085-b9c6-47d6-aeab-2c9cfdf2ad0f
To: /content/scenes_feat.zip
100%|██████████| 422M/422M [00:04<00:00, 93.9MB/s]


Scenes training features extracted to /content/dataset
Scenes training .npy files moved to /content/dataset/scenes/npy


Downloading...
From (original): https://drive.google.com/uc?id=1o1yBqdtqH3tjOHN4GKISJHlY2Qyu_ouX
From (redirected): https://drive.google.com/uc?id=1o1yBqdtqH3tjOHN4GKISJHlY2Qyu_ouX&confirm=t&uuid=659a29ed-97d3-4509-add1-810d9c19271a
To: /content/words_data.zip
100%|██████████| 1.17G/1.17G [00:14<00:00, 82.8MB/s]


Words training data extracted to /content/dataset/words


Downloading...
From (original): https://drive.google.com/uc?id=1o2fj6QAM00zg8YMxsHwcNa2lkIXLXDYs
From (redirected): https://drive.google.com/uc?id=1o2fj6QAM00zg8YMxsHwcNa2lkIXLXDYs&confirm=t&uuid=04f22cff-36a0-4a80-b937-21a1fa8a5449
To: /content/words_feat.zip
100%|██████████| 1.51G/1.51G [00:52<00:00, 28.5MB/s]


Words training ,npy files s extracted to /content/dataset/words


Downloading...
From: https://drive.google.com/uc?id=1xLxget7c5nCkwYt9Ru2RpYi5rMkk_pl0
To: /content/dataset/development_scene_annotations.csv.orig
100%|██████████| 70.4k/70.4k [00:00<00:00, 58.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1xLxget7c5nCkwYt9Ru2RpYi5rMkk_pl0
To: /content/dataset/development_scene_annotations.csv
100%|██████████| 70.4k/70.4k [00:00<00:00, 18.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=1xLgB7-cCz6nReyQbFJJcJGOUKCCbNhCG
To: /content/dataset/development_scenes.csv
100%|██████████| 29.5k/29.5k [00:00<00:00, 33.4MB/s]

CSV files downloaded to /content/dataset/scenes





### Determine CPU/GPU

In [4]:
# Function to check if GPU is available
#def is_gpu_available():
#    try:
#        result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
#        return result.returncode == 0
#    except FileNotFoundError:
#        return False

def is_gpu_available():
    try:
        import torch
        is_gpu = torch.cuda.is_available()
        print(f'GPU available: {is_gpu}')
        return is_gpu
    except ImportError as ie:
        print("No GPU support", ie)
        pass

    try:
        import tensorflow as tf
        is_gpu =  tf.config.list_physical_devices('GPU') != []
        print(f'GPU available: {is_gpu}')
        return is_gpu
    except ImportError as ie:
        print("No GPU support", ie)
        pass

    print("No GPU support found")
    return False

is_gpu_available()

GPU available: True


True

### Fix erreneous metadata

#### Before

In [5]:
import pandas as pd

# Load the CSV files
scene_annotations_df = pd.read_csv(f'{data_dir}/development_scene_annotations.csv')
scenes_df = pd.read_csv(f'{data_dir}/development_scenes.csv')

# Check the head of the dataframes to understand their structure
print(scene_annotations_df.head())
print(scenes_df.head())

# Check the distribution of labels in the annotations CSV
label_distribution_annotations = scene_annotations_df['command'].value_counts()
print("Label Distribution in development_scene_annotations.csv:")
print(label_distribution_annotations)

# Check the distribution of speaker IDs in the scenes CSV
label_distribution_scenes = scenes_df['speaker_id'].value_counts()
print("Label Distribution in development_scenes.csv:")
print(label_distribution_scenes)


                        filename         command     start       end
0         2_speech_true_Ofen_aus        Ofen aus  11.25230  12.07747
1         3_speech_true_Radio_an  Staubsauger an  21.48040  23.18083
2         4_speech_true_Alarm_an        Alarm an  14.45720  16.08301
3        9_speech_true_Radio_aus  Staubsauger an   3.67909   5.63126
4  11_speech_false_Fernseher_aus  Staubsauger an  10.57850  11.67886
                        filename  speaker_id
0         2_speech_true_Ofen_aus         132
1         3_speech_true_Radio_an         132
2         4_speech_true_Alarm_an         132
3        9_speech_true_Radio_aus         132
4  11_speech_false_Fernseher_aus         132
Label Distribution in development_scene_annotations.csv:
command
Staubsauger an     288
Licht aus           77
Licht an            64
Fernseher an        56
Alarm an            56
Heizung an          55
Heizung aus         54
Radio aus           53
Radio an            52
Ofen aus            49
Alarm aus           4

#### Fix

In [6]:
import os
import re
import shutil
import pandas as pd

# Paths to the original and working copy files
original_file_path = f'{data_dir}/development_scene_annotations.csv.orig'
working_copy_path = f'{data_dir}/development_scene_annotations.csv.0'
corrected_file_path = f'{data_dir}/development_scene_annotations.csv'

# Step 1: Create a working copy of the original file
shutil.copy(original_file_path, working_copy_path)

# Step 2: Load the working copy into a DataFrame
df = pd.read_csv(working_copy_path)

# Define the pattern to parse the filename
filename_pattern = re.compile(r'(\d+)_speech_(true|false)_((?:[a-zA-ZäöüÄÖÜß]+_(?:an|aus)_?)+)', re.UNICODE)

# Function to parse filename and extract commands
def parse_filename(filename):
    match = filename_pattern.match(filename)
    if not match:
        return []

    commands_str = match.group(3)
    commands = commands_str.split('_')

    command_list = []
    for i in range(0, len(commands), 2):
        command_list.append(f"{commands[i]} {commands[i+1]}")

    return command_list

# Parse the commands from filenames and add to the DataFrame
df['parsed_commands'] = df['filename'].apply(parse_filename)

# Step 3: Group by filename and sort by start time
grouped = df.groupby('filename').apply(lambda x: x.sort_values(by='start')).reset_index(drop=True)

# Step 4: Assign the correct labels based on the order of commands in the filename
def assign_labels(group):
    commands = group['parsed_commands'].iloc[0]  # get the parsed commands from the first row
    group = group.reset_index(drop=True)
    for i in range(len(group)):
        if i < len(commands):
            group.at[i, 'command'] = commands[i]
        else:
            print(f"Warning: More segments than commands in {group['filename'].iloc[0]}")
    return group

# Apply the label assignment function
corrected_df = grouped.groupby('filename').apply(assign_labels).reset_index(drop=True)

# Drop the temporary column
corrected_df = corrected_df.drop(columns=['parsed_commands'])

# Step 5: Save the corrected DataFrame to a new CSV file
corrected_df.to_csv(corrected_file_path, index=False)

# Verify the saved corrections
print("Label corrections applied and saved successfully.")
print(corrected_df.head())


Label corrections applied and saved successfully.
                        filename       command     start       end
0    1003_speech_false_Licht_aus     Licht aus  12.20090  13.57599
1       1008_speech_true_Ofen_an       Ofen an   6.90112   8.52638
2      1010_speech_true_Radio_an      Radio an  13.03100  14.03146
3  1011_speech_true_Fernseher_an  Fernseher an  14.11030  15.36121
4   1012_speech_true_Heizung_aus   Heizung aus  11.20520  12.70590


#### After

In [7]:
import pandas as pd

# Load the CSV files
scene_annotations_df = pd.read_csv(f'{data_dir}/development_scene_annotations.csv')
scenes_df = pd.read_csv(f'{data_dir}/development_scenes.csv')

# Check the head of the dataframes to understand their structure
print(scene_annotations_df.head())
print(scenes_df.head())

# Check the distribution of labels in the annotations CSV
label_distribution_annotations = scene_annotations_df['command'].value_counts()
print("Label Distribution in development_scene_annotations.csv:")
print(label_distribution_annotations)

# Check the distribution of speaker IDs in the scenes CSV
label_distribution_scenes = scenes_df['speaker_id'].value_counts()
print("Label Distribution in development_scenes.csv:")
print(label_distribution_scenes)


                        filename       command     start       end
0    1003_speech_false_Licht_aus     Licht aus  12.20090  13.57599
1       1008_speech_true_Ofen_an       Ofen an   6.90112   8.52638
2      1010_speech_true_Radio_an      Radio an  13.03100  14.03146
3  1011_speech_true_Fernseher_an  Fernseher an  14.11030  15.36121
4   1012_speech_true_Heizung_aus   Heizung aus  11.20520  12.70590
                        filename  speaker_id
0         2_speech_true_Ofen_aus         132
1         3_speech_true_Radio_an         132
2         4_speech_true_Alarm_an         132
3        9_speech_true_Radio_aus         132
4  11_speech_false_Fernseher_aus         132
Label Distribution in development_scene_annotations.csv:
command
Licht aus          86
Licht an           78
Heizung an         76
Fernseher an       74
Radio aus          69
Heizung aus        67
Alarm an           66
Radio an           65
Lüftung aus        64
Ofen aus           64
Lüftung an         63
Ofen an            63

### Preprocess audio data

In [8]:
import os
import numpy as np
import pandas as pd
import librosa
import soundfile as sf
import random
from IPython.display import Audio
from sklearn.decomposition import FastICA

# Function to apply ICA on audio segments
def apply_ica(segment, sr):
    ica = FastICA(n_components=1, whiten='arbitrary-variance')  # Explicitly set whiten parameter
    segment_reshaped = segment.reshape(-1, 1)
    segment_ica = ica.fit_transform(segment_reshaped).flatten()
    return segment_ica

# Function to preprocess segments and optionally save to the filesystem
def preprocess_and_save_segments(scenes_dir, annotations_path, save_dir=None, save_to_filesystem=False, apply_ica_flag=False):
    # Load the annotations
    annotations_df = pd.read_csv(annotations_path)

    # Ensure the save directory exists if saving to filesystem
    if save_to_filesystem and save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    preprocessed_segments = []

    for index, row in annotations_df.iterrows():
        filename = row['filename']
        command = row['command']
        start = row['start']
        end = row['end']

        # Load the audio file
        file_path = os.path.join(scenes_dir, f"{filename}.wav")
        y, sr = librosa.load(file_path, sr=None)

        # Extract the segment
        start_sample = int(start * sr)
        end_sample = int(end * sr)
        segment = y[start_sample:end_sample]

        # Normalize the segment
        segment = librosa.util.normalize(segment)

        # Apply ICA if the flag is set
        if apply_ica_flag:
            segment = apply_ica(segment, sr)

        # Add the segment to the list
        preprocessed_segments.append((filename, command, segment, sr))

        # Save the segment to the filesystem if required
        if save_to_filesystem and save_dir is not None:
            save_path = os.path.join(save_dir, f"{filename}_{start}_{end}.wav")
            sf.write(save_path, segment, sr)

    return preprocessed_segments

# Function to play a random segment from preprocessed segments
def play_random_segment(preprocessed_segments):
    # Select a random segment
    random_segment = random.choice(preprocessed_segments)

    filename, command, audio_data, sample_rate = random_segment

    # Print the command and play the audio segment
    print(f"Filename: {filename}")
    print(f"Command: {command}")

    return Audio(audio_data, rate=sample_rate)

# Function to play a random segment from the filesystem
def play_random_segment_from_filesystem(save_dir, annotations_path):
    # List all the preprocessed segment files
    segment_files = [f for f in os.listdir(save_dir) if f.endswith('.wav')]

    # Select a random segment file
    random_segment_file = random.choice(segment_files)
    random_segment_path = os.path.join(save_dir, random_segment_file)

    # Extract start and end times from the file name
    filename_parts = random_segment_file.split('_')
    filename = '_'.join(filename_parts[:-2])
    start_time = float(filename_parts[-2])
    end_time = float(filename_parts[-1].replace('.wav', ''))

    # Find the command in the annotations
    annotations_df = pd.read_csv(annotations_path)
    command_row = annotations_df[
        (annotations_df['filename'] == filename) &
        (annotations_df['start'] == start_time) &
        (annotations_df['end'] == end_time)
    ]

    if command_row.empty:
        print(f"No matching annotation found for {random_segment_file}")
        return

    command = command_row.iloc[0]['command']

    # Load the audio segment
    y, sr = librosa.load(random_segment_path, sr=None)

    # Print the command and play the audio segment
    print(f"Filename: {filename}")
    print(f"Command: {command}")

    return Audio(y, rate=sr)

# Example usage
scenes_dir = f'{data_dir}/scenes/wav'
annotations_path = f'{data_dir}/development_scene_annotations.csv'
save_dir = f'{data_dir}/clipped_commands'

# Preprocess segments and save to filesystem with optional ICA
preprocessed_segments = preprocess_and_save_segments(scenes_dir, annotations_path, save_dir, save_to_filesystem=True, apply_ica_flag=True)

# Play a random segment from memory
audio_memory = play_random_segment(preprocessed_segments)
display(audio_memory)

# Play a random segment from filesystem
audio_filesystem = play_random_segment_from_filesystem(save_dir, annotations_path)
display(audio_filesystem)


Filename: 1098_speech_true_Alarm_an_Staubsauger_an
Command: Staubsauger an


Filename: 1733_speech_true_Radio_aus
Command: Radio aus


## Main

### Libraries

### Training

In [2]:
import os
import re
import torch
import librosa
import logging
from torch.utils.data import DataLoader, Dataset
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from tqdm import tqdm
import numpy as np

# Configure logging to output to console
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
for handler in logger.handlers[:]:
    logger.removeHandler(handler)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Load pre-trained tokenizer and model
def load_model_and_tokenizer(model_name="facebook/wav2vec2-large-xlsr-53-german"):
    logger.info("Loading model and tokenizer...")
    processor = Wav2Vec2Processor.from_pretrained(model_name)
    model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
    logger.info("Model and tokenizer loaded successfully.")
    return processor, model

# Extract labels from filenames
def extract_labels_from_filename(filename):
    match = re.search(r'speech_true_(.*)\.wav', filename)
    if match:
        words = match.group(1).split('_')
        return ' '.join(words)
    return ''

# Dataset class with data augmentation
class AudioDataset(Dataset):
    def __init__(self, audio_files, processor, augment=False):
        self.audio_files = audio_files
        self.processor = processor
        self.augment = augment

    def __len__(self):
        return len(self.audio_files)

    def augment_audio(self, audio, sr):
        if np.random.rand() > 0.5:
            audio = librosa.effects.pitch_shift(y=audio, sr=sr, n_steps=np.random.uniform(-2, 2), bins_per_octave=24)
        if np.random.rand() > 0.5:
            audio = librosa.effects.time_stretch(audio, rate=np.random.uniform(0.8, 1.2))
        if np.random.rand() > 0.5:
            audio = audio + 0.005 * np.random.randn(len(audio))
        return audio

    def __getitem__(self, idx):
        file_path, audio, sr = self.audio_files[idx]
        if self.augment:
            audio = self.augment_audio(audio, sr)
        inputs = self.processor(audio, return_tensors="pt", padding="longest", sampling_rate=sr)
        label = extract_labels_from_filename(os.path.basename(file_path))
        label_ids = self.processor.tokenizer(label, return_tensors="pt").input_ids
        return inputs.input_values.squeeze(), label_ids.squeeze()

# Collate function to handle padding in DataLoader
def collate_fn(batch):
    input_values = [item[0] for item in batch]
    label_ids = [item[1] for item in batch]

    input_values = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True, padding_value=0)
    label_ids = torch.nn.utils.rnn.pad_sequence(label_ids, batch_first=True, padding_value=-100)

    return input_values, label_ids

# Load audio files
def load_audio_files(directory):
    audio_data = []
    logger.info(f"Loading audio files from {directory}...")
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith('.wav'):
                file_path = os.path.join(root, file)
                y, sr = librosa.load(file_path, sr=16000)  # Ensuring consistent sampling rate
                audio_data.append((file_path, y, sr))
    logger.info(f"Loaded {len(audio_data)} audio files from {directory}.")
    return audio_data

# Fine-tuning function
def fine_tune_model(model, processor, words_loader, num_epochs=10, lr=1e-5):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        epoch_loss = 0
        logger.info(f"Starting fine-tuning epoch {epoch + 1}/{num_epochs}...")
        with tqdm(total=len(words_loader), desc=f"Fine-tuning Epoch {epoch + 1}") as pbar:
            for input_values, label_ids in words_loader:
                input_values, label_ids = input_values.to(device), label_ids.to(device)

                optimizer.zero_grad()
                outputs = model(input_values)
                logits = outputs.logits

                # Compute lengths for CTC loss
                input_lengths = torch.full((logits.shape[0],), logits.shape[1], dtype=torch.long).to(device)
                label_lengths = torch.sum(label_ids != -100, dim=1).to(device)

                loss = torch.nn.CTCLoss()(logits.transpose(0, 1), label_ids, input_lengths, label_lengths)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
                pbar.update(1)

        logger.info(f"Fine-tuning epoch {epoch + 1} completed. Loss: {epoch_loss / len(words_loader):.4f}")

# Training function with validation and model checkpointing
def train_model(model, processor, train_loader, val_loader, num_epochs=15, lr=1e-5, accumulation_steps=4):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        epoch_loss = 0
        logger.info(f"Starting epoch {epoch + 1}/{num_epochs}...")
        with tqdm(total=len(train_loader), desc=f"Training Epoch {epoch + 1}") as pbar:
            optimizer.zero_grad()
            for i, (input_values, label_ids) in enumerate(train_loader):
                input_values, label_ids = input_values.to(device), label_ids.to(device)

                outputs = model(input_values)
                logits = outputs.logits

                # Compute lengths for CTC loss
                input_lengths = torch.full((logits.shape[0],), logits.shape[1], dtype=torch.long).to(device)
                label_lengths = torch.sum(label_ids != -100, dim=1).to(device)

                # Debugging output
                logger.debug(f"input_lengths: {input_lengths}")
                logger.debug(f"label_lengths: {label_lengths}")
                logger.debug(f"logits: {logits.shape}")
                logger.debug(f"label_ids: {label_ids}")

                loss = torch.nn.CTCLoss()(logits.transpose(0, 1), label_ids, input_lengths, label_lengths)
                loss.backward()

                if (i + 1) % accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                epoch_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
                pbar.update(1)

        val_loss = validate_model(model, val_loader)
        logger.info(f"Epoch {epoch + 1} completed. Training Loss: {epoch_loss / len(train_loader):.4f}, Validation Loss: {val_loss:.4f}")

        # Check if this is the best model so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            logger.info("Saving the new best model...")
            model.save_pretrained("fine_tuned_wav2vec2")
            processor.save_pretrained("fine_tuned_wav2vec2")

# Validation function
def validate_model(model, val_loader):
    model.eval()
    val_loss = 0
    logger.info("Starting validation...")
    with tqdm(total=len(val_loader), desc="Validation") as pbar:
        with torch.no_grad():
            for input_values, label_ids in val_loader:
                input_values, label_ids = input_values.to(device), label_ids.to(device)

                outputs = model(input_values)
                logits = outputs.logits

                # Compute lengths for CTC loss
                input_lengths = torch.full((logits.shape[0],), logits.shape[1], dtype=torch.long).to(device)
                label_lengths = torch.sum(label_ids != -100, dim=1).to(device)

                # Debugging output
                logger.debug(f"Validation input_lengths: {input_lengths}")
                logger.debug(f"Validation label_lengths: {label_lengths}")
                logger.debug(f"Validation logits: {logits.shape}")
                logger.debug(f"Validation label_ids: {label_ids}")

                loss = torch.nn.CTCLoss()(logits.transpose(0, 1), label_ids, input_lengths, label_lengths)
                val_loss += loss.item()
                pbar.set_postfix({'loss': loss.item()})
                pbar.update(1)
    model.train()
    return val_loss / len(val_loader)

# Inference function with timestamps
def infer_with_timestamps(model, processor, audio_file):
    y, sr = librosa.load(audio_file, sr=16000)  # Ensuring consistent sampling rate
    inputs = processor(y, return_tensors="pt", padding="longest", sampling_rate=sr).to(device)

    with torch.no_grad():
        logits = model(inputs.input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids.cpu())[0]

    # Get the frame timestamps
    frame_duration = model.config.inputs_to_logits_ratio / sr
    frame_timestamps = [i * frame_duration for i in range(logits.shape[1])]

    # Decode token ids to words with timestamps
    word_timestamps = []
    current_word = ""
    current_word_start = None

    for i, token_id in enumerate(predicted_ids[0].cpu()):
        token = processor.decode([token_id])
        if token.strip() != "":
            if current_word == "":
                current_word_start = frame_timestamps[i]
            current_word += token
        else:
            if current_word != "":
                word_timestamps.append((current_word, current_word_start, frame_timestamps[i]))
                current_word = ""
                current_word_start = None

    # Handle last word if any
    if current_word != "":
        word_timestamps.append((current_word, current_word_start, frame_timestamps[-1]))

    return transcription, word_timestamps

# Main execution
if __name__ == "__main__":
    #data_dir = '/content/dataset'
    scenes_path = f'{data_dir}/scenes/wav'
    words_path = f'{data_dir}/words'

    scenes_audio = load_audio_files(scenes_path)[:100]  # Increase subset size for more data
    words_audio = load_audio_files(words_path)[:100]  # Increase subset size for more data

    processor, model = load_model_and_tokenizer()

    words_dataset = AudioDataset(words_audio, processor, augment=True)
    words_loader = DataLoader(words_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)  # Larger batch size

    # Fine-tune on words dataset
    fine_tune_model(model, processor, words_loader, num_epochs=10, lr=1e-5)

    val_split = int(len(scenes_audio) * 0.2)
    train_dataset = AudioDataset(scenes_audio[val_split:], processor, augment=True)
    val_dataset = AudioDataset(scenes_audio[:val_split], processor)

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)  # Larger batch size
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

    # Train on scenes dataset
    train_model(model, processor, train_loader, val_loader, num_epochs=15, lr=1e-5)

    unseen_audio_file = f'{scenes_path}/98_speech_true_Alarm_aus_Lüftung_aus_Heizung_aus.wav'
    transcription, word_timestamps = infer_with_timestamps(model, processor, unseen_audio_file)

    logger.info("Transcription: " + transcription)
    logger.info("Word Timestamps:")
    for word, start, end in word_timestamps:
        logger.info(f"Word: {word}, Start: {start:.2f}s, End: {end:.2f}s")


2024-06-14 10:28:58,097 - INFO - Using device: cuda
INFO:__main__:Using device: cuda
2024-06-14 10:28:58,105 - INFO - Loading audio files from /content/dataset/scenes/wav...
INFO:__main__:Loading audio files from /content/dataset/scenes/wav...
2024-06-14 10:29:01,119 - INFO - Loaded 814 audio files from /content/dataset/scenes/wav.
INFO:__main__:Loaded 814 audio files from /content/dataset/scenes/wav.
2024-06-14 10:29:01,169 - INFO - Loading audio files from /content/dataset/words...
INFO:__main__:Loading audio files from /content/dataset/words...
2024-06-14 10:29:15,344 - INFO - Loaded 45296 audio files from /content/dataset/words.
INFO:__main__:Loaded 45296 audio files from /content/dataset/words.
2024-06-14 10:29:15,368 - INFO - Loading model and tokenizer...
INFO:__main__:Loading model and tokenizer...
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), s