In [2]:
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_distances

from m2t.evaluation.utils import strip_completion

In [3]:
from m2t.evaluation.utils import acc_at_k, confusion_matrix_from_distance_matrix

In [4]:
os.chdir("..")

In [7]:
# LLark
infer_results = pd.read_csv("inference-results/v5/infer_results_medleydb_genre_v5_100k.csv")

# ImageBind-LLM
# infer_results = pd.read_csv("inference-results/imagebind-llm/imagebind-llm_medleydb_genre.csv").rename(
#     columns={"imagebind_llm_caption":"model_completion_text",
#             "imagebind_llm_caption_embedding":"model_completion_text_embedding"})

# LTU
# infer_results = pd.read_csv("inference-results/ltu/listen_think_understand_medleydb_genre.csv").rename(
#     columns={"ltu_caption": "model_completion_text",
#              "ltu_caption_embedding": "model_completion_text_embedding"})

infer_results['example_id'] = infer_results['example_id'].apply(lambda x: os.path.basename(x).split("_MIX")[0])

In [8]:
infer_results

Unnamed: 0,example_id,prompt_text,model_completion_text,model_completion_text_embedding
0,MusicDelta_BebopJazz,What genre is this song?,The genre of this song is jazz.\n###,"[-0.013809381052851677, -0.018408285453915596,..."
1,MusicDelta_GriegTrolltog,What genre is this song?,This song belongs to the classical genre.\n###,"[0.012897172942757607, -0.005517087411135435, ..."
2,AmarLal_SpringDay1,What genre is this song?,This song falls under the genre of Folk.\n###,"[0.008347189985215664, -0.016116052865982056, ..."
3,MusicDelta_Disco,What genre is this song?,The genre of this song is funk.\n###,"[0.013618100434541702, -0.01634937897324562, -..."
4,MusicDelta_ModalJazz,What genre is this song?,The genre of this song is jazz.\n###,"[-0.013809381052851677, -0.018408285453915596,..."
...,...,...,...,...
93,Schubert_Erstarrung,What genre is this song?,The genre of this song is classical.\n###,"[0.009412404149770737, 0.0027184076607227325, ..."
94,BrandonWebster_YesSirICanFly,What genre is this song?,The genre of this song is pop.\n###,"[-0.0037041434552520514, -0.010868906043469906..."
95,PurlingHiss_Lolita,What genre is this song?,The genre of this song is rock.\n###,"[0.00020601178403012455, -0.016521435230970383..."
96,AlexanderRoss_VelvetCurtain,What genre is this song?,This song is in the pop genre.\n###,"[-0.012440926395356655, -0.005167915485799313,..."


In [11]:
# load medleydb results
import os
import yaml
from collections import defaultdict
from m2t.dataset_utils import fetch_audio_start_end


def load_medleydb_genres(medleydb_dir="datasets/medleydb"):
    audio_files = os.listdir(os.path.join(medleydb_dir, "wav-crop"))

    mdb_data = defaultdict(dict)
    for audio_file in audio_files:
        start, end = fetch_audio_start_end(audio_file)
        if start == 0. and end == 0.:
            print(f"skipping file {audio_file} with zero duration.")
            continue
        mdb_id = audio_file.split("_MIX-start")[0]
        mdb_data[mdb_id]["audio_file"] = audio_file
        mdb_data[mdb_id]["start"] = start
        mdb_data[mdb_id]["end"] = end
        meta_fp = os.path.join(medleydb_dir, "metadata", mdb_id + "_METADATA.yaml")
        with open(meta_fp, "r") as f:
            metadata = yaml.safe_load(f)
        mdb_data[mdb_id]["metadata"] = metadata

    return mdb_data


mdb_data = load_medleydb_genres()

genres = pd.DataFrame([(k, v["metadata"]["genre"]) for k, v in mdb_data.items()],
                      columns=["example_id", "genre"])

skipping file MusicDelta_Zeppelin_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_InTheHalloftheMountainKing_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Hendrix_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Beethoven_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Punk_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Rock_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Reggae_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Vivaldi_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Country2_MIX-start0.000-end0.000.wav with zero duration.
skipping file MusicDelta_Rockabilly_MIX-start0.000-end0.000.wav with zero duration.


In [12]:
def _map_fn(x):
    x = strip_completion(x)

    x = re.sub('\W+', ' ', x)  # remove multiple consecutive spaces
    x = x.replace('Hip Hop', 'hiphop')
    x = x.replace('hip hop', 'hiphop')
    return x.strip()


caption_colname = 'model_completion_text'

infer_results[caption_colname] = infer_results[caption_colname].map(_map_fn)

In [13]:
infer_results = infer_results.merge(genres, on=["example_id"])

In [14]:
infer_results[["model_completion_text", "genre"]]

Unnamed: 0,model_completion_text,genre
0,The genre of this song is jazz,Jazz
1,This song belongs to the classical genre,Classical
2,This song falls under the genre of Folk,Singer/Songwriter
3,The genre of this song is funk,Pop
4,The genre of this song is jazz,Jazz
...,...,...
93,The genre of this song is classical,Classical
94,The genre of this song is pop,Musical Theatre
95,The genre of this song is rock,Rock
96,This song is in the pop genre,Singer/Songwriter


In [15]:
GENRE_COLNAME = 'genre'
true_genre_in_completion_text = np.mean(
    [x[GENRE_COLNAME].lower() in x[caption_colname].replace('-', '').lower() for _, x in infer_results.iterrows()])
print(true_genre_in_completion_text)

0.4387755102040816


