In [1]:
import torch
import torchaudio
import sys
import os
import random
import IPython.display as ipd
import numpy as np
import gradio as gr

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[setup]: determining CUDA support...")
print("PyTorch version:", torch.__version__)
print("Torchaudio version:", torchaudio.__version__)
print("CUDA is available:", torch.cuda.is_available())


[setup]: determining CUDA support...
PyTorch version: 1.13.0+cu117
Torchaudio version: 0.13.0+cu117
CUDA is available: True


In [3]:
# libs = [
#     "pytorchvideo@git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d",
#     "timm",
#     "ftfy",
#     "regex",
#     "einops",
#     "fvcore",
#     "decord"
# ]

# for lib in libs:
#     command = f"pip install {lib}"
#     os.system(command)

# os.system("wget -nc -P models https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth")
# model_path = "./.checkpoints/imagebind_huge.pth"

imagebindmodels_path = os.path.abspath('./models')
if imagebindmodels_path not in sys.path:
    sys.path.append(imagebindmodels_path)
imagebindbpe_path = os.path.abspath('./bpe')
if imagebindbpe_path not in sys.path:
    sys.path.append(imagebindbpe_path)

import data
from models import imagebind_model
from models.imagebind_model import ModalityType



bpe/bpe_simple_vocab_16e6.txt.gz


In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Instantiate ImageBind model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model = model.to(device)

In [5]:
from torchaudio.datasets import SPEECHCOMMANDS
import os

class SubsetSC(SPEECHCOMMANDS):
    def __init__(self, subset: str = None):
        super().__init__("./data", download=True)

        def load_list(filename):
            filepath = os.path.join(self._path, filename)
            with open(filepath) as fileobj:
                return [os.path.normpath(os.path.join(self._path, line.strip())) for line in fileobj]

        if subset == "validation":
            self._walker = load_list("validation_list.txt")
        elif subset == "testing":
            self._walker = load_list("testing_list.txt")
        elif subset == "training":
            excludes = load_list("validation_list.txt") + load_list("testing_list.txt")
            excludes = set(excludes)
            self._walker = [w for w in self._walker if w not in excludes]

train_set = SubsetSC("training")
test_set = SubsetSC("testing")
val_set = SubsetSC("validation")

labels = sorted(list(set(datapoint[2] for datapoint in train_set)))
labels.append("unknown")
labels.extend("silence")

In [6]:
# get random sample in the validation set
sample = random.choice(val_set)
waveform, sample_rate, label, speaker_id, utterance_number = sample
print("ground truth label:", label)
sample_path = f"data/SpeechCommands/speech_commands_v0.02/{label}/{speaker_id}_nohash_{utterance_number}.wav"
print("file path:",sample_path)
# plt.plot(waveform.t().numpy())
ipd.Audio(waveform.numpy(), rate=sample_rate)

ground truth label: cat
file path: data/SpeechCommands/speech_commands_v0.02/cat/2a89ad5c_nohash_2.wav


In [9]:
def classify_audio(filepath):
    """
    classify an audio recording from the keyword dataset
    
    inputs:
        filepath(string): file path to the audio file
    """
    text_list=labels
    audio_paths=[filepath]

    # Load data
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
        ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
    }

    with torch.no_grad():
        embeddings = model(inputs)

    probs = torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1)
    from pprint import pprint
    index = np.argmax(probs.cpu().numpy())
    label = labels[index]
    return label
    

gr.Interface(
    fn=classify_audio, 
    inputs=gr.Audio(source="microphone", type="filepath"), 
    outputs="text").launch(share=True)


label = classify_audio(sample_path)
print(label)

Running on local URL:  http://127.0.0.1:7862
Running on public URL: https://435dde82558a88a5c5.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




go
