In [48]:
import torch
import numpy as np
from statistics import mean, stdev
import gc
from tqdm import tqdm
from scipy.stats import entropy
from music21 import converter, key, note, chord, pitch
import os
import random
import mido
import string
from sklearn.metrics.pairwise import cosine_similarity

In [49]:
midi_directory = "C:/Users/hp/Desktop/SEM6_PROJECTS/MIDIMix/Full_Dataset"

In [50]:
rounds = 100
samples_per_round = 500
min_timesteps = 6000

In [51]:
def msg2dict(msg):
    result = dict()
    if 'note_on' in msg:
        on_ = True
    elif 'note_off' in msg:
        on_ = False
    else:
        on_ = None
    result['time'] = int(msg[msg.rfind('time'):].split(' ')[0].split('=')[1].translate(
        str.maketrans({a: None for a in string.punctuation})))

    if on_ is not None:
        for k in ['note', 'velocity']:
            result[k] = int(msg[msg.rfind(k):].split(' ')[0].split('=')[1].translate(
                str.maketrans({a: None for a in string.punctuation})))
    return [result, on_]

In [52]:
def switch_note(last_state, note, velocity, on_=True):
    # piano has 88 notes, corresponding to note id 21 to 108, any note out of this range will be ignored
    result = [0] * 88 if last_state is None else last_state.copy()
    if 21 <= note <= 108:
        result[note-21] = velocity if on_ else 0
    return result

In [53]:
def get_new_state(new_msg, last_state):
    new_msg, on_ = msg2dict(str(new_msg))
    new_state = switch_note(last_state, note=new_msg['note'], velocity=new_msg['velocity'], on_=on_) if on_ is not None else last_state
    return [new_state, new_msg['time']]

In [54]:
def track2seq(track):
    result = []
    last_state, last_time = get_new_state(str(track[0]), [0]*88)
    for i in range(1, len(track)):
        new_state, new_time = get_new_state(track[i], last_state)
        if new_time > 0:
            result += [last_state]*new_time
        last_state, last_time = new_state, new_time
    return result

In [55]:
def mid2arry(mid, min_msg_pct=0.1):
    tracks_len = [len(tr) for tr in mid.tracks]
    min_n_msg = max(tracks_len) * min_msg_pct
    # convert each track to nested list
    all_arys = []
    for i in range(len(mid.tracks)):
        if len(mid.tracks[i]) > min_n_msg:
            ary_i = track2seq(mid.tracks[i])
            all_arys.append(ary_i)
    # make all nested list the same length
    max_len = max([len(ary) for ary in all_arys])
    for i in range(len(all_arys)):
        if len(all_arys[i]) < max_len:
            all_arys[i] += [[0] * 88] * (max_len - len(all_arys[i]))
    all_arys = np.array(all_arys)
    all_arys = all_arys.max(axis=0)
    # trim: remove consecutive 0s in the beginning and at the end
    sums = all_arys.sum(axis=1)
    ends = np.where(sums > 0)[0]
    return all_arys[min(ends): max(ends)]

