In [None]:
import os
import IPython.display as ipd
import numpy as np
import librosa
import librosa.display
from matplotlib import pyplot as plt
from skimage.feature import peak_local_max
import json

# Set paths for database, fingerprints, query, and output as required
path_to_database = "C:/DriveSync/Queen_Mary/Modules/Music_Informatics/MI Coursework 2 and lab 5/MI coursework 2/database_recordings"
path_to_fingerprints = "C:/DriveSync/Queen_Mary/Modules/Music_Informatics/MI Coursework 2 and lab 5/MI coursework 2/Fingerprints"
query_path = "C:/DriveSync/Queen_Mary/Modules/Music_Informatics/MI Coursework 2 and lab 5/MI coursework 2/query_recordings"
path_to_output = "C:/DriveSync/Queen_Mary/Modules/Music_Informatics/MI Coursework 2 and lab 5/MI coursework 2/output.txt"


# Load audio
def load_audio(file_path):

        y, sr = librosa.load(file_path, sr=None)

        return y, sr

# Compute and plot STFT spectrogram
def compute_stft(y):
        D = np.abs(librosa.stft(y,n_fft=2048,window='hann',win_length=2048,hop_length=512))
        return D


def plot_stft(D, sr, file_name):
    plt.figure(figsize=(10, 5))
    librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max),
                             y_axis='linear', x_axis='time', cmap='gray_r', sr=sr)
    plt.title(f"{file_name} STFT plot")
    plt.tight_layout()
    plt.show()

        

def create_constellation(D):
        """
        Create a constellation map by detecting spectral peaks in a spectrogram.      
        Args:
                D: Magnitude spectrogram.

        Returns:
                Array of peak coordinates with shape (N, 2), where each row is [frequency_bin, time_bin].
    """
        
        constellation_coordinates = peak_local_max(np.log(D), min_distance=10,threshold_rel=0.05)
        return(constellation_coordinates)



