In [1]:
import os
import pickle
# import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# from scipy.stats import norm

# from utils_subdivision.gen_distribution_single_plots import analyze_phases
# from utils_subdivision.gen_distribution_subplot import analyze_single_type    # plot_combined_results
# from utils_subdivision.gen_distribution_merged_plot import plot_merged
# from utils_dot_plot.drum_single import analyze_phases
# from utils_dot_plot.drum_merged import plot_merged_per_mode

# from utils_subdivision.gen_distribution_subplot import analyze_single_type
from utils_dot_plot.kinematic_dot_plot import *
from utils_dot_plot.drum_merged import *
from utils_dot_plot.drum_dance_piece import *


PIECE_TYPES = ["Suku", "Maraka", "Manjanin", "Wasulunka"]

base_output_dir =  "output_dot_plots"
by_piece_dir = os.path.join(base_output_dir, "by_piece")
os.makedirs(base_output_dir, exist_ok=True)
os.makedirs(by_piece_dir, exist_ok=True)


## All Modes

## Drum Kde Plot All Modes Combined

In [50]:
def plot_combined_drum_stacked_all_modes(
    piece_type,
    piece_drum_phases_kde,  # Dictionary containing all modes' data
    figsize=(10, 3),
    dpi=200,
    legend_flag=True
):
    """Create a single plot showing combined drum onset analysis for all pieces and all modes."""
    
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    vertical_ranges = {
        'Dun': (1, 6),
        'J1': (8, 13),
        'J2': (15, 20)
    }

    # Fixed colors for each drum type
    drum_colors = {
        'Dun': '#1f77b4',   # blue
        'J1': '#d62728',    # red
        'J2': '#2ca02c'     # green
    }

    combined_phases = []
    
    # Process each mode
    for mode in piece_drum_phases_kde.keys():
        if piece_type not in piece_drum_phases_kde[mode]:
            continue
            
        # Get the list of data for this piece type in this mode
        piece_data_list = piece_drum_phases_kde[mode][piece_type]
        
        # Combine data for each drum type
        for drum_type, color in drum_colors.items():
            # Combine phases and y_scaled from all pieces
            all_phases = []
            all_y_scaled = []

            # Loop through all pieces' data
            for piece_data in piece_data_list:
                if drum_type in piece_data:
                    # Combine phases and y_scaled
                    all_phases.extend(piece_data[drum_type]["phases"])
                    all_y_scaled.extend(piece_data[drum_type]["y_scaled"])
                    combined_phases.extend(piece_data[drum_type]["phases"])

            if not all_phases:  # Skip if no data found
                continue

            # Convert lists to numpy arrays
            phases = np.array(all_phases)
            y_scaled = np.array(all_y_scaled)
            
            # Normalize y_scaled values to fit within the vertical range
            y_min, y_max = vertical_ranges[drum_type]
            y_scaled = y_min + (y_scaled - np.min(y_scaled)) * (y_max - y_min) / (np.max(y_scaled) - np.min(y_scaled))
                
            # Plot scatter with drum-specific color
            ax.scatter(phases * 400,
                      y_scaled,
                      s=5, alpha=0.4,  # Reduced alpha for better visibility of overlapping points
                      color=color,
                      label=f'{drum_type}')

    # Combined KDE at bottom using kde_estimate
    if len(combined_phases) > 0:
        kde_xx, kde_h = kde_estimate(np.array(combined_phases), SIG=0.01)
        
        # Only plot the region that maps to the x-axis
        mask = (kde_xx * 400 >= -33) & (kde_xx * 400 <= 400)
        kde_xx_plot = kde_xx[mask]
        kde_h_plot = kde_h[mask]
        
        if np.max(kde_h_plot) > 0:
            kde_scaled = -5 + (5 * kde_h_plot / np.max(kde_h_plot))
            ax.fill_between(kde_xx_plot * 400, -5, kde_scaled, alpha=0.3, color='purple', label='Combined KDE')

    # Subdivision lines
    for subdiv in range(1, 13):
        color = get_subdiv_color(subdiv)
        x_pos = ((subdiv-1) * 400) / 12
        
        if subdiv in [1, 4, 7, 10]:
            ax.vlines(x_pos, -5.5, 20.5, color=color, linestyle='-', linewidth=1.5, alpha=0.7)
        else:
            ax.vlines(x_pos, -5.5, 20.5, color=color, linestyle='--', linewidth=1, alpha=0.3)

    # Styling
    xtick = [0, 100, 200, 300, 400]
    xtick_labels = [1, 2, 3, 4, 5]
    
    ax.set_xticks(xtick)
    ax.set_xticklabels(xtick_labels)
    ax.set_xlim(-33, 400)
    ax.set_xlabel('Beat span')
    
    ax.set_ylim(-5.5, 20.5)
    ax.set_yticks([3, 10, 17])
    ax.set_yticklabels(['Dun', 'J1', 'J2'])
    ax.set_ylabel('Drum')
    ax.grid(True, alpha=0.3)

    # Title & legend
    title = f'Piece: {piece_type} | All Modes Combined'
    # title += f' | Combined from {len(piece_drum_phases_kde)} modes'
    ax.set_title(title, pad=10)
    
    if legend_flag:
        ax.legend(loc='upper left', framealpha=0.4, fontsize=6)

    return fig, ax