In [56]:
harmony_scores = {
    'Arabian tetramirror': 0.25,
    'Augmented Eleventh': 0.50,
    'Augmented Fifth': 0.60,
    'Augmented Fifth with octave doublings': 0.55,
    'Augmented Fourth': 0.65,
    'Augmented Fourth with octave doublings': 0.60,
    'Augmented Ninth': 0.55,
    'Augmented Octave': 0.60,
    'Augmented Octave with octave doublings': 0.55,
    'Augmented Second': 0.60,
    'Augmented Second with octave doublings': 0.55,
    'Augmented Sixth': 0.60,
    'Augmented Sixth with octave doublings': 0.55,
    'Augmented Tenth': 0.60,
    'Augmented Third': 0.60,
    'Augmented Third with octave doublings': 0.55,
    'Augmented Thirteenth': 0.55,
    'Augmented Twelfth': 0.55,
    'Augmented Unison': 0.45,
    'Augmented Unison with octave doublings': 0.40,
    'Diminished Eleventh': 0.45,
    'Diminished Fifth': 0.55,
    'Diminished Fifth with octave doublings': 0.50,
    'Diminished Fourth': 0.50,
    'Diminished Fourth with octave doublings': 0.45,
    'Diminished Octave': 0.40,
    'Diminished Octave with octave doublings': 0.35,
    'Diminished Seventh': 0.60,
    'Diminished Seventh with octave doublings': 0.55,
    'Diminished Sixth': 0.50,
    'Diminished Sixth with octave doublings': 0.45,
    'Diminished Tenth': 0.40,
    'Diminished Third': 0.45,
    'Diminished Third with octave doublings': 0.40,
    'Diminished Twelfth': 0.35,
    'French augmented sixth chord': 0.50,
    'French augmented sixth chord in first inversion': 0.45,
    'French augmented sixth chord in root position': 0.50,
    'French augmented sixth chord in third inversion': 0.40,
    'German augmented sixth chord': 0.45,
    'German augmented sixth chord in root position': 0.50,
    'German augmented sixth chord in second inversion': 0.45,
    'German augmented sixth chord in third inversion': 0.40,
    'Italian augmented sixth chord': 0.50,
    'Italian augmented sixth chord in root position': 0.55,
    'Italian augmented sixth chord in second inversion': 0.50,
    'Major Fourteenth': 0.55,
    'Major Ninth': 0.70,
    'Major Second': 0.60,
    'Major Second with octave doublings': 0.55,
    'Major Seventh': 0.85,
    'Major Seventh with octave doublings': 0.80,
    'Major Sixth': 0.75,
    'Major Sixth with octave doublings': 0.70,
    'Major Tenth': 0.60,
    'Major Third': 0.65,
    'Major Third with octave doublings': 0.60,
    'Major Thirteenth': 0.65,
    'Messiaen\'s truncated mode 6': 0.40,
    'Minor Fourteenth': 0.50,
    'Minor Ninth': 0.60,
    'Minor Second': 0.55,
    'Minor Second with octave doublings': 0.50,
    'Minor Seventh': 0.60,
    'Minor Seventh with octave doublings': 0.55,
    'Minor Sixth': 0.60,
    'Minor Sixth with octave doublings': 0.55,
    'Minor Tenth': 0.50,
    'Minor Third': 0.50,
    'Minor Third with octave doublings': 0.45,
    'Minor Thirteenth': 0.55,
    'Perfect Eleventh': 0.75,
    'Perfect Fifth': 1.0,
    'Perfect Fifth with octave doublings': 0.90,
    'Perfect Fourth': 0.90,
    'Perfect Fourth with octave doublings': 0.85,
    'Perfect Octave': 1.0,
    'Perfect Twelfth': 0.85,
    'all-interval tetrachord': 0.30,
    'alternating tetramirror': 0.35,
    'augmented major tetrachord': 0.50,
    'augmented seventh chord': 0.65,
    'augmented triad': 0.75,
    'chromatic tetramirror': 0.20,
    'chromatic trimirror': 0.10,
    'diminished seventh chord': 0.65,
    'diminished triad': 0.80,
    'dominant seventh chord': 0.85,
    'double tritone tetramirror': 0.30,
    'double-fourth tetramirror': 0.35,
    'enharmonic equivalent to diminished triad': 0.70,
    'enharmonic equivalent to half-diminished seventh chord': 0.60,
    'enharmonic equivalent to major seventh chord': 0.75,
    'enharmonic equivalent to major triad': 0.90,
    'enharmonic equivalent to minor seventh chord': 0.60,
    'enharmonic equivalent to minor triad': 0.80,
    'enharmonic to dominant seventh chord': 0.80,
    'half-diminished seventh chord': 0.75,
    'harmonic minor tetrachord': 0.50,
    'incomplete dominant-seventh chord': 0.55,
    'incomplete half-diminished seventh chord': 0.60,
    'incomplete major-seventh chord': 0.70,
    'incomplete minor-seventh chord': 0.65,
    'lydian tetrachord': 0.55,
    'major seventh chord': 0.85,
    'major third tetracluster': 0.30,
    'major triad': 1.0,
    'major-diminished tetrachord': 0.50,
    'major-minor tetramirror': 0.45,
    'major-minor trichord': 0.55,
    'major-second major tetrachord': 0.50,
    'major-second minor tetrachord': 0.50,
    'major-second tetracluster': 0.40,
    'major-third diminished tetrachord': 0.50,
    'minor seventh chord': 0.60,
    'minor tetramirror': 0.40,
    'minor third tetracluster': 0.45,
    'minor triad': 0.95,
    'minor trichord': 0.50,
    'minor-augmented tetrachord': 0.45,
    'minor-diminished tetrachord': 0.45,
    'minor-second diminished tetrachord': 0.45,
    'minor-second quartal tetrachord': 0.50,
    'note': 0.0,
    'perfect fourth tetramirror': 0.40,
    'perfect-fourth diminished tetrachord': 0.45,
    'perfect-fourth major tetrachord': 0.50,
    'perfect-fourth minor tetrachord': 0.50,
    'phrygian tetrachord': 0.55,
    'phrygian trichord': 0.50,
    'quartal tetramirror': 0.40,
    'quartal trichord': 0.45,
    'tritone quartal tetrachord': 0.35,
    'tritone-fourth': 0.40,
    'whole-tone tetramirror': 0.30,
    'whole-tone trichord': 0.35,
    'unknown': 0.0
}


