In [1]:
import random
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio
import os

# If running inside "tests" folder, move up one level
pwd = os.getcwd()
if pwd.endswith("tests"):
    os.chdir(os.path.dirname(pwd))


def visualize_and_play_with_onset_tone(
    dataset,
    beat_freq=440,
    downbeat_freq=880,
    tone_duration=0.05,
    tone_amplitude=0.3,
):
    """
    Randomly select one clip from the dataset, visualize the original waveform and onset masks,
    generate separate square-wave pulses for beats and downbeats at different frequencies,
    overlay them onto the original audio, visualize the combined waveform and onset masks,
    and display a frequency-domain plot of the combined signal. Finally, return an Audio
    widget to play back the combined signal.

    Args:
        dataset: an instance of a GTZANBeatTracking dataset (inherited from _GTZANBeatTrackingAudioBase)
        beat_freq: frequency (Hz) of the square-wave pulse for beats (default 440 Hz)
        downbeat_freq: frequency (Hz) of the square-wave pulse for downbeats (default 880 Hz)
        tone_duration: duration (seconds) of each square-wave pulse (default 0.05 s)
        tone_amplitude: amplitude of each square-wave pulse (0.0 < amplitude <= 1.0)
    """

    # 1. Randomly pick one clip index
    idx = random.randrange(len(dataset))
    waveform, label_dict, audio_path = dataset[idx]
    # waveform: Tensor(shape=(channels, clip_len_samples))
    # label_dict["beat"]: Tensor(shape=(label_len,))
    # label_dict.get("downbeat"): Tensor(shape=(label_len,)) (if present)

    # 2. Convert waveform to numpy and select the first channel if multi-channel
    wav_np = waveform.cpu().numpy()
    if wav_np.ndim > 1:
        wav_np = wav_np[0, :]
    clip_len = wav_np.shape[0]

    fs = dataset.sample_rate           # sample rate (samples per second)
    clip_seconds = dataset.clip_seconds  # clip duration in seconds

    # Build time axis for waveform
    t_wav = np.linspace(0, clip_seconds, num=clip_len, endpoint=False)

    # 3. Extract beat and downbeat masks and compute their time arrays
    beat_mask = label_dict["beat"].cpu().numpy()  # shape = (label_len,)
    db_mask = label_dict.get("downbeat", None)
    if db_mask is not None:
        db_mask = db_mask.cpu().numpy()

    label_len = beat_mask.shape[0]
    label_freq = dataset.label_freq
    # Time axis for masks: each frame i corresponds to time i / label_freq
    t_mask = np.arange(label_len) / label_freq

    # Compute beat times and downbeat times in seconds
    beat_times = t_mask[beat_mask.astype(bool)]
    if db_mask is not None:
        db_times = t_mask[db_mask.astype(bool)]
    else:
        db_times = np.array([], dtype=float)

    # 4. Prepare square-wave pulses for beat and downbeat separately
    tone_len = int(round(tone_duration * fs))
    if tone_len < 1:
        tone_len = 1

    # Time axis for a single pulse
    t_tone = np.linspace(0, tone_duration, num=tone_len, endpoint=False)

    # Square-wave pulse for beats at beat_freq
    beat_pulse = tone_amplitude * np.sign(np.sin(2 * np.pi * beat_freq * t_tone))
    beat_pulse = beat_pulse.astype(np.float32)

    # Square-wave pulse for downbeats at downbeat_freq
    db_pulse = tone_amplitude * np.sign(np.sin(2 * np.pi * downbeat_freq * t_tone))
    db_pulse = db_pulse.astype(np.float32)

    # 5. Create a combined signal by copying the original waveform
    combined = wav_np.copy().astype(np.float32)

    # Overlay beat pulses at each beat time
    for bt in beat_times:
        start_idx = int(round(bt * fs))
        end_idx = start_idx + tone_len
        if start_idx >= clip_len:
            continue
        if end_idx <= clip_len:
            combined[start_idx:end_idx] += beat_pulse
        else:
            valid_len = clip_len - start_idx
            combined[start_idx:clip_len] += beat_pulse[:valid_len]

    # Overlay downbeat pulses at each downbeat time
    for dt in db_times:
        start_idx = int(round(dt * fs))
        end_idx = start_idx + tone_len
        if start_idx >= clip_len:
            continue
        if end_idx <= clip_len:
            combined[start_idx:end_idx] += db_pulse
        else:
            valid_len = clip_len - start_idx
            combined[start_idx:clip_len] += db_pulse[:valid_len]

    # 6. Clip the combined signal to [-1.0, +1.0] to avoid distortion
    combined = np.clip(combined, -1.0, +1.0)

    # 7. Visualization
    # Create a figure with three subplots:
    #   1) Time-domain: original vs combined waveform with onset vertical lines
    #   2) Time-domain: beat and downbeat masks (stem plots)
    #   3) Frequency-domain: magnitude spectrum of the combined signal

    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 10), constrained_layout=True)

    # 7.1 Time-domain waveforms
    ax1 = axes[0]
    ax1.plot(t_wav, wav_np, color='gray', linewidth=1, label='Original Waveform')
    ax1.plot(t_wav, combined, color='orange', linewidth=1, alpha=0.6, label='Combined Waveform')
    ax1.set_xlim(0, clip_seconds)
    ax1.set_ylabel("Amplitude")
    ax1.set_title(f"Time-Domain Waveform with Onset Pulses\n{audio_path} (clip idx={idx})")

    # Draw vertical lines for beat times (red dashed) and downbeat times (blue dash-dot)
    if beat_times.size > 0:
        for i, bt in enumerate(beat_times):
            if i == 0:
                ax1.axvline(bt, color='r', linestyle='--', alpha=0.7, label='Beat Onset')
            else:
                ax1.axvline(bt, color='r', linestyle='--', alpha=0.7)
    if db_times.size > 0:
        for i, dt in enumerate(db_times):
            if i == 0:
                ax1.axvline(dt, color='b', linestyle='-.', alpha=0.8, label='Downbeat Onset')
            else:
                ax1.axvline(dt, color='b', linestyle='-.', alpha=0.8)

    ax1.legend(loc='upper right')

    # 7.2 Onset masks (stem plots)
    ax2 = axes[1]
    markerline1, stemlines1, baseline1 = ax2.stem(
        t_mask, beat_mask, linefmt='r-', markerfmt='ro', basefmt='k-', label='Beat Mask'
    )
    plt.setp(markerline1, markersize=4)
    if db_mask is not None:
        markerline2, stemlines2, baseline2 = ax2.stem(
            t_mask, db_mask, linefmt='b-', markerfmt='bs', basefmt='k-', label='Downbeat Mask'
        )
        plt.setp(markerline2, markersize=4)
    ax2.set_xlim(0, clip_seconds)
    ax2.set_ylim(-0.1, 1.1)
    ax2.set_xlabel("Time (s)")
    ax2.set_ylabel("Mask Value")
    ax2.set_title("Beat & Downbeat Onset Masks")
    ax2.legend(loc='upper right')

    # 7.3 Frequency-domain: magnitude spectrum of combined signal
    ax3 = axes[2]
    # Compute real FFT of the combined signal
    fft_vals = np.fft.rfft(combined)
    fft_freq = np.fft.rfftfreq(clip_len, d=1.0 / fs)
    magnitude = np.abs(fft_vals)

    # Plot magnitude spectrum (in linear scale)
    ax3.plot(fft_freq, magnitude, color='purple', linewidth=0.8)
    ax3.set_xlim(0, fs / 2)
    ax3.set_xlabel("Frequency (Hz)")
    ax3.set_ylabel("Magnitude")
    ax3.set_title("Frequency-Domain (Magnitude Spectrum) of Combined Signal")

    plt.show()

    # 8. Return an Audio widget for playback of the combined signal
    print("▶️ Playing the combined audio with onset-synchronized square-wave pulses:")
    return Audio(combined, rate=fs)