def plot_constellation(constellation_coordinates, file_name, hashed=False):
    """
    Plot the constellation map.

    Args:
        constellation_coordinates: 2D array of [frequency, time] pairs.
        file_name (str): Name of the audio file (used in the plot title).
        hashed (bool): Whether the constellation is hashed or not (affects y-axis label and title).
    """
    plt.figure(figsize=(10, 5))
    plt.plot(constellation_coordinates[:, 1], constellation_coordinates[:, 0], 'r.')
    plt.xlabel("time step")
    plt.ylabel("hashed frequency" if hashed else "frequency")
    title_suffix = "hashed constellation" if hashed else "constellation"
    plt.title(f"{file_name} {title_suffix}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

        
def hashing(constellation_coordinates, freq_bin_size=4):
        """
        Hash frequency bins to reduce frequency resolution.

        Args:
                constellation_coordinates: Array of [frequency_bin, time_bin] coordinates.
                freq_bin_size (int): Binning size for frequency.

        Returns:
                Hashed constellation coordinates with frequency bins compressed.
        """

        constellation_coordinates[:, 0] = constellation_coordinates[:, 0] // freq_bin_size
        return constellation_coordinates
       


def fingerprintBuilder (path_to_database, path_to_fingerprints):

        """
        Build hashed constellation fingerprints for each audio file in the database and save them to text files.

        Args:
                path_to_database (str): Path to folder containing .wav GTZAN files.
                path_to_fingerprints (str): Path to folder where fingerprint .txt files will be saved.

        Returns:
        dict: Dictionary mapping each database filename to its hashed constellation list.
        """

        fingerprints = {}
        # create an empty dictionary to store the fingerprints

        for file_name in os.listdir(path_to_database):
                if file_name.endswith('.wav'): 
                        dBfile_path = os.path.join(path_to_database, file_name)
                        dBy, dBsr = load_audio(dBfile_path)
                        dBstft = compute_stft(dBy)
                        # plot_stft(dBstft, dBsr, file_name)
                        dBconstellation = create_constellation(dBstft)
                        # plot_constellation(dBconstellation, file_name, hashed=False)
                        dBhashed_constellation = hashing(dBconstellation)
                        # plot_constellation(dBhashed_constellation, file_name, hashed=True)


                txt_file_name = file_name.replace('.wav', '.txt')
                txt_path = os.path.join(path_to_fingerprints, txt_file_name)

                # Save the dictionary to a text file (JSON-style)
                with open(txt_path, 'w') as f:
                        json.dump(dBhashed_constellation.tolist(), f)
        
                fingerprints[file_name] = dBhashed_constellation.tolist()  # convert to list to save as JSON
        
        return fingerprints # return the fingerprints dictionary



def prepare_query(path_to_queryset):
        
        """
        Prepare hashed constellation coordinates for all query audio files.

        Args:
                path_to_queryset (str): Path to folder containing the query .wav files.

        Returns:
                dict: Dictionary mapping each query to a list of [time_bin, frequency_bin] pairs.
        """

        query_dict = {}

        for file_name in os.listdir(path_to_queryset):
                if file_name.endswith('.wav'): 
                        file_path = os.path.join(query_path, file_name)

                        y, sr = load_audio(file_path)
                        # print(y.shape, "and sampling rate, ", sr) # can delete this line before submitting
                        target_sr = 22050
                        y_downsampled = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
                        # print(y_downsampled.shape, "and sampling rate ", target_sr) # can delete this line before submitting
                        stft = compute_stft(y_downsampled)
                        # plotted_stft = plot_stft(stft, target_sr)
                        # print("The stft is", stft)
                        # print("The shape after stft is", stft.shape, "in the form of (frequency, time)") # and each entry in this matrix is the magnitude of that frequency at that time (note magnitude is different to amplitude and energy)
                        constellation = create_constellation(stft)
                        # print("The output constellation has shape", constellation.shape, "as peak_local_max returns the coordinates of the peaks, with each row giving the frequency and time of the peak")
                        # print("The constellation is", constellation)
                        # plotted_constellation = plot_constellation(constellation)
        
                        hashed_constellation = hashing(constellation)
                        # print("The hashed constellation is", hashed_constellation)

                        # query_dict[file_name] = hashed_constellation.tolist() # convert to list to save as JSON
                        query_dict[file_name] = [[time, freq] for freq, time in hashed_constellation.tolist()]
                        # IMPORTNAT!! HERE I SWITCH THE ORDER OF TIME AND FREQUENCY HERE, AS I WANT TO HAVE THE TIME AS THE FIRST COLUMN AND FREQUENCY AS THE SECOND COLUMN IN THE FINAL OUTPUT

        return query_dict 


""" This 
Input: constellation coordinates for A SINGLE fingerprint
Output: inverted_list_dict for that fingerprint, a dictionary where the keys are frequency bins and the values are lists of time bins.
"""
def inverted_list_construction(fingerprint_constellation_coordinates):

        """
        Change each fingerprint to an inverted list dictionary.

        Args:
                fingerprint_constellation_coordinates (list): List of [frequency_bin, time_bin] pairs.

        Returns:
                Dictionary for a single fingerprint where keys are frequency bins and values are lists of time bins.
        """
        inverted_list_dict = {}
        for row in fingerprint_constellation_coordinates:
                frequency_bin = row[0]
                time_bin = row[1]
                if frequency_bin not in inverted_list_dict:
                        inverted_list_dict[int(frequency_bin)] = [] # setting value of the key (frquency bin) to an empty list
                inverted_list_dict[frequency_bin].append(int(time_bin))
        return inverted_list_dict






def audio_matching(query_dict, all_inverted_lists):

        """
        Compare each query against all database fingerprints, obtaining a matching score for each possible time shift,
        (L(h) - n), pick the best time shift for each query as the best alignment of the query with the fingerprint,
        and return the top 3 highest scoring fingerprints for each query.


    Args:
        query_dict (dict): Dictionary mapping each query file to its list of [n, h] pairs.
        all_inverted_lists (dict): Dictionary mapping each database file to its inverted list.

    Returns:
        dict: Dictionary mapping each query to its top 3 matched database files.
    """
       

        final = {}  # Store final results (top-3 matches)

        for query_file, nh_list in query_dict.items():
                scores_for_one_query = {}  # Dictionary to keep scores for all database files for this query

                for fingerprint_file, inverted_list in all_inverted_lists.items(): # for each fingerprint in the database...
                        aggregated_shifted_lists = []  # store all numbers that come from the shifted inverted list for this query (no need to have different lists for every [n,h] pair, just aggregate them all)

                        for n, h in nh_list: # for each [n,h] pair in the query
                                h = int(h)
                                n = int(n)

                                 # Only proceed if this frequency bin exists in the fingerprint's inverted list
                                for freq_bin in inverted_list:
                                        if h == freq_bin: 

                                                for times_list in inverted_list[freq_bin]:
                                                        aggregated_shifted_lists.append(times_list - n) # # Compute the time shift (L(h) - n) and store it

                        # Count how many times each time shift occurred
                        dict_for_counting = {} 

                        for number in aggregated_shifted_lists:
                                if number not in dict_for_counting:
                                        dict_for_counting[number] = 1
                                else:
                                        dict_for_counting[number] += 1

                        # # Sort the time shifts by the number of votes (most common shifts first)
                        sorted_dict_for_counting = sorted(dict_for_counting.items(), key=lambda x: x[1], reverse=True)  # sort the dictionary by value in descending order
                        
                         # Use the count of the most frequent shift as the matching score
                        if sorted_dict_for_counting:
                                most_common_time, count = sorted_dict_for_counting[0]
                                best_count = count
                        else:
                                best_count = 0  # or some fallback if there's no match

                        scores_for_one_query[fingerprint_file] = best_count  # store the best count for this query and fingerprint

                ranked_scores_for_one_query = sorted(scores_for_one_query.items(), key=lambda x: x[1], reverse=True)  # sort the scores for this query in descending order
                top_3_matches = [item[0] for item in ranked_scores_for_one_query[:3]]  # get the top 3 matches
                final[query_file] = top_3_matches  # store the top 3 matches for this query

        return final 




def audioIdentification(path_to_queryset, path_to_fingerprints, path_to_output):
        
        """
        Perform audio matching for each query file against the fingerprinted database,
        and write the top 3 matches for each query to a text file.

        Args:
                path_to_queryset (str): Path to the folder containing query audio files.
                path_to_fingerprints (dict): Dictionary of saved hashed constellations
                path_to_output (str): File path to write the results.

        Returns:
                dict: Dictionary mapping each query filename to a list of its top 3 matched database filenames.
        """
        prepared_queries_dict = prepare_query(query_path)

        all_inverted_lists = {} # Dictionary to store inverted lists for all fingerprints
        # Build inverted lists for all fingerprints
        for file_name in saved_fingerprints:
                constellation = saved_fingerprints[file_name] # get the hashed constellation (list of lists) for the file
                inverted_database_list = inverted_list_construction(constellation) # apply inverted_list_construction to the hashed constellation
                all_inverted_lists[file_name] = inverted_database_list # store the inverted list in the dictionary with the file name as the key

        audio_matching_results = audio_matching(prepared_queries_dict, all_inverted_lists)

        with open(path_to_output, 'w') as f:
                for query_file, top_matches in audio_matching_results.items():
                        line = '\t'.join([query_file] + top_matches)
                        f.write(line + '\n')

        return audio_matching_results




# Call the fingerprintBuilder and audioIdentification functions to get results

saved_fingerprints = fingerprintBuilder(path_to_database, path_to_fingerprints)
final_results = audioIdentification(query_path, saved_fingerprints, path_to_output)


  constellation_coordinates = peak_local_max(np.log(D), min_distance=10,threshold_rel=0.05)


In [39]:
def evaluate_identification(path_to_output):
    correct_top1 = 0
    correct_top3 = 0
    total_queries = 0

    classical_top1 = 0
    classical_top3 = 0
    classical_total = 0

    pop_top1 = 0
    pop_top3 = 0
    pop_total = 0

    completely_wrong_queries = []  # To store query filenames with no correct matches

    with open(path_to_output, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) < 4:
                continue  # skip malformed lines

            query_file = parts[0]
            predictions = parts[1:4]

            # Get the correct expected match
            if "-snippet-" not in query_file:
                continue  # skip unexpected file formats

            expected = query_file.split("-snippet-")[0] + ".wav"

            is_top1_correct = predictions[0].strip() == expected
            is_top3_correct = expected in predictions

            if not is_top3_correct:
                completely_wrong_queries.append(query_file)

            # Overall counts
            total_queries += 1
            if is_top1_correct:
                correct_top1 += 1
            if is_top3_correct:
                correct_top3 += 1

            # Class-specific breakdown
            if query_file.lower().startswith('classical'):
                classical_total += 1
                if is_top1_correct:
                    classical_top1 += 1
                if is_top3_correct:
                    classical_top3 += 1

            elif query_file.lower().startswith('pop'):
                pop_total += 1
                if is_top1_correct:
                    pop_top1 += 1
                if is_top3_correct:
                    pop_top3 += 1

    print(" Overall Accuracy")
    print(f"  Top-1: {100 * correct_top1 / total_queries:.2f}%")
    print(f"  Top-3: {100 * correct_top3 / total_queries:.2f}%")

    print("\n Classical Accuracy")
    print(f"  Top-1: {100 * classical_top1 / classical_total:.2f}%")
    print(f"  Top-3: {100 * classical_top3 / classical_total:.2f}%")

    print("\n Pop Accuracy")
    print(f"  Top-1: {100 * pop_top1 / pop_total:.2f}%")
    print(f"  Top-3: {100 * pop_top3 / pop_total:.2f}%")

    print("\n The number of queries for which the correct match wasn't in the top 3 guesses is:", len(completely_wrong_queries))
    print(f" These were:")
    for query in completely_wrong_queries:
        print(f"  - {query}")

    return completely_wrong_queries

# Call the evaluation function  
evaluation = evaluate_identification("C:/DriveSync/Queen_Mary/Modules/Music_Informatics/MI Coursework 2 and lab 5/MI coursework 2/output.txt")

 Overall Accuracy
  Top-1: 72.77%
  Top-3: 77.46%

 Classical Accuracy
  Top-1: 58.33%
  Top-3: 60.19%

 Pop Accuracy
  Top-1: 87.62%
  Top-3: 95.24%

 The number of queries for which the correct match wasn't in the top 3 guesses is: 48
 These were:
  - classical.00003-snippet-10-0.wav
  - classical.00003-snippet-10-10.wav
  - classical.00003-snippet-10-20.wav
  - classical.00009-snippet-10-0.wav
  - classical.00009-snippet-10-10.wav
  - classical.00009-snippet-10-20.wav
  - classical.00014-snippet-10-0.wav
  - classical.00014-snippet-10-20.wav
  - classical.00019-snippet-10-0.wav
  - classical.00019-snippet-10-10.wav
  - classical.00019-snippet-10-20.wav
  - classical.00020-snippet-10-0.wav
  - classical.00039-snippet-10-0.wav
  - classical.00039-snippet-10-10.wav
  - classical.00039-snippet-10-20.wav
  - classical.00040-snippet-10-0.wav
  - classical.00040-snippet-10-10.wav
  - classical.00044-snippet-10-0.wav
  - classical.00044-snippet-10-20.wav
  - classical.00050-snippet-10-0.wav