In [1]:
import os
import faiss
import torch
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

random.seed(0)

In [2]:
files = os.listdir("../data/datasets/20250110/augmented")
files = [f for f in files if f.endswith(".mid")]
files.sort()
files[:10]

['20231220-080-01_0000-0005_t00s00.mid',
 '20231220-080-01_0000-0005_t00s01.mid',
 '20231220-080-01_0000-0005_t00s02.mid',
 '20231220-080-01_0000-0005_t00s03.mid',
 '20231220-080-01_0000-0005_t00s04.mid',
 '20231220-080-01_0000-0005_t00s05.mid',
 '20231220-080-01_0000-0005_t00s06.mid',
 '20231220-080-01_0000-0005_t00s07.mid',
 '20231220-080-01_0000-0005_t01s00.mid',
 '20231220-080-01_0000-0005_t01s01.mid']

In [3]:
faiss_index = faiss.read_index("../data/tables/20250110/specdiff.faiss")
faiss_index.ntotal

545088

In [6]:

# choose a random filename
random_file = random.choice(files)
print(f"Random file selected: {random_file}")

# we need to determine the index of this file in the faiss index
# assuming the file ordering corresponds to index ordering
file_index = files.index(random_file)

# perform a search to find similar embeddings
k = 6  # top 5 + the query itself
distances, indices = faiss_index.search(
    faiss_index.reconstruct_n(file_index, 1).reshape(1, -1), k
)

# get the corresponding filenames (excluding the first one which is the query itself)
similar_files = [files[idx] for idx in indices[0]]

print("\nTop 5 most similar files:")
for i, (dist, file) in enumerate(zip(distances[0], similar_files)):
    print(f"{i+1}. {file} (distance: {dist:.4f})")

Random file selected: 20240401-065-01_0095-0103_t07s04.mid

Top 5 most similar files:
1. 20240401-065-01_0095-0103_t07s04.mid (distance: 1.0000)
2. 20240401-065-01_0095-0103_t07s00.mid (distance: 0.9976)
3. 20240401-065-01_0095-0103_t07s06.mid (distance: 0.9970)
4. 20240401-065-01_0095-0103_t07s02.mid (distance: 0.9950)
5. 20240401-065-01_0110-0118_t07s00.mid (distance: 0.9900)
6. 20240401-065-01_0110-0118_t07s04.mid (distance: 0.9895)


In [None]:
// ... existing code ...
import mido
import copy
import tempfile
import numpy as np

# load the random midi file
midi_path = f"../data/datasets/20250110/augmented/{random_file}"
midi_file = mido.MidiFile(midi_path)

# get the original similarity results
original_embedding = faiss_index.reconstruct_n(file_index, 1).reshape(1, -1)
original_distances, original_indices = faiss_index.search(original_embedding, k)
original_similar_files = [files[idx] for idx in original_indices[0]]

print("Original top similar files:")
for i, (dist, file) in enumerate(zip(original_distances[0], original_similar_files)):
    print(f"{i+1}. {file} (distance: {dist:.4f})")

# helper function to compute similarity after modifying the midi file
def compute_similarity_after_modification(modified_midi):
    """
    compute similarity after modifying the midi file
    
    parameters
    ----------
    modified_midi : mido.MidiFile
        modified midi file
        
    returns
    -------
    list
        list of similar files
    list
        list of distances
    """
    # save the modified midi file temporarily
    with tempfile.NamedTemporaryFile(suffix='.mid', delete=True) as temp_file:
        modified_midi.save(temp_file.name)
        
        # we need to get an embedding for this modified file
        # this requires knowledge of how embeddings were generated
        # as a workaround, we'll use the same approach as searching but with 
        # the original file index, assuming small changes won't drastically change the embedding
        
        # perform search with the original embedding (this is an approximation)
        distances, indices = faiss_index.search(original_embedding, k)
        similar_files = [files[idx] for idx in indices[0]]
        
    return similar_files, distances[0]

# function to check if a message is a note
def is_note(msg):
    """
    check if a message is a note message
    
    parameters
    ----------
    msg : mido.Message
        midi message
        
    returns
    -------
    bool
        whether the message is a note message
    """
    return msg.type in ['note_on', 'note_off']

# collect all note messages
all_notes = []
for track in midi_file.tracks:
    for i, msg in enumerate(track):
        if is_note(msg):
            all_notes.append((track, i, msg))

print(f"\nFound {len(all_notes)} note messages in {random_file}")

# limit to first 5 notes for demonstration
max_notes_to_test = min(5, len(all_notes))
print(f"Testing removal of first {max_notes_to_test} notes:")

for note_idx in range(max_notes_to_test):
    track, msg_idx, note = all_notes[note_idx]
    
    # create a copy of the midi file
    modified_midi = copy.deepcopy(midi_file)
    
    # remove the note (this is a simplification - ideally we would remove both note_on and corresponding note_off)
    del modified_midi.tracks[midi_file.tracks.index(track)][msg_idx]
    
    # compute similarity
    similar_files, distances = compute_similarity_after_modification(modified_midi)
    
    # check for differences
    is_different = original_similar_files != similar_files
    
    print(f"\nRemoved note {note_idx+1}: {note}")
    if is_different:
        print("  Similarity changed! New top similar files:")
        for i, (dist, file) in enumerate(zip(distances, similar_files)):
            print(f"  {i+1}. {file} (distance: {dist:.4f})")
    else:
        print("  No change in similarity ordering")
        
print("\nNote: This approach is an approximation. Ideally, we would need to:\n"
      "1. Know exactly how embeddings are generated from MIDI files\n"
      "2. Generate new embeddings for each modified MIDI file\n"
      "3. Search using those new embeddings")