In [3]:
import os
import shutil

source_folder = '/kaggle/input/environment-sound/Environment_Sound'
destination_folder = '/kaggle/working/dataset'

try:
    # Create the destination folder if it doesn't exist
    os.makedirs(destination_folder, exist_ok=True)
    print(f"Destination folder created or already exists: {destination_folder}")

    # Iterate through the subfolders in the source folder
    for item_name in os.listdir(source_folder):
        item_path = os.path.join(source_folder, item_name)
        destination_item_path = os.path.join(destination_folder, item_name)

        # If it's a directory, copy it recursively
        if os.path.isdir(item_path):
            shutil.copytree(item_path, destination_item_path)
            print(f"Copied folder: {item_name}")
        # If it's a file, copy it
        elif os.path.isfile(item_path):
            shutil.copy2(item_path, destination_item_path)  # copy2 preserves metadata
            print(f"Copied file: {item_name}")

    print("All folders and files copied successfully!")

except FileNotFoundError:
    print(f"Error: Source folder not found at {source_folder}")
except Exception as e:
    print(f"An error occurred: {e}")

Destination folder created or already exists: /kaggle/working/dataset
Copied folder: Bus
Copied folder: Park
Copied folder: Metro
Copied folder: University
Copied folder: Shopping_Mall
Copied folder: Restaurant
Copied folder: Metro_Station
All folders and files copied successfully!


In [4]:
import os
import numpy as np
import librosa
import soundfile as sf
from tqdm import tqdm
import random
import time
from scipy import signal
from google.colab import drive
import glob

# Paths
chunks_dir = '/kaggle/working/dataset'  # Input directory with 10-second chunks

# Set target files per class
TARGET_FILES_PER_CLASS = 1500

# Create output directory
os.makedirs(chunks_dir, exist_ok=True)

# Set random seed for reproducibility
np.random.seed(42)
random.seed(42)

# Augmentation functions
def add_white_noise(audio, noise_factor=0.005):
    """Add white noise to the audio signal."""
    noise = np.random.randn(len(audio))
    augmented_audio = audio + noise_factor * noise
    return augmented_audio

def time_stretch(audio, rate=1.0):
    """Time stretch the audio signal without changing the pitch."""
    return librosa.effects.time_stretch(audio, rate=rate)

def pitch_shift(audio, sr, n_steps=2):
    """Shift the pitch of the audio signal."""
    return librosa.effects.pitch_shift(audio, sr=sr, n_steps=n_steps)

def change_volume(audio, factor=1.0):
    """Change the volume of the audio signal."""
    return audio * factor

def add_background_noise(audio, noise_audio, noise_factor=0.2):
    """Mix background noise with audio."""
    # Make sure the noise is the same length as the audio
    if len(noise_audio) < len(audio):
        # If noise is shorter, repeat it
        repeats = int(np.ceil(len(audio) / len(noise_audio)))
        noise_audio = np.tile(noise_audio, repeats)

    # Trim to match the length of audio
    noise_audio = noise_audio[:len(audio)]

    # Mix with the given factor
    augmented_audio = audio + noise_factor * noise_audio

    # Normalize to prevent clipping
    if np.max(np.abs(augmented_audio)) > 1.0:
        augmented_audio = augmented_audio / np.max(np.abs(augmented_audio))

    return augmented_audio

def apply_filter(audio, sr, filter_type='lowpass', cutoff=1000):
    """Apply a filter to the audio signal."""
    nyquist = sr // 2
    normalized_cutoff = cutoff / nyquist

    if filter_type == 'lowpass':
        b, a = signal.butter(5, normalized_cutoff, btype='low')
    elif filter_type == 'highpass':
        b, a = signal.butter(5, normalized_cutoff, btype='high')
    elif filter_type == 'bandpass':
        b, a = signal.butter(5, [normalized_cutoff * 0.5, normalized_cutoff], btype='band')
    else:
        return audio

    return signal.filtfilt(b, a, audio)

def speed_up(audio, factor=1.2):
    """Speed up the audio by resampling."""
    indices = np.round(np.arange(0, len(audio), factor)).astype(int)
    indices = indices[indices < len(audio)]
    return audio[indices]

