In [6]:
import tempfile
import os
import music21
import pandas as pd
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
from collections import Counter
from scipy.stats import entropy
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import signal
from multiprocessing import Pool, cpu_count

warnings.filterwarnings('ignore')

# --- Configuration ---
USE_SMALL_SUBSET = False
TEST_SAMPLE_SIZE = 200
SAMPLES_PER_STYLE = 10000
FILE_TIMEOUT_SECONDS = 60

METRICS_LIST = [
    'note_density', 'avg_duration', 'pitch_range', 
    'rhythmic_entropy', 'chord_ratio', 'pitch_class_entropy', 'melodic_interval_mean'
]

# --- Metric Functions ---
def calc_note_density(all_notes, score):
    if not all_notes: return 0
    total_duration = score.highestTime
    return len(all_notes) / total_duration if total_duration > 0 else 0

def calc_avg_duration(all_notes, score):
    durations = [n.duration.quarterLength for n in all_notes]
    return np.mean(durations) if durations else 0

def calc_pitch_range(all_notes, score):
    pitches = []
    for n in all_notes:
        if n.isNote: pitches.append(n.pitch.ps)
        elif n.isChord: pitches.extend([p.ps for p in n.pitches])
    return (max(pitches) - min(pitches)) if pitches else 0

def calc_rhythmic_entropy(all_notes, score):
    durations = [n.duration.quarterLength for n in all_notes]
    if durations:
        counts = Counter(durations)
        probs = [c / len(durations) for c in counts.values()]
        return entropy(probs, base=2)
    return 0

def calc_chord_ratio(all_notes, score):
    if len(all_notes) > 0:
        chord_count = sum(1 for n in all_notes if n.isChord)
        return chord_count / len(all_notes)
    return 0

def calc_pitch_class_entropy(all_notes, score):
    pcs = []
    for n in all_notes:
        if n.isNote: pcs.append(n.pitch.pitchClass)
        elif n.isChord: pcs.extend([p.pitchClass for p in n.pitches])
    if pcs:
        counts = Counter(pcs)
        probs = [c / len(pcs) for c in counts.values()]
        return entropy(probs, base=2)
    return 0

def calc_melodic_interval_mean(all_notes, score):
    sorted_notes = sorted(all_notes, key=lambda x: x.offset)
    p_seq = []
    for n in sorted_notes:
        if n.isNote: p_seq.append(n.pitch.ps)
        elif n.isChord and n.pitches: p_seq.append(max(p.ps for p in n.pitches))
    
    if len(p_seq) > 1:
        diffs = [abs(p_seq[i] - p_seq[i-1]) for i in range(1, len(p_seq))]
        return np.mean(diffs)
    return 0

METRIC_FUNCTIONS = {
    'note_density': calc_note_density,
    'avg_duration': calc_avg_duration,
    'pitch_range': calc_pitch_range,
    'rhythmic_entropy': calc_rhythmic_entropy,
    'chord_ratio': calc_chord_ratio,
    'pitch_class_entropy': calc_pitch_class_entropy,
    'melodic_interval_mean': calc_melodic_interval_mean
}

# --- Timeout Handling ---
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException

def process_single_sample(sample):
    raw_style = sample.get('style', None)
    if raw_style:
        style_label = str(raw_style).strip().title()
    else:
        return None

    midi_bytes = sample['midi']
    fp = tempfile.NamedTemporaryFile(suffix=".mid", delete=False)
    
    try:
        fp.write(midi_bytes)
        fp.close()

        # Set Linux timeout
        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(FILE_TIMEOUT_SECONDS)

        score = music21.converter.parse(fp.name, forceSource=True)
        
        # Disable timeout
        signal.alarm(0)

        all_notes = score.flat.notes
        if len(all_notes) == 0:
            return None

        features = {
            'style': style_label,
            'composer': sample.get('composer', 'Unknown')
        }

        for metric_name in METRICS_LIST:
            calc_func = METRIC_FUNCTIONS.get(metric_name)
            if calc_func:
                features[metric_name] = calc_func(all_notes, score)
            else:
                features[metric_name] = 0
        
        return features

    except (TimeoutException, Exception):
        return None
    finally:
        signal.alarm(0)
        if os.path.exists(fp.name):
            try: os.remove(fp.name)
            except: pass

