# The New Bayesian Proo

This notebook contains the updated code for fine-tuning an audio fingerprinting and matching system for Ghanaian music and radio clips using Bayesian optimization with Optuna.

## Prerequisites

- **Audio Files**: Place your full song WAV files in `./audio/songs` and corresponding 10-second clip WAV files in `./audio/clips`. Ensure clip filenames contain the corresponding song filename (e.g., `song1.wav` and `song1_clip1.wav`).
- **Dependencies**: Install the required Python packages: `librosa`, `numba`, `xxhash`, `matplotlib`, `numpy`, `scipy`, `pandas`, and `optuna`.
- **Environment**: Run in a Python environment with access to ffmpeg (for audio conversion/extraction if needed).

## Setup

Install dependencies and create the audio directories.

In [None]:
!pip install librosa numba xxhash matplotlib numpy scipy pandas optuna
!mkdir -p ./audio/songs ./audio/clips

## Define Core Functions

Define the necessary functions for audio loading, fingerprinting, and matching, including the corrected `get_2D_peaks_numba` and `generate_song_fingerprints` functions.

In [None]:
import os
import glob
import pandas as pd
import optuna
import numpy as np
import librosa
import matplotlib.pyplot as plt
from numba import jit
import xxhash
from operator import itemgetter
from collections import Counter
import logging
from typing import List, Tuple

# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configuration (will be overridden by optimized parameters)
CONFIG = {
    'DEFAULT_FS': 44100,
    'DEFAULT_WINDOW_SIZE': 2048,
    'DEFAULT_OVERLAP_RATIO': 0.5,
    'DEFAULT_FAN_VALUE': 15,
    'DEFAULT_AMP_MIN': -20,
    'PEAK_NEIGHBORHOOD_SIZE': 10,
    'MIN_HASH_TIME_DELTA': 0,
    'MAX_HASH_TIME_DELTA': 500,
    'FINGERPRINT_REDUCTION': 20,
    'PEAK_SORT': True
}

@jit(nopython=True)
def get_2D_peaks_numba(arr2D: np.ndarray, amp_min: float, peak_neighborhood_size: int) -> List[Tuple[int, int]]:
    """Optimized peak detection with numba."""
    peaks = []
    rows, cols = arr2D.shape
    neighborhood_size = peak_neighborhood_size // 2
    for i in range(neighborhood_size, rows - neighborhood_size):
        for j in range(neighborhood_size, cols - neighborhood_size):
            if arr2D[i, j] > amp_min:
                is_max = True
                for di in range(-neighborhood_size, neighborhood_size + 1):
                    for dj in range(-neighborhood_size, neighborhood_size + 1): # Corrected typo
                        if di == 0 and dj == 0:
                            continue
                        if arr2D[i + di, j + dj] > arr2D[i, j]:
                            is_max = False
                            break
                    if not is_max:
                        break
                if is_max:
                    peaks.append((i, j))
    return peaks

def get_2D_peaks(arr2D: np.ndarray, plot: bool = False, amp_min: float = CONFIG['DEFAULT_AMP_MIN'],
                 peak_neighborhood_size: int = CONFIG['PEAK_NEIGHBORHOOD_SIZE']) -> List[Tuple[int, int]]:
    """Extract peaks from spectrogram."""
    try:
        peaks = get_2D_peaks_numba(arr2D, amp_min, peak_neighborhood_size)
        # logger.info(f"Detected {len(peaks)} peaks with amp_min={amp_min}") # Removed for simplicity in this demo
        if plot:
            plt.figure(figsize=(10, 6))
            plt.imshow(arr2D, origin='lower', aspect='auto', cmap='viridis')
            if peaks:
                freqs, times = zip(*peaks)
                plt.scatter(times, freqs, c='r', s=10, label='Peaks')
            plt.colorbar(label='Amplitude (dB)')
            plt.xlabel('Time (frames)')
            plt.ylabel('Frequency (bins)')
            plt.title(f'Spectrogram with Detected Peaks (amp_min={amp_min})')
            plt.legend()   
            plt.show()
        return peaks
    except Exception as e:
        logger.error(f"Peak detection failed: {e}")
        return []