def slow_down(audio, factor=0.8):
    """Slow down the audio by resampling."""
    return speed_up(audio, 1/factor)

def time_shift(audio, shift_factor=0.2):
    """Shift the audio in time."""
    shift = int(len(audio) * shift_factor)
    return np.roll(audio, shift)

def reverb_effect(audio, sr, decay=2.0):
    """Add a simple reverb effect by convolving with an exponentially decaying filter."""
    reverb_length = int(sr * decay)
    impulse_response = np.exp(-np.arange(reverb_length) / sr * 3)

    # Normalize impulse response
    impulse_response = impulse_response / np.sum(impulse_response)

    # Apply convolution
    reverb_audio = signal.convolve(audio, impulse_response, mode='full')

    # Trim to original length
    reverb_audio = reverb_audio[:len(audio)]

    return reverb_audio

# Function to randomly collect background noises for each category
def collect_background_noises(chunks_dir, num_per_category=3):
    """Collect a set of random audio samples to use as background noise."""
    background_noises = {}
    categories = [d for d in os.listdir(chunks_dir)
                  if os.path.isdir(os.path.join(chunks_dir, d))]

    for category in categories:
        category_path = os.path.join(chunks_dir, category)
        audio_files = [f for f in os.listdir(category_path) if f.endswith('.wav')]

        # Select random files for background noise
        if len(audio_files) > num_per_category:
            selected_files = random.sample(audio_files, num_per_category)
        else:
            selected_files = audio_files

        # Load audio files
        noises = []
        for file in selected_files:
            audio_path = os.path.join(category_path, file)
            try:
                audio, sr = librosa.load(audio_path, sr=None)
                noises.append((audio, sr))
            except Exception as e:
                print(f"Error loading {audio_path}: {e}")

        background_noises[category] = noises

    return background_noises

# Function to apply random augmentations to an audio file
def augment_audio(audio, sr, category, background_noises):
    """Apply multiple random augmentations to an audio file."""
    # Randomly select the number of augmentations to apply (1-4)
    num_augmentations = random.randint(1, 4)

    # List of available augmentation methods
    augmentation_methods = [
        "white_noise", "time_stretch", "pitch_shift", "volume_change",
        "background_noise", "filter", "speed", "time_shift", "reverb"
    ]

    # Randomly select which augmentations to apply
    selected_augmentations = random.sample(augmentation_methods, num_augmentations)

    # Apply the selected augmentations
    augmented_audio = audio.copy()
    augmentation_desc = []

    for aug_method in selected_augmentations:
        if aug_method == "white_noise":
            # Add white noise
            noise_factor = random.choice([0.003, 0.005, 0.01])
            augmented_audio = add_white_noise(augmented_audio, noise_factor=noise_factor)
            augmentation_desc.append(f"noise_{noise_factor:.3f}")

        elif aug_method == "time_stretch":
            # Time stretching
            stretch_rate = random.choice([0.85, 0.9, 1.1, 1.15])
            try:
                augmented_audio = time_stretch(augmented_audio, rate=stretch_rate)
                augmentation_desc.append(f"stretch_{stretch_rate:.2f}")
            except:
                # Skip if time stretching fails
                pass

        elif aug_method == "pitch_shift":
            # Pitch shifting
            n_steps = random.choice([-2, -1, 1, 2])
            augmented_audio = pitch_shift(augmented_audio, sr=sr, n_steps=n_steps)
            augmentation_desc.append(f"pitch_{n_steps}")

        elif aug_method == "volume_change":
            # Volume change
            vol_factor = random.choice([0.75, 1.25, 1.5])
            augmented_audio = change_volume(augmented_audio, factor=vol_factor)
            augmentation_desc.append(f"vol_{vol_factor:.2f}")

        elif aug_method == "background_noise":
            # Background noise addition (from different category)
            other_categories = [c for c in background_noises.keys() if c != category]
            if other_categories:
                # Select random category and noise
                noise_category = random.choice(other_categories)
                if background_noises[noise_category]:
                    noise_audio, noise_sr = random.choice(background_noises[noise_category])

                    # Apply background noise
                    noise_factor = random.choice([0.1, 0.15, 0.2])
                    augmented_audio = add_background_noise(augmented_audio, noise_audio, noise_factor=noise_factor)
                    augmentation_desc.append(f"bg_{noise_category}_{noise_factor:.2f}")

        elif aug_method == "filter":
            # Apply filters
            filter_type = random.choice(['lowpass', 'highpass'])
            cutoff = random.choice([1000, 2000, 3000])
            augmented_audio = apply_filter(augmented_audio, sr, filter_type=filter_type, cutoff=cutoff)
            augmentation_desc.append(f"filter_{filter_type[:2]}_{cutoff}")

        elif aug_method == "speed":
            # Speed modification
            speed_factor = random.choice([0.9, 1.1])
            if speed_factor > 1.0:
                augmented_audio = speed_up(augmented_audio, factor=speed_factor)
                augmentation_desc.append(f"speedup_{speed_factor:.2f}")
            else:
                augmented_audio = slow_down(augmented_audio, factor=speed_factor)
                augmentation_desc.append(f"slowdown_{speed_factor:.2f}")

        elif aug_method == "time_shift":
            # Time shifting
            shift_factor = random.choice([0.1, 0.2, 0.3])
            augmented_audio = time_shift(augmented_audio, shift_factor=shift_factor)
            augmentation_desc.append(f"shift_{shift_factor:.1f}")

        elif aug_method == "reverb":
            # Reverb effect
            decay = random.choice([1.0, 2.0, 3.0])
            augmented_audio = reverb_effect(augmented_audio, sr, decay=decay)
            augmentation_desc.append(f"reverb_{decay:.1f}")

    # Create a description of all applied augmentations
    aug_description = "_".join(augmentation_desc)

    return augmented_audio, aug_description

