In [1]:
from pydub import AudioSegment
from torch.utils.data import Dataset, DataLoader
import json
from PIL import Image
import numpy as np
from torchvision import transforms
import torch
from model.audioclip import AudioCLIP
import csv
from sklearn.metrics import average_precision_score
from tqdm import tqdm
torch.set_grad_enabled(False)
import librosa
import soundfile as sf

### Data

In [2]:
eval_dataset = './eval.json'
labels_csv = './class_labels_indices.csv'

In [3]:
def get_labels(mid_str):
    labels = []
    for mid in mid_str.split(','):
        labels.append(idx_to_label[mid_to_idx[mid]])
    return labels

mid_to_idx = {}
idx_to_label = {}
label_to_idx = {}

with open(eval_dataset, 'r') as f:
    data_json = json.load(f)

with open(labels_csv, 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        if row[0] == 'index':
            continue
        mid_to_idx[row[1]] = int(row[0])
        idx_to_label[int(row[0])] = row[2]
        label_to_idx[row[2]] = int(row[0])        

In [4]:
data_json['data'][1]

{'wav': '/home/adrian/Projects/AudioCLIP/audio/-0RWZT-miFs.wav',
 'labels': '/m/03v3yw,/m/0k4j',
 'video_id': '-0RWZT-miFs',
 'video_path': '/home/adrian/Projects/AudioCLIP/frames'}

In [5]:
# gt_labels = data_json['data'][1]['labels']
gt_labels = "/m/03dnzn,/m/068hy,/m/07p7b8y,/m/07ptzwd,/m/0838f,/m/0jbk,/t/dd00088"
get_labels(gt_labels)

['Bathtub (filling or washing)',
 'Domestic animals, pets',
 'Fill (with liquid)',
 'Pump (liquid)',
 'Water',
 'Animal',
 'Gush']

### Model

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = AudioCLIP(pretrained='./assets/AudioCLIP-Full-Training.pt')
_ = model.to(device)
_ = model.eval()

Model has 861 parameters
Loaded state dict has 861 parameters
Missing keys: 0
Unexpected keys: 0
Loaded from assets


In [7]:
from utils.transforms import ToTensor1D, RandomPadding, RandomCrop
import torchvision as tv

audio_transforms = []
audio_transforms.append(ToTensor1D())
audio_transforms.append(RandomPadding(out_len=176400, train=False))
audio_transforms.append(RandomCrop(out_len=176400, train=False))
transforms_test = tv.transforms.Compose(audio_transforms)

In [8]:
def scale(old_value, old_min, old_max, new_min, new_max):
    old_range = (old_max - old_min)
    new_range = (new_max - new_min)
    new_value = (((old_value - old_min) * new_range) / old_range) + new_min
    return new_value

def get_norm_audio_embd(paths_to_audio, aclp, mono=True):
    sample_rate = 44100
    audio_list = list()

    for path_to_audio in paths_to_audio:
        wav, sample_rate_ = sf.read(
            path_to_audio,
            dtype='float32',
            always_2d=True
        )
        wav = librosa.resample(wav.T, sample_rate_, sample_rate)

        if wav.shape[0] == 1 and not mono:
            wav = np.concatenate((wav, wav), axis=0)

        wav = wav[:, :sample_rate * 4]
        wav = scale(wav, wav.min(), wav.max(), -32768.0, 32767.0).astype(np.float32)
        audio = transforms_test(wav)
        audio_list.append(audio)

    audio_batch = torch.stack([track.reshape(1, -1) for track in audio_list])
    aclp.eval()
    with torch.no_grad():
        ((audio_features, _, _), _), _ = aclp(audio=audio_batch)

    audio_features = audio_features / torch.linalg.norm(audio_features, dim=-1, keepdim=True)
    return audio_features

In [16]:
prompts = list(label_to_idx.keys())
# prompts = ["dog barking", "rain falling", "gunshots fired", "cat meowing"]
prompts = [prompt.split(",")[0] for prompt in prompts]
# prompts = [[f"An audio of {prompt}"] for prompt in prompts]
prompts = [[f"{prompt}"] for prompt in prompts]

with torch.no_grad():
    ((_, _, text_features), _), _ = model(text=prompts)

text_features = text_features / torch.linalg.norm(text_features, dim=-1, keepdim=True)

In [17]:
audio_path = '/mnt/user/saksham/AV_robust/adrian_audioCLIP/audio/-1EXhfqLLwQ.wav'
audio_feat = get_norm_audio_embd([audio_path], model)

  wav = librosa.resample(wav.T, sample_rate_, sample_rate)


In [18]:
scale_audio_text = torch.clamp(model.logit_scale_at.exp(), min=1.0, max=100.0)
logits_audio_text = scale_audio_text * audio_feat @ text_features.T

In [19]:
#get top 5 argmax
N = 5
top5_idx = torch.argsort(logits_audio_text[0].detach().cpu(), descending=True)[:N]
top5_labels = [idx_to_label[idx.item()] for idx in top5_idx]
top5_labels

['Gush',
 'Sink (filling or washing)',
 'Pump (liquid)',
 'Toilet flush',
 'Fill (with liquid)']

In [15]:
from IPython.display import Audio
Audio(audio_path)