def generate_hashes(peaks: List[Tuple[int, int]], fan_value: int = CONFIG['DEFAULT_FAN_VALUE'],
                    min_hash_time_delta: int = CONFIG['MIN_HASH_TIME_DELTA'],
                    max_hash_time_delta: int = CONFIG['MAX_HASH_TIME_DELTA'],
                    fingerprint_reduction: int = CONFIG['FINGERPRINT_REDUCTION'],
                    peak_sort: bool = CONFIG['PEAK_SORT']) -> List[Tuple[str, int]]:
    """Generate hashes from peaks."""
    try:
        if peak_sort:
            peaks.sort(key=itemgetter(1))
        hashes = []
        valid_pairs = 0
        for i in range(len(peaks)):
            for j in range(1, fan_value):
                if (i + j) < len(peaks):
                    freq1 = peaks[i][0]
                    freq2 = peaks[i + j][0]
                    t1 = peaks[i][1]
                    t2 = peaks[i + j][1]
                    t_delta = t2 - t1
                    if min_hash_time_delta <= t_delta <= max_hash_time_delta:
                        valid_pairs += 1
                        h = xxhash.xxh64(f"{freq1}|{freq2}|{t_delta}".encode('utf-8'))
                        hash_str = h.hexdigest()[:fingerprint_reduction]
                        hashes.append((hash_str, t1))
        # logger.info(f"Generated {valid_pairs} valid peak pairs for hashing") # Removed for simplicity in this demo
        return hashes
    except Exception as e:
        logger.error(f"Hash generation failed: {e}")
        return []

def fingerprint(channel_samples: np.ndarray, Fs: int = CONFIG['DEFAULT_FS'],
                wsize: int = CONFIG['DEFAULT_WINDOW_SIZE'], wratio: float = CONFIG['DEFAULT_OVERLAP_RATIO'],
                fan_value: int = CONFIG['DEFAULT_FAN_VALUE'], amp_min: float = CONFIG['DEFAULT_AMP_MIN'],
                peak_neighborhood_size: int = CONFIG['PEAK_NEIGHBORHOOD_SIZE'],
                min_hash_time_delta: int = CONFIG['MIN_HASH_TIME_DELTA'],
                max_hash_time_delta: int = CONFIG['MAX_HASH_TIME_DELTA'],
                fingerprint_reduction: int = CONFIG['FINGERPRINT_REDUCTION'],
                peak_sort: bool = CONFIG['PEAK_SORT']) -> List[Tuple[str, int]]:
    """Generate fingerprints from audio samples."""
    try:
        samples = channel_samples.astype(np.float32) / 32768.0
        hop_length = int(wsize * (1 - wratio))
        S = librosa.stft(samples, n_fft=wsize, hop_length=hop_length, window='hann')
        arr2D = librosa.amplitude_to_db(np.abs(S), ref=np.max)
        # logger.info(f"Spectrogram min: {arr2D.min():.2f}, max: {arr2D.max():.2f}") # Removed for simplicity
        local_maxima = get_2D_peaks(arr2D, amp_min=amp_min, peak_neighborhood_size=peak_neighborhood_size)
        hashes = generate_hashes(local_maxima, fan_value=fan_value,
                                min_hash_time_delta=min_hash_time_delta,
                                max_hash_time_delta=max_hash_time_delta,
                                fingerprint_reduction=fingerprint_reduction,
                                peak_sort=peak_sort)
        # logger.info(f"Generated {len(hashes)} fingerprints for {len(samples)/Fs:.2f}s audio") # Removed for simplicity
        return hashes
    except Exception as e:
        logger.error(f"Fingerprinting failed: {e}")
        return []

def load_audio(file_path: str) -> Tuple[np.ndarray, int]:
    """Load audio file and convert to int16."""
    try:
        if not os.path.exists(file_path):
            logger.error(f"File not found: {file_path}")
            return np.array([]), 0
        samples, sr = librosa.load(file_path, sr=CONFIG['DEFAULT_FS'], mono=True)
        samples = (samples * 32768).astype(np.int16)
        # logger.info(f"Loaded {file_path}: {len(samples)} samples, {sr} Hz, max amplitude: {np.max(np.abs(samples))}") # Removed for simplicity
        return samples, sr
    except Exception as e:
        logger.error(f"Failed to load {file_path}: {e}")
        return np.array([]), 0

# Simulate Django Song model (simplified)
class Song:
    def __init__(self, id, title):
        self.id = id
        self.title = title

def generate_song_fingerprints(samples, sr, song_id, **params):
    """Generate fingerprints for a song using provided parameters."""
    if len(samples) == 0:
        logger.error("No samples provided for fingerprinting")
        return []

    # Extract only fingerprinting parameters from the params dictionary
    fingerprint_params = {
        key: params[key] for key in ['wsize', 'wratio', 'fan_value', 'amp_min',
                                     'peak_neighborhood_size', 'min_hash_time_delta',
                                     'max_hash_time_delta', 'fingerprint_reduction', 'peak_sort']
        if key in params
    }

    fingerprints = fingerprint(samples, Fs=sr, **fingerprint_params)
    db_fingerprints = [(song_id, h, o) for h, o in fingerprints]
    return db_fingerprints

