# setup
libraries and functions

In [None]:
!pip install community -q
!pip install pretty_midi pydub PyWavelets -q

In [None]:
######### necessary libraries #########
import pretty_midi
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
from pydub import AudioSegment
import pywt
import scipy.signal
from scipy.signal import correlate
from scipy.fftpack import fft
from scipy.io import wavfile
from scipy.interpolate import interp1d
from matplotlib.patches import Patch
import networkx as nx
import numpy as np
from matplotlib.lines import Line2D
import matplotlib.cm as cm
import matplotlib.colors as mcolors

## data processing

In [None]:
######### data cleaning/preprocessing functions #########
def combine_mp3(path, common_file_name, num_files):
    '''
    Combine multiple MP3 files into a single MP3 file.
    
    Parameters:
        path (str): Path to the directory containing the MP3 files
        common_file_name (str): Prefix of the MP3 files (e.g., "song" for "song1.mp3", "song2.mp3", etc.)
        num_files (int): Number of files to combine
    
    Returns:
        None: The function saves the combined file to disk as "combined_{common_file_name}.mp3"
    '''
    # Initialize an empty list to store the file paths
    mp3_files = []

    for i in range(1, num_files+1):
        mp3_file = f"{path}/{common_file_name}{i}.mp3"
        print(f"--------Appending {common_file_name}{i} to list of mp3 files--------")
        mp3_files.append(mp3_file)
    print("--------File appending complete--------")

    print("--------Loading first file--------")
    combined_audio = AudioSegment.from_mp3(mp3_files[0])

    print(".\n.\n.\n")
    # Append each additional file
    for mp3_file in mp3_files[1:]:
        print(f"--------Adding {mp3_file} to the combined audio segment--------")
        audio_segment = AudioSegment.from_mp3(mp3_file)
        combined_audio += audio_segment  

    # Export the concatenated audio to a new file
    combined_audio.export(f"{path}/combined_{common_file_name}.mp3", format="mp3")

    print(f"Files concatenated successfully into combined_{common_file_name}.mp3")

def mp3_to_wav(mp3_filename, wav_filename):
    '''
    Convert an MP3 file to WAV format.
    
    Parameters:
        mp3_filename (str): Path to the input MP3 file
        wav_filename (str): Path for the output WAV file
    
    Returns:
        None: The function saves the WAV file to disk
    '''
    # Load MP3 file using pydub
    audio = AudioSegment.from_mp3(mp3_filename)
    
    # Export audio as WAV
    audio.export(wav_filename, format="wav")
    print(f"WAV file saved as {wav_filename}")


## data analyses

### midi data

In [None]:
######### all functions for midi note number analysis #########
def get_midi_notes_over_time(midi_data):
    '''
    Get MIDI notes (pitch) over time.
    
    Parameters:
        midi_data (PrettyMIDI): PrettyMIDI object containing the MIDI data
    
    Returns:
        tuple: A tuple containing:
            - start_times (list): List of note start times
            - note_numbers (list): List of MIDI note numbers
    '''
    # Lists to hold start times and MIDI note numbers
    start_times = []
    note_numbers = []

    # Extract the notes and their start times
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            start_times.append(note.start)
            note_numbers.append(note.pitch)

    return start_times, note_numbers

def get_onset_times(midi_data):
    '''
    Extract onset times from a MIDI file.
    
    Parameters:
        midi_data (PrettyMIDI): PrettyMIDI object containing the MIDI data
    
    Returns:
        onset_times (list): List of onset times (in seconds)
    '''
    # Collect all onset times
    onset_times = []
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            onset_times.append(note.start)

    # Sort onset times
    onset_times.sort()
    
    return onset_times

def calculate_iois(midi_data):
    '''
    Calculate inter-onset intervals (IOIs) from onset times.
    
    Parameters:
        midi_data (PrettyMIDI): PrettyMIDI object containing the MIDI data
    
    Returns:
        iois (np array): Array of inter-onset intervals
    '''
    onset_times = get_onset_times(midi_data)
    iois = np.diff(onset_times)  # Calculate differences between successive onset times
    return iois

def analyze_midi_distributions(midi_files, labels=None, mode='overlay'):
    '''
    Analyzes and visualizes MIDI note distributions for multiple MIDI files.
    
    Parameters:
        midi_files (list): List of PrettyMIDI objects already loaded
        labels (list, optional): List of labels for each file (defaults to indices if None)
        mode (str): Visualization mode - 'individual' for separate plots, 'overlay' for a combined plot
    
    Returns:
        results (dict): Dictionary with indices as keys and (note_counts, stats) as values
              where stats is a dictionary containing mean, min, max, variance, and std
    '''
    from collections import Counter
    import numpy as np
    import matplotlib.pyplot as plt
    
    # Set default labels if not provided
    if labels is None:
        labels = [f"MIDI {i+1}" for i in range(len(midi_files))]
    
    # Dictionary to store results
    results = {}
    
    # If overlay mode, prepare a single figure
    if mode == 'overlay':
        plt.figure(figsize=(12, 6))
    
    # Process each file
    for i, (midi_file, label) in enumerate(zip(midi_files, labels)):
        try:
            # Get all notes from all instruments in the MIDI file
            notes = [note.pitch for instrument in midi_file.instruments for note in instrument.notes]
            
            # Count note occurrences
            note_counts = Counter(notes)
            
            # Sort the notes for consistent plotting
            sorted_notes = sorted(note_counts.items())
            keys, values = zip(*sorted_notes) if sorted_notes else ([], [])
            
            # Compute descriptive statistics
            if notes:
                stats = {
                    'mean': np.mean(notes),
                    'min': np.min(notes),
                    'max': np.max(notes),
                    'variance': np.var(notes),
                    'std': np.std(notes)
                }
            else:
                stats = {
                    'mean': 0,
                    'min': 0,
                    'max': 0,
                    'variance': 0,
                    'std': 0
                }
            
            # Store results
            results[i] = (note_counts, stats)
            
            # Print descriptive statistics
            print(f"\nDescriptive Statistics for {label}:")
            print(f"Mean Note: {stats['mean']:.2f}")
            print(f"Min Note: {stats['min']}")
            print(f"Max Note: {stats['max']}")
            print(f"Variance: {stats['variance']:.2f}")
            print(f"Standard Deviation: {stats['std']:.2f}")
            
            # Visualization
            if mode == 'individual':
                # Create a separate plot for each file
                plt.figure(figsize=(10, 5))
                plt.bar(keys, values, color='blue', alpha=0.7)
                plt.axvline(stats['mean'], color='red', linestyle='dotted', linewidth=2, 
                           label=f'Mean={stats["mean"]:.2f}')
                plt.xlabel('MIDI Note Number')
                plt.ylabel('Frequency')
                plt.title(f"Histogram of {label} MIDI Note Frequencies")
                plt.legend()
                plt.show()
            elif mode == 'overlay':
                # Add to the overlay plot
                plt.bar(keys, values, alpha=0.5, label=f"{label} (Mean: {stats['mean']:.2f})")
        
        except Exception as e:
            print(f"Error processing MIDI {i+1}: {e}")
    
    # Show the overlay plot if needed
    if mode == 'overlay' and results:
        plt.xlabel('MIDI Note Number')
        plt.ylabel('Frequency')
        plt.title("Distribution of MIDI Notes")
        plt.legend()
        plt.show()
    
    return results

#### f-test ####
def perform_f_test(var1, var2, n1, n2, alpha=0.05):
    """
    Performs an F-test to determine if two variances are significantly different.
    
    Parameters:
        var1 (float): Variance of the first sample
        var2 (float): Variance of the second sample
        n1 (int): Size of the first sample
        n2 (int): Size of the second sample
        alpha (float, optional): Significance level, default is 0.05
        
    Returns:
        dict: Dictionary containing F-statistic, p-value, and test conclusion
    """
    # Calculate F statistic (ratio of variances)
    # Convention is to have the larger variance in the numerator
    if var1 > var2:
        f_statistic = var1 / var2
        df1, df2 = n1 - 1, n2 - 1
    else:
        f_statistic = var2 / var1
        df1, df2 = n2 - 1, n1 - 1
    
    # Calculate p-value (one-tailed)
    p_value = 1 - stats.f.cdf(f_statistic, df1, df2)
    
    # Two-tailed test (multiply by 2)
    p_value_two_tailed = 2 * p_value
    
    # Test conclusion
    conclusion = "Reject null hypothesis: variances are significantly different" if p_value_two_tailed < alpha else "Fail to reject null hypothesis: variances are not significantly different"
    
    return {
        "F_statistic": f_statistic,
        "p_value": p_value_two_tailed,
        "df1": df1,
        "df2": df2,
        "conclusion": conclusion
    }

def compare_variances_to_head(moaning_distributions_artists):
    """
    Compares variances of artist distributions to the head distribution using F-tests.
    
    Parameters:
        moaning_distributions_artists (list): List of distribution data for each artist
    
    Returns:
        None: Results are printed to standard output
    """
    # Extract the head distribution from any artist (same for all)
    head_distribution = moaning_distributions_artists[0][1][0]  # Using Lee's head as reference
    n_head = sum(head_distribution.values())
    head_variance = moaning_distributions_artists[0][1][1]['variance']  # Variance of head

    # Loop through all artist distributions
    artist_names = ['lee', 'fred', 'art', 'roy', 'ter']
    for i, distributions in enumerate(moaning_distributions_artists):
        # Extract the artist's first distribution (channel 0)
        artist_distribution = distributions[0][0]  # Channel 0
        n_artist = sum(artist_distribution.values())
        artist_variance = distributions[0][1]['variance']

        # Perform F-test
        result = perform_f_test(artist_variance, head_variance, n_artist, n_head)

        # Print results
        artist_name = artist_names[i]
        print(f"F-test for equality of variances between {artist_name.capitalize()} and Head")
        print(f"F-statistic: {result['F_statistic']:.4f}")
        print(f"p-value: {result['p_value']:.20f}")
        print(f"Conclusion: {result['conclusion']}")
        print('-' * 50)

#### kullback–Llibler divergence ####
def calculate_kl_divergence(p, q):
    """
    Calculate the Kullback-Leibler divergence between two distributions.

    Parameters:
        p (Counter or dict): The first distribution (reference).
        q (Counter or dict): The second distribution.

    Returns:
        kl_div (float): KL divergence value.
    """
    # Get all unique notes from both distributions
    all_notes = sorted(set(list(p.keys()) + list(q.keys())))
    
    # Create normalized probability distributions
    p_sum = sum(p.values())
    q_sum = sum(q.values())
    
    # Add a small epsilon to avoid division by zero or log(0)
    epsilon = 1e-10
    
    p_dist = np.array([p.get(note, 0) / p_sum for note in all_notes])
    q_dist = np.array([q.get(note, 0) / q_sum for note in all_notes])
    
    # Add epsilon to avoid zeros
    p_dist = p_dist + epsilon
    q_dist = q_dist + epsilon
    
    # Renormalize
    p_dist = p_dist / np.sum(p_dist)
    q_dist = q_dist / np.sum(q_dist)
    
    # Calculate KL divergence
    kl_div = np.sum(p_dist * np.log(p_dist / q_dist))
    
    return kl_div

def analyze_and_visualize_divergence(distributions_data, artist_names, title=None):
    """
    Analyze and visualize the KL divergence between artist solos and a head melody.

    Parameters:
        distributions_data (dict): Dictionary with artist indices as keys and (Counter, ...) tuples as values.
        artist_names (list): List of artist names corresponding to indices in distributions_data.
        title (str): Optional title for the visualization (default: None).

    Returns:
        None: Displays a bar chart of KL divergence values.
    """
    # Artist names (maps to keys in the distributions_data)
    num_artists = len(artist_names)
    artist_indices = list(range(num_artists))
    head_index = num_artists  # The head melody
    
    # Extract distributions
    head_distribution = distributions_data[head_index][0]  # First element is the Counter
    artist_distributions = [distributions_data[i][0] for i in artist_indices]
    
    # Calculate head divergence
    n_artists = len(artist_names)
    head_divergence = np.zeros(n_artists)
    
    # Calculate artist-to-head divergence
    for i, artist_idx in enumerate(artist_indices):
        head_divergence[i] = calculate_kl_divergence(artist_distributions[i], head_distribution)
    
    # Set larger font sizes
    plt.rcParams.update({
        'font.size': 14,
        'axes.titlesize': 18,
        'axes.labelsize': 16,
        'xtick.labelsize': 14,
        'ytick.labelsize': 14
    })
    
    # Create visualizations
    fig, ax1 = plt.subplots(figsize=(14, 7))  # Slightly taller figure to accommodate larger fonts

    # 1. Bar chart of divergence from head
    metric_name = "KL Divergence"
    bar_colors = plt.cm.viridis(np.linspace(0, 0.8, n_artists))
    ax1.bar(artist_names, head_divergence, color=bar_colors)
    
    # Set title with optional piece name
    if title:
        ax1.set_title(f'{metric_name} — {title}', pad=15)
    else:
        ax1.set_title(f'{metric_name} Between Artists and Head Melody', pad=15)
        
    ax1.set_ylabel(metric_name)
    ax1.set_xlabel('Artists')
    
    # Need to call this after creating the bars but before rotation
    ax1.set_xticks(range(len(artist_names)))
    ax1.set_xticklabels(artist_names, rotation=45, ha='right')
    
    # Add values on top of bars with larger font
    for i, v in enumerate(head_divergence):
        ax1.text(i, v + max(head_divergence)*0.05, f'{v:.3f}', ha='center', fontsize=14)
    
    # Add grid for better readability
    ax1.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.show()


### network

In [None]:
######### all functions for network analysis #########
def midi_to_pitch_class(midi_number):
    """
    Convert MIDI note number to pitch class name (ignoring octave).
    
    Parameters:
        midi_number (int): MIDI note number (0-127)
        
    Returns:
        pitch_class (str): Pitch class name (C, C#, D, etc.)
    """
    pitch_classes = ['C', 'Db', 'D', 'Eb', 'E', 'F', 'F#', 'G', 'Ab', 'A', 'Bb', 'B']
    return pitch_classes[midi_number % 12]

#### just all pitches ####
def get_prob_transitions(midi_data):
    """
    Calculate transition probabilities for pitches and inter-onset intervals from a MIDI file.
    
    Parameters:
        midi_data (PrettyMIDI): PrettyMIDI object containing the MIDI data
    
    Returns:
        transition_probs (dict): Dictionary with pitch transition probabilities
        iois_probs (dict): Dictionary with IOI probabilities
        iois (list): List of inter-onset intervals
    """
    note_times, pitches = get_midi_notes_over_time(midi_data)

    # Pitch
    # P(A → B) = (Number of times B follows A) / (Total number of transitions in the piece)
    transitions = [(pitches[i], pitches[i + 1]) for i in range(len(pitches) - 1)]
    transition_counts = Counter(transitions)
    total_transitions = sum(transition_counts.values())
    transition_probs = {k: v / total_transitions for k, v in transition_counts.items()} if total_transitions > 0 else {}

    # Rhythm
    # P(IOI = t) = (Number of times IOI value t occurs) / (Total number of IOIs in the piece)
    iois = np.diff(sorted(note_times))
    iois_counts = Counter(iois)
    total_iois = sum(iois_counts.values())
    iois_probs = {k: v / total_iois for k, v in iois_counts.items()} if total_iois > 0 else {}

    return transition_probs, iois_probs, iois

def visualize_strongest_trans(transitions, 
                               plot_type='overlay',
                               top_n=5,
                               title="Strongest Pitch Transition Network",
                               labels=None,
                               edge_colors=None,
                               edge_thickness=30,
                               prob_threshold=0.01,
                               fixed_positions=None):
    """
    Filters and visualizes only the strongest connections by selecting the most active nodes.
    
    Parameters:
        transitions (list): List of transition dictionaries where each dict maps (from_pitch, to_pitch) -> probability
        plot_type (str): 'overlay' for single plot with all transitions, 'individual' for separate plots
        top_n (int): Number of highest-activity nodes to retain in the network
        title (str or list): Graph title(s). If list, must match length of transitions for 'individual' mode
        labels (list): Labels for each transition dictionary. Must match length of transitions
        edge_colors (list): Colors for edges. Must match length of transitions
        edge_thickness (int): Base multiplier for edge thickness
        prob_threshold (float): Minimum probability to display edge labels
        fixed_positions (dict): Dictionary mapping node IDs to (x,y) coordinates. If None, positions are calculated once and reused
        
    Returns:
        tuple: (Graph object(s), node positions dictionary)
    """
    import numpy as np
    
    # Validate and prepare inputs
    n_networks = len(transitions)
    
    if labels is None:
        labels = [f"Network {i+1}" for i in range(n_networks)]
    
    # Compute node activity from all transitions combined
    node_activity = {}
    for transition_dict in transitions:
        for (from_pitch, to_pitch), probability in transition_dict.items():
            node_activity[from_pitch] = node_activity.get(from_pitch, 0) + probability
            node_activity[to_pitch] = node_activity.get(to_pitch, 0) + probability

    # Select the top N most active nodes
    top_nodes = set(sorted(node_activity, key=node_activity.get, reverse=True)[:top_n])
    
    # Create new filtered transition dictionaries
    filtered_transitions = []
    for transition_dict in transitions:
        filtered_dict = {k: v for k, v in transition_dict.items() if k[0] in top_nodes and k[1] in top_nodes}
        filtered_transitions.append(filtered_dict)

    # Call the modified visualization function with the filtered data
    return visualize_all_pitch_trans(filtered_transitions, 
                                      plot_type=plot_type,
                                      title=title,
                                      labels=labels, 
                                      edge_colors=edge_colors,
                                      edge_thickness=edge_thickness,
                                      prob_threshold=prob_threshold,
                                      fixed_positions=fixed_positions)