In [57]:
piano_notes_dict = {
    'A0': 21, 'A#0': 22, 'B0': 23, 'C1': 24, 'C#1': 25, 'D1': 26, 'D#1': 27, 'E1': 28, 'F1': 29, 'F#1': 30, 'G1': 31, 'G#1': 32, 
    'A1': 33, 'A#1': 34, 'B1': 35, 'C2': 36, 'C#2': 37, 'D2': 38, 'D#2': 39, 'E2': 40, 'F2': 41, 'F#2': 42, 'G2': 43, 'G#2': 44, 
    'A2': 45, 'A#2': 46, 'B2': 47, 'C3': 48, 'C#3': 49, 'D3': 50, 'D#3': 51, 'E3': 52, 'F3': 53, 'F#3': 54, 'G3': 55, 'G#3': 56, 
    'A3': 57, 'A#3': 58, 'B3': 59, 'C4': 60, 'C#4': 61, 'D4': 62, 'D#4': 63, 'E4': 64, 'F4': 65, 'F#4': 66, 'G4': 67, 'G#4': 68, 
    'A4': 69, 'A#4': 70, 'B4': 71, 'C5': 72, 'C#5': 73, 'D5': 74, 'D#5': 75, 'E5': 76, 'F5': 77, 'F#5': 78, 'G5': 79, 'G#5': 80, 
    'A5': 81, 'A#5': 82, 'B5': 83, 'C6': 84, 'C#6': 85, 'D6': 86, 'D#6': 87, 'E6': 88, 'F6': 89, 'F#6': 90, 'G6': 91, 'G#6': 92, 
    'A6': 93, 'A#6': 94, 'B6': 95, 'C7': 96, 'C#7': 97, 'D7': 98, 'D#7': 99, 'E7': 100, 'F7': 101, 'F#7': 102, 'G7': 103, 'G#7': 104, 
    'A7': 105, 'A#7': 106, 'B7': 107, 'C8': 108
}

In [58]:
note_piano_dict = {v: k for k, v in piano_notes_dict.items()}

In [59]:
def post_processing(generated_melody):
    processed_melody = []
    
    for row in generated_melody:
        k = np.random.randint(1, 5)  # Random value between 1 and 4
        top_k_indices = np.argpartition(-row, k-1)[:k]  # Indices of top-k values
        new_row = np.zeros_like(row)
        new_row[top_k_indices] = row[top_k_indices]

        # Scale to MIDI velocity range and convert to int
        new_row = (new_row * 127).astype(int)

        repeated_rows = np.tile(new_row, (150, 1))  # Repeat 150 times vertically
        processed_melody.append(repeated_rows)

    # Combine into a piano roll and clip to max length 8000
    piano_roll = np.vstack(processed_melody)[:8000]
    return piano_roll


In [60]:
def average_velocity_range(piano_roll):
    """
    Computes the average velocity range (max - min) across all timesteps, including zero velocities.
    Scales each timestep's values to the range of 0 to 127.

    Args:
        piano_roll (np.ndarray): Shape (timesteps, 88), velocity values (0-1).

    Returns:
        float: Average velocity range including zeros after scaling.
    """
    velocity_differences = []

    # Scale each row (timestep) to the range of 0 to 127
    for timestep in piano_roll:
        # Check if the timestep is non-empty and has valid velocity values
        if np.any(timestep):  # Ensures that the timestep has non-zero values
            max_v = timestep.max()
            min_v = timestep.min()
            diff = max_v - min_v
            velocity_differences.append(diff)
        else:
            # If timestep is empty (or has only zero values), add zero velocity range
            velocity_differences.append(0)

    return np.mean(velocity_differences) if velocity_differences else 0.0


In [61]:
def chord_consonance_and_dissonance(piano_roll):
    harmony_scores_for_all_steps = []
    for time_step in piano_roll:
        on_notes = []
        for i, note_velocity in enumerate(time_step):
            if note_velocity > 0:
                on_notes.append(note_piano_dict.get(i+21)) 
        if on_notes:
            c_chord = chord.Chord(on_notes)
            common_name = c_chord.commonName
            score = harmony_scores.get(common_name, 0.0) 
            harmony_scores_for_all_steps.append(score) 
        else:
            pass
        
    if harmony_scores_for_all_steps:
        return sum(harmony_scores_for_all_steps) / len(harmony_scores_for_all_steps)
    else:
        return 0.0

In [62]:
def deduplicate_timesteps(piano_roll):
    unique_steps = [piano_roll[0]]
    for i in range(1, len(piano_roll)):
        if not np.array_equal(piano_roll[i], piano_roll[i - 1]):
            unique_steps.append(piano_roll[i])
    return np.array(unique_steps)