# Main function to augment the dataset to reach target file count per class
def augment_dataset_to_target(chunks_dir, target_files_per_class=1500):
    """Augment audio files in each category until reaching the target count."""
    print(f"Starting audio augmentation to reach {target_files_per_class} files per class...")

    # Collect background noise samples
    print("Collecting background noises...")
    background_noises = collect_background_noises(chunks_dir)

    # Get all categories
    categories = [d for d in os.listdir(chunks_dir) if os.path.isdir(os.path.join(chunks_dir, d))]

    print(f"Found {len(categories)} categories: {', '.join(categories)}")

    # Start timing
    start_time = time.time()

    # Process each category
    for category in categories:
        category_path = os.path.join(chunks_dir, category)

        # Get existing files
        existing_files = [f for f in os.listdir(category_path) if f.endswith('.wav')]
        existing_count = len(existing_files)

        print(f"\nCategory: {category}")
        print(f"Existing files: {existing_count}")

        if existing_count >= target_files_per_class:
            print(f"Category {category} already has {existing_count} files. No augmentation needed.")
            continue

        # Calculate how many more files we need
        files_needed = target_files_per_class - existing_count
        print(f"Need to generate {files_needed} new files")

        # Track progress
        with tqdm(total=files_needed, desc=f"Augmenting {category}") as pbar:
            files_created = 0

            while files_created < files_needed:
                # Randomly select a file to augment
                source_file = random.choice(existing_files)
                source_path = os.path.join(category_path, source_file)

                try:
                    # Load audio file
                    audio, sr = librosa.load(source_path, sr=None)

                    # Apply random augmentations
                    augmented_audio, aug_description = augment_audio(audio, sr, category, background_noises)

                    # Create output filename
                    base_name = os.path.splitext(source_file)[0]
                    output_filename = f"{base_name}_aug_{files_created+1}_{aug_description}.wav"
                    output_path = os.path.join(category_path, output_filename)

                    # Save augmented audio
                    sf.write(output_path, augmented_audio, sr)

                    files_created += 1
                    pbar.update(1)

                except Exception as e:
                    print(f"Error processing {source_path}: {e}")
                    continue

        # Verify final count
        final_files = [f for f in os.listdir(category_path) if f.endswith('.wav')]
        final_count = len(final_files)
        print(f"Final file count for {category}: {final_count}")

    # Report total time
    total_time = time.time() - start_time
    print(f"\nAugmentation complete! Total time: {total_time:.1f} seconds")