def match_clip(clip_samples, clip_sr, db_fingerprints, **params):
    """Match clip fingerprints against database using provided parameters."""
    if len(clip_samples) == 0:
        return {"match": False, "reason": "No samples in clip"}

    # Extract only fingerprinting parameters for generating clip fingerprints
    fingerprint_params = {
        key: params[key] for key in ['wsize', 'wratio', 'fan_value', 'amp_min',
                                     'peak_neighborhood_size', 'min_hash_time_delta',
                                     'max_hash_time_delta', 'fingerprint_reduction', 'peak_sort']
        if key in params
    }

    # Generate clip fingerprints
    clip_fingerprints = fingerprint(clip_samples, Fs=clip_sr, **fingerprint_params)

    if not clip_fingerprints:
        return {"match": False, "reason": "No fingerprints extracted", "hashes_matched": 0,
                "input_confidence": 0.0, "db_confidence": 0.0}

    # Match fingerprints
    query_hashes = [h for h, _ in clip_fingerprints]
    db_hashes = {(h, o, song_id) for song_id, h, o in db_fingerprints if h in query_hashes}
    if not db_hashes:
        return {"match": False, "reason": "No matching hashes found", "hashes_matched": 0,
                "input_confidence": 0.0, "db_confidence": 0.0}

    match_map = Counter()
    for h, query_offset in clip_fingerprints:
        for db_hash, db_offset, song_id in db_hashes:
            if h == db_hash:
                offset_diff = db_offset - query_offset
                match_map[(song_id, offset_diff)] += 1

    if not match_map:
        return {"match": False, "reason": "No offset alignment found", "hashes_matched": 0,
                "input_confidence": 0.0, "db_confidence": 0.0}

    (song_id, offset_diff), match_count = match_map.most_common(1)[0]
    total_query_hashes = len(query_hashes)
    total_db_hashes = sum(1 for _, _, sid in db_fingerprints if sid == song_id)
    input_confidence = (match_count / total_query_hashes) * 100
    db_confidence = (match_count / total_db_hashes) * 100 if total_db_hashes else 0

    min_match_count = params.get('min_match_count', 10)
    min_input_conf = params.get('min_input_conf', 10.0)
    min_db_conf = params.get('min_db_conf', 2.0)


    if match_count < min_match_count or input_confidence < min_input_conf or db_confidence < min_db_conf:
        return {
            "match": False,
            "reason": "Low confidence match",
            "hashes_matched": match_count,
            "input_confidence": input_confidence,
            "db_confidence": db_confidence
        }

    return {
        "match": True,
        "song_id": song_id,
        "offset": offset_diff,
        "hashes_matched": match_count,
        "input_confidence": input_confidence,
        "db_confidence": db_confidence
    }

## Load Dataset

Load the audio files from the specified directories and create song-clip pairs for evaluation.

In [None]:
# Define the directories for full songs and clips
SONG_DIR = './audio/songs'
CLIP_DIR = './audio/clips'

# IMPORTANT: Please ensure your full song WAV files are in the ./audio/songs directory
# and your 10-second clip WAV files are in the ./audio/clips directory before running this cell.
# Ensure clip filenames contain the corresponding song filename (e.g., song1.wav and song1_clip1.wav).


def load_audio_files(directory: str) -> dict:
    """Recursively load audio files from a directory."""
    audio_data = {}
    for file_path in glob.glob(os.path.join(directory, '**/*.wav'), recursive=True):
        samples, sr = load_audio(file_path)
        if len(samples) > 0:
            audio_data[os.path.basename(file_path)] = (samples, sr)
    return audio_data

# Load all songs and clips
all_songs = load_audio_files(SONG_DIR)
all_clips = load_audio_files(CLIP_DIR)

print(f"Loaded {len(all_songs)} songs from {SONG_DIR}")
print(f"Loaded {len(all_clips)} clips from {CLIP_DIR}")

# Create song-clip pairs. Assuming clip filenames contain the corresponding song filename.
# Example: song1.wav and song1_clip1.wav, song1_clip2.wav
song_clip_pairs = []
song_id_counter = 0
song_mapping = {} # Map song filename to song_id

