In [None]:
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model, BertTokenizer, BertModel
import torchaudio
import pinecone
import numpy as np
import evaluate
from multimodal_training_pipeline import MultimodalCommandClassifier


In [None]:
# Pinecone setup
pinecone.init(api_key="", environment="us-east-1")
index = pinecone.Index("voice-command-index")


In [None]:
# Load processors and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
labels = ["turn on light", "turn off light", "play music", "stop music", "increase volume", "decrease volume", "open window", "close window", "set alarm", "what time is it"]

In [None]:
model = MultimodalCommandClassifier(num_classes=len(labels))
model.load_state_dict(torch.load("multimodal_spoken_cmd_model.pt", map_location=torch.device("cpu")))
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Load evaluation metrics
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")


In [None]:
def predict_command(audio_path):
    waveform, sr = torchaudio.load(audio_path)
    audio_inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt", padding="max_length", truncation=True)
    audio_input = audio_inputs.input_values.to(device)
    audio_attention = audio_inputs.attention_mask.to(device)

    # For each label, run BERT encoder
    all_logits = []
    for label_text in labels:
        text_inputs = tokenizer(label_text, return_tensors="pt", padding="max_length", truncation=True)
        text_input = text_inputs.input_ids.to(device)
        text_attention = text_inputs.attention_mask.to(device)

        with torch.no_grad():
            logits = model(audio_input, audio_attention, text_input, text_attention)
            all_logits.append(logits.squeeze(0))

    all_logits = torch.stack(all_logits)
    pred_index = torch.argmax(all_logits.mean(dim=0)).item()
    return labels[pred_index]


In [None]:
# Benchmark over test set
def evaluate_on_testset(test_dataset):
    preds, ground_truth = [], []
    for item in test_dataset:
        pred = predict_command(item["audio"]["path"])
        label = item["label"]
        preds.append(pred)
        ground_truth.append(label)

In [None]:

    acc = accuracy.compute(predictions=preds, references=ground_truth)
    f1_score = f1.compute(predictions=preds, references=ground_truth, average="weighted")
    print(f"Saved Model Inference Accuracy: {acc['accuracy']:.4f}, F1 Score: {f1_score['f1']:.4f}")
    return acc, f1_score


# Gradio

In [None]:
# Predict command using Pinecone

def encode_audio(audio_path):
    waveform, sr = torchaudio.load(audio_path)
    inputs = wav_proc(waveform.squeeze(), sampling_rate=sr, return_tensors="pt")
    with torch.no_grad():
        emb = wav_model(**inputs).last_hidden_state.mean(dim=1).squeeze().numpy()
    return emb


def top_k_predictions(audio_path, k=3):
    audio_vec = encode_audio(audio_path)
    res = index.query(vector=audio_vec, top_k=k, include_metadata=True)
    predictions = [(match['id'], match['score']) for match in res['matches']]
    return predictions

# UI function
def gradio_pipeline(audio):
    waveform, sr = torchaudio.load(audio)
    preds = top_k_predictions(audio, k=3)
    pred_text = "\n".join([f"{cmd}: {score:.3f}" for cmd, score in preds])
    return waveform.squeeze().numpy(), pred_text

# Gradio app
audio_input = gr.Audio(type="filepath", label="Input Audio")
wave_output = gr.Plot(label="Waveform")
text_output = gr.Textbox(label="Top-3 Predicted Commands with Scores")

app = gr.Interface(
    fn=gradio_pipeline,
    inputs=audio_input,
    outputs=[wave_output, text_output],
    title="Multimodal Spoken Command Inference with Pinecone",
    description="Upload or record audio. Predict top-k commands using Wav2Vec2 + Pinecone retrieval."
)

if __name__ == "__main__":
    app.launch()