# Run the augmentation
if __name__ == "__main__":
    augment_dataset_to_target(chunks_dir, TARGET_FILES_PER_CLASS)
    print("Audio dataset augmentation complete! Each class now has approximately 1500 files.")

Starting audio augmentation to reach 1500 files per class...
Collecting background noises...
Found 7 categories: Metro_Station, University, Metro, Restaurant, Park, Shopping_Mall, Bus

Category: Metro_Station
Existing files: 629
Need to generate 871 new files


Augmenting Metro_Station: 100%|██████████| 871/871 [01:52<00:00,  7.76it/s]


Final file count for Metro_Station: 1500

Category: University
Existing files: 724
Need to generate 776 new files


Augmenting University: 100%|██████████| 776/776 [01:26<00:00,  9.00it/s]


Final file count for University: 1500

Category: Metro
Existing files: 576
Need to generate 924 new files


Augmenting Metro: 100%|██████████| 924/924 [01:47<00:00,  8.59it/s]


Final file count for Metro: 1500

Category: Restaurant
Existing files: 988
Need to generate 512 new files


Augmenting Restaurant: 100%|██████████| 512/512 [01:01<00:00,  8.36it/s]


Final file count for Restaurant: 1500

Category: Park
Existing files: 496
Need to generate 1004 new files


Augmenting Park: 100%|██████████| 1004/1004 [02:00<00:00,  8.32it/s]


Final file count for Park: 1500

Category: Shopping_Mall
Existing files: 690
Need to generate 810 new files


Augmenting Shopping_Mall: 100%|██████████| 810/810 [01:35<00:00,  8.51it/s]


Final file count for Shopping_Mall: 1500

Category: Bus
Existing files: 1432
Need to generate 68 new files


Augmenting Bus: 100%|██████████| 68/68 [00:09<00:00,  7.44it/s]

Final file count for Bus: 1500

Augmentation complete! Total time: 592.3 seconds
Audio dataset augmentation complete! Each class now has approximately 1500 files.





In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torchaudio
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import LabelBinarizer
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import gc
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define paths - update these based on your Kaggle environment
DATA_DIR = "/kaggle/working/dataset"  # Update this path
OUTPUT_DIR = "/kaggle/working/wav2vec_finetuned"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Create a function to load audio files and their labels
def load_dataset(data_dir):
    data = []
    
    # Each folder in the data directory is a class
    for class_name in os.listdir(data_dir):
        class_dir = os.path.join(data_dir, class_name)
        
        if os.path.isdir(class_dir):
            # Iterate through all audio files in the class directory
            for file_name in os.listdir(class_dir):
                if file_name.endswith(('.wav', '.mp3', '.flac')):  # Add other formats if needed
                    file_path = os.path.join(class_dir, file_name)
                    data.append((file_path, class_name))
    
    return data

# Custom Dataset class for audio classification
class AudioClassificationDataset(Dataset):
    def __init__(self, data, feature_extractor, max_length=16000*5):  # 5 seconds at 16kHz
        self.data = data
        self.feature_extractor = feature_extractor
        self.max_length = max_length
        
        # Create mapping from class names to indices
        self.class_names = sorted(list(set([label for _, label in data])))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.class_names)}
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        audio_path, class_name = self.data[idx]
        
        # Load audio file
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample if necessary
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)
            sample_rate = 16000
        
        # Ensure the waveform is the right shape (1, n)
        waveform = waveform.squeeze().numpy()
        
        # Truncate or pad to max_length
        if len(waveform) > self.max_length:
            waveform = waveform[:self.max_length]
        else:
            padding = np.zeros(self.max_length - len(waveform))
            waveform = np.concatenate((waveform, padding))
        
        # Get features - more recent versions might not return attention_mask
        inputs = self.feature_extractor(
            waveform, 
            sampling_rate=16000, 
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length
        )
        
        # Create attention mask if not provided by feature extractor
        if not hasattr(inputs, 'attention_mask'):
            # Create attention mask based on non-zero values in input_values
            # Assuming input_values are padded with zeros
            attention_mask = (inputs.input_values != 0).float()
        else:
            attention_mask = inputs.attention_mask
        
        # Get label
        label = self.class_to_idx[class_name]
        
        return {
            "input_values": inputs.input_values.squeeze(),
            "attention_mask": attention_mask.squeeze(),
            "labels": torch.tensor(label)
        }

