In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from datasets import Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import soundfile as sf
import io
import ast
import os
import librosa
import json
from collections import Counter
import random
import time
import tensorflow as tf
import tensorflow_hub as hub

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAMPLE_RATE = 16000
MAX_AUDIO_LENGTH = 6 * SAMPLE_RATE

class YAMNetForAudioClassification(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        # We'll load YAMNet embedding model separately
        # Only need to create the classifier layers here
        self.classifier = nn.Sequential(
            nn.Linear(1024, 512),  # YAMNet outputs 1024-dim embeddings
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_labels)
        )
        self.num_labels = num_labels
        
    def forward(self, embeddings, labels=None):
        logits = self.classifier(embeddings)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
        return {"loss": loss, "logits": logits} if loss is not None else logits


# Load YAMNet model
def load_yamnet():
    print("Loading YAMNet model...")
    yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')
    return yamnet_model


def get_yamnet_embeddings(audio_data):
    """Extract embeddings from YAMNet model"""
    scores, embeddings, _ = yamnet_model(audio_data)
    # Average the embeddings across time
    embedding = tf.reduce_mean(embeddings, axis=0).numpy()
    return embedding


def augment_audio(audio, augmentation_type="all"):
    """Apply audio augmentation techniques"""
    if augmentation_type == "none" or np.random.random() < 0.1:
        return audio
    
    # Make a copy to avoid modifying original
    augmented = np.copy(audio)
    
    if augmentation_type == "all" or augmentation_type == "noise":
        # Add random noise
        if np.random.random() < 0.5:
            noise_level = np.random.uniform(0.001, 0.01)
            noise = np.random.normal(0, noise_level, len(augmented))
            augmented = augmented + noise
    
    if augmentation_type == "all" or augmentation_type == "pitch":
        # Pitch shift
        if np.random.random() < 0.5:
            n_steps = np.random.uniform(-3, 3)
            augmented = librosa.effects.pitch_shift(augmented, sr=SAMPLE_RATE, n_steps=n_steps)
    
    if augmentation_type == "all" or augmentation_type == "speed":
        # Time stretching
        if np.random.random() < 0.5:
            rate = np.random.uniform(0.8, 1.2)
            augmented = librosa.effects.time_stretch(augmented, rate=rate)
            
            # Ensure the length is still correct
            if len(augmented) > MAX_AUDIO_LENGTH:
                augmented = augmented[:MAX_AUDIO_LENGTH]
            elif len(augmented) < MAX_AUDIO_LENGTH:
                padding = MAX_AUDIO_LENGTH - len(augmented)
                augmented = np.pad(augmented, (0, padding), 'constant')
    
    if augmentation_type == "all" or augmentation_type == "shift":
        # Time shifting
        if np.random.random() < 0.5:
            shift_amount = int(np.random.uniform(-SAMPLE_RATE * 0.5, SAMPLE_RATE * 0.5))
            augmented = np.roll(augmented, shift_amount)
    
    # Normalize the audio again
    augmented = augmented / (np.max(np.abs(augmented)) + 1e-6)
    
    return augmented.astype(np.float32)


def process_audio(audio_bytes, max_length=MAX_AUDIO_LENGTH, apply_augmentation=False):
    try:
        audio_data, sr = sf.read(io.BytesIO(audio_bytes))
        
        if len(audio_data.shape) > 1:
            audio_data = audio_data.mean(axis=1)
        
        if sr != SAMPLE_RATE:
            audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=SAMPLE_RATE)
        
        if len(audio_data) > max_length:
            audio_data = audio_data[:max_length]
        else:
            padding = max_length - len(audio_data)
            audio_data = np.pad(audio_data, (0, padding), 'constant')
        
        audio_data = audio_data / (np.max(np.abs(audio_data)) + 1e-6)
        
        if apply_augmentation:
            audio_data = augment_audio(audio_data)
            
        # Get YAMNet embeddings
        embedding = get_yamnet_embeddings(audio_data)
        return embedding
    except Exception as e:
        print(f"Error processing audio: {e}")
        return np.zeros(1024, dtype=np.float32)  # YAMNet embedding dimension is 1024


