In [23]:
import os
import torch
import numpy as np
import pandas as pd
from torchaudio.datasets import GTZAN
import sys
from omegaconf import DictConfig, OmegaConf
sys.path.append("../")
from model.modules import CNN1D, MusicTaggingTransformer
from model.emb_model import EmbModel
from model.lightning_model import ZSLRunner
from tqdm.notebook import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from gensim.models.keyedvectors import KeyedVectors

# DEMO Contents

- Download GTZAN
- Load Word Model: GLOVE
- Load Audio Model: 1D CNN
- Projection to Joint Embedding Space
- Seen & Unseen Query Example

## Download GTZAN

In [2]:
dataset = GTZAN(root=".", download=True)
len(dataset)

1000

In [3]:
waveform,samplerate,label= dataset[0]

## Load Word Model

In [4]:
glove_path = "{your own glove model}"
glove_name = "glove.42B.300d.txt"
# 3min?...
glove_model = KeyedVectors.load_word2vec_format(os.path.join(glove_path, glove_name), binary=False)

## Load Audio Model

In [5]:
task_type = "zeroshot"
emb_type = "glove"
backbone = "CNN1D"
# backbone = "Transformer"
supervisions = "tag"
DEVICE = "cuda:0"

In [6]:
pretrained_path = "../dataset/pretrained"
save_path = os.path.join(pretrained_path, f"{task_type}/{emb_type}/{backbone}/{supervisions}")
args = OmegaConf.load(os.path.join(save_path, "hparams.yaml"))

In [7]:
if backbone == "CNN1D":
    backbone = CNN1D()
    model = EmbModel(
            audio_model = backbone,
            projection_ndim = 100
    )
elif backbone == "Transformer":
    backbone = MusicTaggingTransformer(conv_ndim=128, attention_ndim=64)
    model = EmbModel(
            audio_model = backbone,
            projection_ndim = 64
    )
    
runner = ZSLRunner(
model = model,
margin = args.margin, 
lr = args.lr, 
supervisions = args.supervisions,
opt_type = args.opt_type
)
state_dict = torch.load(os.path.join(save_path, "best.ckpt"))
runner.load_state_dict(state_dict.get("state_dict"))

<All keys matched successfully>

## Extract Music Embedding

In [8]:
runner.to(DEVICE).eval()
zeroshot_model = runner.model

In [16]:
input_length = args.duration * 22050
num_chunks = 16
item_dict = {}
for idx, item in enumerate(dataset):
    audio,_,label= item
    audio = audio.squeeze(0).numpy()
    hop = (len(audio) - input_length) // num_chunks
    audio = np.array([audio[i * hop : i * hop + input_length] for i in range(num_chunks)]).astype('float32')
    audio = torch.from_numpy(audio).to(DEVICE)
    with torch.no_grad():
        audio_emb = zeroshot_model.audio_model(audio)
    item_dict[idx] = {
        "audio_emb": audio_emb.mean(0,False).detach().cpu().numpy(),
        "label": label
    }

In [17]:
len(item_dict)

1000

## Query by Tag

In [18]:
audio_embs = np.stack([item_dict[idx]['audio_emb'] for idx in item_dict.keys()])

In [19]:
def query_by_tag(query, word_model, zeroshot_model, audio_embs):
    query_emb = word_model[query]
    with torch.no_grad():
        joint_emb = zeroshot_model.text_projection(torch.from_numpy(query_emb).to(DEVICE))
    joint_emb = joint_emb.unsqueeze(0).detach().cpu().numpy()
    sim_matrix = cosine_similarity(joint_emb, audio_embs)
    df_sim = pd.DataFrame(sim_matrix, index=[query]).T
    top5_idx = df_sim[query].sort_values(ascending=False).head()
    print("top5 music tracks: " ,[item_dict[i]['label'] + str(i) for i in top5_idx.index])
    return top5_idx

In [20]:
query = "guitar"
top5_music = query_by_tag(query, glove_model, zeroshot_model, audio_embs)

top5 music tracks:  ['blues49', 'rock907', 'country261', 'rock910', 'country271']


In [21]:
query = "happy"
top5_music = query_by_tag(query, glove_model, zeroshot_model, audio_embs)

top5 music tracks:  ['rock958', 'pop740', 'country244', 'rock959', 'disco341']


In [22]:
top5_music

958    0.644991
740    0.635107
244    0.628895
959    0.623938
341    0.618849
Name: happy, dtype: float32