def train_model(model, train_loader, val_loader, num_epochs, optimizer, scheduler, device):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Initialize tqdm progress bar for training
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for batch in train_pbar:
            input_values = batch["input_values"].to(device)
            attention_mask = batch["attention_mask"].to(device) if "attention_mask" in batch else None
            labels = batch["labels"].to(device)
            
            optimizer.zero_grad()
            
            # Forward pass with or without attention mask
            if attention_mask is not None:
                outputs = model(input_values=input_values, attention_mask=attention_mask, labels=labels)
            else:
                outputs = model(input_values=input_values, labels=labels)
                
            loss = outputs.loss
            logits = outputs.logits
            
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            running_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Update progress bar
            current_loss = loss.item()
            current_acc = (predicted == labels).sum().item() / labels.size(0)
            train_pbar.set_postfix({'loss': f'{current_loss:.4f}', 'acc': f'{current_acc:.4f}'})
            
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        
        # Initialize tqdm progress bar for validation
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Valid]")
        with torch.no_grad():
            for batch in val_pbar:
                input_values = batch["input_values"].to(device)
                attention_mask = batch["attention_mask"].to(device) if "attention_mask" in batch else None
                labels = batch["labels"].to(device)
                
                # Forward pass with or without attention mask
                if attention_mask is not None:
                    outputs = model(input_values=input_values, attention_mask=attention_mask, labels=labels)
                else:
                    outputs = model(input_values=input_values, labels=labels)
                    
                loss = outputs.loss
                logits = outputs.logits
                
                val_running_loss += loss.item()
                
                # Calculate accuracy
                _, predicted = torch.max(logits, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                # Update progress bar
                current_loss = loss.item()
                current_acc = (predicted == labels).sum().item() / labels.size(0)
                val_pbar.set_postfix({'loss': f'{current_loss:.4f}', 'acc': f'{current_acc:.4f}'})
        
        val_epoch_loss = val_running_loss / len(val_loader)
        val_epoch_acc = val_correct / val_total
        val_losses.append(val_epoch_loss)
        val_accuracies.append(val_epoch_acc)
        
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, "
              f"Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}")
        
        # Save memory
        torch.cuda.empty_cache()
        gc.collect()
    
    return train_losses, val_losses, train_accuracies, val_accuracies

def evaluate_model(model, test_loader, class_names, device):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    # Initialize tqdm progress bar for evaluation
    test_pbar = tqdm(test_loader, desc="Evaluating")
    
    with torch.no_grad():
        for batch in test_pbar:
            input_values = batch["input_values"].to(device)
            attention_mask = batch["attention_mask"].to(device) if "attention_mask" in batch else None
            labels = batch["labels"].to(device)
            
            # Forward pass with or without attention mask
            if attention_mask is not None:
                outputs = model(input_values=input_values, attention_mask=attention_mask)
            else:
                outputs = model(input_values=input_values)
                
            logits = outputs.logits
            
            # Get predictions
            probs = torch.nn.functional.softmax(logits, dim=1)
            _, preds = torch.max(logits, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
            # Update progress bar with current batch accuracy
            current_acc = (preds == labels).sum().item() / labels.size(0)
            test_pbar.set_postfix({'acc': f'{current_acc:.4f}'})
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    # Print classification report
    report = classification_report(all_labels, all_preds, target_names=class_names, zero_division=0)
    print("\nClassification Report:")
    print(report)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'confusion_matrix.png'))
    plt.close()
    
    # Calculate ROC curve and AUC for each class (one-vs-rest)
    lb = LabelBinarizer()
    lb.fit(range(len(class_names)))
    all_labels_bin = lb.transform(all_labels)
    all_probs = np.array(all_probs)
    
    plt.figure(figsize=(10, 8))
    
    for i, class_name in enumerate(class_names):
        fpr, tpr, _ = roc_curve(all_labels_bin[:, i], all_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'roc_curve.png'))
    plt.close()
    
    return accuracy, precision, recall, f1, all_preds, all_labels, all_probs

