In [175]:
import os
import sys
import glob
import resampy
import librosa
import numpy as np
import scipy.signal
import librosa.display
import soundfile as sf
import pickle
import matplotlib.pyplot as plt
from skimage.feature import peak_local_max
import random


Functions for the system:

In [176]:
#this function picks peaks above a given threshold from the stft of an audio file
def peak_picking(
    audio_file : str,
    sample_rate : int = 22050,
    n_fft : int = 2048, 
    hop_length : int = 256,
    win_length : int = None,
    window : str = 'hann',
    peak_distance : int = 24,
    threshold_abs : float = -60,
    
) -> list:
#read the audio file
    x, sr = sf.read(audio_file)
#standardize all audio to 22.5KHz
    if sr != sample_rate:
            x = resampy.resample(x, sr, sample_rate)
#normalise the audio
    x /= np.max(np.abs(x))
#calculate stft
    D = librosa.stft(
            x, 
            n_fft=n_fft, 
            hop_length=hop_length,
            win_length=win_length,
            window=window
        )
#calculate magnitude
    X = np.abs(D)
    X /= np.max(X)

    
#find peaks in the stft based on the threshold
    peaks = peak_local_max(
                        X, 
                        min_distance=peak_distance, 
                        threshold_abs=threshold_abs
                    )
#sort the peaks in the time axis
    peaks = peaks[np.argsort(peaks[:, 1])]

    return peaks
        


In [177]:
#this function creates hashes from the peaks
def generate_hashes(
           audio_file: str,
           max_peaks : int = 24 ,
           target_zone_freq : int = 256,
           target_zone_time : int = 128,
           
        
):

    filename = os.path.basename(audio_file).replace(".wav", "")
    peaks = peak_picking(audio_file)
    hashes = []
    # iterate over each peak
    for n in range(peaks.shape[0]):
        anchor_peaks = peaks[n,:]

        # iterate over all peaks to find peaks within the target zone
        peaks_in_target_zone = 0
        for i in range(peaks.shape[0]):
            target_peaks = peaks[i,:] 
            # find peaks within the target zone to generate hashes
            if ((target_peaks[0] < anchor_peaks[0] + target_zone_freq) and (target_peaks[0] > anchor_peaks[0] - target_zone_freq) and 
                (target_peaks[1] < anchor_peaks[1] + target_zone_time) and (target_peaks[1] > anchor_peaks[1] + 1) and 
                not np.array_equal(target_peaks, anchor_peaks)):
                hashes.append({
                    "hash" : (anchor_peaks[0], target_peaks[0], target_peaks[1] - anchor_peaks[1]),
                            # (anchor_peaks freq, target_peaks freqs, time difference)
                    "song" : filename,
                    "timestep" : anchor_peaks[1]
                })
                peaks_in_target_zone += 1
            if peaks_in_target_zone > max_peaks: break

    if len(hashes) < 1:
        print(f"No hashes found. {filename}")

    return hashes

In [179]:
#this functions generates fingerprints for the database
def db_fingerprint(audiodata_dir : str,
                   fp_dir : str = "fingerprints.pkl",
                   
):


    audio_files = sorted(glob.glob(os.path.join(audiodata_dir, "*.wav")))
 
    db_songs_hashes = []

    for audio_file in audio_files:
        db_songs_hashes.append(generate_hashes(audio_file))


    db_songs = []

    # structure the list of songs
    # create a list of dicts, which the song id, and inverted lists
    for db_song in db_songs_hashes:
        db_song_inv_lists = {}
        if len(db_song) < 1:
            print(db_song)
        song_id = db_song[0]["song"]
        
        for db_song_hash in db_song:
            if db_song_hash["hash"] not in db_song_inv_lists:
                db_song_inv_lists[db_song_hash["hash"]] = [db_song_hash["timestep"]]
            else:
                db_song_inv_lists[db_song_hash["hash"]] += db_song_hash["timestep"]
        db_songs.append({
            "song_id" : song_id,
            "inverted_lists" : db_song_inv_lists
        })

    # save list of database song hashes to disk
    with open(fp_dir, 'wb') as f:
        pickle.dump(db_songs, f)    


    
    