In [16]:
genre_embeds = pd.read_csv("inference-results/genres_medleydb.csv")

genre_embeds.head()

Unnamed: 0,genre,genre_embedding
0,Singer/Songwriter,"[-0.012259929440915585, -0.012756437063217163,..."
1,Pop,"[0.00818187277764082, -0.02048569917678833, -0..."
2,World/Folk,"[0.0006129959365352988, 0.0031966841779649258,..."
3,Classical,"[-0.013702787458896637, 0.00244644726626575, 0..."
4,Rock,"[0.012388329021632671, -0.021920479834079742, ..."


In [17]:
genre_embeds['genre'].value_counts() / len(genre_embeds)

genre
Singer/Songwriter    0.125
Pop                  0.125
World/Folk           0.125
Classical            0.125
Rock                 0.125
Electronic/Fusion    0.125
Jazz                 0.125
Musical Theatre      0.125
Name: count, dtype: float64

In [18]:
genre_idxs = genre_embeds[GENRE_COLNAME].to_dict()
genre_idxs

{0: 'Singer/Songwriter',
 1: 'Pop',
 2: 'World/Folk',
 3: 'Classical',
 4: 'Rock',
 5: 'Electronic/Fusion',
 6: 'Jazz',
 7: 'Musical Theatre'}

In [19]:
# Create a numeric label column
infer_results['genre_numeric'] = infer_results[GENRE_COLNAME].map({v: k for k, v in genre_idxs.items()})

# Sanity check the labels
pd.crosstab(infer_results[GENRE_COLNAME], infer_results['genre_numeric'])

genre_numeric,0,1,2,3,4,5,6,7
genre,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Classical,0,0,0,18,0,0,0,0
Electronic/Fusion,0,0,0,0,0,2,0,0
Jazz,0,0,0,0,0,0,11,0
Musical Theatre,0,0,0,0,0,0,0,3
Pop,0,10,0,0,0,0,0,0
Rock,0,0,0,0,14,0,0,0
Singer/Songwriter,22,0,0,0,0,0,0,0
World/Folk,0,0,18,0,0,0,0,0


# OpenAI Embeddings

In [20]:
import json

dist = cosine_distances(np.row_stack(infer_results[caption_colname + '_embedding'].apply(json.loads).tolist()),
                        np.row_stack(genre_embeds['genre_embedding'].apply(json.loads).tolist()))

In [21]:
dist.shape

(98, 8)

In [22]:
import scipy
import math


def clopper_pearson(acc, n, alpha=0.05):
    """Estimate the confidence interval for a sampled Bernoulli random
    variable.
    `x` is the number of successes and `n` is the number trials (x <=
    n). `alpha` is the confidence level (i.e., the true probability is
    inside the confidence interval with probability 1-alpha). The
    function returns a `(low, high)` pair of numbers indicating the
    interval on the probability.
    """
    x = int(acc * n)
    b = scipy.stats.beta.ppf
    lo = b(alpha / 2, x, n - x + 1)
    hi = b(1 - alpha / 2, x + 1, n - x)
    return 0.0 if math.isnan(lo) else lo, 1.0 if math.isnan(hi) else hi

In [29]:
import matplotlib

matplotlib.rcParams.update(matplotlib.rcParamsDefault)

plot_data = []
for k in range(1, 9):
    acc = acc_at_k(dist, k=k, labels=infer_results['genre_numeric'])
    plot_data.append((k, acc))

genre_results_at_k = pd.DataFrame(plot_data, columns=["k", "acc@k"])
genre_results_at_k["ci_upper"] = genre_results_at_k["acc@k"].apply(lambda x: clopper_pearson(x, 1000)[1])
genre_results_at_k["ci_lower"] = genre_results_at_k["acc@k"].apply(lambda x: clopper_pearson(x, 1000)[0])

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(genre_results_at_k["k"],
        genre_results_at_k["acc@k"],
        "-o",
        color="blue",
        label="LLark Acc@k")
ax.fill_between(genre_results_at_k["k"],
                genre_results_at_k["ci_upper"],
                genre_results_at_k["ci_lower"],
                alpha=0.5,
                color="blue",
                )

ax.plot(genre_results_at_k["k"],
        genre_results_at_k["k"] / 8,
        "-o",
        color="orange",
        label="Random Baseline")
ax.scatter(x=1, y=true_genre_in_completion_text, s=128, marker="*", color="blue",
           label="True genre text in LLark output")
ax.legend()
ax.set_title("Genre Classification Acc@K, MedleyDB Dataset")
ax.set_xlabel("k")
# ax.set_facecolor("white")
ax.grid()

plt.savefig("notebooks/acc-at-k-medleydb.pdf", bbox_inches="tight")

In [30]:
plot_data

[(1, 0.5612244897959183),
 (2, 0.7959183673469388),
 (3, 0.826530612244898),
 (4, 0.8979591836734694),
 (5, 0.9489795918367347),
 (6, 0.9693877551020408),
 (7, 1.0),
 (8, 1.0)]

In [179]:
cm = confusion_matrix_from_distance_matrix(dist, labels=infer_results['gtzan_genre_numeric'])

cm = pd.DataFrame(cm, index=genre_embeds['genre'], columns=genre_embeds['genre'])

import seaborn as sns

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 7))

sns.heatmap(cm, annot=True, ax=ax)
ax.set(xlabel='Predicted Genre', ylabel='GTZAN Labeled Genre',
       title="Confusion MAtrix\nZero-Shot GTZAN Genre Prediction")

KeyError: 'gtzan_genre_numeric'