def plot_learning_curves(train_losses, val_losses, train_accs, val_accs):
    # Plot training and validation loss
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    # Plot training and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Training Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'learning_curves.png'))
    plt.close()

def main():
    # Step 1: Load dataset
    print("Loading dataset...")
    data = load_dataset(DATA_DIR)
    
    # Print dataset statistics
    class_names = sorted(list(set([label for _, label in data])))
    print(f"Found {len(data)} audio files across {len(class_names)} classes:")
    for class_name in class_names:
        count = sum(1 for _, label in data if label == class_name)
        print(f"  - {class_name}: {count} files")
    
    # Step 2: Split into train, validation and test sets (80%, 10%, 10%)
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42, stratify=[label for _, label in data])
    train_data, val_data = train_test_split(train_data, test_size=0.125, random_state=42, stratify=[label for _, label in train_data])  # 0.125 of 80% = 10% of total
    
    print(f"Train set: {len(train_data)} samples")
    print(f"Validation set: {len(val_data)} samples")
    print(f"Test set: {len(test_data)} samples")
    
    # Step 3: Initialize model and feature extractor
    print("Initializing model and feature extractor...")
    model_name = "facebook/wav2vec2-base"  # You can try other variants if needed
    
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
    
    # Create a model for sequence classification
    model = Wav2Vec2ForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(class_names),
        attention_dropout=0.1
    )
    
    # Move model to device
    model = model.to(device)
    
    # Step 4: Create datasets and dataloaders
    print("Creating datasets and dataloaders...")
    
    # Show progress with tqdm
    print("Preparing train dataset...")
    train_dataset = AudioClassificationDataset(train_data, feature_extractor)
    
    print("Preparing validation dataset...")
    val_dataset = AudioClassificationDataset(val_data, feature_extractor)
    
    print("Preparing test dataset...")
    test_dataset = AudioClassificationDataset(test_data, feature_extractor)
    
    # Smaller batch size to avoid out of memory issues on Kaggle
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8)
    test_loader = DataLoader(test_dataset, batch_size=8)
    
    # Step 5: Set up training parameters
    num_epochs = 20  # Start with a small number for testing
    
    # Optimizer with weight decay
    optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    
    # Learning rate scheduler
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    # Step 6: Train the model
    print("Training model...")
    train_losses, val_losses, train_accs, val_accs = train_model(
        model, train_loader, val_loader, num_epochs, optimizer, scheduler, device
    )
    
    # Step 7: Save the model
    model.save_pretrained(OUTPUT_DIR)
    feature_extractor.save_pretrained(OUTPUT_DIR)
    print(f"Model saved to {OUTPUT_DIR}")
    
    # Step 8: Plot learning curves
    print("Plotting learning curves...")
    plot_learning_curves(train_losses, val_losses, train_accs, val_accs)
    
    # Step 9: Evaluate on test set
    print("Evaluating model on test set...")
    accuracy, precision, recall, f1, all_preds, all_labels, all_probs = evaluate_model(
        model, test_loader, train_dataset.class_names, device
    )
    
    # Save results to CSV
    results_df = pd.DataFrame({
        'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score'],
        'Value': [accuracy, precision, recall, f1]
    })
    results_df.to_csv(os.path.join(OUTPUT_DIR, 'metrics.csv'), index=False)
    
    print("\nEvaluation complete. Results saved to", OUTPUT_DIR)

if __name__ == "__main__":
    main()

Using device: cuda
Loading dataset...
Found 10500 audio files across 7 classes:
  - Bus: 1500 files
  - Metro: 1500 files
  - Metro_Station: 1500 files
  - Park: 1500 files
  - Restaurant: 1500 files
  - Shopping_Mall: 1500 files
  - University: 1500 files
Train set: 7350 samples
Validation set: 1050 samples
Test set: 2100 samples
Initializing model and feature extractor...


Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Creating datasets and dataloaders...
Preparing train dataset...
Preparing validation dataset...
Preparing test dataset...
Training model...


Epoch 1/20 [Train]:   0%|          | 0/919 [00:00<?, ?it/s]