In [1]:
import os

from tqdm.auto import tqdm
import numpy as np

from datasets import load_dataset
from panns_inference import AudioTagging
import soundfile as sf

import os
from dotenv import load_dotenv
load_dotenv()

import cassio

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
audio_dataset = load_dataset("ashraq/esc50", split="train")



In [3]:
from transformers import BertModel, BertTokenizer
import torch

text_model_name = 'bert-base-uncased'
text_model = BertModel.from_pretrained(text_model_name)
tokenizer = BertTokenizer.from_pretrained(text_model_name)

def get_bert_embeddings_batch(texts):
    encoded_input = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    
    with torch.no_grad():
        output = text_model(**encoded_input)
    
    last_hidden_state = output.last_hidden_state

    attention_mask = encoded_input['attention_mask'].unsqueeze(-1)
    sum_embeddings = torch.sum(last_hidden_state * attention_mask, dim=1)
    sum_mask = torch.sum(attention_mask, dim=1)
    embeddings = sum_embeddings / sum_mask
    embeddings = [embedding.tolist() for embedding in embeddings]
    
    return embeddings

texts = ["str2",
         "str3",
         "str1"]

embeddings = get_bert_embeddings_batch(texts)
print(embeddings)

[[-0.017076456919312477, -0.25058627128601074, -0.17424842715263367, -0.16393253207206726, 0.024736713618040085, -4.159212039667182e-05, 0.29253286123275757, -0.4058496356010437, 0.1399151086807251, -0.04161446541547775, -0.08672268688678741, 0.2666170597076416, 0.06936009973287582, 0.17359909415245056, 0.016446828842163086, 0.13338644802570343, 0.052063845098018646, 0.13876721262931824, 0.22121784090995789, 0.11722362041473389, 0.062032680958509445, -0.027402931824326515, 0.0756334736943245, 0.07299835234880447, 0.5308029651641846, 0.20805203914642334, -0.3630518317222595, -0.02770533785223961, -0.009243829175829887, 0.012901735492050648, 0.06954209506511688, -0.16931724548339844, -0.12104751914739609, -0.17488297820091248, -0.43916773796081543, -0.16322734951972961, -0.23581556975841522, -0.19000443816184998, -0.2570960223674774, 0.017036497592926025, -0.1859849989414215, -0.31987059116363525, 0.006716573145240545, 0.09521719068288803, 0.30772632360458374, -0.10284501314163208, -0.43

In [4]:
GPU_AVAILABLE = torch.cuda.device_count() > 0

if GPU_AVAILABLE:
    audio_model = AudioTagging(checkpoint_path=None, device="cuda")
    print("\nLoaded the sound embedding model on the GPU.")
else:
    print("GPU not available")

Checkpoint path: /home/yhchen2001/panns_data/Cnn14_mAP=0.431.pth
GPU number: 5

Loaded the sound embedding model on the GPU.


    There is an imbalance between your GPUs. You may want to exclude GPU 1 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


In [5]:
from unqlite import UnQLite
import pickle

# Open (or create) the database
class EmbedDB():
    def __init__(self, name):
        self.db = UnQLite(name)

    def store_audio(self, file_name, category, text_embedding, audio_embedding):
        record = {
            'category': category, 
            'text_embedding': pickle.dumps(text_embedding),
            'audio_embedding': pickle.dumps(audio_embedding),
        }

        serialized_record = pickle.dumps(record)

        # Store in the database
        self.db[file_name] = serialized_record

    def retrieve_audio(self, file_name):
        # Retrieve the serialized record from the database
        serialized_record = self.db[file_name]
        
        # Deserialize the record
        record = pickle.loads(serialized_record)

        # Extract and deserialize each component
        category = record['category']
        text_embedding = pickle.loads(record['text_embedding'])
        audio_embedding = pickle.loads(record['audio_embedding'])
        
        return {
            'category': category, 
            'text_embedding': text_embedding,
            'audio_embedding': audio_embedding,
        }

    def list_filenames(self):
        files = []
        for key in self.db.keys():
            files.append(key)
        return files
    
    def db_commit(self):
        self.db.commit()

embed_db = EmbedDB("my_audio.db")

In [7]:
BATCH_SIZE = 100
SAMPLES_TO_PROCESS = len(audio_dataset)


for i in tqdm(range(0, SAMPLES_TO_PROCESS, BATCH_SIZE)):
    # Find end of batch
    i_end = min(i + BATCH_SIZE, SAMPLES_TO_PROCESS)

    filenames = audio_dataset["filename"][i:i_end]
    cats = audio_dataset["category"][i:i_end]

    audios_for_embs = np.array([item["array"] for item in audio_dataset["audio"][i:i_end]])

    _, audio_embs = audio_model.inference(audios_for_embs)
    text_embs = get_bert_embeddings_batch(cats)


    for j in range(i_end - i):
        embed_db.store_audio(filenames[j], cats[j], text_embs[j], audio_embs[j])


embed_db.db.commit()

100%|██████████| 20/20 [01:06<00:00,  3.30s/it]


True