# Local CLaMP

generate CLaMP embeddings for segment comparison

## imports and setup

In [1]:
import os
import torch
import pandas as pd
# from rich import print
from sklearn.metrics.pairwise import cosine_similarity
from midi_player import MIDIPlayer
from midi_player.stylers import dark

In [4]:
s='a_b_t01s03'
x = s.split('_')[-1]
x[4:]

'03'

In [3]:
def get_matches(
    query_tensor: torch.Tensor, df: pd.DataFrame, n: int = 5, mode: str = "best"
) -> list:
    """Finds the names and similarities of the n most similar segments based on cosine similarity.

    Args:
        query_tensor (torch.Tensor): The tensor to compare.
        df (pd.DataFrame): The dataframe containing tensors in the 'embedding' column.
        n (int): number of matches to return
        mode (str): 'worst'

    Returns:
        list: List of tuples containing the name of the row and its cosine similarity.
    """
    query_numpy = query_tensor.reshape(1, -1)
    tensor_matrix = torch.stack(
        [torch.Tensor(tensor) for tensor in df["embedding"]]
    ).numpy()

    similarities = cosine_similarity(query_numpy, tensor_matrix).flatten()
    if mode == "best":
        matches = similarities.argsort()[-n:][::-1]
    elif mode == "worst":
        matches = similarities.argsort()[:n]

    results = [(df.index[i], float(similarities[i])) for i in matches]

    return results

In [4]:
df = pd.read_parquet("test-clamp.parquet")
df.head()

Unnamed: 0,abc,id,embedding
baba-060-02_0005-0011,"M: 4/4\nL: 1/8\nQ:1/4=120\nK:A\nV:1\nz2 B,3-B,...","[46, 27, 1, 21, 16, 21, 97, 0, 0, 0, 0, 0, 0, ...","[0.5433163, -0.46188298, -0.1765729, 0.0864300..."
bbbb-060-01_0005-0011,"M: 4/4\nL: 1/8\nQ:1/4=120\nK:B\nV:1\nz2 B,3-B,...","[46, 27, 1, 21, 16, 21, 97, 0, 0, 0, 0, 0, 0, ...","[0.88814145, -0.43045604, -0.4846054, 0.208652..."
c4v100-060-03_0000-0005,M: 4/4\nL: 1/8\nQ:1/4=120\nK:C\nV:1\nz2 c6-|c6...,"[46, 27, 1, 21, 16, 21, 97, 0, 0, 0, 0, 0, 0, ...","[0.5935626, -0.27001852, -0.4847664, 0.1845568..."
c4v97-060-03_0005-0011,M: 4/4\nL: 1/8\nQ:1/4=120\nK:C\nV:1\nz2 c6-|c6...,"[46, 27, 1, 21, 16, 21, 97, 0, 0, 0, 0, 0, 0, ...","[0.5935626, -0.27001852, -0.4847664, 0.1845568..."
cascas-060-02_0000-0005,M: 4/4\nL: 1/8\nQ:1/4=120\nK:Bb\nV:1\nz2 C3-C/...,"[46, 27, 1, 21, 16, 21, 97, 0, 0, 0, 0, 0, 0, ...","[0.7008802, -0.57024217, -0.47158375, -0.19007..."


In [5]:
for index in df.index:
    print(index)

baba-060-02_0005-0011
bbbb-060-01_0005-0011
c4v100-060-03_0000-0005
c4v97-060-03_0005-0011
cascas-060-02_0000-0005
cascd-060-02_0000-0005
cascd-060-02_0005-0011
cccc-060-01_0000-0005
leadtest-060-03_0000-0005
trailtest-060-03_0005-0011


In [11]:
es = df["embedding"]
# Print index and content of third element in "es"
print("Index:", es.index[2])
print("Content:", es.iloc[2])


In [31]:
get_matches(df.loc["cccc-060-01_0000-0005_t04s00", "embedding"], df, 20, mode="best")

[('bbbb-060-01_0005-0011_t05s00', 0.9999999403953552),
 ('cccc-060-01_0000-0005_t04s00', 0.9999999403953552),
 ('cccc-060-01_0000-0005_t04s01', 0.9804266691207886),
 ('bbbb-060-01_0005-0011_t05s01', 0.9804266691207886),
 ('cccc-060-01_0000-0005_t03s00', 0.9461910724639893),
 ('bbbb-060-01_0005-0011_t04s00', 0.9461910724639893),
 ('leadtest-060-03_0000-0005_t04s03', 0.9433386325836182),
 ('cccc-060-01_0000-0005_t04s04', 0.9393335580825806),
 ('bbbb-060-01_0005-0011_t05s04', 0.9393335580825806),
 ('bbbb-060-01_0005-0011_t05s02', 0.931435227394104),
 ('cccc-060-01_0000-0005_t04s02', 0.931435227394104),
 ('cccc-060-01_0000-0005_t03s01', 0.9300938248634338),
 ('bbbb-060-01_0005-0011_t04s01', 0.9300938248634338),
 ('baba-060-02_0005-0011_t07s00', 0.929381251335144),
 ('cascas-060-02_0000-0005_t06s00', 0.929381251335144),
 ('cccc-060-01_0000-0005_t06s00', 0.9127424955368042),
 ('bbbb-060-01_0005-0011_t07s00', 0.9127424955368042),
 ('bbbb-060-01_0005-0011_t05s05', 0.9062713384628296),
 ('cccc-

In [35]:
print(df.loc["cccc-060-01_0000-0005_t04s00", "abc"])
MIDIPlayer(
    "../../disklavier/data/datasets/test/train/cccc-060-01_0000-0005_t04s00.mid",
    150,
    styler=dark,
)

In [36]:
print(df.loc["20240123-070-06_0020-0027_t07s02", "abc"])
MIDIPlayer(
    "../../disklavier/data/datasets/test/train/20240123-070-06_0020-0027_t07s02.mid",
    400,
    styler=dark,
)