In [53]:
# Load all modes' data
piece_drum_phases_kde = {}
for dance_mode in ["group", "individual", "audience"]:
    load_path = os.path.join(base_output_dir, f"piece_drum_phases_kde_{dance_mode}.pkl")
    with open(load_path, 'rb') as f:
        piece_drum_phases_kde[dance_mode] = pickle.load(f)

# Create combined plots for each piece type
PIECE_TYPES = ["Suku", "Maraka", "Manjanin", "Wasulunka"]
for piece_type in PIECE_TYPES:
    fig, ax = plot_combined_drum_stacked_all_modes(
        piece_type=piece_type,
        piece_drum_phases_kde=piece_drum_phases_kde,
        legend_flag=False
    )
    
    # Save the figure
    save_dir = os.path.join(by_piece_dir, "drum_kde_by_piece", "all_modes")
    os.makedirs(save_dir, exist_ok=True)
    
    save_path = os.path.join(save_dir, f"{piece_type}_all_modes_combined.png")
    plt.savefig(save_path, bbox_inches='tight', dpi=200)
    plt.close()
    
# dict_keys(['group', 'individual', 'audience'])
# dict_keys(['Suku', 'Maraka', 'Wasulunka', 'Manjanin'])
# dict_keys(['Dun', 'J1', 'J2'])
# dict_keys(['phases', 'y_scaled', 'kde_h', 'kde_xx'])

## Dance Kde Plot All Modes Combined