for song_filename, (song_samples, song_sr) in all_songs.items():
    song_id_counter += 1
    current_song_id = song_id_counter
    song_mapping[song_filename] = current_song_id

    corresponding_clips = [(clip_filename, (clip_samples, clip_sr))
                           for clip_filename, (clip_samples, clip_sr) in all_clips.items()
                           if song_filename.split('.')[0] in clip_filename] # Simple filename matching

    if corresponding_clips:
        for clip_filename, (clip_samples, clip_sr) in corresponding_clips:
            song_clip_pairs.append({
                'song_id': current_song_id,
                'song_filename': song_filename,
                'song_samples': song_samples,
                'song_sr': song_sr,
                'clip_filename': clip_filename,
                'clip_samples': clip_samples,
                'clip_sr': clip_sr
            })
    else:
         logger.warning(f"No corresponding clips found for song: {song_filename}")

print(f"Created {len(song_clip_pairs)} song-clip pairs for evaluation.")

## Define Optimization Objective and Parameter Ranges

Define the objective function for Optuna and the search space for the parameters.

In [None]:
# Define parameter ranges tailored for Ghanaian music (Adjusted ranges)
param_ranges = {
    'amp_min': (-35, -10),
    'peak_neighborhood_size': (5, 12),
    'fan_value': (10, 20),
    'min_match_count': (5, 15),
    'min_input_conf': (5.0, 15.0),
    'min_db_conf': (1.0, 3.0),
    'wsize': [1024, 2048, 4096],
    'wratio': (0.5, 0.85),
    'min_hash_time_delta': (0, 10),
    'max_hash_time_delta': (200, 500),
    'fingerprint_reduction': [16, 20, 24],
    'peak_sort': [True, False]
}

# Update evaluate_params to process the dataset
def evaluate_params_on_dataset(dataset: List[dict], params: dict):
    """Evaluate parameters across the entire dataset."""
    total_score = 0
    trial_results = []

    if not dataset:
        return 0, [] # Return 0 score and empty results if dataset is empty

    for item in dataset:
        song_id = item['song_id']
        song_samples = item['song_samples']
        song_sr = item['song_sr']
        clip_samples = item['clip_samples']
        clip_sr = item['clip_sr']

        # Generate song fingerprints
        song_fingerprint_params = {
            key: params[key] for key in ['wsize', 'wratio', 'fan_value', 'amp_min',
                                         'peak_neighborhood_size', 'min_hash_time_delta',
                                         'max_hash_time_delta', 'fingerprint_reduction', 'peak_sort']
            if key in params
        }
        song_fingerprints = generate_song_fingerprints(song_samples, song_sr, song_id, **song_fingerprint_params)

        # Generate clip fingerprints and match
        clip_fingerprint_params = {
            key: params[key] for key in ['wsize', 'wratio', 'fan_value', 'amp_min',
                                         'peak_neighborhood_size', 'min_hash_time_delta',
                                         'max_hash_time_delta', 'fingerprint_reduction', 'peak_sort']
            if key in params
        }
        result = match_clip(clip_samples, clip_sr, song_fingerprints, **params)

        # Compute metrics (re-using the scoring logic from the original evaluate_params)
        # Need to pass parameters to fingerprint function when calculating clip_fingerprints_len
        clip_fingerprints_len = len([h for h, _ in fingerprint(clip_samples, Fs=clip_sr, **clip_fingerprint_params)])
        song_fingerprints_len = len(song_fingerprints)

        max_clip_fingerprints = clip_fingerprints_len if clip_fingerprints_len > 0 else 1
        max_song_fingerprints = song_fingerprints_len if song_fingerprints_len > 0 else 1

        score = (
            (clip_fingerprints_len / max_clip_fingerprints) * 0.3 +
            (song_fingerprints_len / max_song_fingerprints) * 0.3 +
            result.get('input_confidence', 0.0) * 0.2 +
            result.get('db_confidence', 0.0) * 0.2
        ) if result['match'] else 0


        trial_results.append({
            'song_id': song_id,
            'song_filename': item['song_filename'],
            'clip_filename': item['clip_filename'],
            'clip_fingerprints': clip_fingerprints_len,
            'song_fingerprints': song_fingerprints_len,
            'match': result['match'],
            'hashes_matched': result.get('hashes_matched', 0),
            'input_confidence': result.get('input_confidence', 0.0),
            'db_confidence': result.get('db_confidence', 0.0),
            'reason': result.get('reason', ''),
            'score': score,
            **params # Include parameters in the results for analysis
        })
        total_score += score

    # Return average score or a combined metric
    average_score = total_score / len(dataset) if dataset else 0
    return average_score, trial_results