if __name__ == '__main__':
    print("Loading Dataset...")
    full_dataset = load_dataset('TiMauzi/imslp-midi-by-sa', split='train')

    if USE_SMALL_SUBSET:
        print(f"Test Mode: {TEST_SAMPLE_SIZE} samples")
        analysis_dataset = full_dataset.shuffle(seed=42).select(range(TEST_SAMPLE_SIZE))
    else:
        print("Full Mode: Analyzing entire dataset")
        analysis_dataset = full_dataset

    # Detect CPU cores (Linux/Slurm optimized)
    try:
        num_cores = len(os.sched_getaffinity(0))
    except AttributeError:
        num_cores = max(1, cpu_count() - 1)
        
    print(f"Using {num_cores} cores")
    
    results = []
    style_counts = {}

    with Pool(processes=num_cores) as pool:
        for data in tqdm(pool.imap_unordered(process_single_sample, analysis_dataset), total=len(analysis_dataset)):
            if data:
                s = data['style']
                if style_counts.get(s, 0) < SAMPLES_PER_STYLE:
                    style_counts[s] = style_counts.get(s, 0) + 1
                    results.append(data)
                
    print("\nGenerating Charts...")
    df = pd.DataFrame(results)

    if not df.empty:
        # --- Data Cleaning (Critical for preventing TypeErrors) ---
        print("Cleaning data types...")
        for metric in METRICS_LIST:
            if metric in df.columns:
                df[metric] = pd.to_numeric(df[metric], errors='coerce')
        
        initial_len = len(df)
        df.dropna(subset=METRICS_LIST, inplace=True)
        if len(df) < initial_len:
            print(f"Dropped {initial_len - len(df)} invalid rows.")

        # --- Filtering for Target Styles ---
        TARGET_STYLES = ['Baroque', 'Romantic']
        df_plot = df[df['style'].isin(TARGET_STYLES)].copy()

        if df_plot.empty:
            print("Warning: No Baroque or Romantic samples found.")
        else:
            print(f"Plotting for: {TARGET_STYLES}")
            print(df_plot['style'].value_counts())

            order = TARGET_STYLES
            sns.set_style("whitegrid")
            
            for metric in METRICS_LIST:
                try:
                    plt.figure(figsize=(8, 6))
                    sns.boxplot(x='style', y=metric, data=df_plot, order=order, palette="Set3", showfliers=False)
                    sns.stripplot(x='style', y=metric, data=df_plot, order=order, color='black', alpha=0.3, jitter=True)
                    plt.title(f'{metric}')
                    plt.ylabel(metric)
                    plt.xlabel('Style')
                    plt.tight_layout()
                    plt.savefig(f'analysis_boxplot_{metric}.png')
                    plt.close()
                except Exception as e:
                    print(f"Error plotting {metric}: {e}")
            
            print("Charts saved.")

        df.to_csv('final_analysis_results.csv', index=False)
        print("Saved final_analysis_results.csv")
    else:
        print("No valid data extracted.")

Loading Dataset...
Full Mode: Analyzing entire dataset
Using 16 cores


100%|██████████| 5593/5593 [08:45<00:00, 10.64it/s]



Generating Charts...
Cleaning data types...
Dropped 82 invalid rows.
Plotting for: ['Baroque', 'Romantic']
style
Romantic    2168
Baroque     1589
Name: count, dtype: int64
Charts saved.
Saved final_analysis_results.csv


In [8]:
import pandas as pd

# Set Pandas display options
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)

# 1. Load Data
df = pd.read_csv('final_analysis_results.csv')
# A. Data Availability
style_counts = df['style'].value_counts()
print("1. Sample Counts per Style:")
print(style_counts)
print("\n" + "-"*30 + "\n")

# Define metrics
metrics = [
    'note_density', 'avg_duration', 'pitch_range', 
    'rhythmic_entropy', 'chord_ratio', 
    'pitch_class_entropy', 'melodic_interval_mean'
]

# Force numeric conversion for metrics
for m in metrics:
    if m in df.columns:
        df[m] = pd.to_numeric(df[m], errors='coerce')

# Filter for styles with sufficient samples (> 50)
valid_styles = style_counts[style_counts > 50].index
df_filtered = df[df['style'].isin(valid_styles)]

# B. Calculate Aggregate Statistics (Mean, Median, Std)
agg_df = df_filtered.groupby('style')[metrics].agg(['mean', 'median', 'std']).round(2)

print(f"2. Statistical analysis performed on {len(valid_styles)} valid styles.")
print("   (Styles with >50 samples included)\n")

# C. Generate Summary CSV
# Flatten multi-index columns for cleaner CSV headers
agg_df_flat = agg_df.copy()
agg_df_flat.columns = ['_'.join(col).strip() for col in agg_df.columns.values]
agg_df_flat.reset_index(inplace=True)

output_filename = 'style_statistics_summary.csv'
agg_df_flat.to_csv(output_filename, index=False)
print(f"Statistics summary saved to: {output_filename}")
print(f"File dimensions: {len(agg_df_flat)} rows x {len(agg_df_flat.columns)} columns")
print("\n" + "-"*30 + "\n")

# D. Style Leaderboard (for Phase 2 selection)
print("3. Style Leaderboard (Distinctive Features):")
print("(Identifying styles with maximum divergence)\n")

means_df = df_filtered.groupby('style')[metrics].mean()

for m in metrics:
    if m in means_df.columns:
        # Identify highest and lowest
        highest_style = means_df[m].idxmax()
        highest_val = means_df[m].max()
        
        lowest_style = means_df[m].idxmin()
        lowest_val = means_df[m].min()
        
        print(f"Metric: {m}")
        print(f"   Highest: {highest_style:<15} (Mean: {highest_val:.2f})")
        print(f"   Lowest:  {lowest_style:<15} (Mean: {lowest_val:.2f})")
        print("-" * 40)

1. Sample Counts per Style:
style
Romantic                 2168
Baroque                  1589
Renaissance               679
Classical                 419
Modern                    402
Early 20Th Century        187
Traditional                24
Jazz                       13
Medieval                    6
Ancient                     6
Non-Western Classical       3
Name: count, dtype: int64

------------------------------

2. Statistical analysis performed on 6 valid styles.
   (Styles with >50 samples included)

Statistics summary saved to: style_statistics_summary.csv
File dimensions: 6 rows x 22 columns

------------------------------

3. Style Leaderboard (Distinctive Features):
(Identifying styles with maximum divergence)

Metric: note_density
   Highest: Early 20Th Century (Mean: 4.69)
   Lowest:  Renaissance     (Mean: 2.68)
----------------------------------------
Metric: avg_duration
   Highest: Renaissance     (Mean: 1.60)
   Lowest:  Classical       (Mean: 0.67)
----------------