In [None]:
!pip install torch numpy librosa soundfile transformers gradio

In [None]:
import torch
import numpy as np
import librosa
import json
import soundfile as sf
import torch.nn as nn
from transformers import HubertModel
from torch.nn import functional as F
import os
import gradio as gr

# Constants (matching those used during training)
SAMPLE_RATE = 16000
MAX_AUDIO_LENGTH = 6 * SAMPLE_RATE  # 6 seconds at 16kHz
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model class (same as in your training code)
class HubertForAudioClassification(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.hubert = HubertModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.hubert.config.hidden_size, num_labels)
        self.num_labels = num_labels
        
    def forward(self, input_values, attention_mask=None, labels=None):
        outputs = self.hubert(input_values=input_values, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        pooled_output = torch.mean(hidden_states, dim=1)
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
        return {"loss": loss, "logits": logits} if loss is not None else logits

def process_audio_for_inference(audio_data, sr, max_length=MAX_AUDIO_LENGTH):
    """
    Process audio data for inference
    """
    try:
        # Ensure audio is mono
        if len(audio_data.shape) > 1:
            audio_data = audio_data.mean(axis=1)
        
        # Resample if needed
        if sr != SAMPLE_RATE:
            audio_data = librosa.resample(y=audio_data, orig_sr=sr, target_sr=SAMPLE_RATE)
        
        # Handle length (trim or pad)
        if len(audio_data) > max_length:
            audio_data = audio_data[:max_length]
        else:
            padding = max_length - len(audio_data)
            audio_data = np.pad(audio_data, (0, padding), 'constant')
        
        # Normalize
        audio_data = audio_data / (np.max(np.abs(audio_data)) + 1e-6)
        
        return audio_data.astype(np.float32)
    
    except Exception as e:
        print(f"Error processing audio data: {e}")
        return np.zeros(max_length, dtype=np.float32)

def predict_audio_from_gradio(audio_input, model, label_mapping):
    """
    Run inference on audio input from Gradio
    
    Args:
        audio_input: Tuple of (sample_rate, audio_data) from Gradio
        model: Loaded HuBERT classification model
        label_mapping: Dictionary mapping class indices to class names
        
    Returns:
        Dictionary with prediction results and chart data
    """
    sr, audio_data = audio_input
    
    # Process the audio
    processed_audio = process_audio_for_inference(audio_data, sr)
    
    # Convert to tensor and add batch dimension
    input_values = torch.tensor(processed_audio, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    
    # Create attention mask (all ones, same shape as input_values)
    attention_mask = torch.ones_like(input_values).to(DEVICE)
    
    # Set model to evaluation mode
    model.eval()
    
    # Run inference
    with torch.no_grad():
        outputs = model(input_values=input_values, attention_mask=attention_mask)
    
    # Handle different output types
    if isinstance(outputs, dict):
        logits = outputs["logits"]
    else:
        logits = outputs
    
    # Get predictions
    probabilities = F.softmax(logits, dim=1)[0]
    predicted_class_idx = torch.argmax(probabilities).item()
    predicted_class = label_mapping[str(predicted_class_idx)]
    confidence = probabilities[predicted_class_idx].item()
    
    # Get all probabilities for chart
    all_probs = {label_mapping[str(i)]: float(probabilities[i].item()) for i in range(len(label_mapping))}
    
    # Sort probabilities for better visualization
    sorted_probs = sorted(all_probs.items(), key=lambda x: x[1], reverse=True)
    labels = [item[0] for item in sorted_probs]
    values = [item[1] for item in sorted_probs]
    
    # Create result message
    result_message = f"Predicted: {predicted_class} (Confidence: {confidence:.2f})"
    
    return {
        "prediction": result_message,
        "chart": (labels, values)
    }

def create_gradio_interface():
    # Model and label mapping paths - update these to your actual paths
    model_path = "model_state_dict.pt"  # Update this path
    label_path = "label_mapping.json"   # Update this path
    
    # Load the label mapping
    with open(label_path, 'r') as f:
        label_mapping = json.load(f)
    
    # Initialize the model
    num_labels = len(label_mapping)
    model = HubertForAudioClassification("facebook/hubert-base-ls960", num_labels)
    
    # Load model weights
    state_dict = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(state_dict)
    model = model.to(DEVICE)
    
    # Define the prediction function
    def predict(audio):
        if audio is None:
            return {"prediction": "No audio provided", "chart": ([], [])}
        
        result = predict_audio_from_gradio(audio, model, label_mapping)
        return result
    
    # Create Gradio interface
    with gr.Blocks() as demo:
        gr.Markdown("# Audio Classification with HuBERT")
        
        with gr.Row():
            audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Audio Input")
        
        with gr.Row():
            predict_btn = gr.Button("Predict")
        
        with gr.Row():
            prediction_output = gr.Textbox(label="Prediction")
        
        with gr.Row():
            chart_output = gr.BarPlot(
                x="Class", 
                y="Probability",
                title="Class Probabilities",
                x_title="Class",
                y_title="Probability",
                height=400,
                width=600
            )
        
        # Set up the click event
        predict_btn.click(
            fn=predict, 
            inputs=[audio_input], 
            outputs={
                "prediction": prediction_output,
                "chart": chart_output
            }
        )
    
    return demo

if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(share=True)  # Set share=False if you don't want a public link