# Musical Pattern Recognition

**Goal:** Identify repeating patterns (bars/phrases/sections)

In [None]:
# Loads the autoreload extension
%load_ext autoreload
# Automatically reloads all modules before executing any code
%autoreload 2

In [None]:
# External Imports
import matplotlib.pyplot as plt
import torch as torch
import numpy as np
import pypianoroll as pr

# Internal Imports
import sys, os
sys.path.append(os.path.abspath('src'))

from src.util.types import Song, PianoState, NoteSample, PianoStateSamples
from src.util.globals import resolution, beats_per_bar, num_pitches
from src.util.process_audio import quantize_pianoroll
import src.util.plot as plot


from src.dataset.load import (
    load_multi_track,
    get_track_by_instrument,
)

In [None]:
# Testing Sample
dir_id = 'TRAAAGR128F425B14B'
song_id = 'b97c529ab9ef783a849b896816001748'

desired_instrument = 'Bass'

# EXAMPLE: Load a NPZ file into a Multitrack object.
multi_track = load_multi_track(f'A/A/A/TRAAAGR128F425B14B/b97c529ab9ef783a849b896816001748.npz')
bass_track = get_track_by_instrument(multi_track, desired_instrument)

# Binarize Note Velocity!
binary_track = bass_track.binarize()

track_pr = binary_track.pianoroll.astype(int)

print('unique values =', np.unique(track_pr))

# Quantize!
quantized_pr = quantize_pianoroll(track_pr, resolution//2)

plot.plot_pianoroll(quantized_pr, tick_resolution=resolution*8)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.colors as mcolors
from matplotlib.ticker import MultipleLocator


def window_similarity(pianoroll: np.ndarray, window_size: int):
	"""
	Find repeating sections in piano roll

	Args:
		piano_roll: 2D array (time, pitch)
		X_bars: number of bars per window
		beats_per_bar: typically 4
		resolution: time steps per beat (4 = 16th notes)
	"""

	# Sliding window view
	windows: np.ndarray = np.lib.stride_tricks.sliding_window_view(
		pianoroll,
		window_shape=(window_size, pianoroll.shape[1]),
		axis=(0, 1)
	)[::window_size, 0]  # Non-overlapping windows

	# Flatten windows for similarity computation
	windows_flat = windows.reshape(windows.shape[0], -1)

	# Compute similarity matrix
	similarity_matrix = cosine_similarity(windows_flat)

	return len(windows), similarity_matrix

def plot_sim_matrix(SM):
	plt.figure(figsize=(10, 8))
	plt.imshow(SM, cmap='hot', interpolation='nearest')
	plt.colorbar(label='Cosine Similarity')
	plt.title('Self-Similarity Matrix')
	plt.xlabel('Window Index')
	plt.ylabel('Window Index')
	plt.show()

def find_similair_pairs(sm: np.ndarray, threshold=0.99) -> list[tuple[int,int]]:
	''' Find similar pairs (excluding diagonal) '''
	similar_pairs = []
	for i in range(len(sm)):
		for j in range(i+1, len(sm)):
			if sm[i,j] > threshold:
				similar_pairs.append((i, j))
	return similar_pairs

# Create groups:
def create_groups(sim_pairs: list[tuple[int,int]], n_windows: int):
	'''
	Returns:

		tuple:

		1. List of each window's group number(n_windows,)
		2. List of group, each containing a set of window indexes (n_groups, ?)
	'''
	groups: list[list[int]] = []
	window_groups = [-1] * n_windows  # assign each bar to a group

	for a, b in sim_pairs:
		if window_groups[a] == -1 and window_groups[b] == -1:
			# Both items are unassigned - create new group
			group_id = len(groups)
			window_groups[a] = group_id
			window_groups[b] = group_id
			groups.append([a, b])

		elif window_groups[a] == -1 and window_groups[b] != -1:
			# Add a to b's group
			groups[window_groups[b]].append(a)
			window_groups[a] = window_groups[b]

		elif window_groups[a] != -1 and window_groups[b] == -1:
			# Add b to a's group
			groups[window_groups[a]].append(b)
			window_groups[b] = window_groups[a]

		elif window_groups[a] != window_groups[b]:
			# Both belong to different groups - merge them
			group_a = window_groups[a]
			group_b = window_groups[b]

			# Move all items from group_b to group_a
			groups[group_a].extend(groups[group_b])

			# Update group assignments for all items that were in group_b
			for item in groups[group_b]:
				window_groups[item] = group_a

			# Clear group_b
			groups[group_b] = []

		# If window_groups[a] == window_groups[b] and both != -1,
		# they're already in the same group - do nothing

	# Filter out empty groups
	final_groups = [group for group in groups if group]

	return window_groups, final_groups

def plot_window_groups(
	pianoroll: np.ndarray,
	desired_instrument: str,
	window_resolution: int,
	groups: list = []
):

	n = len(groups)

	num_windows = int(pianoroll.shape[0] / window_resolution)

	# plot the track (with bars)

	# plot the track (with bars)
	_, ax = plt.subplots(figsize=(12, 6))
	ax.imshow(pianoroll.T,
				 aspect='auto',
				 origin='lower',
				 cmap='binary',
				 interpolation='none')

	# Axis Ticks <-- windows
	ax.xaxis.set_major_locator(MultipleLocator(window_resolution))
	ticks = ax.get_xticks()
	ax.set_xticklabels([int(tick / window_resolution) for tick in ticks])

	# Vertical lines at ticks
	ax.grid(visible=True, axis='x')

  	# Add plot labels
	title = f'{desired_instrument} Track - {num_windows} windows of size {window_resolution}, {n} groups'
	ax.set_title(title)
	ax.set_xlabel(f'Windows (={window_resolution} beats)')

	if n != 0:
		# Groups are provided

		# Define distinct colors groups
		colors = [mcolors.to_hex(mcolors.hsv_to_rgb([i/n, 0.8, 0.9])) for i in range(n)]

		# Create a mapping from window index to group color
		window_to_color = [-1] * num_windows
		for group_idx, group in enumerate(groups):
			for window_idx in group:
				window_to_color[window_idx] = colors[group_idx]

		# Color background regions for each WINDOW based on groups
		for window_idx in range(num_windows):
			color = window_to_color[window_idx]
			if color != -1:
				ax.axvspan(
					xmin=window_idx * window_resolution,
					xmax=(window_idx + 1) * window_resolution,
					color=color,
					alpha=0.3,
					zorder=0
				)

		# Add legend if groups are provided
		legend_elements = []
		for group_idx, group in enumerate(groups):
			color = colors[group_idx]
			legend_elements.append(plt.Line2D([0], [0], color=color, lw=4,
											label=f'Group {group_idx} ({len(group)} windows)'))
		ax.legend(handles=legend_elements, loc='upper right')

	plt.tight_layout()
	plt.show()

In [None]:
window_resolution = 32 * resolution
n_windows, sim_matrix = window_similarity(
	quantized_pr,
	window_resolution
)

plot_sim_matrix(sim_matrix)

sim_pairs = find_similair_pairs(sim_matrix, threshold=0.95)
_, groups = create_groups(sim_pairs, n_windows)

print(f'Got {n_windows} windows @ resolution={window_resolution} beats')
print(f'Found {len(groups)} unique groups:')
print(*groups, sep='\n')

In [None]:
plot_window_groups(quantized_pr, desired_instrument, window_resolution, groups)