In [None]:
import datasets

ds = datasets.load_dataset("renumics/esc50",split='train')

In [None]:
from datasets import Audio
ds = ds.cast_column("audio", Audio(sampling_rate=None, mono=True, decode=False, id=None))

In [None]:
from msclap import CLAP

#build text embeddings
names = ds.features['label'].names
clap_model = CLAP(version = '2023', use_cuda=False)

prompt = 'this is the sound of '
y = [prompt + x for x in names]

label_embeddings= clap_model.get_text_embeddings(y)


In [None]:
#add text embedding to dataset
labels = ds['label']
label_embedding_numpy = label_embeddings
label_embedding_column = [label_embedding_numpy[i,:].tolist() for i in labels]
ds = ds.add_column('text_embedding', label_embedding_column)


In [None]:
from msclap import CLAP
from io import BytesIO
import torch.nn.functional as F
import numpy as np
from scipy.stats import entropy

def extract_embeddings(label_embeddings):
    """Utility to compute embeddings."""
    clap_model = CLAP(version = '2023', use_cuda=False)

    def pp(batch):
        audio_bytes = [BytesIO(audio['bytes']) for audio in batch["audio"]]
        #text_embeddings = clap_model.get_text_embeddings(audio_filenames)
        audio_embeddings = clap_model.get_audio_embeddings(audio_bytes)
     
        similarity = clap_model.compute_similarity(audio_embeddings, label_embeddings)
        probs = F.softmax(similarity.detach().cpu(), dim=1).numpy()
        #print(probs.shape)
        prediction= np.argmax(probs, axis=1)
        probs_entropy = entropy(probs, axis=1)
        #print(probs_entropy.shape)

        return {'prediction': prediction, 'entropy': probs_entropy, 'audio_embedding': audio_embeddings}
                


    return pp

In [None]:
ds = ds.map(extract_embeddings(label_embeddings), batched=True, batch_size=4)

In [None]:
features = ds.features
features['prediction'] = features['label']
ds = ds.cast(features)
pred_incorrect = [ True if ds[i]['prediction'] != ds[i]['label'] else False for i in range(len(ds))]
ds = ds.add_column('pred_incorrect', pred_incorrect)


In [None]:
ds = ds.remove_columns(['src_file', 'fold', 'label', 'esc10', 'take', 'audio'])

In [None]:
ds_dict= datasets.DatasetDict({'train': ds})

ds_dict.push_to_hub('renumics/esc50-clap2023-results')

In [None]:
from renumics import spotlight

spotlight.show(ds, dtype={'audio_embedding': spotlight.Embedding, 'text_embedding': spotlight.Embedding})