In [180]:
#this function is used to find all the possible matches for the query
def find_matches(
      query_hashes : list,
      db_fingerprints: list
  ):
 
    matches = {}
    # iterate over each fingerprint in the database
    for db_fingerprint in db_fingerprints:
        fingerprint_id = db_fingerprint["song_id"]
        inverted_lists = db_fingerprint["inverted_lists"]
        match_timesteps = {}
        for q_hash in query_hashes: # each hash is a query
            if q_hash["hash"] in inverted_lists: # check if hash is in inverted lists
                for timestep in inverted_lists[q_hash["hash"]]:
                    shifted = timestep - q_hash["timestep"]
                    if shifted not in match_timesteps:
                        match_timesteps[shifted] = 1
                    else:
                        match_timesteps[shifted] += 1

        if len(match_timesteps.values()) > 0:
            # return the max value of the mathcing function histogram
            matches[fingerprint_id] = max(match_timesteps.values())

    return matches

In [181]:
#this function finda the matches
def audioid(
    query_audio: str,
    fingerprint_database: str = "fingerprints.pkl",
    ):
#generate fingerprint for the query
    query_hash = generate_hashes(query_audio)
    #print("\ngenerating query hashes..")
#load back the fingerprints for the database
    with open(fingerprint_database, 'rb') as f:
        db_fingerprints = pickle.load(f)
    
    query_match = [] # we will store all matches here
    #print("Finding matches...")
    query_match.append(find_matches(query_hash, db_fingerprints))

    return query_match
    



In [182]:
#This function finda the top match for a query
def top_match_(
    q_audio : str
):
    matches = audioid(q_audio)
    max_val = 0
    max_key = None
    #Check if the song from the database with highest number of matches is the same as the baseline truth
    for k in matches[0].keys():
    #print(matches[0][k])
        if max_val < matches[0][k]:
            max_key = k
            max_val = matches[0][k]

    query_id = os.path.basename(q_audio).strip(".wav")

    print("\n for %s top match is %s" %(query_id,max_key))
    return max_key



In [183]:
#To evaluate the system, we find the number of correct top matches
def accuracy(query_files : list,


):

    
    total_queries = len(query_files)
    correct_matches = 0
    progress = 0
    for audio_file in query_files:
        progress += 1
        print ("\n %s of %s "%(progress,total_queries))
        top_match = top_match_(audio_file)
        
        # get the ground truth song id

        query_id = os.path.basename(audio_file).strip(".wav")
        #print("\n",query_id)
        song_id = query_id.split('-')[0]
        if song_id == top_match:
            correct_matches += 1
        
    
    accuracy = correct_matches/total_queries * 100

    return accuracy


Get the paths : 

In [184]:
PATH = os.getcwd()
#print(PATH)
database = PATH + '/gtzan/database_recordings/'
query = PATH + '/gtzan/query_recordings/'

In [185]:
#generate fingerprints for the database
#db_fingerprint(database)
print("Already generated!")

Already generated!


In [186]:
query_files = sorted(glob.glob(os.path.join(query, "*.wav")))
total_queries = len(query_files)
print(total_queries)
type(query_files)

213


list

Testing the system for a sample query audio:

In [188]:
#Testing for query_audio

qaudio_file = query + 'classical.00000-snippet-10-0.wav'
match = top_match_(qaudio_file)
#print(match)

NameError: name 'n_fft' is not defined

Evaluate the system:

In [117]:
#Compute accuracy of the system
query_files = sorted(glob.glob(os.path.join(query_dir, "*.wav")))
accuracy_ = accuracy(query_files)




 1 of 213 

 for classical.00000-snippet-10-0 top match is classical.00000

 classical.00000-snippet-10-0

 2 of 213 

 for classical.00000-snippet-10-10 top match is classical.00000

 classical.00000-snippet-10-10

 3 of 213 

 for classical.00000-snippet-10-20 top match is classical.00000

 classical.00000-snippet-10-20

 4 of 213 

 for classical.00003-snippet-10-0 top match is classical.00003

 classical.00003-snippet-10-0

 5 of 213 

 for classical.00003-snippet-10-10 top match is classical.00003

 classical.00003-snippet-10-10

 6 of 213 

 for classical.00003-snippet-10-20 top match is pop.00036

 classical.00003-snippet-10-20

 7 of 213 

 for classical.00004-snippet-10-0 top match is classical.00004

 classical.00004-snippet-10-0

 8 of 213 

 for classical.00004-snippet-10-10 top match is classical.00004

 classical.00004-snippet-10-10

 9 of 213 

 for classical.00005-snippet-10-0 top match is classical.00005

 classical.00005-snippet-10-0

 10 of 213 

 for classical.0000