In [52]:
def plot_combined_foot_stacked_all_modes(
    piece_type,
    piece_dance_phases_kde,  # Dictionary containing all modes' data
    figsize=(10, 3),
    dpi=200,
    legend_flag=True
):
    """Create a single plot showing combined foot analysis for all pieces and all modes."""
    
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    vertical_ranges = {
        'left': (1, 6),
        'right': (8, 13),
    }

    # Fixed colors for feet
    foot_colors = {
        'left': '#1f77b4',   # blue
        'right': '#d62728'   # red
    }

    combined_phases = []
    
    # Process each mode
    for mode in piece_dance_phases_kde.keys():
        if piece_type not in piece_dance_phases_kde[mode]:
            continue
            
        # Get the list of data for this piece type in this mode
        piece_data_list = piece_dance_phases_kde[mode][piece_type]
        
        # Combine data for each foot
        for foot_type, color in foot_colors.items():
            # Combine phases and y_scaled from all pieces
            all_phases = []
            all_y_scaled = []

            # Loop through all pieces' data
            for piece_data in piece_data_list:
                if foot_type in piece_data["phases"]:
                    # Combine phases and y_scaled
                    all_phases.extend(piece_data["phases"][foot_type])
                    all_y_scaled.extend(piece_data["y_scaled"][foot_type])
                    combined_phases.extend(piece_data["phases"][foot_type])

            if not all_phases:  # Skip if no data found
                continue

            # Convert lists to numpy arrays
            phases = np.array(all_phases)
            y_scaled = np.array(all_y_scaled)
            
            # Normalize y_scaled values to fit within the vertical range
            y_min, y_max = vertical_ranges[foot_type]
            y_scaled = y_min + (y_scaled - np.min(y_scaled)) * (y_max - y_min) / (np.max(y_scaled) - np.min(y_scaled))
                
            # Plot scatter with foot-specific color
            ax.scatter(phases * 400,
                      y_scaled,
                      s=5, alpha=0.4,  # Reduced alpha for better visibility of overlapping points
                      color=color,
                      label=f'{foot_type.capitalize()} Foot')

    # Rest of the function remains the same...
    # Combined KDE at bottom using kde_estimate
    if len(combined_phases) > 0:
        kde_xx, kde_h = kde_estimate(np.array(combined_phases), SIG=0.01)
        
        # Only plot the region that maps to the x-axis
        mask = (kde_xx * 400 >= -33) & (kde_xx * 400 <= 400)
        kde_xx_plot = kde_xx[mask]
        kde_h_plot = kde_h[mask]
        
        if np.max(kde_h_plot) > 0:
            kde_scaled = -5 + (5 * kde_h_plot / np.max(kde_h_plot))
            ax.fill_between(kde_xx_plot * 400, -5, kde_scaled, alpha=0.3, color='purple', label='Combined KDE')

    # Subdivision lines
    for subdiv in range(1, 13):
        color = get_subdiv_color(subdiv)
        x_pos = ((subdiv-1) * 400) / 12
        
        if subdiv in [1, 4, 7, 10]:
            ax.vlines(x_pos, -5.5, 13.5, color=color, linestyle='-', linewidth=1.5, alpha=0.7)
        else:
            ax.vlines(x_pos, -5.5, 13.5, color=color, linestyle='--', linewidth=1, alpha=0.3)

    # Styling
    xtick = [0, 100, 200, 300, 400]
    xtick_labels = [1, 2, 3, 4, 5]
    
    ax.set_xticks(xtick)
    ax.set_xticklabels(xtick_labels)
    ax.set_xlim(-33, 400)
    ax.set_xlabel('Beat span')
    
    ax.set_ylim(-5.5, 13.5)
    ax.set_yticks([3, 10])
    ax.set_yticklabels(['LF', 'RF'])
    ax.set_ylabel('Foot')
    ax.grid(True, alpha=0.3)

    # Title & legend
    title = f'Piece: {piece_type} | All Modes Combined'
    # title += f' | Combined from {len(piece_dance_phases_kde)} modes'
    ax.set_title(title, pad=10)
    
    if legend_flag:
        ax.legend(loc='upper left', framealpha=0.4, fontsize=6)

    return fig, ax

# """

In [29]:

# Load data for all modes
piece_dance_phases_kde = {}
for mode in ["group", "individual", "audience"]:
    load_path = os.path.join(base_output_dir, f"piece_dance_phases_kde_{mode}.pkl")
    with open(load_path, 'rb') as f:
        piece_dance_phases_kde[mode] = pickle.load(f)

# Create combined plots for each piece type
for piece_type in PIECE_TYPES:
    if piece_type in piece_dance_phases_kde["group"]:  # Check if piece exists in any mode
        fig, ax = plot_combined_foot_stacked_all_modes(
            piece_type=piece_type,
            piece_dance_phases_kde=piece_dance_phases_kde,
            legend_flag=False
        )
        
        # Save the figure
        # save_dir = by_piece_dir
        save_dir = os.path.join(by_piece_dir, "dance_kde_by_piece", "all_modes")
        os.makedirs(save_dir, exist_ok=True)
        
        save_path = os.path.join(save_dir, f"{piece_type}_all_modes_combined.png")
        plt.savefig(save_path, bbox_inches='tight', dpi=200)
        plt.close()


dict_keys(['group', 'individual', 'audience'])