# Define objective function for optuna using the dataset
def objective_dataset(trial):
    params = {
        'amp_min': trial.suggest_float('amp_min', param_ranges['amp_min'][0], param_ranges['amp_min'][1]),
        'peak_neighborhood_size': trial.suggest_int('peak_neighborhood_size', param_ranges['peak_neighborhood_size'][0], param_ranges['peak_neighborhood_size'][1]),
        'fan_value': trial.suggest_int('fan_value', param_ranges['fan_value'][0], param_ranges['fan_value'][1]),
        'min_match_count': trial.suggest_int('min_match_count', param_ranges['min_match_count'][0], param_ranges['min_match_count'][1]),
        'min_input_conf': trial.suggest_float('min_input_conf', param_ranges['min_input_conf'][0], param_ranges['min_input_conf'][1]),
        'min_db_conf': trial.suggest_float('min_db_conf', param_ranges['min_db_conf'][0], param_ranges['min_db_conf'][1]),
        'wsize': trial.suggest_categorical('wsize', param_ranges['wsize']),
        'wratio': trial.suggest_float('wratio', param_ranges['wratio'][0], param_ranges['wratio'][1]),
        'min_hash_time_delta': trial.suggest_int('min_hash_time_delta', param_ranges['min_hash_time_delta'][0], param_ranges['min_hash_time_delta'][1]),
        'max_hash_time_delta': trial.suggest_int('max_hash_time_delta', param_ranges['max_hash_time_delta'][0], param_ranges['max_hash_time_delta'][1]),
        'fingerprint_reduction': trial.suggest_categorical('fingerprint_reduction', param_ranges['fingerprint_reduction']),
        'peak_sort': trial.suggest_categorical('peak_sort', param_ranges['peak_sort'])
    }
    average_score, _ = evaluate_params_on_dataset(song_clip_pairs, params)
    print(f"Trial {trial.number}: average_score={average_score:.4f}, params={params}")
    return average_score

## Run Bayesian Optimization

Execute the Bayesian optimization process. This can be computationally intensive depending on the dataset size and number of trials.

In [None]:
# Run Bayesian optimization on the dataset
if song_clip_pairs: # Only run optimization if there are pairs
    print(f"Starting Bayesian Optimization with {len(song_clip_pairs)} song-clip pairs and {param_ranges} parameter ranges.")
    study_dataset = optuna.create_study(direction='maximize')
    study_dataset.optimize(objective_dataset, n_trials=500) # Increased n_trials to 500

    # Collect results for the best trial on the dataset
    best_params_dataset = study_dataset.best_trial.params
    average_score, all_trial_results = evaluate_params_on_dataset(song_clip_pairs, best_params_dataset)

    # Create DataFrame from all trial results for detailed analysis
    df_all_trials = pd.DataFrame(all_trial_results)

    # Print the best parameters and average score
    print("\nBest Configuration (Dataset Evaluation):")
    print(best_params_dataset)
    print(f"Average Score with Best Configuration: {average_score:.4f}")

    # Display results table for the best trial's evaluation
    print("\nDetailed Results for Best Configuration on Dataset:")
    display(df_all_trials[['song_filename', 'clip_filename', 'match', 'hashes_matched', 'input_confidence', 'db_confidence', 'score', 'reason']])

else:
    print("\nNo song-clip pairs found. Please add audio files to ./audio/songs and ./audio/clips.")

## Analyze Optimization Results

Analyze the results of the optimization study, including the best performing parameters and visualizations.

In [None]:
import optuna.visualization as optuna_vis

# Access the best trial's parameters and value
if 'study_dataset' in locals() and study_dataset.trials:
    print("\nBest trial:")
    print(f"  Value (Score): {study_dataset.best_trial.value:.4f}")
    print("  Params:")
    for key, value in study_dataset.best_trial.params.items():
        print(f"    {key}: {value}")

    # Generate and display visualization plots
    print("\nGenerating optimization history plot...")
    fig_history = optuna_vis.plot_optimization_history(study_dataset)
    fig_history.show()

    print("\nGenerating parallel coordinate plot...")
    fig_parallel = optuna_vis.plot_parallel_coordinate(study_dataset)
    fig_parallel.show()

    try:
        print("\nGenerating parameter importance plot...")
        fig_importance = optuna_vis.plot_param_importances(study_dataset)
        fig_importance.show()
    except RuntimeError as e:
        print(f"Could not generate parameter importance plot: {e}. This can happen if there is zero variance in the trial results.")

else:
    print("Optimization study was not run or contains no trials. Please ensure audio files are in place and the optimization step completed successfully.")