def visualize_all_pitch_trans(transitions, 
                             plot_type='overlay',
                             title="Directed Weighted Pitch Transition Network",
                             labels=None,
                             edge_colors=None,
                             edge_thickness=30,
                             prob_threshold=0.01,
                             fixed_positions=None):
    """
    Create a network graph visualization with directed, weighted edges, probability labels,
    and nodes colored based on activity.
    
    Parameters:
        transitions (list): List of transition dictionaries where each dict maps (from_pitch, to_pitch) -> probability
        plot_type (str): 'overlay' for single plot with all transitions, 'individual' for separate plots
        title (str or list): Graph title(s). If list, must match length of transitions for 'individual' mode
        labels (list): Labels for each transition dictionary. Must match length of transitions
        edge_colors (list): Colors for edges. Must match length of transitions
        edge_thickness (int): Base multiplier for edge thickness
        prob_threshold (float): Minimum probability to display edge labels
        fixed_positions (dict): Dictionary mapping node IDs to (x,y) coordinates. If None, positions are calculated once and reused
        
    Returns:
        tuple: (Graph object(s), node positions dictionary)
    """
    # Validate and prepare inputs
    n_networks = len(transitions)
    
    if labels is None:
        labels = [f"Network {i+1}" for i in range(n_networks)]
    elif len(labels) != n_networks:
        raise ValueError(f"Number of labels ({len(labels)}) must match number of transition sets ({n_networks})")
    
    if edge_colors is None:
        # Use a colormap to automatically generate colors
        cmap = plt.cm.get_cmap('tab10')
        edge_colors = [cmap(i % 10) for i in range(n_networks)]
    elif len(edge_colors) != n_networks:
        raise ValueError(f"Number of edge colors ({len(edge_colors)}) must match number of transition sets ({n_networks})")
        
    if isinstance(title, list) and plot_type == 'individual' and len(title) != n_networks:
        raise ValueError(f"Number of titles ({len(title)}) must match number of transition sets ({n_networks})")

    def midi_to_note_name(midi_num):
        notes = ['C', 'Db', 'D', 'Eb', 'E', 'F', 'F#', 'G', 'Ab', 'A', 'Bb', 'B']
        octave = midi_num // 12 - 1
        note = notes[midi_num % 12]
        return f"{note}{octave}"
    
    # Create a combined graph with all nodes first to establish fixed positions
    G_for_positions = nx.DiGraph()
    all_nodes = set()
    
    # Collect all nodes from all transition sets
    for trans_dict in transitions:
        for from_pitch, to_pitch in trans_dict.keys():
            all_nodes.add(from_pitch)
            all_nodes.add(to_pitch)
    
    # Add all nodes to the positioning graph
    for node in all_nodes:
        G_for_positions.add_node(node)
    
    # Generate positions only once if not provided
    global_positions = fixed_positions
    if global_positions is None:
        global_positions = nx.spring_layout(G_for_positions, k=0.3, iterations=100, seed=42)
    
    def create_graph(transition_sets, current_labels, current_colors, current_title):
        # Create combined graph with all nodes
        G_combined = nx.DiGraph()
        
        # Add edges and compute node activity
        node_activity = {}
        all_transitions = {}
        
        # Combine all transitions for node activity calculation
        for trans_dict in transition_sets:
            all_transitions.update(trans_dict)
            
        for (from_pitch, to_pitch), probability in all_transitions.items():
            G_combined.add_edge(from_pitch, to_pitch, weight=probability)
            node_activity[from_pitch] = node_activity.get(from_pitch, 0) + probability
            node_activity[to_pitch] = node_activity.get(to_pitch, 0) + probability
        
        # Normalize activity levels for colormap scaling
        activity_values = np.array(list(node_activity.values()))
        norm = mcolors.Normalize(vmin=min(activity_values), vmax=max(activity_values))
        cmap = cm.get_cmap('viridis')  # Node color map based on activity
        node_colors = {node: cmap(norm(activity)) for node, activity in node_activity.items()}
        
        plt.figure(figsize=(14, 12))
        
        # Use the pre-computed fixed positions
        pos = global_positions
        
        # Draw nodes
        node_labels = {n: midi_to_note_name(n) for n in G_combined.nodes()}
        nx.draw_networkx_nodes(G_combined, pos, 
                              node_size=700, 
                              node_color=[node_colors[n] for n in G_combined.nodes()],
                              edgecolors='black',
                              alpha=0.9)
        nx.draw_networkx_labels(G_combined, pos, labels=node_labels, font_size=10, font_weight="bold")
        
        # Draw edges with different colors
        legend_handles = []
        
        for i, transition_dict in enumerate(transition_sets):
            edges = [(u, v, w) for (u, v), w in transition_dict.items()]
            if edges:
                # Calculate scaled edge widths based on weight
                edge_widths = [max(w * edge_thickness, 1) for _, _, w in edges]
                
                # Draw edges with enhanced arrow style
                nx.draw_networkx_edges(G_combined, pos,
                                      edgelist=[(u, v) for u, v, _ in edges],
                                      width=edge_widths,
                                      edge_color=current_colors[i],
                                      alpha=0.7,
                                      arrows=True, 
                                      arrowsize=20,
                                      arrowstyle='-|>', 
                                      connectionstyle='arc3,rad=0.1')
                
                # Add to legend
                legend_handles.append(Line2D([0], [0], color=current_colors[i], lw=4, label=current_labels[i]))
                
        # Annotate edges with probability labels (only for significant transitions)
        for (u, v, w) in G_combined.edges(data='weight'):
            if w >= prob_threshold:
                if u == v:  # Self-loop: position label above node
                    # Offset the label position above the node
                    xy = (pos[u][0], pos[u][1] + 0.06)  # Adjust the y-coordinate to place the label above
                else:
                    # Normal edge: position label along the edge
                    xy = (pos[u][0] * 0.7 + pos[v][0] * 0.3, pos[u][1] * 0.7 + pos[v][1] * 0.3)
        
                plt.annotate(f"{w:.4f}", 
                            xy=xy,
                            fontsize=9,
                            weight='bold',
                            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", alpha=0.8))
        
        # Legend with thicker lines
        plt.legend(handles=legend_handles, loc='upper right', fontsize=12)
        
        # Display node activity using a color bar
        sm = cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=plt.gca(), shrink=0.7, aspect=30)
        cbar.set_label('Node Activity Level', fontsize=12)
        
        plt.title(current_title, fontsize=16)
        plt.axis('off')
        plt.tight_layout()
        
        return G_combined
    
    # Create plots based on plot_type
    if plot_type == 'overlay':
        # Create a single plot with all transitions
        current_title = title if isinstance(title, str) else "Directed Weighted Pitch Transition Network"
        G = create_graph(transitions, labels, edge_colors, current_title)
        plt.show()
        return G, global_positions
    
    elif plot_type == 'individual':
        # Create separate plots for each transition set
        graphs = []
        titles = title if isinstance(title, list) else [f"{labels[i]} Transitions" for i in range(n_networks)]
        
        for i, trans in enumerate(transitions):
            current_title = titles[i]
            G = create_graph([trans], [labels[i]], [edge_colors[i]], current_title)
            graphs.append(G)
            plt.show()
        
        return graphs, global_positions
    
    else:
        raise ValueError("plot_type must be either 'overlay' or 'individual'")

#### only pitches, no octaves ####
def extract_transitions_from_midi(midi_file_path):
    """
    Extract chromatic and interval transitions from a MIDI file.
    
    Parameters:
        midi_file_path (str): Path to the MIDI file
        
    Returns:
        tuple: Collection of transition dictionaries in the order:
               (chromatic_transitions, diatonic_transitions, third_transitions,
                fourth_transitions, fifth_transitions, sixth_transitions,
                seventh_transitions, octave_transitions)
    """
    import pretty_midi
    from collections import defaultdict
    
    # the interval categories
    interval_categories = {
        'chromatic': (1,),      # Chromatic (1 semitone)
        'diatonic': (2,),       # Whole tone/diatonic (2 semitones)
        'third': (3, 4),        # Minor third (3 semitones) and major third (4 semitones)
        'fourth': (5,),         # Perfect fourth (5 semitones)
        'fifth': (7,),          # Perfect fifth (7 semitones)
        'sixth': (8, 9),        # Minor sixth (8 semitones) and major sixth (9 semitones)
        'seventh': (10, 11),    # Minor seventh (10 semitones) and major seventh (11 semitones)
        'octave': (12,)         # Octave (12 semitones)
    }
    
    # Load the MIDI file
    midi_data = pretty_midi.PrettyMIDI(midi_file_path)
    
    # Get all notes from all instruments
    all_notes = []
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            all_notes.append({
                'note': note.pitch,
                'start_time': note.start,
                'end_time': note.end,
                'velocity': note.velocity,
                'instrument': instrument.program,
                'is_drum': instrument.is_drum
            })
    
    # Group notes by instrument (each instrument might represent a different voice)
    notes_by_instrument = defaultdict(list)
    for note in sorted(all_notes, key=lambda x: x['start_time']):
        # Create a unique identifier for each instrument
        instrument_id = (note['instrument'], note['is_drum'])
        notes_by_instrument[instrument_id].append(note)
    
    # Initialize transition dictionaries
    transitions = {
        'chromatic': defaultdict(float),
        'diatonic': defaultdict(float),
        'third': defaultdict(float),
        'fourth': defaultdict(float),
        'fifth': defaultdict(float),
        'sixth': defaultdict(float),
        'seventh': defaultdict(float),
        'octave': defaultdict(float)
    }
    
    # Minimum sequence length for detecting patterns
    min_seq_length = 3
    
    # Process each instrument's notes
    for instrument_id, instrument_notes in notes_by_instrument.items():
        # Skip drum tracks for melodic analysis
        if instrument_id[1]:  # is_drum
            continue
            
        # Sort by start time
        instrument_notes.sort(key=lambda x: x['start_time'])
        
        # Analyze transitions
        for i in range(len(instrument_notes) - 1):
            current_note = instrument_notes[i]['note']
            next_note = instrument_notes[i+1]['note']
            transition = (current_note, next_note)
            
            # Calculate pitch difference
            pitch_diff = abs(next_note - current_note)
            
            # Classify the interval
            for category, intervals in interval_categories.items():
                if pitch_diff in intervals:
                    transitions[category][transition] += 1
            
            # Analyze sequences of similar intervals
            direction = 1 if next_note > current_note else -1 if next_note < current_note else 0
            
            # Check for sequences of similar intervals
            for category, intervals in interval_categories.items():
                if pitch_diff in intervals:
                    # Check for sequence
                    seq_length = 2  # Start with current two notes
                    for j in range(i+1, len(instrument_notes)-1):
                        next_diff = abs(instrument_notes[j+1]['note'] - instrument_notes[j]['note'])
                        next_direction = 1 if instrument_notes[j+1]['note'] > instrument_notes[j]['note'] else -1 if instrument_notes[j+1]['note'] < instrument_notes[j]['note'] else 0
                        
                        # Check if this is the same interval type and direction
                        if next_diff in intervals and next_direction == direction:
                            seq_length += 1
                        else:
                            break
                    
                    # If we found a sequence of at least min_seq_length notes
                    if seq_length >= min_seq_length:
                        # Add all transitions in this sequence 
                        for k in range(i, i+seq_length-1):
                            trans = (instrument_notes[k]['note'], instrument_notes[k+1]['note'])
                            # Add extra weight to transitions that are part of longer sequences
                            transitions[category][trans] += 0.5  # Boost the weight
    
    # Normalize all transition dictionaries
    for category, trans_dict in transitions.items():
        total = sum(trans_dict.values())
        if total > 0:
            transitions[category] = {k: v/total for k, v in trans_dict.items()}
        else:
            transitions[category] = dict()
    
    # Return the transition dictionaries in the requested order
    return (
        transitions['chromatic'],
        transitions['diatonic'],
        transitions['third'],
        transitions['fourth'],
        transitions['fifth'],
        transitions['sixth'],
        transitions['seventh'],
        transitions['octave']
    )