In [63]:
def rhythmic_measure(deduped_roll, binarize=True):
    if binarize:
        input_roll = (deduped_roll > 0).astype(int)
    else:
        input_roll = deduped_roll

    sim_matrix = cosine_similarity(input_roll)
    mask = ~np.eye(sim_matrix.shape[0], dtype=bool)
    return sim_matrix[mask].mean()

In [64]:
def pitch_entropy(piano_roll):
    piano_roll_binarized = (piano_roll > 0).astype(int)
    pitch_activations = piano_roll_binarized.sum(axis=0)  
    total = pitch_activations.sum()
    if total == 0:
        return 0  
    pitch_probs = pitch_activations / total
    return entropy(pitch_probs, base=2)

In [65]:
def extreme_pitch_density(piano_roll, low_extreme_end=13, high_extreme_start=72):
    """
    Calculates the percentage of notes played that are in the extreme pitch range.

    Args:
        piano_roll (np.ndarray): Shape (timesteps, 88), binary or velocity values.
        low_extreme_end (int): Last index of low extreme range (inclusive), B1
        high_extreme_start (int): Starting index of high extreme range (inclusive), C6

    Returns:
        float: Proportion of notes in extreme ranges.
    """
    total_notes = 0
    extreme_notes = 0

    for timestep in piano_roll:
        for i, note_on in enumerate(timestep):
            if note_on > 0:
                total_notes += 1
                if i <= low_extreme_end or i >= high_extreme_start:
                    extreme_notes += 1

    return extreme_notes / total_notes if total_notes > 0 else 0

In [84]:
metric_names = [
    "chord_consonance_and_dissonance",
    "rhythmic_measure",
    "pitch_entropy",
    "extreme_pitch_density",
    "average_velocity_range"
]

In [89]:
def get_valid_midi_files(midi_directory, num_files=500):
    all_files = [f for f in os.listdir(midi_directory) if f.endswith('.mid') or f.endswith('.midi')]
    random.shuffle(all_files)
    valid_midi_files = []

    for midi_file in all_files:
        if len(valid_midi_files) >= num_files:
            break
        midi_path = os.path.join(midi_directory, midi_file)
        try:
            mid = mido.MidiFile(midi_path)
            arr = mid2arry(mid)

            if arr.size == 0:
                continue

        except Exception:
            continue

        if len(arr) > min_timesteps:
            valid_midi_files.append(arr)

    return valid_midi_files

In [90]:
valid_midi_files = get_valid_midi_files(midi_directory, num_files=5000)

if len(valid_midi_files) < 5000:
    print(f"Warning: Only found {len(valid_midi_files)} valid MIDI files.")

# Initialize metrics
all_metrics = {name: [] for name in metric_names}

# Evaluate all files (CPU-only)
for piano_roll in tqdm(valid_midi_files, desc="Evaluating 5000 MIDI files"):
    deduped = deduplicate_timesteps(piano_roll)
    if deduped.shape[0] <= 1:
        continue
    all_metrics["average_velocity_range"].append(average_velocity_range(piano_roll))
    all_metrics["pitch_entropy"].append(pitch_entropy(piano_roll))
    all_metrics["extreme_pitch_density"].append(extreme_pitch_density(piano_roll))
    all_metrics["rhythmic_measure"].append(rhythmic_measure(deduped))
    all_metrics["chord_consonance_and_dissonance"].append(chord_consonance_and_dissonance(piano_roll))

  return sim_matrix[mask].mean()
  ret = ret.dtype.type(ret / rcount)
  return sim_matrix[mask].mean()
  ret = ret.dtype.type(ret / rcount)
Evaluating 5000 MIDI files: 100%|████████████████████████████████████████████████| 5000/5000 [1:11:18<00:00,  1.17it/s]


In [100]:
final_metrics = {}
for name, values in all_metrics.items():
    clean_values = [v for v in values if not np.isnan(v)]
    final_metrics[name] = {
        "mean": round(np.mean(clean_values), 4),
        "std": round(np.std(clean_values), 4)
    }

print("\nFinal Evaluation Metrics (from 5000 samples):")
for name, stats in final_metrics.items():
    print(f"{name}: {stats['mean']} ± {stats['std']}")


Final Evaluation Metrics (from 5000 samples):
chord_consonance_and_dissonance: 0.3628 ± 0.2244
rhythmic_measure: 0.1648 ± 0.0899
pitch_entropy: 3.6068 ± 0.6853
extreme_pitch_density: 0.0362 ± 0.0755
average_velocity_range: 89.2089 ± 20.1188