In [None]:
from marble.tasks.GTZANBeatTracking.datamodule import GTZANBeatTrackingAudioTrain

# Instantiate the dataset (adjust paths as needed)
dataset = GTZANBeatTrackingAudioTrain(
    sample_rate=22050,
    channels=1,
    clip_seconds=10.0,
    jsonl="data/GTZAN/GTZANBeatTracking.val.jsonl",
    label_freq=50,
    num_neighbors=2,
    channel_mode="mix",
    min_clip_ratio=0.5,
)

# Call the function; in a Jupyter environment the Audio widget will display automatically
audio_widget = visualize_and_play_with_onset_tone(
    dataset,
    beat_freq=440,
    downbeat_freq=880,
    tone_duration=0.05,
    tone_amplitude=0.3,
)
audio_widget  # Display the IPython Audio player in Jupyter

In [14]:
from marble.tasks.GTZANBeatTracking.datamodule import GTZANBeatTrackingAudioTrain

# Instantiate the dataset (adjust paths as needed)
dataset = GTZANBeatTrackingAudioTrain(
    sample_rate=22050,
    channels=1,
    clip_seconds=10.0,
    jsonl="data/GTZAN/GTZANBeatTracking.val.jsonl",
    label_freq=75,
    num_neighbors=2,
    channel_mode="mix",
    min_clip_ratio=0.5,
)

In [15]:
total_num_pos_beat = 0
total_num_neg_beat = 0
total_num_pos_downbeat = 0
total_num_neg_downbeat = 0

for idx in range(len(dataset)):
    beat = dataset[idx][1]['beat']
    downbeat = dataset[idx][1]['downbeat']

    total_num_pos_beat += torch.sum(beat).item()
    total_num_neg_beat += beat.shape[0] - torch.sum(beat).item()
    total_num_pos_downbeat += torch.sum(downbeat).item()
    total_num_neg_downbeat += downbeat.shape[0] - torch.sum(downbeat).item()

loss_weight_beat = total_num_neg_beat / total_num_pos_beat
loss_weight_downbeat = total_num_neg_downbeat / total_num_pos_downbeat
print(f"pos loss weight for beat: {loss_weight_beat:.2f}")
print(f"pos loss weight for downbeat: {loss_weight_downbeat:.2f}")

pos loss weight for beat: 13.78
pos loss weight for downbeat: 57.40