def visualize_pitch_trans_mod(transitions, 
                                   plot_type='overlay',
                                   title="Pitch Transition Network",
                                   labels=None,
                                   edge_colors=None,
                                   save_path=None):
    """
    Create a network graph visualization with notes grouped by pitch class (ignoring octave).
    
    Parameters:
        transitions (list): List of transition dictionaries where each dict maps (from_pitch, to_pitch) -> probability
        plot_type (str): 'overlay' for single plot with all transitions, 'individual' for separate plots
        title (str or list): Title(s) for the graph(s)
        labels (list): Labels for each transition dictionary in the legend
        edge_colors (list): Colors to use for each transition set
        save_path (str or list): Path(s) to save the figure(s). If None, no saving is performed
    
    Returns:
        graph: Graph objects and positional data for the visualization
    """
    import networkx as nx
    import numpy as np
    import matplotlib.pyplot as plt
    
    # Validate inputs
    n_networks = len(transitions)
    
    if labels is None:
        labels = [f"Network {i+1}" for i in range(n_networks)]
    elif len(labels) != n_networks:
        raise ValueError(f"Number of labels ({len(labels)}) must match number of transition sets ({n_networks})")
    
    if edge_colors is None:
        # Default color palette
        cmap = plt.cm.get_cmap('tab10')
        edge_colors = [cmap(i % 10) for i in range(n_networks)]
    elif len(edge_colors) != n_networks:
        raise ValueError(f"Number of edge colors ({len(edge_colors)}) must match number of transition sets ({n_networks})")
    
    if isinstance(title, list) and plot_type == 'individual' and len(title) != n_networks:
        raise ValueError(f"Number of titles ({len(title)}) must match number of transition sets ({n_networks})")
    
    if save_path is not None and isinstance(save_path, list) and plot_type == 'individual' and len(save_path) != n_networks:
        raise ValueError(f"Number of save paths ({len(save_path)}) must match number of transition sets ({n_networks})")
    
    def normalize_transitions(transitions):
        """Normalize transition probabilities for each source pitch class"""
        # Group transitions by source pitch class
        source_groups = {}
        for (src, dst), prob in transitions.items():
            if src not in source_groups:
                source_groups[src] = 0
            source_groups[src] += prob
            
        # Normalize by dividing each transition by the total for its source
        normalized = {}
        for (src, dst), prob in transitions.items():
            if source_groups[src] > 0:  # Avoid division by zero
                normalized[(src, dst)] = prob / source_groups[src]
            else:
                normalized[(src, dst)] = 0
        return normalized
    
    def process_transitions(transition_dict):
        """Convert MIDI note transitions to pitch class transitions and normalize"""
        pc_transitions = {}
        
        for (from_pitch, to_pitch), probability in transition_dict.items():
            from_pc = midi_to_pitch_class(from_pitch)
            to_pc = midi_to_pitch_class(to_pitch)
            
            # Add or update the pitch class transition probability
            pc_key = (from_pc, to_pc)
            if pc_key in pc_transitions:
                pc_transitions[pc_key] += probability
            else:
                pc_transitions[pc_key] = probability
        
        # Normalize to make them probabilities again
        return normalize_transitions(pc_transitions)
    
    # Process all transition dictionaries
    all_pc_transitions = [process_transitions(trans) for trans in transitions]
    
    # Create layout - arrange pitch classes in a circle of fifths
    circle_of_fifths = ['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'Db', 'Ab', 'Eb', 'Bb', 'F']
    pos = {}
    n_nodes = len(circle_of_fifths)
    for i, pc in enumerate(circle_of_fifths):
        angle = 2 * np.pi * i / n_nodes
        pos[pc] = (np.cos(angle), np.sin(angle))
    
    # Define note colors for consistent appearance
    note_colors = {
        'C': '#FF0000',   # Red
        'Db': '#FF7F00',  # Orange
        'D': '#FFFF00',   # Yellow
        'Eb': '#7FFF00',  # Chartreuse
        'E': '#00FF00',   # Green
        'F': '#00FF7F',   # Spring Green
        'F#': '#00FFFF',  # Cyan
        'G': '#007FFF',   # Azure
        'Ab': '#0000FF',  # Blue
        'A': '#7F00FF',   # Violet
        'Bb': '#FF00FF',  # Magenta
        'B': '#FF007F',   # Rose
    }
    
    def create_plot(pc_transitions_list, current_labels, current_colors, current_title, current_save_path=None):
        """Create a plot for the given transitions"""
        # Extract all unique pitch classes
        pitch_classes = set()
        for pc_trans in pc_transitions_list:
            for from_pc, to_pc in pc_trans.keys():
                pitch_classes.add(from_pc)
                pitch_classes.add(to_pc)
        
        # Create graphs
        graphs = []
        for pc_trans in pc_transitions_list:
            G = nx.DiGraph()
            
            # Add nodes
            for pc in pitch_classes:
                G.add_node(pc)
            
            # Add edges
            for (from_pc, to_pc), probability in pc_trans.items():
                if probability > 0:
                    G.add_edge(from_pc, to_pc, weight=probability)
            
            graphs.append(G)
        
        plt.figure(figsize=(14, 14))
        
        # Draw nodes only once
        node_color_list = [note_colors.get(node, '#AAAAAA') for node in pitch_classes]
        node_sizes = [1500 for _ in pitch_classes]
        
        # Add any missing positions for pitch classes not in circle of fifths
        for pc in pitch_classes:
            if pc not in pos:
                pos[pc] = (0, 0)  # Place in center as fallback
        
        nodes = nx.draw_networkx_nodes(graphs[0], pos, 
                                      nodelist=list(pitch_classes),
                                      node_size=node_sizes, 
                                      node_color=node_color_list, 
                                      edgecolors="black", 
                                      linewidths=1.5, 
                                      alpha=0.9)
        
        # Calculate angle offsets to avoid edge overlap in overlay mode
        rad_offset = 0.15 / len(pc_transitions_list)
        rad_start = -0.15 if len(pc_transitions_list) > 1 else 0
        
        # Draw edges for each transition set
        legend_elements = []
        threshold = 0.1  # Threshold for edge labels
        
        for i, (G, color) in enumerate(zip(graphs, current_colors)):
            # Alternate the curve direction for better visualization when overlaid
            rad = rad_start + i * rad_offset * 2
            
            # Draw edges with arrows
            edge_weights = [d['weight'] * 4 + 0.5 for _, _, d in G.edges(data=True)]
            if edge_weights:  # Check if there are any edges
                edges = nx.draw_networkx_edges(
                    G, pos, width=edge_weights, 
                    alpha=0.7, edge_color=color, 
                    connectionstyle=f'arc3,rad={rad}',  # Curved edges
                    arrowstyle='-|>',  # Arrow style
                    arrowsize=20       # Arrow size
                )
            
            # Create edge labels
            edge_labels = {(u, v): f"{d['weight']:.2f}" 
                          for u, v, d in G.edges(data=True) 
                          if d['weight'] > threshold}
            
            # Position labels with slight offset based on the graph
            pos_labels = {k: (v[0] * (1 + 0.05 * i), v[1] * (1 + 0.05 * i)) for k, v in pos.items()}
            
            # Only draw edge labels if there are any meeting the threshold
            if edge_labels:
                nx.draw_networkx_edge_labels(
                    G, pos_labels, edge_labels=edge_labels, 
                    font_size=12, font_color=color,
                    label_pos=0.5 + 0.1 * i  # Adjust position along edge
                )
            
            # Add to legend
            legend_elements.append(
                plt.Line2D([0], [0], color=color, lw=4, marker='>', markersize=10, label=current_labels[i])
            )
        
        # Add node labels
        nx.draw_networkx_labels(graphs[0], pos, font_size=16, font_weight="bold", font_color="black")
        
        # Add legend
        plt.legend(handles=legend_elements, loc='upper right', fontsize=14)
        
        plt.title(current_title, fontsize=18)
        plt.axis('off')
        
        plt.tight_layout()
        
        if current_save_path:
            plt.savefig(current_save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
        
        return graphs, pos, pc_transitions_list
    
    # Create plots based on plot_type
    if plot_type == 'overlay':
        # Single plot with all transitions
        current_title = title if isinstance(title, str) else "Pitch Transition Network"
        current_save = save_path if isinstance(save_path, str) or save_path is None else save_path[0]
        result = create_plot(all_pc_transitions, labels, edge_colors, current_title, current_save)
        return result
    
    elif plot_type == 'individual':
        # Individual plots for each transition set
        results = []
        titles = title if isinstance(title, list) else [f"{title}\n{labels[i]} Transitions" for i in range(n_networks)]
        save_paths = save_path if isinstance(save_path, list) or save_path is None else [save_path] * n_networks
        
        for i, pc_trans in enumerate(all_pc_transitions):
            current_title = titles[i]
            current_save = None if save_paths is None else save_paths[i]
            result = create_plot([pc_trans], [labels[i]], [edge_colors[i]], current_title, current_save)
            results.append(result)
        
        return results
    
    else:
        raise ValueError("plot_type must be either 'overlay' or 'individual'")

def visualize_strongest_trans_mod(transitions_tuple, 
                                  threshold=0.15, 
                                  plot_type='overlay',
                                  title=None,
                                  labels=None):
    """
    Create a network graph showing only the strongest octave-modulo transitions above a threshold.
    
    Parameters:
        transitions_tuple (tuple): Tuple of dictionaries in the order:
                                  (chromatic, diatonic, third, fourth, fifth, sixth, seventh, octave)
        threshold (float): Minimum probability to include (0.0-1.0)
        plot_type (str): Either 'overlay' to show all transitions on one plot or 'individual' for separate plots
        title (str or None): Custom title for the plot. If None, a default title will be generated
        labels (list or None): Custom labels for each transition type. If None, default labels will be used
    Returns:
        None: Displays the plot(s) and returns nothing
    """
    # Default labels if none provided
    default_labels = ['Chromatic', 'Diatonic', 'Third', 'Fourth', 'Fifth', 'Sixth', 'Seventh', 'Octave']
    
    if labels is None:
        labels = default_labels
    
    # Convert tuple to dictionary with labels
    transitions_dict = {}
    for i, trans_dict in enumerate(transitions_tuple):
        if i < len(labels):  # Ensure we have a label for this transition
            transitions_dict[labels[i]] = trans_dict
    
    # Filter strong transitions for each set
    strong_transitions = {}
    for label, transition_probs in transitions_dict.items():
        strong_transitions[label] = {k: v for k, v in transition_probs.items() if v >= threshold}
    
    # Generate default title if none provided
    default_title = f"Strongest Transitions (Probability > {threshold})"
    plot_title = title if title is not None else default_title
    
    # For overlay plot - show all transitions on a single plot
    if plot_type == 'overlay':
        if any(strong_transitions.values()):  # Only if we have some strong transitions
            visualize_pitch_trans_mod(
                list(strong_transitions.values()),  # Pass list of transition dictionaries
                plot_type='overlay',
                title=plot_title,
                labels=list(strong_transitions.keys())  # Pass the labels
            )
        else:
            print(f"No transitions with probability >= {threshold} found")
    
    # For individual plots - create a separate plot for each transition set
    elif plot_type == 'individual':
        for label, transition_probs in strong_transitions.items():
            if transition_probs:  # Only create plot if there are strong transitions
                # For individual plots, append the label to the title if not custom
                indiv_title = f"{label}: {plot_title}" if title is not None else f"{label}: Strongest Transitions (Probability > {threshold})"
                visualize_pitch_trans_mod(
                    [transition_probs],  # Pass as a list with single dictionary
                    plot_type='overlay',  # Use overlay for a single plot
                    title=indiv_title,
                    labels=[label]  # Pass the single label
                )
            else:
                print(f"No transitions with probability >= {threshold} found for {label}")
    else:
        print(f"Invalid plot_type: {plot_type}. Please use 'overlay' or 'individual'")

def count_interval_edges(data_transitions, probability_threshold=0.0):
    """
    Count the number of edges for each interval type based on the transitions data.
    
    Parameters:
    data_transitions (dict): The transitions data used by create_static_pitch_class_graph2
    probability_threshold (float): Minimum probability for transitions to be counted
    
    Returns:
    counts (dict): A dictionary with counts of edges for each interval type
    """
    # We need to use the same function to create the modulo transitions
    # This is referenced in the original code but not defined there
    # Let's assume it's available and use it
    chromatic_pc, diatonic_pc, thirds_pc, fourths_pc, fifths_pc, sixths_pc, sevenths_pc, octave_pc = data_transitions
        
    # Define transition types and their corresponding dictionaries
    transition_types = {
        'Chromatic': chromatic_pc,
        'Diatonic': diatonic_pc,
        'Thirds': thirds_pc,
        'Fourths': fourths_pc,
        'Fifths': fifths_pc,
        'Sixths': sixths_pc,
        'Sevenths': sevenths_pc,
        'Octave Jump': octave_pc
    }
    # Initialize counts dictionary from transition_types keys
    counts = {transition_name: 0 for transition_name in transition_types}

    # Count all transitions in a single loop
    for transition_name, transitions_dict in transition_types.items():
        for (from_pc, to_pc), weight in transitions_dict.items():
            if weight >= probability_threshold:
                counts[transition_name] += 1
    
    return counts

def plot_interval_usage(edge_counts_list, artist_names):
    """
    Convert edge counts into a DataFrame and plot interval type usage and standard deviation per artist.

    Parameters:
        edge_counts_list (list): List of dictionaries, each containing edge counts for an artist.
        artist_names (list): List of artist names corresponding to the edge counts.

    Returns:
        None: Displays two plots - interval type usage and standard deviation per artist.
    """
    # Validate input lengths
    if len(edge_counts_list) != len(artist_names):
        raise ValueError("Length of edge_counts_list must match length of artist_names")

    # Define interval types based on the keys in the dictionaries
    interval_types = sorted(set().union(*[set(d.keys()) for d in edge_counts_list]))
    
    # Prepare data dictionary for DataFrame
    data = {'Artist': artist_names}
    for interval in interval_types:
        data[interval] = [edge_counts.get(interval, 0) for edge_counts in edge_counts_list]

    # Create DataFrame and set Artist as index
    df = pd.DataFrame(data)
    df.set_index('Artist', inplace=True)

    # Plot 1: Interval Type Usage by Artist
    ax = df.plot(kind='bar', figsize=(12, 6), width=0.8)
    plt.title('Interval Type Usage by Artist')
    plt.xlabel('Artist')
    plt.ylabel('Count')
    plt.xticks(rotation=0)
    plt.legend(title='Interval Type', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

    # Calculate standard deviation for each artist
    std_devs = df.std(axis=1).sort_values(ascending=False)

    # Plot 2: Standard Deviation of Interval Usage per Artist
    std_devs.plot(kind='bar', figsize=(10, 6))
    plt.title('Interval Type Usage Variance per Artist')
    plt.xlabel('Artist')
    plt.ylabel('Standard Deviation')
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.show()

# for interval types save labels here
labels = ['Chromatic', 'Diatonic', 'Third', 'Fourth', 'Fifth', 'Sixth', 'Seventh', 'Octave']

#### probability transition matrices heatmap, simulations not used ####
def build_transition_matrix(transition_probs):
    """
    Build a transition matrix from a dictionary of transition probabilities.

    Parameters:
        transition_probs (dict): Dictionary with (from_pitch, to_pitch) tuples as keys and probabilities as values.

    Returns:
        (P, states) (tuple):
            - P (numpy.ndarray): Transition matrix.
            - states (list): List of unique pitch values.
    """
    # Extract unique pitches
    unique_pitches = set()
    for from_pitch, to_pitch in transition_probs.keys():
        unique_pitches.add(from_pitch)
        unique_pitches.add(to_pitch)
    
    # Sort pitches for consistent ordering
    states = sorted(list(unique_pitches))
    n = len(states)
    
    # Create mapping from pitch to index
    pitch_to_index = {pitch: i for i, pitch in enumerate(states)}
    
    # Initialize transition matrix with zeros
    P = np.zeros((n, n))
    
    # Fill the transition matrix
    for (from_pitch, to_pitch), prob in transition_probs.items():
        i = pitch_to_index[from_pitch]
        j = pitch_to_index[to_pitch]
        P[i, j] = prob
    
    # Ensure rows sum to 1 (handle missing transitions)
    row_sums = P.sum(axis=1)
    for i in range(n):
        if row_sums[i] > 0:
            P[i, :] /= row_sums[i]
    
    return P, states

def compute_steady_state(P, max_iterations=1000, tolerance=1e-6):
    """
    Compute the steady state distribution of a Markov chain.

    Parameters:
        P (numpy.ndarray): Transition matrix.
        max_iterations (int): Maximum number of iterations (default: 1000).
        tolerance (float): Convergence tolerance (default: 1e-6).

    Returns:
        steady_state, iterations, converged (tuple): 
            - steady_state (array): Steady state distribution if converged, None otherwise.
            - iterations (int): Number of iterations performed.
            - converged (bool): Whether the algorithm converged.
    """
    n = P.shape[0]
    
    # Initialize with uniform distribution
    pi = np.ones(n) / n
    
    # Power iteration
    for i in range(max_iterations):
        pi_next = pi @ P
        
        # Check convergence
        if np.max(np.abs(pi_next - pi)) < tolerance:
            return pi_next, i+1, True
        
        pi = pi_next
    
    return pi, max_iterations, False

def analyze_markov_chain(transition_probs, show_heatmap=False, show_plot=False, artist_name=None):
    """
    Analyze a Markov chain defined by transition probabilities.

    Parameters:
        transition_probs (dict): Dictionary with (from_pitch, to_pitch) tuples as keys and probabilities as values.
        show_heatmap (bool): Whether to plot the transition matrix heatmap (default: False).
        show_plot (bool): Whether to plot the steady state distribution (default: False).
        artist_name (str): Name of the artist for output (default: None).

    Returns:
        P, states, steady_state (tuple):
            - P (array): Transition matrix.
            - states (list): List of states (MIDI notes).
            - steady_state (dict): Dictionary mapping pitch to steady state probability.
    """
    # Build transition matrix
    P, states = build_transition_matrix(transition_probs)
    
    print(f"\n{'=' * 20} {artist_name if artist_name else 'Artist'} {'=' * 20}")
    
    # Print a more interpretable version of the transition matrix
    print(f"States (MIDI notes): {states}")
    print(f"Matrix shape: {P.shape[0]} x {P.shape[1]}")
    
    # Either print a small sample of the matrix or visualize it
    if len(states) <= 10:
        # If matrix is small enough, print it with labeled rows and columns
        print("\nFull Transition Matrix:")
        row_format = "{:>5}" + "{:>8}" * len(states)
        print(row_format.format("", *[f"{s}" for s in states]))
        for i, row in enumerate(P):
            print(row_format.format(f"{states[i]}", *[f"{x:.3f}" for x in row]))
    else:
        # If matrix is large, print a sample from the center of the matrix
        print("\nSample of Transition Matrix (central portion):")
        center = len(states) // 2
        sample_size = min(5, len(states) // 2)
        sample_start = max(0, center - sample_size)
        sample_end = min(len(states), center + sample_size + 1)
        
        # Print column headers (MIDI notes)
        row_format = "{:>5}" + "{:>8}" * (sample_end - sample_start)
        sample_states = states[sample_start:sample_end]
        print(row_format.format("", *[f"{s}" for s in sample_states]))
        
        # Print sample rows with row headers
        for i in range(sample_start, sample_end):
            print(row_format.format(f"{states[i]}", *[f"{P[i,j]:.3f}" for j in range(sample_start, sample_end)]))
    
    # Create a heatmap visualization of the transition matrix if requested
    if show_heatmap:
        plt.figure(figsize=(10, 8))
        plt.imshow(P, cmap='Blues')
        plt.colorbar(label='Transition Probability')
        plt.title(f'Transition Matrix for {artist_name if artist_name else "Artist"}')
        
        # Add tick labels (may be too crowded for large matrices)
        if len(states) <= 20:
            plt.xticks(range(len(states)), states, rotation=90)
            plt.yticks(range(len(states)), states)
        else:
            # For larger matrices, show fewer ticks
            tick_step = len(states) // 10
            plt.xticks(range(0, len(states), tick_step), [states[i] for i in range(0, len(states), tick_step)], rotation=90)
            plt.yticks(range(0, len(states), tick_step), [states[i] for i in range(0, len(states), tick_step)])
        
        plt.xlabel('To MIDI Note')
        plt.ylabel('From MIDI Note')
        plt.tight_layout()
        plt.show()
    
    # Compute steady state distribution
    pi, iterations, converged = compute_steady_state(P)
    
    # Create result dictionary
    steady_state = {pitch: prob for pitch, prob in zip(states, pi)}
    
    # Print results
    print(f"\n{artist_name if artist_name else 'Artist'} - Markov chain analysis completed in {iterations} iterations.")
    if converged:
        print(f"{artist_name if artist_name else 'Artist'} - Steady state distribution converged.")
    else:
        print(f"{artist_name if artist_name else 'Artist'} - Warning: Steady state distribution did not converge.")
    
    # Find most likely pitch
    most_likely_pitch = states[np.argmax(pi)]
    print(f"{artist_name if artist_name else 'Artist'} - Most likely pitch in steady state: {most_likely_pitch} (MIDI note)")
    
    # Plot steady state distribution if requested
    if show_plot:
        plt.figure(figsize=(12, 6))
        plt.bar(states, pi)
        plt.xlabel('Pitch (MIDI note)')
        plt.ylabel('Steady State Probability')
        plt.title(f'Steady State Distribution for {artist_name if artist_name else "Artist"}')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.show()
    
    return P, states, steady_state

def simulate_markov_process(P, states, num_simulations=100, num_steps=1000, initial_state=None):
    """
    Simulate a Markov process multiple times and compute the average state distribution.

    Parameters:
        P (numpy.ndarray): Transition matrix.
        states (list): List of states (MIDI notes).
        num_simulations (int): Number of simulations to run (default: 100).
        num_steps (int): Number of steps in each simulation (default: 1000).
        initial_state (int): Initial state index (default: None, chosen randomly).

    Returns:
        avg_distribution (array): Average state distribution after simulations.
    """
    n = P.shape[0]
    state_counts = np.zeros(n)
    
    for _ in range(num_simulations):
        # Choose initial state (random if not specified)
        current_state = initial_state if initial_state is not None else np.random.choice(n)
        
        # Run simulation for specified number of steps
        for _ in range(num_steps):
            # Transition to next state based on transition probabilities
            current_state = np.random.choice(n, p=P[current_state])
            
            # Only count states after some warm-up period (e.g., last 20% of steps)
            if _ >= 0.8 * num_steps:
                state_counts[current_state] += 1
    
    # Normalize to get distribution
    total_counts = np.sum(state_counts)
    if total_counts > 0:
        avg_distribution = state_counts / total_counts
    else:
        avg_distribution = np.ones(n) / n  # Fallback to uniform if no counts
    
    return avg_distribution

def run_multiple_simulations(transition_probs, num_simulations=100, num_steps=1000, show_final_plot=True, artist_name=None):
    """
    Run multiple Markov process simulations and compare with theoretical steady state.

    Parameters:
        transition_probs (dict): Dictionary with (from_pitch, to_pitch) tuples as keys and probabilities as values.
        num_simulations (int): Number of simulations to run (default: 100).
        num_steps (int): Number of steps in each simulation (default: 1000).
        show_final_plot (bool): Whether to plot the final comparison (default: True).
        artist_name (str): Name of the artist for output (default: None).

    Returns:
        P, states, theoretical_dist, simulated_dist (tuple):
            - P (numpy.ndarray): Transition matrix.
            - states (list): List of states (MIDI notes).
            - theoretical_dist (dict): Theoretical steady state probabilities.
            - simulated_dist (dict): Simulated steady state probabilities.
    """
    # First analyze the Markov chain to get transition matrix and theoretical steady state
    P, states, theoretical_dist = analyze_markov_chain(transition_probs, show_heatmap=False, show_plot=False, artist_name=artist_name)
    
    # Run simulations
    print(f"\nRunning {num_simulations} simulations with {num_steps} steps each for {artist_name if artist_name else 'Artist'}...")
    avg_distribution = simulate_markov_process(P, states, num_simulations, num_steps)
    
    # Create simulated distribution dictionary
    simulated_dist = {pitch: prob for pitch, prob in zip(states, avg_distribution)}
    
    # Extract theoretical probabilities for comparison
    theoretical_probs = np.array([theoretical_dist[state] for state in states])
    
    # Calculate mean absolute error between theoretical and simulated
    mae = np.mean(np.abs(theoretical_probs - avg_distribution))
    
    print(f"{artist_name if artist_name else 'Artist'} - Mean absolute error between theoretical and simulated: {mae:.6f}")
    
    # Plot comparison if requested
    if show_final_plot:
        plt.figure(figsize=(12, 6))
        
        x = np.arange(len(states))
        width = 0.35
        
        plt.bar(x - width/2, theoretical_probs, width, label='Theoretical', alpha=0.7)
        plt.bar(x + width/2, avg_distribution, width, label='Simulated', alpha=0.7)
        
        plt.xlabel('Pitch (MIDI note)')
        plt.ylabel('Steady State Probability')
        plt.title(f'Theoretical vs Simulated Steady State for {artist_name if artist_name else "Artist"}')
        plt.xticks(x, states, rotation=90)
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.legend()
        plt.tight_layout()
        plt.show()
        
        # Also show heatmap of transition matrix
        plt.figure(figsize=(10, 8))
        plt.imshow(P, cmap='Blues')
        plt.colorbar(label='Transition Probability')
        plt.title(f'Transition Matrix for {artist_name if artist_name else "Artist"}')
        
        # Add tick labels (may be too crowded for large matrices)
        if len(states) <= 20:
            plt.xticks(range(len(states)), states, rotation=90)
            plt.yticks(range(len(states)), states)
        else:
            # For larger matrices, show fewer ticks
            tick_step = len(states) // 10
            plt.xticks(range(0, len(states), tick_step), [states[i] for i in range(0, len(states), tick_step)], rotation=90)
            plt.yticks(range(0, len(states), tick_step), [states[i] for i in range(0, len(states), tick_step)])
        
        plt.xlabel('To MIDI Note')
        plt.ylabel('From MIDI Note')
        plt.tight_layout()
        plt.show()
    
    return P, states, theoretical_dist, simulated_dist

def compare_markov_chains(all_transition_probs, all_labels, show_plot=False):
    """
    Compare steady state distributions of multiple Markov chains.

    Parameters:
        all_transition_probs (list): List of transition probability dictionaries.
        all_labels (list): List of labels for each transition probability dictionary.
        show_plot (bool): Whether to plot the comparison (default: False).

    Returns:
        all_steady_states, all_matrices, all_states_lists (tuple):
            - all_steady_states (list): List of (steady_state, label) tuples.
            - all_matrices (list): List of transition matrices.
            - all_states_lists (list): List of state lists.
    """
    if show_plot:
        plt.figure(figsize=(15, 8))
    
    # Define a color map for multiple artists
    colors = plt.cm.tab10(np.linspace(0, 1, len(all_transition_probs)))
    
    # Track all pitches across all artists
    all_pitches = set()
    all_steady_states = []
    all_matrices = []
    all_states_lists = []
    
    # Compute steady states for each artist
    for i, (transition_probs, label) in enumerate(zip(all_transition_probs, all_labels)):
        # Extract artist name from label (assuming format "Moanin - ArtistName")
        artist_name = label.split(" - ")[1] if " - " in label else label
        
        P, states, steady_state = analyze_markov_chain(transition_probs, show_heatmap=False, show_plot=False, artist_name=artist_name)
        all_steady_states.append((steady_state, label))
        all_matrices.append(P)
        all_states_lists.append(states)
        all_pitches.update(steady_state.keys())
    
    # Sort pitches for consistent x-axis
    all_pitches = sorted(list(all_pitches))
    
    # Plot steady states for each artist if requested
    if show_plot:
        # First plot the bar chart comparison
        bar_width = 0.8 / len(all_transition_probs)
        for i, (steady_state, label) in enumerate(all_steady_states):
            # Fill in missing pitches with zeros
            probs = [steady_state.get(pitch, 0) for pitch in all_pitches]
            
            # Calculate bar positions
            positions = np.array(all_pitches) + (i - len(all_transition_probs)/2 + 0.5) * bar_width
            
            plt.bar(positions, probs, width=bar_width, color=colors[i], label=label, alpha=0.7)
        
        plt.xlabel('Pitch (MIDI note)')
        plt.ylabel('Steady State Probability')
        plt.title('Comparison of Steady State Distributions')
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.5)
        plt.tight_layout()
        plt.show()
        
        # Now plot all heatmaps in a row
        n_artists = len(all_transition_probs)
        fig, axes = plt.subplots(1, n_artists, figsize=(5*n_artists, 6))
        
        # If there's only one artist, axes won't be an array
        if n_artists == 1:
            axes = [axes]
        
        for i, (P, states, (_, label)) in enumerate(zip(all_matrices, all_states_lists, all_steady_states)):
            artist_name = label.split(" - ")[1] if " - " in label else label
            im = axes[i].imshow(P, cmap='Blues')
            axes[i].set_title(f'{artist_name}')
            
            # Add tick labels (may be too crowded for large matrices)
            if len(states) <= 10:
                axes[i].set_xticks(range(len(states)))
                axes[i].set_xticklabels(states, rotation=90)
                axes[i].set_yticks(range(len(states)))
                axes[i].set_yticklabels(states)
            else:
                # For larger matrices, show fewer ticks
                tick_step = len(states) // 5
                axes[i].set_xticks(range(0, len(states), tick_step))
                axes[i].set_xticklabels([states[j] for j in range(0, len(states), tick_step)], rotation=90)
                axes[i].set_yticks(range(0, len(states), tick_step))
                axes[i].set_yticklabels([states[j] for j in range(0, len(states), tick_step)])
            
            # Only add y-label for the first subplot
            if i == 0:
                axes[i].set_ylabel('From MIDI Note')
            
            axes[i].set_xlabel('To MIDI Note')
        
        # Add a colorbar
        fig.subplots_adjust(right=0.9)
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        fig.colorbar(im, cax=cbar_ax, label='Transition Probability')
        
        plt.suptitle('Transition Matrices for All Artists', fontsize=16)
        plt.tight_layout(rect=[0, 0, 0.9, 0.95])  # Adjust the rect to leave space for the colorbar
        plt.show()
    
    return all_steady_states, all_matrices, all_states_lists

def run_all_simulations(midi_files, num_simulations=100, num_steps=1000, show_individual_plots=True, show_comparison=True):
    """
    Run simulations for all artists and compare their steady state distributions.

    Parameters:
        midi_files (list): List of (artist_name, midi_data) tuples.
        num_simulations (int): Number of simulations per artist (default: 100).
        num_steps (int): Number of steps per simulation (default: 1000).
        show_individual_plots (bool): Whether to show individual artist plots (default: True).
        show_comparison (bool): Whether to show comparison plot of all artists (default: True).

    Returns:
        all_theoretical_dists, all_simulated_dists, all_matrices, all_states_lists (tuple):
            - all_theoretical_dists (list): List of (theoretical_dist, label) tuples.
            - all_simulated_dists (list): List of (simulated_dist, label) tuples.
            - all_matrices (list): List of transition matrices.
            - all_states_lists (list): List of state lists.
    """
    all_transition_probs = []
    all_labels = []
    all_theoretical_dists = []
    all_simulated_dists = []
    all_matrices = []
    all_states_lists = []

    for artist, midi_file in midi_files:
        transition_probs, _, _ = get_prob_transitions(midi_file)

        # Add this line to run ergodicity analysis
        P, states, ergodicity_results = integrate_with_existing_code(transition_probs, artist_name=artist.capitalize())
        
        all_transition_probs.append(transition_probs)
        label = f"Moanin - {artist.capitalize()}"
        all_labels.append(label)
        
        # Run simulations for this artist
        P, states, theoretical_dist, simulated_dist = run_multiple_simulations(
            transition_probs,
            num_simulations=num_simulations,
            num_steps=num_steps,
            show_final_plot=show_individual_plots,
            artist_name=artist.capitalize()
        )
        
        all_theoretical_dists.append((theoretical_dist, label))
        all_simulated_dists.append((simulated_dist, label))
        all_matrices.append(P)
        all_states_lists.append(states)

    # Compare all artists if requested
    if show_comparison:
        print("\nComparing theoretical steady states across all artists:")
        all_steady_states, _, _ = compare_markov_chains(all_transition_probs, all_labels, show_plot=True)
        
        print("\nMost likely notes for each artist:")
        for (steady_state, label) in all_theoretical_dists:
            # Extract artist name
            artist_name = label.split(" - ")[1] if " - " in label else label
            
            # Sort by probability (descending)
            sorted_notes = sorted(steady_state.items(), key=lambda x: x[1], reverse=True)
            top_notes = sorted_notes[:3]
            
            print(f"\n{artist_name}:")
            for note, prob in top_notes:
                print(f"  MIDI note {note}/{midi_to_pitch_class(note)}: {prob:.4f} probability")
    
    return all_theoretical_dists, all_simulated_dists, all_matrices, all_states_lists

#### check ergodicity of transition matrix, this was not used ####
def is_irreducible(P, tolerance=1e-10):
    """
    Check if a Markov chain transition matrix is irreducible.

    Parameters:
        P (array): Transition matrix.
        tolerance (float): Numerical tolerance for considering a probability non-zero (default: 1e-10).

    Returns:
        bool: True if the Markov chain is irreducible, False otherwise.
    """
    n = P.shape[0]
    
    # Compute P^1 + P^2 + ... + P^n
    # If all states can reach all other states, this sum will have all non-zero entries
    P_sum = np.copy(P)
    P_power = np.copy(P)
    
    for _ in range(2, n+1):
        P_power = P_power @ P
        P_sum += P_power
    
    # Check if any entry is zero (or very close to zero)
    return np.all(P_sum > tolerance)

def get_period(P, state, tolerance=1e-10):
    """
    Calculate the period of a specific state in a Markov chain.

    Parameters:
        P (array): Transition matrix.
        state (int): Index of the state to check.
        tolerance (float): Numerical tolerance for considering a probability non-zero (default: 1e-10).

    Returns:
        period (int): Period of the state, or 0 if the state doesn't return to itself.
    """
    # Keep track of when we can return to the starting state
    return_times = []
    
    # Compute P^k for k=1,2,...,n and check return probabilities
    P_power = np.copy(P)
    
    for k in range(1, P.shape[0] + 1):
        if P_power[state, state] > tolerance:
            return_times.append(k)
        P_power = P_power @ P
    
    if not return_times:
        return 0  # State doesn't return to itself
    
    # Period is the GCD of all return times
    period = return_times[0]
    for t in return_times[1:]:
        period = np.gcd(period, t)
    
    return period

def is_aperiodic(P, tolerance=1e-10):
    """
    Check if a Markov chain transition matrix is aperiodic.

    Parameters:
        P (array): Transition matrix.
        tolerance (float): Numerical tolerance for considering a probability non-zero (default: 1e-10).

    Returns:
        bool: True if the Markov chain is aperiodic, False otherwise.
    """
    n = P.shape[0]
    
    # Alternative approach: If P has a self-loop (P[i,i] > 0), it's aperiodic
    # Check if any state has a self-transition
    for i in range(n):
        if P[i, i] > tolerance:
            return True
    
    # If no self-loops, compute the period for each state
    for i in range(n):
        period = get_period(P, i, tolerance)
        if period == 1:
            return True  # Found a state with period 1, so chain is aperiodic
        elif period == 0:
            continue  # This state doesn't return to itself
    
    return False  # No state with period 1 found, chain is periodic

def is_ergodic(P, tolerance=1e-10):
    """
    Check if a Markov chain transition matrix is ergodic.

    Parameters:
        P (array): Transition matrix.
        tolerance (float): Numerical tolerance for considering a probability non-zero (default: 1e-10).

    Returns:
        ergodic, info (tuple):
            - ergodic (bool): True if the Markov chain is ergodic, False otherwise.
            - info (dict): Dictionary with 'irreducible', 'aperiodic', and 'ergodic' keys.
    """
    irreducible_result = is_irreducible(P, tolerance)
    aperiodic_result = is_aperiodic(P, tolerance)
    ergodic_result = irreducible_result and aperiodic_result
    
    info = {
        'irreducible': irreducible_result,
        'aperiodic': aperiodic_result,
        'ergodic': ergodic_result
    }
    
    return ergodic_result, info

def eigenvalue_analysis(P, tolerance=1e-10):
    """
    Analyze eigenvalues of a transition matrix to check for ergodicity.

    Parameters:
        P (array): Transition matrix.
        tolerance (float): Numerical tolerance for eigenvalue magnitude (default: 1e-10).

    Returns:
        info (dict): Dictionary with eigenvalue analysis information including 
                    'eigenvalues', 'unit_eigenvalues', 'num_unit_eigenvalues', 
                    and 'ergodic_by_eigenvalues'.
    """
    # Compute eigenvalues
    eigenvalues = linalg.eigvals(P)
    
    # Count eigenvalues with magnitude close to 1
    unit_eigenvalues = [ev for ev in eigenvalues if abs(abs(ev) - 1) < tolerance]
    
    info = {
        'eigenvalues': eigenvalues,
        'unit_eigenvalues': unit_eigenvalues,
        'num_unit_eigenvalues': len(unit_eigenvalues),
        'ergodic_by_eigenvalues': len(unit_eigenvalues) == 1
    }
    
    return info

def analyze_ergodicity(P, states=None, tolerance=1e-10, artist_name=None):
    """
    Perform comprehensive ergodicity analysis on a Markov chain.

    Parameters:
        P (array): Transition matrix.
        states (list): List of state labels (default: None).
        tolerance (float): Numerical tolerance for tests (default: 1e-10).
        artist_name (str): Name for output (default: None).

    Returns:
        results (dict): Dictionary with 'ergodic', 'irreducible', 'aperiodic', and 'eigenvalue_analysis' results.
    """
    n = P.shape[0]
    
    # Check for ergodicity
    is_ergodic_result, ergodic_info = is_ergodic(P, tolerance)
    eigenvalue_info = eigenvalue_analysis(P, tolerance)
    
    # Print results
    print(f"\n{'=' * 20} Ergodicity Analysis for {artist_name if artist_name else 'Matrix'} {'=' * 20}")
    print(f"Matrix size: {n} x {n}")
    print(f"Irreducible: {ergodic_info['irreducible']}")
    print(f"Aperiodic: {ergodic_info['aperiodic']}")
    print(f"Ergodic: {ergodic_info['ergodic']}")
    print(f"Number of eigenvalues with magnitude 1: {eigenvalue_info['num_unit_eigenvalues']}")
    print(f"Ergodic by eigenvalue analysis: {eigenvalue_info['ergodic_by_eigenvalues']}")
    
    if not ergodic_info['irreducible']:
        print("Warning: The chain is not irreducible. There may be states that cannot reach other states.")
    
    if not ergodic_info['aperiodic']:
        print("Warning: The chain is periodic. The steady state may not be unique or convergence may not occur.")
    
    # Combine all results
    results = {
        'ergodic': ergodic_info['ergodic'],
        'irreducible': ergodic_info['irreducible'],
        'aperiodic': ergodic_info['aperiodic'],
        'eigenvalue_analysis': eigenvalue_info,
    }
    
    return results

def test_example_matrices():
    """
    Test ergodicity functions on example matrices.

    Parameters:
        None

    Returns:
        None: Prints analysis results for example matrices.
    """
    # Example 1: Ergodic chain
    P1 = np.array([
        [0.7, 0.3],
        [0.4, 0.6]
    ])
    
    # Example 2: Reducible chain (state 1 can't reach state 0)
    P2 = np.array([
        [0.8, 0.2],
        [0.0, 1.0]
    ])
    
    # Example 3: Periodic chain with period 2
    P3 = np.array([
        [0.0, 1.0],
        [1.0, 0.0]
    ])
    
    print("\nTesting example matrices:")
    analyze_ergodicity(P1, artist_name="Example 1 (Should be ergodic)")
    analyze_ergodicity(P2, artist_name="Example 2 (Should be reducible, not ergodic)")
    analyze_ergodicity(P3, artist_name="Example 3 (Should be periodic, not ergodic)")

def integrate_with_existing_code(transition_probs, artist_name=None):
    """
    Integrate ergodicity analysis with existing transition matrix code.

    Parameters:
        transition_probs (dict): Dictionary with (from_pitch, to_pitch) tuples as keys and probabilities as values.
        artist_name (str): Name of the artist for output (default: None).

    Returns:
        P, states, ergodicity_results (tuple):
            - P (array): Transition matrix.
            - states (list): List of states (MIDI notes).
            - ergodicity_results (dict): Ergodicity analysis results.
    """
    # Build transition matrix using existing function
    P, states = build_transition_matrix(transition_probs)
    
    # Perform ergodicity analysis
    ergodicity_results = analyze_ergodicity(P, states, artist_name=artist_name)
    
    return P, states, ergodicity_results

### morlet

In [None]:
##### all functions for molet wavelet analysis #####
def prep_data_morlet(data_file_path, sample_rate=1000, frequencies=None, wavelet='cmor1.5-1.0', wav=False):
    """
    Prepare data for Morlet wavelet analysis from MIDI or WAV files.

    Parameters:
        data_file_path (str): File path to the data (MIDI or WAV).
        sample_rate (int): Sampling rate in Hz (default: 1000). If None, uses file's sample rate for WAV.
        frequencies (list): List of frequencies for wavelet analysis in Hz (default: None, uses [60, 250, 500, 2000]).
        wavelet (str): Wavelet type for transform (default: 'cmor1.5-1.0', where 1.5 is bandwidth, 1.0 is center frequency).
        wav (bool): Whether to process a WAV file (True) or MIDI file (False) (default: False).

    Returns:
        tuple: Depending on `wav`:
            - If wav=True: (coefficients, _, audio_data, sample_rate, frequencies)
            - If wav=False: (coefficients, _, time_series, velocity_series, sample_rate, frequencies)
    """
    # Use default frequencies if none provided
    if frequencies is None:
        frequencies = np.arange(60, 2000, 10)
    else:
        frequencies = np.array(frequencies)
    
    if wav:
        # Load the WAV file
        file_sample_rate, audio_data = wavfile.read(data_file_path)
        
        # If the audio is stereo, average the channels to get mono
        if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
            audio_data = np.mean(audio_data, axis=1)
            
        # Normalize audio data to range -1 to 1 if it's int type
        if np.issubdtype(audio_data.dtype, np.integer):
            audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
        
        # Use the provided sample rate or the file's sample rate
        if sample_rate is None:
            sample_rate = file_sample_rate
        # Resample if necessary (basic implementation - for better results use resampy or librosa)
        elif sample_rate != file_sample_rate:
            from scipy import signal
            original_length = len(audio_data)
            new_length = int(original_length * sample_rate / file_sample_rate)
            audio_data = signal.resample(audio_data, new_length)
    
        # Perform Continuous Wavelet Transform (CWT) with the specified wavelet
        coefficients, _ = pywt.cwt(audio_data, frequencies, wavelet)

        return coefficients, _, audio_data, sample_rate, frequencies
    else:
        # Load the MIDI file
        midi_file = pretty_midi.PrettyMIDI(data_file_path)

        # Lists to hold start times and MIDI note numbers
        start_times, note_numbers = get_midi_notes_over_time(midi_file)

        # Convert to numpy arrays for sorting
        times = np.array(start_times)
        pitches = np.array(note_numbers)

        # Sort by time
        sorted_indices = np.argsort(times)
        times = times[sorted_indices]
        pitches = pitches[sorted_indices]
        
        # convert pitches to frequencies from (Byrd, 2007)
        freq = [440 * (2 ** ((notes - 69) / 12)) for notes in note_numbers]

        # Create a time series of velocities
        time_series = midi_file.get_piano_roll(fs=sample_rate)
        velocity_series = np.sum(time_series, axis=0)
        
        # Perform Continuous Wavelet Transform (CWT) with the specified wavelet
        coefficients, _ = pywt.cwt(velocity_series, frequencies, wavelet)
    
        return coefficients, _, time_series, velocity_series, sample_rate, frequencies

def get_amplitude_phase(transformed_data, wav=False):
    """
    Extract amplitude and phase from wavelet-transformed data.

    Parameters:
        transformed_data (tuple): Data from prep_data_morlet (coefficients, _, audio_data/time_series, sample_rate, frequencies).
        wav (bool): Whether the data is from a WAV file (True) or MIDI file (False) (default: False).

    Returns:
        tuple: Depending on `wav`:
            - If wav=True: (audio_data, sample_rate, phase, amplitude, frequencies)
            - If wav=False: (time_series, velocity_series, sample_rate, phase, amplitude, frequencies)
    """
    if wav:
        coefficients, _, audio_data, sample_rate, frequencies = transformed_data

        # Get amplitude and phase
        amplitude = np.abs(coefficients)
        phase = np.angle(coefficients)

        return audio_data, sample_rate, phase, amplitude, frequencies
    else:
        coefficients, _, time_series, velocity_series, sample_rate, frequencies = transformed_data
        # Get amplitude and phase
        amplitude = np.abs(coefficients)
        phase = np.angle(coefficients)

        return time_series, velocity_series, sample_rate, phase, amplitude, frequencies

def morlet_analysis(data_list, time_window=(0, 16), titles=None, 
                   figsize=(15, 12), wav=False):
    """
    Perform and visualize Morlet wavelet analysis with waveform and spectrogram plots.

    Parameters:
        data_list (list): List of tuples from get_amplitude_phase (WAV or MIDI format).
        time_window (tuple): Time range (start_time, end_time) in seconds (default: (0, 16)).
        titles (list): List of titles for each dataset (default: None, uses ["Data 1", "Data 2", ...]).
        figsize (tuple): Figure size as (width, height) (default: (15, 12)).
        wav (bool): Whether data is from WAV files (True) or MIDI files (False) (default: False).

    Returns:
        matplotlib.figure.Figure: Figure object containing the plots.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    
    # Validate input
    num_datasets = len(data_list)
    if num_datasets < 1:
        raise ValueError("At least one dataset must be provided")
    
    # Set default titles if none provided
    if titles is None:
        titles = [f"Data {i+1}" for i in range(num_datasets)]
    elif len(titles) != num_datasets:
        raise ValueError(f"Number of titles ({len(titles)}) must match number of datasets ({num_datasets})")

    # Prepare data for each dataset
    waveform_data = []
    sample_rates = []
    amplitudes = []
    frequencies_list = []
    
    for i, data in enumerate(data_list):
        if wav:
            # WAV data format
            audio_data, sample_rate, phase, amplitude, frequencies = data
            waveform_data.append(audio_data)
        else:
            # MIDI data format
            time_series, velocity_series, sample_rate, phase, amplitude, frequencies = data
            waveform_data.append(velocity_series)
        
        sample_rates.append(sample_rate)
        amplitudes.append(amplitude)
        frequencies_list.append(frequencies)
    
    # Calculate time indices and window times for each dataset
    time_axes = []
    window_times = []
    a_indices = []
    b_indices = []
    
    for i in range(num_datasets):
        time_axis = np.arange(len(waveform_data[i])) / sample_rates[i]
        time_axes.append(time_axis)
        
        a = int(time_window[0] * sample_rates[i])
        b = int(time_window[1] * sample_rates[i])
        b = min(b, len(waveform_data[i]))  # Ensure we don't go beyond array bounds
        
        a_indices.append(a)
        b_indices.append(b)
        window_times.append(time_axis[a:b])
    
    # Fixed at 2 rows for waveform and spectrogram
    num_rows = 2
    
    # Adjust figure size based on number of datasets
    dynamic_figsize = (figsize[0] * num_datasets / 2, figsize[1])
    
    # Create figure and gridspec
    fig = plt.figure(figsize=dynamic_figsize)
    gs = fig.add_gridspec(num_rows, num_datasets)
    
    # Plot waveforms (row 0)
    for i in range(num_datasets):
        ax = fig.add_subplot(gs[0, i])
        ax.plot(window_times[i], waveform_data[i][a_indices[i]:b_indices[i]])
        ax.set_title(f'{titles[i]} Waveform')
        
        # Only add y-label to the first plot in each row
        if i == 0:
            ax.set_ylabel('Amplitude')
    
    # Plot spectrograms (row 1)
    for i in range(num_datasets):
        ax = fig.add_subplot(gs[1, i])
        
        # Make sure we have the right shapes for the spectrograms
        amp_spec = amplitudes[i][:, a_indices[i]:b_indices[i]] if amplitudes[i].shape[1] >= b_indices[i] else amplitudes[i][:, a_indices[i]:amplitudes[i].shape[1]]
        
        # Update window_time if necessary
        window_time_spec = window_times[i][:amp_spec.shape[1]]
        
        im = ax.pcolormesh(window_time_spec, frequencies_list[i], amp_spec, 
                          shading='gouraud', cmap='viridis')
        ax.set_yscale('log')
        ax.set_title(f'{titles[i]} Spectrogram')
        ax.set_xlabel('Time (s)')
        
        # Only add y-label to the first plot in each row
        if i == 0:
            ax.set_ylabel('Frequency (Hz)')
            
        fig.colorbar(im, ax=ax, label='Magnitude')
    
    # Adjust layout for better spacing
    plt.tight_layout()
    
    return fig

def plot_wavelet_spectrogram(data_list, wav=False, time_window=None, titles=None, figsize=(12, 6)):
    """
    Plot wavelet spectrograms (scalograms) for multiple datasets using continuous wavelet transform.

    Parameters:
        data_list (list): List of tuples from get_amplitude_phase (WAV or MIDI format).
        wav (bool): Whether data is from WAV files (True) or MIDI files (False) (default: False).
        time_window (tuple): Time range (start_time, end_time) in seconds (default: None, shows all).
        titles (list): List of titles for each dataset (default: None, uses ["Dataset 1", "Dataset 2", ...]).
        figsize (tuple): Figure size as (width, height) (default: (12, 6)).

    Returns:
        fig (matplotlib.figure): Figure object containing the spectrogram plots.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    
    # Handle single dataset case
    if not isinstance(data_list, list):
        data_list = [data_list]
    
    # Validate input
    num_datasets = len(data_list)
    if num_datasets < 1:
        raise ValueError("At least one dataset must be provided")
    
    # Set default titles if none provided
    if titles is None:
        titles = [f"Dataset {i+1}" for i in range(num_datasets)]
    elif len(titles) != num_datasets:
        raise ValueError(f"Number of titles ({len(titles)}) must match number of datasets ({num_datasets})")
    
    # Adjust figure size based on number of datasets
    adjusted_figsize = (figsize[0] * num_datasets, figsize[1])
    
    # Create figure and axes
    fig, axes = plt.subplots(1, num_datasets, figsize=adjusted_figsize)
    
    # Convert to array for consistent indexing, even with single dataset
    if num_datasets == 1:
        axes = np.array([axes])
    
    # Process each dataset
    for i, (data, title, ax) in enumerate(zip(data_list, titles, axes)):
        if wav:
            audio_data, sample_rate, phase, amplitude, frequencies = data
            # Compute time array from audio_data length and sample_rate
            N = len(audio_data)
            time = np.linspace(0, (N-1)/sample_rate, N)
            freqs = frequencies
        else:
            time_series, velocity_series, sample_rate, phase, amplitude, frequencies = data
            # Compute time array from velocity_series length and sample_rate
            N = len(velocity_series)
            time = np.linspace(0, (N-1)/sample_rate, N)
            freqs = frequencies
        
        power = amplitude ** 2  # Compute power of coefficients
        
        # Set x-axis limits based on time_window parameter
        x_min = time[0]
        x_max = time[-1]
        if time_window is not None:
            start_time, end_time = time_window
            if start_time is not None:
                x_min = max(start_time, time[0])
            if end_time is not None:
                x_max = min(end_time, time[-1])
        
        # Plot the spectrogram
        im = ax.imshow(power, aspect='auto', extent=[time[0], time[-1], freqs[-1], freqs[0]],
                   cmap='inferno', interpolation='nearest')
        
        # Set the x-axis limits for the plot
        ax.set_xlim(x_min, x_max)
        
        # Add colorbar
        plt.colorbar(im, ax=ax, label="Power")
        
        # Set labels
        ax.set_xlabel("Time (s)")
        # Only add y-label to the first plot
        if i == 0:
            ax.set_ylabel("Frequency (Hz)")
        
        ax.set_title(f"{title}\n({x_min:.2f}s - {x_max:.2f}s)")
    
    plt.tight_layout()
    return fig

def plot_amp_phase(data_list, wav=False, time_window=None, titles=None, figsize=(15, 12)):
    """
    Plot amplitude and phase spectrograms for multiple datasets.

    Parameters:
        data_list (list): List of tuples from get_amplitude_phase (WAV or MIDI format).
        wav (bool): Whether data is from WAV files (True) or MIDI files (False) (default: False).
        time_window (tuple): Time range (start_time, end_time) in seconds (default: None, shows all).
        titles (list): List of titles for each dataset (default: None, uses ["Data 1", "Data 2", ...]).
        figsize (tuple): Figure size as (width, height) (default: (15, 12)).

    Returns:
        fig (matplotlib.figure): Figure object containing amplitude and phase plots.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Validate input
    num_datasets = len(data_list)
    if num_datasets < 1:
        raise ValueError("At least one dataset must be provided")
    
    # Set default titles if none provided
    if titles is None:
        titles = [f"Data {i+1}" for i in range(num_datasets)]
    elif len(titles) != num_datasets:
        raise ValueError(f"Number of titles ({len(titles)}) must match number of datasets ({num_datasets})")
    
    # Extract data for each dataset
    time_axes = []
    phases = []
    amplitudes = []
    frequencies_list = []
    extents = []
    
    for i, data in enumerate(data_list):
        if wav:
            # WAV format
            audio_data, sample_rate, phase, amplitude, frequencies = data
            # Calculate time axis from amplitude shape
            time_axis = np.arange(amplitude.shape[1]) / sample_rate
        else:
            # MIDI format
            time_series, velocity, sample_rate, phase, amplitude, frequencies = data
            # Calculate time axis from velocity length
            time_axis = np.arange(len(velocity)) / sample_rate
        
        time_axes.append(time_axis)
        phases.append(phase)
        amplitudes.append(amplitude)
        frequencies_list.append(frequencies)
        
        # Calculate extent (time range and frequency range)
        # Frequencies are ordered with smallest at the bottom
        extent = [0, time_axis[-1], frequencies[0], frequencies[-1]]
        extents.append(extent)
    
    # Create a grid of subplots - 2 rows (amplitude and phase) x num_datasets columns
    fig, axes = plt.subplots(2, num_datasets, figsize=(figsize[0] * num_datasets / 2, figsize[1]))
    
    # Handle the case where there's only one dataset (axes won't be 2D)
    if num_datasets == 1:
        axes = np.array(axes).reshape(2, 1)
    
    # First row: Amplitude spectrograms
    for i in range(num_datasets):
        im_amp = axes[0, i].imshow(amplitudes[i], extent=extents[i], aspect='auto', 
                                   cmap='viridis', origin='lower')
        axes[0, i].set_title(f'{titles[i]} Amplitude')
        
        # Only add y-label to the first plot in each row
        if i == 0:
            axes[0, i].set_ylabel('Frequency (Hz)')
            
        axes[0, i].set_xlabel('Time (s)')
        plt.colorbar(im_amp, ax=axes[0, i], label='Magnitude')
        
        # Apply time window if specified
        if time_window is not None:
            start_time, end_time = time_window
            axes[0, i].set_xlim(start_time, end_time)
    
    # Second row: Phase spectrograms
    for i in range(num_datasets):
        im_phase = axes[1, i].imshow(phases[i], extent=extents[i], aspect='auto', 
                                     cmap='twilight', origin='lower')
        axes[1, i].set_title(f'{titles[i]} Phase')
        
        # Only add y-label to the first plot in each row
        if i == 0:
            axes[1, i].set_ylabel('Frequency (Hz)')
            
        axes[1, i].set_xlabel('Time (s)')
        plt.colorbar(im_phase, ax=axes[1, i], label='Phase (radians)')
        
        # Apply time window if specified
        if time_window is not None:
            start_time, end_time = time_window
            axes[1, i].set_xlim(start_time, end_time)
    
    plt.tight_layout()
    
    return fig

#### dominant frequencies (which i ended up not using) ####
def find_dominant_frequencies(amplitude, freqs, threshold=0.1):
    """
    Identify dominant frequency ranges based on wavelet coefficients.

    Parameters:
        amplitude (array): Wavelet coefficients (frequencies x time).
        freqs (array): Array of frequency values.
        threshold (float): Fraction of max coefficient to consider dominant (default: 0.1).

    Returns:
        dominant_freqs (array): Array of dominant frequencies.
    """
    import numpy as np
    
    mean_coeffs = np.mean(np.abs(amplitude), axis=1)  # Average across time
    max_coeff = np.max(mean_coeffs)

    # coefficients are normall distributed
    # so use 1.96 to get 95% of the data
    threshold = np.mean(mean_coeffs) + 1.5 * np.std(mean_coeffs)
    dominant_freqs = freqs[mean_coeffs > threshold]
    # used to manually set this to like 0.8
    #dominant_freqs = freqs[mean_coeffs > threshold * max_coeff]
    
    return dominant_freqs

def plot_dominant_frequencies(freqs_list, amp_list, labels, threshold=0.1, title=None, figsize=(10, 6)):
    """
    Plot mean wavelet coefficients and dominant frequencies for multiple datasets.

    Parameters:
        freqs_list (list): List of frequency arrays for each dataset.
        amp_list (list): List of amplitude arrays (wavelet coefficients) for each dataset.
        labels (list): List of labels for each dataset.
        threshold (float): Fraction of max coefficient to consider dominant (default: 0.1).
        title (str): Custom plot title (default: None, uses "Dominant Frequencies Comparison").
        figsize (tuple): Figure size as (width, height) (default: (10, 6)).

    Returns:
        dominant_freqs_list (list): List of arrays containing dominant frequencies for each dataset.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    
    # Input validation
    if not (len(freqs_list) == len(amp_list) == len(labels)):
        raise ValueError("freqs_list, amp_list, and labels must have the same length")
    
    # Create figure
    plt.figure(figsize=figsize)
    
    # List to store dominant frequencies
    dominant_freqs_list = []
    
    # Process and plot each dataset
    for freqs, amps, label in zip(freqs_list, amp_list, labels):
        # Compute mean coefficients across time
        mean_coeffs = np.mean(np.abs(amps), axis=1)
        
        # Call find_dominant_frequencies to get dominant frequencies
        dominant_freqs = find_dominant_frequencies(amps, freqs, threshold)
        dominant_freqs_list.append(dominant_freqs)
        
        # Plot mean coefficients
        plt.plot(freqs, mean_coeffs, label=f"Mean Coefficients - {label}")
        
        # Plot dominant frequencies as scatter points
        dominant_mask = np.isin(freqs, dominant_freqs)
        plt.scatter(freqs[dominant_mask], mean_coeffs[dominant_mask], 
                   label=f"Dominant Freqs - {label}", s=50)

    # Customize plot
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Wavelet Coefficient Magnitude")
    plt.title(title if title is not None else "Dominant Frequencies Comparison")
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    return dominant_freqs_list

def compare_midi_numbers(midi_lists):
    """
    Compare multiple lists of MIDI numbers to find overlapping and unique numbers.

    Parameters:
        midi_lists (list): List of lists, each containing MIDI numbers (integers).

    Returns:
        result (list): [overlapping_numbers, leftover1, leftover2, ...], where:
                            - overlapping_numbers: MIDI numbers present in all lists.
                            - leftoveri: MIDI numbers unique to list i.
    """
    if not midi_lists or len(midi_lists) == 0:
        return []
    
    # Convert all lists to sets for efficient comparison
    sets = [set(midi_list) for midi_list in midi_lists]
    
    # Find numbers that appear in all lists (intersection of all sets)
    overlapping = set.intersection(*sets) if sets else set()
    
    # For each set, find numbers unique to that set (not in any other set)
    leftovers = []
    for i, current_set in enumerate(sets):
        # Create a set of all other sets combined
        other_sets = set()
        for j, s in enumerate(sets):
            if i != j:
                other_sets.update(s)
        
        # Find elements unique to the current set
        unique_elements = current_set - other_sets
        leftovers.append(sorted(list(unique_elements)))
    
    # Return the overlapping numbers and all leftovers
    result = [sorted(list(overlapping))] + leftovers
    return result

def analyze_midi_distributions_with_lines(midi_files, line_groups, labels=None, mode='overlay', linestyle='dotted', linewidth=2):
    """
    Analyze and visualize MIDI note distributions with reference lines for dominant frequencies.

    Parameters:
        midi_files (list): List of PrettyMIDI objects.
        line_groups (list): List of MIDI number lists for comparison.
        labels (list): Labels for each file (default: None, uses ["MIDI 1", "MIDI 2", ...]).
        mode (str): Visualization mode ('individual' or 'overlay') (default: 'overlay').
        linestyle (str): Style of vertical lines (default: 'dotted').
        linewidth (float): Width of vertical lines (default: 2).

    Returns:
        results (dict): Dictionary with indices as keys and (note_counts, mean_note) as values.
    """
    from collections import Counter
    import numpy as np
    import matplotlib.pyplot as plt
    
    # Set default labels if not provided
    if labels is None:
        labels = [f"MIDI {i+1}" for i in range(len(midi_files))]
    
    # # Validate inputs
    # if len(midi_files) != 2 or len(labels) < 2:
    #     raise ValueError("Exactly two MIDI files and at least two labels are required")
    
    # Assign colors based on labels
    bar_colors = ['blue', 'orange', 'purple', 'pink', 'darkgreen', 'brown']  # Colors for bars and leftover lines
    
    # Dictionary to store results
    results = {}
    
    # Get all notes for overlap calculation
    all_notes = []
    for midi_file in midi_files:
        notes = [note.pitch for instrument in midi_file.instruments for note in instrument.notes]
        all_notes.append(notes)
    
    # Compute overlapping and leftover numbers
    sorted_line_grps = compare_midi_numbers(line_groups)
    overlapping_notes = sorted_line_grps[0] 
    leftover_notes = sorted_line_grps[1:]  # [leftover1, leftover2], ignoring overlapping
    
    # If overlay mode, prepare a single figure
    if mode == 'overlay':
        plt.figure(figsize=(12, 6))
    
    # Process each file
    for i, (midi_file, label) in enumerate(zip(midi_files, labels)):
        try:
            # Get all notes from all instruments in the MIDI file
            notes = all_notes[i]
            
            # Count note occurrences
            note_counts = Counter(notes)
            
            # Sort the notes for consistent plotting
            sorted_notes = sorted(note_counts.items())
            keys, values = zip(*sorted_notes) if sorted_notes else ([], [])
            
            # Compute the mean note
            mean_note = np.mean(list(note_counts.keys())) if note_counts else 0
            
            # Store results
            results[i] = (note_counts, mean_note)
            
            # Visualization
            if mode == 'individual':
                # Create a separate plot for each file
                plt.figure(figsize=(10, 5))
                plt.bar(keys, values, color=bar_colors[i], alpha=0.7)
                # Mean line (red dotted)
                plt.axvline(mean_note, color='red', linestyle='dotted', linewidth=2, 
                           label=f'Mean={mean_note:.2f}')
                # Leftover notes with corresponding colors
                for group_idx, leftover_group in enumerate(leftover_notes):
                    for x in leftover_group:
                        plt.axvline(x, color=bar_colors[group_idx], linestyle=linestyle, linewidth=linewidth, 
                                   label=f'{labels[group_idx]} Dom Freq={x}' if i == 0 else None)
                # Add overlapping notes (green)
                for x in overlapping_notes:
                    plt.axvline(x, color='green', linestyle=linestyle, linewidth=linewidth-1, 
                            label=f'Overlap={x}', alpha=0.4)
                plt.xlabel('MIDI Note Number')
                plt.ylabel('Frequency')
                plt.title(f"Histogram of {label} MIDI Note Frequencies")
                plt.legend()
                plt.show()
            elif mode == 'overlay':
                # Add to the overlay plot
                plt.bar(keys, values, alpha=0.5, color=bar_colors[i], label=f"{label} (Mean: {mean_note:.2f})")
        
        except Exception as e:
            print(f"Error processing MIDI {i+1}: {e}")
    
    # Add lines and show the overlay plot if needed
    if mode == 'overlay' and results:
        # Add mean lines for both distributions
        for i in range(len(midi_files)):
            mean_note = results[i][1]
            plt.axvline(mean_note, color='red', linestyle='dotted', linewidth=2, 
                       label=f'{labels[i]} Mean={mean_note:.2f}')
        # Add leftover notes with corresponding colors
        for group_idx, leftover_group in enumerate(leftover_notes):
            for x in leftover_group:
                plt.axvline(x, color=bar_colors[group_idx], linestyle=linestyle, linewidth=linewidth, 
                           label=f'{labels[group_idx]} Dom Freq={x}')
        # Add overlapping notes (green)
        for x in overlapping_notes:
            plt.axvline(x, color='green', linestyle=linestyle, linewidth=linewidth-1, 
                    label=f'Overlap={x}', alpha=0.4)
        plt.xlabel('MIDI Note Number')
        plt.ylabel('Frequency')
        plt.title("Distribution of MIDI Notes with Dominant Frequencies")
        plt.legend()
        plt.show()
    
    return results

#### periodicity vs wavelet entropy ####
def periodicity_wavelet_entropy(data, wav=False):
    """
    Calculate periodicity and wavelet entropy from MIDI or WAV data.

    For MIDI data (wav=False), data should contain:
    - time_series: Time values of the MIDI
    - velocity_series: MIDI note velocities over time
    - sample_rate: Sampling rate of the MIDI
    - phase: Phase values
    - amplitude: Amplitude values
    - frequencies: Extracted frequency spectrum
    
    For WAV data (wav=True), data should contain:
    - audio_data: Raw audio waveform
    - sample_rate: Sampling rate of the audio
    - phase: Phase values
    - amplitude: Amplitude values
    - frequencies: Extracted frequency spectrum

    Parameters:
        data (tuple): Data from get_amplitude_phase (WAV or MIDI format).
        wav (bool): Whether data is from WAV files (True) or MIDI files (False) (default: False).

    Returns:
        tuple: (periodicity_score, wavelet_entropy) where:
            - periodicity_score: Measure of repeating patterns.
            - wavelet_entropy: Measure of frequency spread.
    """
    if wav:
        # Unpack WAV data
        audio_data, sample_rate, phase, amplitude, frequencies = data
        
        # Calculate autocorrelation on audio data
        signal = audio_data
    else:
        # Unpack MIDI data
        time_series, velocity_series, sample_rate, phase, amplitude, frequencies = data
        
        # Calculate autocorrelation on velocity series
        signal = velocity_series
    
    # ---- Periodicity via Autocorrelation ----
    autocorr = np.correlate(signal, signal, mode='full')
    autocorr = autocorr[len(autocorr)//2:]  # Take positive lags only
    periodicity_score = np.max(autocorr) / np.sum(autocorr)  # Normalize peak

    # ---- Wavelet Entropy ----
    energy_distribution = amplitude ** 2
    
    # Calculate entropy (common for both types)
    energy_total = np.sum(energy_distribution, axis=0)
    # Safe normalization with zero handling
    normalized_energy = np.where(energy_total == 0, 0, energy_distribution / energy_total)
    entropy = -np.sum(normalized_energy * np.log(normalized_energy), axis=0)  # Shannon Entropy
    wavelet_entropy = np.mean(entropy)  # Average over time

    return periodicity_score, wavelet_entropy

def plot_entropy_vs_periodicity(entropy_list, periodicity_list, artists, title="Wavelet Entropy vs Periodicity"):
    """
    Plot wavelet entropy versus periodicity for multiple artists.

    Parameters:
        entropy_list (list): List of wavelet entropy values.
        periodicity_list (list): List of periodicity scores.
        artists (list): List of artist names corresponding to each data point.
        title (str): Plot title (default: "Wavelet Entropy vs Periodicity").

    Returns:
        None: Displays the plot.
    """
    # Ensure the lists are of the same length
    if len(entropy_list) != len(periodicity_list) or len(entropy_list) != len(artists):
        raise ValueError("The lists and the artists must have the same length.")
    
    # Create a color map with a unique color for each artist
    unique_artists = list(set(artists))  # Get unique artists
    colors = plt.cm.get_cmap('viridis', len(unique_artists))  # Using a colormap with distinct colors
    
    # Create the plot
    plt.figure(figsize=(8, 6))
    
    # Loop over the unique artists and plot each one with a unique color
    for idx, artist in enumerate(unique_artists):
        # Get the indices for the current artist
        indices = [i for i, a in enumerate(artists) if a == artist]
        
        # Plot the data points for the current artist
        plt.scatter([entropy_list[i] for i in indices], 
                    [periodicity_list[i] for i in indices],
                    color=colors(idx), label=artist)
        
        # Add labels for each point
        for i in indices:
            plt.text(entropy_list[i], periodicity_list[i], artist, fontsize=10, ha='center', va='top')
    
    # Labels and title
    plt.xlabel('Wavelet Entropy')
    plt.ylabel('Periodicity')
    plt.title(title)  # Use the custom title parameter
    
    # Add grid and show the plot
    plt.legend(title='Artists')
    plt.show()

def shapiro_wilk_test(coeff, n_values=None, step=10, plot=True, titles=None):
    """
    Perform Shapiro-Wilk test on wavelet coefficients to assess normality.

    Parameters:
        coeff (list): List of coefficient arrays from wavelet transformation.
        n_values (list): Sample sizes to test (default: None, uses [10, 50, 100, 500, 800, 5000, 10000]).
        step (int): Step size for sampling coefficients (default: 10).
        plot (bool): Whether to generate plots (default: True).
        titles (list): Titles for each coefficient list (default: None, uses ["Coeff 1", "Coeff 2", ...]).

    Returns:
        all_results (dict): Test results with 'individual' and 'flattened' statistics and p-values.
    """
    import numpy as np
    from scipy import stats
    import matplotlib.pyplot as plt
    
    # Set default n_values if not provided
    if n_values is None:
        n_values = [10, 50, 100, 500, 800, 5000, 10000]
    
    # Set default titles if not provided
    if titles is None:
        titles = [f"Coeff {i+1}" for i in range(len(coeff))]
    
    # Ensure titles match the length of coeff
    if len(titles) != len(coeff):
        raise ValueError("Length of 'titles' must match the number of coefficient lists in 'coeff'")
    
    # Initialize results dictionary for all coefficient lists
    all_results = {}

    # Number of coefficient lists
    num_coeffs = len(coeff)

    # Prepare for plotting
    if plot:
        fig, axes = plt.subplots(2, num_coeffs, figsize=(5 * num_coeffs, 10), sharex=True)
        if num_coeffs == 1:  # Handle a single coefficient list case
            axes = np.array([[axes[0]], [axes[1]]])

    # Loop through each coefficient list
    for i, coeff_list in enumerate(coeff):
        # Flatten all coefficients into one array
        flattened_coeffs = np.concatenate(coeff_list)
        
        # Sample indices from coefficient lists
        indices = range(0, len(coeff_list), step)
        
        # Initialize arrays to store results
        stats_results = {n: [] for n in n_values}
        p_results = {n: [] for n in n_values}
        stats_flat = {n: None for n in n_values}
        p_flat = {n: None for n in n_values}
        
        # Perform Shapiro-Wilk test for each coefficient list
        for idx in indices:
            sub_coeff_list = coeff_list[idx]
            for n in n_values:
                # Ensure n doesn't exceed the length of the coefficient list
                n_test = min(n, len(sub_coeff_list))
                stat, p = stats.shapiro(sub_coeff_list[:n_test])
                stats_results[n].append(stat)
                p_results[n].append(p)
        
        # Perform Shapiro-Wilk test for flattened data
        for n in n_values:
            n_test = min(n, len(flattened_coeffs))
            stat, p = stats.shapiro(flattened_coeffs[:n_test])
            stats_flat[n] = stat
            p_flat[n] = p
        
        # Store results for current coefficient list
        all_results[titles[i]] = {
            'indices': indices,
            'individual': {
                'statistics': stats_results,
                'p_values': p_results
            },
            'flattened': {
                'statistics': stats_flat,
                'p_values': p_flat
            }
        }
        
        # Generate plots if requested
        if plot:
            ax1, ax2 = axes[0, i], axes[1, i]
            
            # Plot Shapiro-Wilk statistic
            for n in n_values:
                ax1.plot(indices, stats_results[n], marker='o', label=f'n={n} (Lists)')
                ax1.axhline(y=stats_flat[n], linestyle='--', label=f'n={n} (Flattened)', alpha=0.7)
            
            ax1.set_title(titles[i])
            ax1.set_ylabel("Test Statistic")
            ax1.legend()
            ax1.grid(True, linestyle='--', alpha=0.7)
            
            # Plot p-values
            for n in n_values:
                ax2.plot(indices, p_results[n], marker='o', label=f'n={n} (Lists)')
                ax2.axhline(y=p_flat[n], linestyle='--', label=f'n={n} (Flattened)', alpha=0.7)
            
            ax2.set_xlabel("Coefficient List Index")
            ax2.set_ylabel("p-value")
            ax2.axhline(y=0.05, color='r', linestyle='--', label='p=0.05 threshold')
            ax2.legend()
            ax2.grid(True, linestyle='--', alpha=0.7)
    
    if plot:
        plt.tight_layout()
        plt.show()
        
    # Print summary for flattened data
    for i, coeff_list in enumerate(coeff):
        print(f"\nShapiro-Wilk Test Results for Flattened Coefficients - {titles[i]}:")
        for n in n_values:
            stat = all_results[titles[i]]['flattened']['statistics'][n]
            p = all_results[titles[i]]['flattened']['p_values'][n]
            print(f"n={n}: Statistic={stat:.4f}, p-value={p:.4f}, "
                  f"{'Normal' if p > 0.05 else 'Not Normal'}")
        
        # Print summary for a specific index as an example
        example_idx = 0
        print(f"\nExample results for index {list(indices)[example_idx]} (individual list):")
        for n in n_values:
            stat = all_results[titles[i]]['individual']['statistics'][n][example_idx]
            p = all_results[titles[i]]['individual']['p_values'][n][example_idx]
            print(f"n={n}: Statistic={stat:.4f}, p-value={p:.4f}, "
                  f"{'Normal' if p > 0.05 else 'Not Normal'}")
    
    return all_results

def print_periodicity_entropy(morlet_list, header=None,wav=True, artists=None):
    """
    Calculate and print periodicity and wavelet entropy for multiple datasets.

    Parameters:
        morlet_list (list): List of Morlet wavelet data tuples from get_amplitude_phase.
        header (str): Optional header for output (default: None).
        wav (bool): Whether data is from WAV files (True) or MIDI files (False) (default: True).
        artists (list): List of artist names (default: None, uses ["Artist 1", "Artist 2", ...]).

    Returns:
        results (dict): Dictionary with artist names as keys and periodicity/entropy as values.
    """
    # Import necessary libraries
    import numpy as np
    
    # Use default artist names if none provided
    if artists is None:
        artists = [f"Artist {i+1}" for i in range(len(morlet_list))]
    
    # Check if the number of artist names matches the number of morlet lists
    if len(artists) != len(morlet_list):
        raise ValueError("Length of 'artists' must match the number of Morlet wavelet lists in 'morlet_list'")
    
    # Initialize dictionary to store results
    results = {}

    # Print header
    print(header)
    
    # Loop through each artist and calculate periodicity and entropy
    for i, morlet_data in enumerate(morlet_list):
        # Compute periodicity and entropy
        pe_result = periodicity_wavelet_entropy(morlet_data, wav=wav)
        
        # Save results in dictionary
        results[artists[i]] = {
            'Periodicity': pe_result[0],
            'Wavelet Entropy': pe_result[1]
        }
        
        # Print results for the current artist
        print(f"\n{artists[i]}:")
        print(f"Periodicity: {pe_result[0]:.7f}")
        print(f"Wavelet Entropy: {pe_result[1]:.3f}")
    
    return results


# pipeline

## initialize midi files

In [None]:
# moanin (individual)
moanin_lee_midi = pretty_midi.PrettyMIDI('data/moanin/midi/moanin_lee.mid')
moanin_freddie_midi = pretty_midi.PrettyMIDI('data/moanin/midi/moanin_freddie.mid')
moanin_art_midi = pretty_midi.PrettyMIDI('data/moanin/midi/moanin_art.mid')
moanin_roy_midi = pretty_midi.PrettyMIDI('data/moanin/midi/moanin_roy.mid')
moanin_ter_midi = pretty_midi.PrettyMIDI('data/moanin/midi/moanin_ter.mid')
moanin_head_midi = pretty_midi.PrettyMIDI('data/moanin/midi/moanin_head.mid')

# i rememebr clifford
irmb_lee_midi = pretty_midi.PrettyMIDI('data/irmbclifford/midi/irmb_lee.mid')
irmb_fred_midi = pretty_midi.PrettyMIDI('data/irmbclifford/midi/irmb_fred.mid')
irmb_chet_midi = pretty_midi.PrettyMIDI('data/irmbclifford/midi/irmb_chet.mid')
irmb_art_midi = pretty_midi.PrettyMIDI('data/irmbclifford/midi/irmb_art.mid')
irmb_roy_midi = pretty_midi.PrettyMIDI('data/irmbclifford/midi/irmb_roy.mid')
irmb_ter_midi = pretty_midi.PrettyMIDI('data/irmbclifford/midi/irmb_ter.mid')

### songs not used in final writeup, but for previous iterations of capstone ###
# blue train
bluetrain_lee_midi = pretty_midi.PrettyMIDI('data/bluetrain/midi/bluetrain_lee.mid')
bluetrain_chet_midi = pretty_midi.PrettyMIDI('data/bluetrain/midi/bluetrain_chet.mid')

# moaning combined
lee_midi = pretty_midi.PrettyMIDI('data/moanin/midi/lee.mid')
others_midi = pretty_midi.PrettyMIDI('data/moanin/midi/others.mid')

# canon in d
hiromi_midi = pretty_midi.PrettyMIDI('data/canoninD/midi/hiromi.mid')
pachabel_midi = pretty_midi.PrettyMIDI('data/canoninD/midi/pachelbel.mid')

## distributions

### moanin'

In [None]:
moaning_distributions = analyze_midi_distributions([moanin_lee_midi, moanin_freddie_midi, 
                                                    moanin_art_midi, moanin_roy_midi, 
                                                    moanin_ter_midi, moanin_head_midi], 
                                                    ["Lee", "Freddie", "Art", "Roy", "Terence", "Head"], 
                                                    mode='overlay')

#### individual plots of artist vs head

In [None]:
moaning_distributions_lee = analyze_midi_distributions([moanin_lee_midi, moanin_head_midi], 
                                                       ["Lee","Head"], mode='overlay')
moaning_distributions_fred = analyze_midi_distributions([moanin_freddie_midi, moanin_head_midi], 
                                                        ["Freddie", "Head"], mode='overlay')
moaning_distributions_art = analyze_midi_distributions([moanin_art_midi, moanin_head_midi], 
                                                       ["Art","Head"], mode='overlay')
moaning_distributions_roy = analyze_midi_distributions([moanin_roy_midi, moanin_head_midi], 
                                                       ["Roy", "Head"], mode='overlay')
moaning_distributions_ter = analyze_midi_distributions([moanin_ter_midi, moanin_head_midi], 
                                                       ["Terence", "Head"], mode='overlay')

#### f-test and KL divergence

In [None]:
moaning_distributions_artists = [moaning_distributions_lee, 
                                moaning_distributions_fred, 
                                moaning_distributions_art, 
                                moaning_distributions_roy, 
                                moaning_distributions_ter]

compare_variances_to_head(moaning_distributions_artists)

In [None]:
artist_names = ['Lee', 'Freddie', 'Art', 'Roy', 'Terence']
analyze_and_visualize_divergence(moaning_distributions, artist_names, title="Moanin'")

## networks

### moanin'

#### all artists in one plot

In [None]:
midi_files = [
    ('lee', moanin_lee_midi),
    ('freddie', moanin_freddie_midi),
    ('art', moanin_art_midi),
    ('roy', moanin_roy_midi),
    ('terence', moanin_ter_midi)
]

all_transition_probs = []
all_labels = []

for artist, midi_file in midi_files:
    transition_probs, iois_probs, iois = get_prob_transitions(midi_file)
    
    all_transition_probs.append(transition_probs)
    all_labels.append(f"Moanin - {artist.capitalize()}")

# all pitch transitions
G, positions = visualize_all_pitch_trans(
                                    all_transition_probs,
                                    plot_type='overlay',  
                                    labels=all_labels,
                                    title="'Moanin' Pitch Transitions"
                                )

In [None]:
# strongest pitch transitions
visualize_strongest_trans(
    all_transition_probs,
    plot_type='overlay',
    labels=all_labels,
    title="Moanin' Strongest Pitch Transitions",
    fixed_positions=positions
    )

#### each individual artist plot

In [None]:
# all pitch transitions
visualize_all_pitch_trans(
    all_transition_probs,
    plot_type='individual',  
    labels=all_labels,
    title="Moanin' Pitch Transitions",
    fixed_positions=positions
)

In [None]:
# strongest pitch transitions
visualize_strongest_trans(
    all_transition_probs,
    plot_type='individual',
    labels=all_labels,
    title="Moanin' Strongest Pitch Transitions",
    fixed_positions=positions
    )

#### heatmap

In [None]:
# Run simulations for all artists
all_theoretical_dists, all_simulated_dists, all_matrices, all_states_lists = run_all_simulations(
    midi_files,
    num_simulations=100,  # Number of simulations per artist
    num_steps=10000,       # Steps per simulation
    show_individual_plots=False,  # Show individual artist plots
    show_comparison=True  # Show comparison heatmap plot of all artists
)


#### interval types

In [None]:
###### for interval types ######
# get all transition probabilities
moanin_lee_trans = extract_transitions_from_midi('data/moanin/midi/moanin_lee.mid')
moanin_fred_trans = extract_transitions_from_midi('data/moanin/midi/moanin_freddie.mid')
moanin_art_trans = extract_transitions_from_midi('data/moanin/midi/moanin_art.mid')
moanin_roy_trans = extract_transitions_from_midi('data/moanin/midi/moanin_roy.mid')
moanin_ter_trans = extract_transitions_from_midi('data/moanin/midi/moanin_ter.mid')

# and all edgecounts
moanin_lee_edgecount = count_interval_edges(moanin_lee_trans)
moanin_fred_edgecount = count_interval_edges(moanin_fred_trans)
moanin_art_edgecount = count_interval_edges(moanin_art_trans)
moanin_roy_edgecount = count_interval_edges(moanin_roy_trans)
moanin_ter_edgecount = count_interval_edges(moanin_ter_trans)

In [None]:
# visualize for total, strongest only, and each interval type 
# plot for each artist
moanin_lee_nx1 = visualize_pitch_trans_mod(moanin_lee_trans, title="Moanin — Lee", plot_type='overlay', labels=labels)
moanin_lee_nx2 = visualize_strongest_trans_mod(moanin_lee_trans, title="Moanin — Lee", plot_type='overlay', labels=labels)
moanin_lee_nx3 = visualize_pitch_trans_mod(moanin_lee_trans, title="Moanin — Lee", plot_type='individual', labels=labels)

In [None]:
moanin_fred_nx1 = visualize_pitch_trans_mod(moanin_fred_trans, title="Moanin — Freddie", plot_type='overlay', labels=labels)
moanin_fred_nx2 = visualize_strongest_trans_mod(moanin_fred_trans, title="Moanin — Freddie", plot_type='overlay', labels=labels)
moanin_fred_nx3 = visualize_pitch_trans_mod(moanin_fred_trans, title="Moanin — Freddie", plot_type='individual', labels=labels)

In [None]:
moanin_art_nx1 = visualize_pitch_trans_mod(moanin_art_trans, title="Moanin — Art", plot_type='overlay', labels=labels)
moanin_art_nx2 = visualize_strongest_trans_mod(moanin_art_trans, title="Moanin — Art", plot_type='overlay', labels=labels)
moanin_art_nx3 = visualize_pitch_trans_mod(moanin_art_trans, title="Moanin — Art", plot_type='individual', labels=labels)

In [None]:
moanin_roy_nx1 = visualize_pitch_trans_mod(moanin_roy_trans, title="Moanin — Roy", plot_type='overlay', labels=labels)
moanin_roy_nx2 = visualize_strongest_trans_mod(moanin_roy_trans, title="Moanin — Roy", plot_type='overlay', labels=labels)
moanin_roy_nx3 = visualize_pitch_trans_mod(moanin_roy_trans, title="Moanin — Roy", plot_type='individual', labels=labels)

In [None]:
moanin_ter_nx1 = visualize_pitch_trans_mod(moanin_ter_trans, title="Moanin — Terence", plot_type='overlay', labels=labels)
moanin_ter_nx2 = visualize_strongest_trans_mod(moanin_ter_trans, title="Moanin — Terence", plot_type='overlay', labels=labels)
moanin_ter_nx3 = visualize_pitch_trans_mod(moanin_ter_trans, title="Moanin — Terence", plot_type='individual', labels=labels)

In [None]:
# edge counts
print("Edge counts for each interval type")
print("Lee:")
print(moanin_lee_edgecount)
print("Freddie:")
print(moanin_fred_edgecount)
print("Art:")
print(moanin_art_edgecount)
print("Roy:")
print(moanin_roy_edgecount)
print("Terence:")
print(moanin_ter_edgecount)

In [None]:
# plot bar chart of edge counts for each interval type
# and standard deviation bar chart
edge_counts = [
    moanin_lee_edgecount,
    moanin_fred_edgecount,
    moanin_art_edgecount,
    moanin_roy_edgecount,
    moanin_ter_edgecount
]
artists = ['Lee', 'Freddie', 'Art', 'Roy', 'Terence']

plot_interval_usage(edge_counts, artists)

### i remember clifford

#### all artists in one plot

In [None]:
midi_files = [
    ('lee', irmb_lee_midi),
    ('freddie', irmb_fred_midi),
    ('chet', irmb_chet_midi),
    ('art', irmb_art_midi),
    ('roy', irmb_roy_midi),
    ('terence', irmb_ter_midi)
]

all_transition_probs = []
all_labels = []

for artist, midi_file in midi_files:
    transition_probs, iois_probs, iois = get_prob_transitions(midi_file)
    
    all_transition_probs.append(transition_probs)
    all_labels.append(f"I Remember Clifford - {artist.capitalize()}")


# all pitch transitions
G, positions = visualize_all_pitch_trans(
                                    all_transition_probs,
                                    plot_type='overlay',  
                                    labels=all_labels,
                                    title="I Remember Clifford Pitch Transitions"
                                )

In [None]:
# strongest pitch transitions
visualize_strongest_trans(
    all_transition_probs,
    plot_type='overlay',
    labels=all_labels,
    title="I Remember Clifford Pitch Transitions",
    fixed_positions=positions
    )

#### each individual artist plot


In [None]:
# all pitch transitions
visualize_all_pitch_trans(
    all_transition_probs,
    plot_type='individual',  
    labels=all_labels,
    title="I Remember Clifford Pitch Transitions",
    fixed_positions=positions
)

In [None]:
# strongest pitch transitions
visualize_strongest_trans(
    all_transition_probs,
    plot_type='individual',
    labels=all_labels,
    title="I Remember Clifford Pitch Transitions",
    fixed_positions=positions
    )

#### heatmap

In [None]:
# Run simulations for all artists
all_theoretical_dists, all_simulated_dists, all_matrices, all_states_lists = run_all_simulations(
    midi_files,
    num_simulations=100,  # Number of simulations per artist
    num_steps=10000,       # Steps per simulation
    show_individual_plots=False,  # Show individual artist plots
    show_comparison=True  # Show comparison heatmap plot of all artists
)

#### interval types

In [None]:
###### for interval types ######
# get all transition probabilities
irmb_lee_trans = extract_transitions_from_midi('data/irmbclifford/midi/irmb_lee.mid')
irmb_fred_trans = extract_transitions_from_midi('data/irmbclifford/midi/irmb_fred.mid')
irmb_chet_trans = extract_transitions_from_midi('data/irmbclifford/midi/irmb_chet.mid')
irmb_art_trans = extract_transitions_from_midi('data/irmbclifford/midi/irmb_art.mid')
irmb_roy_trans = extract_transitions_from_midi('data/irmbclifford/midi/irmb_roy.mid')
irmb_ter_trans = extract_transitions_from_midi('data/irmbclifford/midi/irmb_ter.mid')

# and all edgecounts
irmb_lee_edgecount = count_interval_edges(irmb_lee_trans)
irmb_fred_edgecount = count_interval_edges(irmb_fred_trans)
irmb_chet_edgecount = count_interval_edges(irmb_chet_trans)
irmb_art_edgecount = count_interval_edges(irmb_art_trans)
irmb_roy_edgecount = count_interval_edges(irmb_roy_trans)
irmb_ter_edgecount = count_interval_edges(irmb_ter_trans)

In [None]:
# visualize for total, strongest only, and each interval type 
# plot for each artist
irmb_lee_nx1 = visualize_pitch_trans_mod(irmb_lee_trans, title="I Remember Clifford — Lee", plot_type='overlay', labels=labels)
irmb_lee_nx2 = visualize_strongest_trans_mod(irmb_lee_trans, title="I Remember Clifford — Lee", plot_type='overlay', labels=labels)
irmb_lee_nx3 = visualize_pitch_trans_mod(irmb_lee_trans, title="I Remember Clifford — Lee", plot_type='individual', labels=labels)

In [None]:
irmb_fred_nx1 = visualize_pitch_trans_mod(irmb_fred_trans, title="I Remember Clifford — Freddie", plot_type='overlay', labels=labels)
irmb_fred_nx2 = visualize_strongest_trans_mod(irmb_fred_trans, title="I Remember Clifford — Freddie", plot_type='overlay', labels=labels)
irmb_fred_nx3 = visualize_pitch_trans_mod(irmb_fred_trans, title="I Remember Clifford — Freddie", plot_type='individual', labels=labels)

In [None]:
irmb_chet_nx1 = visualize_pitch_trans_mod(irmb_chet_trans, title="I Remember Clifford — Chet", plot_type='overlay', labels=labels)
irmb_chet_nx2 = visualize_strongest_trans_mod(irmb_chet_trans, title="I Remember Clifford — Chet", plot_type='overlay', labels=labels)
irmb_chet_nx3 = visualize_pitch_trans_mod(irmb_chet_trans, title="I Remember Clifford — Chet", plot_type='individual', labels=labels)

In [None]:
irmb_art_nx1 = visualize_pitch_trans_mod(irmb_art_trans, title="I Remember Clifford — Art", plot_type='overlay', labels=labels)
irmb_art_nx2 = visualize_strongest_trans_mod(irmb_art_trans, title="I Remember Clifford — Art", plot_type='overlay', labels=labels)
irmb_art_nx3 = visualize_pitch_trans_mod(irmb_art_trans, title="I Remember Clifford — Art", plot_type='individual', labels=labels)

In [None]:
irmb_roy_nx1 = visualize_pitch_trans_mod(irmb_roy_trans, title="I Remember Clifford — Roy", plot_type='overlay', labels=labels)
irmb_roy_nx2 = visualize_strongest_trans_mod(irmb_roy_trans, title="I Remember Clifford — Roy", plot_type='overlay', labels=labels)
irmb_roy_nx3 = visualize_pitch_trans_mod(irmb_roy_trans, title="I Remember Clifford — Roy", plot_type='individual', labels=labels)

In [None]:
irmb_ter_nx1 = visualize_pitch_trans_mod(irmb_ter_trans, title="I Remember Clifford — Terence", plot_type='overlay', labels=labels)
irmb_ter_nx2 = visualize_strongest_trans_mod(irmb_ter_trans, title="I Remember Clifford — Terence", plot_type='overlay', labels=labels)
irmb_ter_nx3 = visualize_pitch_trans_mod(irmb_ter_trans, title="I Remember Clifford — Terence", plot_type='individual', labels=labels)

In [None]:
# edge counts
print("Edge counts for each interval type")
print("Lee:")
print(irmb_lee_edgecount)
print("Freddie:")
print(irmb_fred_edgecount)
print("Chet:")
print(irmb_chet_edgecount)
print("Art:")
print(irmb_art_edgecount)
print("Roy:")
print(irmb_roy_edgecount)
print("Terence:")
print(irmb_ter_edgecount)

In [None]:
# plot bar chart of edge counts for each interval type
# and standard deviation bar chart
edge_counts = [
    irmb_lee_edgecount,
    irmb_fred_edgecount,
    irmb_chet_edgecount,
    irmb_art_edgecount,
    irmb_roy_edgecount,
    irmb_ter_edgecount
]
artists = ['Lee', 'Freddie', 'Chet', 'Art', 'Roy', 'Terence']

plot_interval_usage(edge_counts, artists)

## wavelet

In [None]:
#### convert mp3 to wav files ####
# moanin individual
mp3_to_wav('data/moanin/stemmed/moanin_lee.mp3', 'data/moanin/wav/moanin_lee.wav')
mp3_to_wav('data/moanin/stemmed/moanin_freddie.mp3', 'data/moanin/wav/moanin_fred.wav')
mp3_to_wav('data/moanin/stemmed/moanin_art.mp3', 'data/moanin/wav/moanin_art.wav')
mp3_to_wav('data/moanin/stemmed/moanin_roy.mp3', 'data/moanin/wav/moanin_roy.wav')
mp3_to_wav('data/moanin/stemmed/moanin_ter.mp3', 'data/moanin/wav/moanin_ter.wav')

# i remember clifford
mp3_to_wav('data/irmbclifford/stemmed/irmb_lee.mp3', 'data/irmbclifford/wav/irmb_lee.wav')
mp3_to_wav('data/irmbclifford/stemmed/irmb_fred.mp3', 'data/irmbclifford/wav/irmb_fred.wav')
mp3_to_wav('data/irmbclifford/stemmed/irmb_chet.mp3', 'data/irmbclifford/wav/irmb_chet.wav')
mp3_to_wav('data/irmbclifford/stemmed/irmb_art.mp3', 'data/irmbclifford/wav/irmb_art.wav')
mp3_to_wav('data/irmbclifford/stemmed/irmb_roy.mp3', 'data/irmbclifford/wav/irmb_roy.wav')
mp3_to_wav('data/irmbclifford/stemmed/irmb_ter.mp3', 'data/irmbclifford/wav/irmb_ter.wav')

### songs not used in final writeup, but for previous iterations of capstone ###
# blue train
mp3_to_wav('data/bluetrain/stemmed/bluetrain_lee.mp3', 'data/bluetrain/wav/bluetrain_lee.wav')
mp3_to_wav('data/bluetrain/stemmed/bluetrain_chet.mp3', 'data/bluetrain/wav/bluetrain_chet.wav')

# canon in d
mp3_to_wav('data/canoninD/stemmed/hiromi.mp3', 'data/canoninD/wav/hiromi.wav')
mp3_to_wav('data/canoninD/stemmed/pachelbel.mp3', 'data/canoninD/wav/pachelbel.wav')


### moanin'

#### prep data (midi)
to show the differences in the plots. after this, only used wav data instead as it was more accurate

In [None]:
# get the coefficients
lee_coeff = prep_data_morlet('data/moanin/midi/moanin_lee.mid')
fred_coeff = prep_data_morlet('data/moanin/midi/moanin_freddie.mid')
art_coeff = prep_data_morlet('data/moanin/midi/moanin_art.mid')
roy_coeff = prep_data_morlet('data/moanin/midi/moanin_roy.mid')
ter_coeff = prep_data_morlet('data/moanin/midi/moanin_ter.mid')

#### playing around with resolution by changing the wavelet width ####
###### i did not put this in the final writeup, the resolution of the final image
###### doesn't really seemed to have changed?
# lee_coeff_hr = prep_data_morlet('data/moanin/midi/moanin_lee.mid', wavelet='cmor1.0-1.0')
# fred_coeff_hr = prep_data_morlet('data/moanin/midi/moanin_freddie.mid', wavelet='cmor1.0-1.0')
# art_coeff_hr = prep_data_morlet('data/moanin/midi/moanin_art.mid', wavelet='cmor1.0-1.0')

# lee_morlet_hr = get_amplitude_phase(lee_coeff_hr)
# fred_morlet_hr = get_amplitude_phase(fred_coeff_hr)
# art_morlet_hr = get_amplitude_phase(art_coeff_hr)

In [None]:
# get amplitude and phase
lee_morlet = get_amplitude_phase(lee_coeff)
fred_morlet = get_amplitude_phase(fred_coeff)
art_morlet = get_amplitude_phase(art_coeff)
roy_morlet = get_amplitude_phase(roy_coeff)
ter_morlet = get_amplitude_phase(ter_coeff)

moanin_midi_list_morlet = [lee_morlet, fred_morlet, art_morlet, roy_morlet, ter_morlet]

#### plots (midi)

In [None]:
morlet_analysis(moanin_midi_list_morlet, time_window=(0, 150), titles=['Lee', 'Freddie', 'Art', 'Roy', 'Terence'])

In [None]:
plot_amp_phase(moanin_midi_list_morlet, titles=['Lee', 'Freddie', 'Art', 'Roy', 'Terence'])

In [None]:
plot_amp_phase(moanin_midi_list_morlet, titles=['Lee', 'Freddie', 'Art', 'Roy', 'Terence'], time_window=(18, 22))

In [None]:
##### not used in writeup
# get frequencies amplitudes and coefficients
freqs_list = [lee_morlet[4], fred_morlet[4], art_morlet[4], roy_morlet[4], ter_morlet[4]]
amp_list = [lee_morlet[3], fred_morlet[3], art_morlet[3], roy_morlet[3], ter_morlet[3]]
wavelet_coeff_list = [lee_coeff[0], fred_coeff[0], art_coeff[0], roy_coeff[0], ter_coeff[0]]

# plot dominant frequencies
dominant_freqs_list = plot_dominant_frequencies(freqs_list, amp_list, 
                                                labels=['Lee', 'Freddie', 'Art', 'Roy', 'Terence'], 
                                                title="Moanin' (midi)")


#### periodicity vs entropy (midi)

In [None]:
moanin_period_entropy = print_periodicity_entropy(moanin_midi_list_morlet, header="Moanin' (midi)", artists=['Lee', 'Freddie', 'Art', 'Roy', 'Terence'], wav=False)

In [None]:
entropy_list = [pe['Wavelet Entropy'] for pe in moanin_period_entropy.values()]
periodicity_list = [pe['Periodicity'] for pe in moanin_period_entropy.values()]
artists = list(moanin_period_entropy.keys())

plot_entropy_vs_periodicity(entropy_list, periodicity_list, artists, title="Moanin' (midi)")

#### prep data 

In [None]:
# get the coefficients
lee_coeff = prep_data_morlet('data/moanin/wav/moanin_lee.wav', wav=True)
fred_coeff = prep_data_morlet('data/moanin/wav/moanin_fred.wav', wav=True)
art_coeff = prep_data_morlet('data/moanin/wav/moanin_art.wav', wav=True)
roy_coeff = prep_data_morlet('data/moanin/wav/moanin_roy.wav', wav=True)
ter_coeff = prep_data_morlet('data/moanin/wav/moanin_ter.wav', wav=True)

In [None]:
# get amplitude and phase
lee_morlet = get_amplitude_phase(lee_coeff, wav=True)
fred_morlet = get_amplitude_phase(fred_coeff, wav=True)
art_morlet = get_amplitude_phase(art_coeff, wav=True)
roy_morlet = get_amplitude_phase(roy_coeff, wav=True)
ter_morlet = get_amplitude_phase(ter_coeff, wav=True)

In [None]:
# adjust who you want to see here
moanin_list_morlet = [lee_morlet, fred_morlet, art_morlet, roy_morlet, ter_morlet]
artists_names = ['Lee', 'Freddie', 'Art', 'Roy', 'Terence']

#### plots

In [None]:
morlet_analysis(moanin_list_morlet, wav=True, time_window=(0, 150), 
                titles=artists_names)

In [None]:
plot_amp_phase(moanin_list_morlet, wav=True, 
               titles=artists_names)

In [None]:
#### time window of 36-40 seconds
plot_amp_phase(moanin_list_morlet, wav=True, 
               titles=artists_names, 
               time_window=(36, 40))

In [None]:
#### time window of 18-22 seconds
plot_amp_phase(moanin_list_morlet, wav=True, 
               titles=artists_names, 
               time_window=(18,22))

In [None]:
##### not used in writeup
# get frequencies amplitudes and coefficients
freqs_list = [lee_morlet[4], fred_morlet[4], art_morlet[4], roy_morlet[4], ter_morlet[4]]
amp_list = [lee_morlet[3], fred_morlet[3], art_morlet[3], roy_morlet[3], ter_morlet[3]]
wavelet_coeff_list = [lee_coeff[0], fred_coeff[0], art_coeff[0], roy_coeff[0], ter_coeff[0]]

# plot dominant frequencies
dominant_freqs_list = plot_dominant_frequencies(freqs_list, amp_list, 
                                                labels=['Lee', 'Freddie', 'Art', 'Roy', 'Terence'], 
                                                title="Moanin'")

In [None]:
shapiro_wilk_test(wavelet_coeff_list, plot=True, titles=['Lee', 'Freddie', 'Art', 'Roy', 'Terence'])

#### periodicity vs entropy

In [None]:
moanin_period_entropy = print_periodicity_entropy(moanin_list_morlet, 
                                                  header="Moanin'", 
                                                  artists=artists_names)

In [None]:
entropy_list = [pe['Wavelet Entropy'] for pe in moanin_period_entropy.values()]
periodicity_list = [pe['Periodicity'] for pe in moanin_period_entropy.values()]
artists = list(moanin_period_entropy.keys())

plot_entropy_vs_periodicity(entropy_list, periodicity_list, artists, title="Moanin'")

### i remember clifford

#### prep data

In [None]:
# get the coefficients
lee_coeff = prep_data_morlet('data/irmbclifford/wav/irmb_lee.wav', wav=True)
chet_coeff = prep_data_morlet('data/irmbclifford/wav/irmb_chet.wav', wav=True)
fred_coeff = prep_data_morlet('data/irmbclifford/wav/irmb_fred.wav', wav=True)
art_coeff = prep_data_morlet('data/irmbclifford/wav/irmb_art.wav', wav=True)
roy_coeff = prep_data_morlet('data/irmbclifford/wav/irmb_roy.wav', wav=True)
ter_coeff = prep_data_morlet('data/irmbclifford/wav/irmb_ter.wav', wav=True)

In [None]:
# get amplitude and phase
lee_morlet = get_amplitude_phase(lee_coeff, wav=True)
chet_morlet = get_amplitude_phase(chet_coeff, wav=True)
fred_morlet = get_amplitude_phase(fred_coeff, wav=True)
art_morlet = get_amplitude_phase(art_coeff, wav=True)
roy_morlet = get_amplitude_phase(roy_coeff, wav=True)
ter_morlet = get_amplitude_phase(ter_coeff, wav=True)

In [None]:
# adjust who you want to see here
irmb_list_morlet = [lee_morlet, chet_morlet, fred_morlet, art_morlet, roy_morlet, ter_morlet]
artists_names = ['Lee', 'Chet', 'Freddie', 'Art', 'Roy', 'Terence']

#### plots

In [None]:
morlet_analysis(irmb_list_morlet, wav=True, time_window=(0, 150), 
                titles=artists_names)

In [None]:
plot_amp_phase(irmb_list_morlet, wav=True, 
               titles=artists_names)

In [None]:
#### time window of 22-24 seconds
plot_amp_phase(irmb_list_morlet, wav=True, 
               titles=artists_names, 
               time_window=(22,24))

In [None]:
##### not used in writeup
# get frequencies amplitudes and coefficients
freqs_list = [lee_morlet[4], chet_morlet[4], fred_morlet[4], art_morlet[4], roy_morlet[4], ter_morlet[4]]
amp_list = [lee_morlet[3], chet_morlet[3], fred_morlet[3], art_morlet[3], roy_morlet[3], ter_morlet[3]]
wavelet_coeff_list = [lee_coeff[0], chet_coeff[0], fred_coeff[0], art_coeff[0], roy_coeff[0], ter_coeff[0]]

# plot dominant frequencies
dominant_freqs_list = plot_dominant_frequencies(freqs_list, amp_list, 
                                                labels=['Lee', 'Chet', 'Freddie', 'Art', 'Roy', 'Terence'], 
                                                title="I Remember Clifford")


In [None]:
shapiro_wilk_test(wavelet_coeff_list, plot=True, titles=['Lee', 'Chet', 'Freddie', 'Art', 'Roy', 'Terence'])

#### periodicity vs entropy

In [None]:
irmb_period_entropy = print_periodicity_entropy(irmb_list_morlet, 
                                                header="I Remember Clifford", 
                                                artists=artists_names)

In [None]:
entropy_list = [pe['Wavelet Entropy'] for pe in irmb_period_entropy.values()]
periodicity_list = [pe['Periodicity'] for pe in irmb_period_entropy.values()]
artists = list(irmb_period_entropy.keys())

plot_entropy_vs_periodicity(entropy_list, periodicity_list, artists, title="I Remember Clifford")