In [1]:
import os
import sys
import glob
import tqdm

os.chdir("AudioCLIP/demo/")

import librosa
import librosa.display

import simplejpeg
import numpy as np

import torch
import torchvision as tv

import matplotlib.pyplot as plt

from PIL import Image
from IPython.display import Audio, display

sys.path.append(os.path.abspath(f'{os.getcwd()}/..'))

from model import AudioCLIP
from utils.transforms import ToTensor1D

from tqdm import tqdm

torch.set_grad_enabled(False)

MODEL_FILENAME = 'AudioCLIP-Partial-Training.pt'
# derived from ESResNeXt
SAMPLE_RATE = 44100
# derived from CLIP
IMAGE_SIZE = 224
IMAGE_MEAN = 0.48145466, 0.4578275, 0.40821073
IMAGE_STD = 0.26862954, 0.26130258, 0.27577711

aclp = AudioCLIP(pretrained=f'../assets/{MODEL_FILENAME}')


In [2]:

audio_transforms = ToTensor1D()

image_transforms = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Resize(IMAGE_SIZE, interpolation=Image.BICUBIC),
    tv.transforms.CenterCrop(IMAGE_SIZE),
    tv.transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
])

In [3]:
def getAudioEmbeddings(sourcePath, destPath, instrument):
    audioPaths = glob.glob(f"{sourcePath}/{instrument}/*.wav")
    audio = []

    for path in tqdm(audioPaths, desc=f"{{{instrument}}}"):
        track, _ = librosa.load(path, sr=SAMPLE_RATE, dtype=np.float32)
        spec = aclp.audio.spectrogram(torch.from_numpy(track.reshape(1, 1, -1)))
        spec = np.ascontiguousarray(spec.numpy()).view(np.complex64)
        pow_spec = 10 * np.log10(np.abs(spec) ** 2 + 1e-18).squeeze()
        audio.append((track, pow_spec))
  

    audio = torch.stack([audio_transforms(track.reshape(1, -1)) for track, _ in audio])
    ((audio_features, _, _), _), _ = aclp(audio=audio)
    audio_features = audio_features / torch.linalg.norm(audio_features, dim=-1, keepdim=True)
    torch.save(audio_features, f"{destPath}/{instrument}_audio_embeddings.pt")


def getImageEmbeddings(sourcePath, destPath, instrument):
    imagePath = glob.glob(f"{sourcePath}/{instrument}/*.wav")
    images = []

    for path in tqdm(imagePath, desc=f"{{{instrument}}}"):
        with open(path, 'rb') as jpg:
            image = simplejpeg.decode_jpeg(jpg.read())
            images.append(image) 
  

    images = torch.stack([image_transforms(image) for image in images])  
    ((_, image_features, _), _), _ = aclp(image=images)
    image_features = image_features / torch.linalg.norm(image_features, dim=-1, keepdim=True)  
    torch.save(image_features, f"{destPath}/{instrument}_image_embeddings.pt")



In [6]:
dataPathAudio = f"../../Data/SubURMPExtendedAudio/clean/validation"
embeddingPathAudio = f"../../Embeddings/audio/subURMPExtendedAudio/validation"

dataPathImage = f"../../Data/SubURMPClean/images/validation/"
embeddingPathImage = f"../../Embeddings/images/validation/"

instrumentNames = os.listdir(dataPathAudio)
instrumentNames = [name for name in instrumentNames if name[0] != '.']
instrumentNames.sort()

print(instrumentNames)


['bassoon', 'cello', 'clarinet', 'double_bass', 'flute', 'horn', 'oboe', 'sax', 'trombone', 'trumpet', 'tuba', 'viola', 'violin']


In [8]:
instrumentNames = ["sax", "trombone", "trumpet", "tuba", "viola", "violin"]

In [9]:
for instrument in instrumentNames:
    getAudioEmbeddings(dataPathAudio, embeddingPathAudio, instrument)

{sax}: 100%|██████████| 17/17 [00:03<00:00,  5.55it/s]
{trombone}: 100%|██████████| 16/16 [00:02<00:00,  5.74it/s]
{trumpet}: 100%|██████████| 10/10 [00:01<00:00,  5.74it/s]
{tuba}: 100%|██████████| 10/10 [00:01<00:00,  5.69it/s]
{viola}: 100%|██████████| 9/9 [00:01<00:00,  5.66it/s]
{violin}: 100%|██████████| 18/18 [00:03<00:00,  5.68it/s]