def process_audio_from_file(file_path, max_length=MAX_AUDIO_LENGTH):
    try:
        audio_data, sr = sf.read(file_path)
        
        if len(audio_data.shape) > 1:
            audio_data = audio_data.mean(axis=1)
        
        if sr != SAMPLE_RATE:
            audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=SAMPLE_RATE)
        
        if len(audio_data) > max_length:
            audio_data = audio_data[:max_length]
        else:
            padding = max_length - len(audio_data)
            audio_data = np.pad(audio_data, (0, padding), 'constant')
        
        audio_data = audio_data / (np.max(np.abs(audio_data)) + 1e-6)
        
        # Get YAMNet embeddings
        embedding = get_yamnet_embeddings(audio_data)
        return embedding
    except Exception as e:
        print(f"Error processing audio file: {e}")
        return np.zeros(1024, dtype=np.float32)


def create_augmented_dataset(dataframe, augmentation_multiplier=5):
    """Create a dataset with augmented audio samples"""
    embeddings = []
    labels = []
    
    for idx, row in dataframe.iterrows():
        try:
            if isinstance(row["audio"], str):
                audio_dict = ast.literal_eval(row["audio"])
            else:
                audio_dict = row["audio"]
                
            audio_bytes = audio_dict['bytes']
            
            # Add the original sample
            processed_embedding = process_audio(audio_bytes)
            embeddings.append(processed_embedding)
            labels.append(row["label_encoded"])
            
            # Add augmented samples
            for _ in range(augmentation_multiplier - 1):
                augmented_embedding = process_audio(audio_bytes, apply_augmentation=True)
                embeddings.append(augmented_embedding)
                labels.append(row["label_encoded"])
                
        except Exception as e:
            print(f"Error creating dataset: {e}")
    
    return {
        "embeddings": np.array(embeddings, dtype=np.float32),
        "label": np.array(labels, dtype=np.int64)
    }


def compute_metrics(pred):
    logits = pred.predictions
    if isinstance(logits, tuple):
        logits = logits[0]
    pred_labels = np.argmax(logits, axis=1)
    
    accuracy = np.mean(pred_labels == pred.label_ids)
    
    # Calculate confusion matrix
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(pred.label_ids, pred_labels)
    
    # Per-class metrics
    class_accuracies = {}
    for class_id in range(len(np.unique(pred.label_ids))):
        class_mask = pred.label_ids == class_id
        if np.sum(class_mask) > 0:
            class_acc = np.mean(pred_labels[class_mask] == class_id)
            class_accuracies[f"class_{class_id}_acc"] = class_acc
    
    return {"accuracy": accuracy, **class_accuracies}


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            "embeddings": self.embeddings[idx],
            "labels": self.labels[idx]
        }


def collate_fn(batch):
    embeddings = torch.tensor(np.stack([item["embeddings"] for item in batch]), dtype=torch.float32)
    labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
    
    return {
        "embeddings": embeddings,
        "labels": labels
    }


