In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

cwd = os.getcwd()
code_dir = os.path.dirname(cwd)
sys.path.append(code_dir)

import seaborn as sns

from lib.directories import *
from lib.plotting import *

In [2]:
# model_name = '/home/roguz/freesound/freesound-perceptual_similarity/models/CAVMAE-audio_model.21.pth'
model_name = '/home/roguz/freesound/freesound-perceptual_similarity/models/CAVMAE-as_46.6.pth'

In [3]:
import librosa
import torch, torchaudio
from torch.cuda.amp import autocast

if 'as_46.6.pth' in model_name:
    from lib.cavmae.src.models import CAVMAEFT as CAVMAE
    model = CAVMAE(label_dim=527, modality_specific_depth=11)
    mode = None
elif 'CAVMAE-audio_model.21.pth' in model_name:
    from lib.cavmae.src.models import CAVMAE
    model = CAVMAE(modality_specific_depth=11)
    mode = 'a'
sdA = torch.load(model_name, map_location='cpu')
if isinstance(model, torch.nn.DataParallel) == False:
    model = torch.nn.DataParallel(model)
msg = model.load_state_dict(sdA, strict=True)
print(msg)
model.eval()

def extract_embeddings(model, audio_path, mode=None):
    waveform, sr = torchaudio.load(audio_path)
    waveform = waveform[:30*44100]
    waveform = waveform - waveform.mean()
    fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False, window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10)
    target_length = 1024
    n_frames = fbank.shape[0]
    p = target_length - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:target_length, :]
    fbank = (fbank - (-5.081)) / (4.4849)
    fbank = fbank.unsqueeze(0)
    with torch.no_grad():
        with autocast():
            if mode is None:
                audio_output, _ = model.module.forward_feat(fbank, torch.zeros_like(fbank))
            else:
                audio_output = model.module.forward_feat(fbank, torch.zeros_like(fbank), mode='a')
        audio_output = audio_output.to('cpu').detach()
    audio_output = audio_output.squeeze(0).mean(dim=0)
    return audio_output.numpy()

  from .autonotebook import tqdm as notebook_tqdm


Use norm_pix_loss:  False
Number of Audio Patches: 512, Visual Patches: 196
Audio Positional Embedding Shape: torch.Size([1, 512, 768])
Visual Positional Embedding Shape: torch.Size([1, 196, 768])
<All keys matched successfully>


  return torch._C._cuda_getDeviceCount() > 0


In [4]:
embeddings = extract_embeddings(model, '/data/FSD50K/FSD50K.eval_audio/271617.wav', mode='a')



In [5]:
embeddings.shape

(768,)