In [None]:
import os
# If running inside "tests" folder, move up one level
pwd = os.getcwd()
if pwd.endswith("tests"):
    os.chdir(os.path.dirname(pwd))

import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Audio
from marble.utils.utils import id2chord_str

def display_sample_audio_and_labels(dataset: torch.utils.data.Dataset, sample_idx: int):
    """
    Visualizes an audio sample and its corresponding chord labels from the dataset.
    Also supports audio playback in Jupyter Notebooks.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to extract the sample from.
        sample_idx (int): The index of the sample to visualize.
    """
    waveform, targets, audio_path = dataset[sample_idx]
    
    # Plot the waveform with reduced opacity for clearer visibility of labels
    plt.figure(figsize=(12, 6))
    plt.plot(np.linspace(0, len(waveform[0]) / dataset.sample_rate, len(waveform[0])), waveform[0].numpy(), color='gray', alpha=0.5)
    plt.title(f"Waveform of {audio_path}")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    
    # Plot the chord labels over the waveform
    label_seq = targets.numpy()
    time_axis = np.linspace(0, len(label_seq) / dataset.label_freq, len(label_seq))

    # Merge consecutive segments with the same chord label
    merged_segments = []
    current_label = label_seq[0]
    segment_start_time = time_axis[0]

    for i in range(1, len(label_seq)):
        if label_seq[i] != current_label:
            segment_end_time = time_axis[i]
            merged_segments.append((segment_start_time, segment_end_time, current_label))
            current_label = label_seq[i]
            segment_start_time = time_axis[i]
    
    # Add the last segment
    merged_segments.append((segment_start_time, time_axis[-1], current_label))

    # Display chord labels in the center of each segment
    for start_time, end_time, chord_label in merged_segments:
        # Assign color based on chord label (e.g., H chord gets a distinct color)
        chord_color = "blue"  # Changed to blue for a neutral tone
        # Display the chord label in the center of the segment
        plt.text((start_time + end_time) / 2, 0.5, id2chord_str(chord_label), color=chord_color, fontsize=12, ha='center', va='center')

        # Draw a vertical line at the segment boundary
        plt.axvline(x=start_time, color='black', linestyle='-', linewidth=1)
        plt.axvline(x=end_time, color='black', linestyle='-', linewidth=1)

    # Tight layout to avoid label clipping
    plt.tight_layout()
    plt.show()

    # Playback the audio in the notebook
    audio_data = waveform.squeeze().numpy()  # Remove channel dimension if mono
    return Audio(audio_data, rate=dataset.sample_rate)

from marble.tasks.Chords1217.datamodule import Chords1217AudioTrain

# Example usage:
dataset = Chords1217AudioTrain(sample_rate=44100, channels=1, clip_seconds=5.0, jsonl="data/Chords1217/Chords1217.train.jsonl", label_freq=10)
sample_idx = 100  # Any index you'd like to inspect
display_sample_audio_and_labels(dataset, sample_idx)