def train_model(model, train_dataset, val_dataset, num_epochs=5, batch_size=32, learning_rate=1e-4):
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        collate_fn=collate_fn
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False,
        collate_fn=collate_fn
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=learning_rate,
        total_steps=num_epochs * len(train_loader),
        pct_start=0.1,
        anneal_strategy='cos'
    )
    
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0
    best_model_state = None
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch in train_loader:
            embeddings = batch["embeddings"].to(device)
            labels = batch["labels"].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(embeddings, labels)
            loss = outputs["loss"]
            logits = outputs["logits"]
            
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            
            _, predicted = torch.max(logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                embeddings = batch["embeddings"].to(device)
                labels = batch["labels"].to(device)
                
                outputs = model(embeddings, labels)
                loss = outputs["loss"]
                logits = outputs["logits"]
                
                val_loss += loss.item()
                
                _, predicted = torch.max(logits, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
    
    # Load best model
    model.load_state_dict(best_model_state)
    return model, best_val_acc


def predict_audio(audio_path, model, label_mapping):
    processed_embedding = process_audio_from_file(audio_path)
    input_values = torch.tensor(processed_embedding, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(input_values)
    
    if isinstance(outputs, dict):
        logits = outputs["logits"]
    else:
        logits = outputs
        
    predicted_class_idx = torch.argmax(logits, dim=1).item()
    predicted_class = label_mapping[predicted_class_idx]
    
    probs = torch.nn.functional.softmax(logits, dim=1)[0]
    probability = probs[predicted_class_idx].item()
    
    all_probs = {label_mapping[i]: probs[i].item() for i in range(len(label_mapping))}
    
    return predicted_class, probability, all_probs


if __name__ == "__main__":
    # Set random seeds for reproducibility
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Load YAMNet model
    yamnet_model = load_yamnet()

    # Data loading
    esc50_path = 'data/train-00000-of-00001-cd782ca55710a2e6.parquet'
    esc50_df = pd.read_parquet("hf://datasets/mskov/ESC50/" + esc50_path)

    target_categories = [
        'dog', 'chirping_birds', 'crow', 'sheep', 'frog', 
        'cow', 'insects', 'hen', 'pig', 'cat', 'crickets'
    ]

    esc50_df = esc50_df[['audio', 'category']]
    esc50_df = esc50_df[esc50_df['category'].isin(target_categories)]

    human_df = pd.read_parquet("hf://datasets/pipecat-ai/human_5_all/data/train-00000-of-00001.parquet")
    human_df = human_df[human_df['endpoint_bool'] == True][['audio']].copy()
    human_df['category'] = 'human'

    # Downsample human class to prevent imbalance
    human_df = human_df.sample(100, random_state=seed)

    combined_df = pd.concat([esc50_df, human_df], ignore_index=True)

    print(f"Original class distribution:")
    print(combined_df['category'].value_counts())

    # Balance dataset
    category_counts = combined_df['category'].value_counts()
    min_samples = min(category_counts)

    balanced_df = pd.DataFrame()
    for category in combined_df['category'].unique():
        category_df = combined_df[combined_df['category'] == category]
        if len(category_df) > min_samples:
            category_df = category_df.sample(min_samples, random_state=seed)
        balanced_df = pd.concat([balanced_df, category_df], ignore_index=True)

    print(f"Balanced class distribution:")
    print(balanced_df['category'].value_counts())

    # Shuffle the dataset
    balanced_df = balanced_df.sample(frac=1, random_state=seed).reset_index(drop=True)

    # Encode labels
    label_encoder = LabelEncoder()
    balanced_df["label_encoded"] = label_encoder.fit_transform(balanced_df["category"])
    num_labels = len(label_encoder.classes_)

    print(f"Label mapping:")
    label_mapping = {i: label for i, label in enumerate(label_encoder.classes_)}
    for i, label in label_mapping.items():
        print(f"{i}: {label}")

    # Split into train/validation sets
    train_df, eval_df = train_test_split(balanced_df, test_size=0.2, stratify=balanced_df["category"], random_state=seed)

    print(f"Number of training samples: {len(train_df)}")
    print(f"Number of validation samples: {len(eval_df)}")

    # Create datasets with augmentation for training
    augmentation_multiplier = 10  # Create 10x more samples
    print(f"Applying {augmentation_multiplier}x augmentation to training data...")
    print("Extracting YAMNet embeddings for training data...")
    train_data = create_augmented_dataset(train_df, augmentation_multiplier=augmentation_multiplier)
    print("Extracting YAMNet embeddings for validation data...")
    eval_data = create_augmented_dataset(eval_df, augmentation_multiplier=1)  # No augmentation for eval

    print(f"Number of augmented training samples: {len(train_data['label'])}")
    print(f"Number of validation samples: {len(eval_data['label'])}")

    # Create custom datasets
    train_dataset = CustomDataset(train_data['embeddings'], train_data['label'])
    eval_dataset = CustomDataset(eval_data['embeddings'], eval_data['label'])

    # Create model
    model = YAMNetForAudioClassification(num_labels)
    model = model.to(device)

    # Train model
    print("Starting training...")
    model, best_val_acc = train_model(
        model,
        train_dataset,
        eval_dataset,
        num_epochs=5,
        batch_size=32,
        learning_rate=5e-5
    )
    print(f"Training completed! Best validation accuracy: {best_val_acc:.4f}")

    # Save model
    output_dir = f"./final-yamnet-audio-classifier-{int(time.time())}"
    os.makedirs(output_dir, exist_ok=True)

    torch.save(model.state_dict(), f"{output_dir}/model_state_dict.pt")

    with open(f'{output_dir}/label_mapping.json', 'w') as f:
        json.dump(label_mapping, f)

    # Test prediction
    test_audio_path = "/kaggle/input/sjc-js/mixkit-dog-barking-twice-1.wav"
    predicted_class, confidence, all_probs = predict_audio(test_audio_path, model, label_mapping)
    print(f"Predicted class: {predicted_class}")
    print(f"Confidence: {confidence:.4f}")
    print("All class probabilities:")
    for label, prob in sorted(all_probs.items(), key=lambda x: x[1], reverse=True):
        print(f"  {label}: {prob:.4f}")

2025-04-18 19:49:56.345834: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745005796.590872      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745005796.672349      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading YAMNet model...


I0000 00:00:1745005812.838787      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1745005812.839524      19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


Original class distribution:
category
human             100
crow               24
insects            24
sheep              20
crickets           20
cow                20
cat                19
dog                18
frog               17
chirping_birds     16
hen                16
pig                16
Name: count, dtype: int64
Balanced class distribution:
category
dog               16
chirping_birds    16
crow              16
sheep             16
frog              16
cow               16
insects           16
hen               16
pig               16
cat               16
crickets          16
human             16
Name: count, dtype: int64
Label mapping:
0: cat
1: chirping_birds
2: cow
3: crickets
4: crow
5: dog
6: frog
7: hen
8: human
9: insects
10: pig
11: sheep
Number of training samples: 153
Number of validation samples: 39
Applying 10x augmentation to training data...
Extracting YAMNet embeddings for training data...


I0000 00:00:1745005846.797703      63 cuda_dnn.cc:529] Loaded cuDNN version 90300


Extracting YAMNet embeddings for validation data...
Number of augmented training samples: 1530
Number of validation samples: 39
Starting training...
Epoch 1/5 | Train Loss: 2.4371 | Train Acc: 0.1765 | Val Loss: 2.4960 | Val Acc: 0.3846
Epoch 2/5 | Train Loss: 2.2649 | Train Acc: 0.5516 | Val Loss: 2.4070 | Val Acc: 0.5897
Epoch 3/5 | Train Loss: 2.1383 | Train Acc: 0.6595 | Val Loss: 2.2562 | Val Acc: 0.6923
Epoch 4/5 | Train Loss: 2.0687 | Train Acc: 0.7124 | Val Loss: 2.2146 | Val Acc: 0.7179
Epoch 5/5 | Train Loss: 2.0344 | Train Acc: 0.7163 | Val Loss: 2.2077 | Val Acc: 0.7179
Training completed! Best validation accuracy: 0.7179
Error processing audio file: Error opening '/kaggle/input/sjc-js/mixkit-dog-barking-twice-1.wav': System error.
Predicted class: insects
Confidence: 0.0855
All class probabilities:
  insects: 0.0855
  sheep: 0.0851
  chirping_birds: 0.0847
  human: 0.0845
  cow: 0.0841
  cat: 0.0839
  hen: 0.0832
  pig: 0.0830
  crow: 0.0826
  dog: 0.0815
  crickets: 0.081