In [118]:
print("Accuracy of the system (in %): ", accuracy_)

Accuracy of the system (in %):  81.2206572769953


In [157]:
query_files = (glob.glob(os.path.join(query, "*.wav")))

In [173]:

#Checking accuracy on a test query set of 10 examples for different n_fft values
query_files = (glob.glob(os.path.join(query, "*.wav")))

#type(query_files)
test_qf = []
test_qf = random.choices(query_files,k=25)
#print(test_qf) 


In [174]:
#Accuracy for n_fft =2048
accuracy_ = accuracy(test_qf)
print("Accuracy for n_fft = 2048: ", accuracy_)



 1 of 25 

 for pop.00003-snippet-10-0 top match is classical.00018

 2 of 25 

 for classical.00003-snippet-10-20 top match is classical.00007

 3 of 25 

 for pop.00011-snippet-10-10 top match is None

 4 of 25 

 for classical.00013-snippet-10-0 top match is classical.00083

 5 of 25 

 for pop.00020-snippet-10-0 top match is pop.00045

 6 of 25 

 for classical.00020-snippet-10-10 top match is classical.00010

 7 of 25 

 for pop.00045-snippet-10-10 top match is None

 8 of 25 

 for pop.00084-snippet-10-0 top match is classical.00001

 9 of 25 

 for pop.00039-snippet-10-0 top match is None

 10 of 25 

 for classical.00039-snippet-10-10 top match is classical.00022

 11 of 25 
No hashes found. classical.00000-snippet-10-0

 for classical.00000-snippet-10-0 top match is None

 12 of 25 

 for pop.00075-snippet-10-0 top match is None

 13 of 25 

 for classical.00075-snippet-10-0 top match is classical.00084

 14 of 25 

 for classical.00085-snippet-10-10 top match is None

 15 of

KeyboardInterrupt: 

In [65]:
#debug code

qaudio_file = database + 'classical.00000.wav'
#peaks = peak_picking(audio_file)
#print(peaks)
#type(peaks)
#peaks.shape
#hash = generate_hashes(audio_file)
matches = audioid(qaudio_file)
print(matches)
type(matches)
len(matches)
#type(hash)
# len(hash)
#print(hash)
#matches = audioid(qaudio_file)
#matches = [(db_song_id, score) for db_song_id, score in matches.items()]
# matches = sorted(matches, key=lambda a: a[1], reverse=True)
# print("\nsorted \n" ,matches)



 generating query hashes..
Finding matches...
[{'classical.00000': 14640, 'classical.00001': 1, 'classical.00002': 3, 'classical.00003': 1, 'classical.00004': 1, 'classical.00005': 1, 'classical.00006': 1, 'classical.00007': 1, 'classical.00008': 1, 'classical.00009': 2, 'classical.00010': 1, 'classical.00011': 1, 'classical.00012': 1, 'classical.00013': 2, 'classical.00014': 1, 'classical.00015': 1, 'classical.00016': 1, 'classical.00017': 3, 'classical.00018': 1, 'classical.00019': 1, 'classical.00020': 1, 'classical.00021': 1, 'classical.00022': 3, 'classical.00023': 1, 'classical.00024': 1, 'classical.00025': 2, 'classical.00026': 1, 'classical.00027': 1, 'classical.00028': 1, 'classical.00029': 3, 'classical.00030': 1, 'classical.00031': 1, 'classical.00032': 1, 'classical.00033': 2, 'classical.00034': 1, 'classical.00035': 1, 'classical.00036': 1, 'classical.00037': 1, 'classical.00038': 1, 'classical.00039': 1, 'classical.00040': 1, 'classical.00041': 1, 'classical.00042': 2, '

1

In [84]:
print(type(matches))
#print(matches[0])
max_val = 0
max_key = None
for k in matches[0].keys():
    #print(matches[0][k])
    if max_val < matches[0][k]:

        max_key = k
        max_val = matches[0][k]

print(max_key, max_val)

<class 'list'>
classical.00000 14640
