## Imports

In [29]:
import torch
import random
from tqdm import tqdm
import numpy as np
import soundfile as sf
from IPython.display import Audio
from torchinfo import summary
from models import GrooveIQ
from data import CANONICAL_DRUM_MAP, SampleData, DrumMIDIFeature, DrumMIDIDataset
from copy import deepcopy
import yaml
import pickle
import os
import yaml

## Setup

In [30]:
# ======= Set Experiment Path Here =======
EXPT_PATH = "expts/giq_exp5_heur_causal"
DATASET_PATH = "dataset/serialized/merged_ts=4-4_2bar_tr0.80-va0.10-te0.10_test.pkl"
CHECKPOINT_PATH = os.path.join(EXPT_PATH, "checkpoints", "checkpoint-ep1-model.pth")

# ======= Config =======
config_path = os.path.join(EXPT_PATH, "config.yaml")
with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


# ======= Audio Save Directory =======
audio_save_dir = os.path.join(EXPT_PATH, "_renders")
os.makedirs(audio_save_dir, exist_ok=True)

# ======= Parameters =======
# Mapping for button sequence
fixed_grid_drum_mapping = {pitch: [i] for i, pitch in enumerate(CANONICAL_DRUM_MAP.keys())}
MAX_LENGTH = 33
E = 9 # Number of drum instruments
M = 3 # Number of steps per quarter

# ======= Model =======
model_config = config["model"]
model_config.update(
    T=MAX_LENGTH,
    E=E,
    M=M
)

# ======= Load Model =======
model = GrooveIQ(**model_config)
input_size = [(4, 33, 9, 3)]
print(f"Loading model from {CHECKPOINT_PATH}: ", end="")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=True)
print(model.load_state_dict(checkpoint['model_state_dict'], strict=True))
summary(model, input_size=input_size, device = device)

# ======= Load Dataset =======
test_dataset = DrumMIDIDataset(
    path     = DATASET_PATH,
    num_bars = config["data"]["num_bars"],
    feature_type = config["data"]["feature_type"],
    steps_per_quarter = config["data"]["steps_per_quarter"],
    subset   = 1.0,
    aug_config = config["data"]["aug_config"],
    calc_desc = config["data"]["calc_desc"]
)

Using device: cpu
Loading model from expts/giq_exp5_heur_causal/checkpoints/checkpoint-ep1-model.pth: <All keys matched successfully>
Loading dataset from: dataset/serialized/merged_ts=4-4_2bar_tr0.80-va0.10-te0.10_test.pkl...
Processing 38370 samples...


Accumulating:: 100%|██████████| 38370/38370 [00:06<00:00, 6091.79sample/s]


Skipped 0 samples due to errors.
Loaded and processed 38370 samples.



In [31]:
def inference(model : GrooveIQ, sample : SampleData, button_hits : torch.Tensor, grid : torch.Tensor, device : str = "cpu", threshold : float = 0.85):
    """
    Inference function for GrooveIQ2 model.

    Args:
        model (GrooveIQ2): GrooveIQ2 model
        sample (SampleData): SampleData object
        button_hits (Tensor): Button hits tensor used to create button_embed. If None, button_hits is created from button_repr.
                              Shape: (1, T, num_buttons)
        
        grid (Tensor): Grid tensor used to create z. If None, z is sampled from prior.
                       Shape: (1, T, E, M)
        threshold (float, optional): Threshold for hit probability. Defaults to 0.85.
        device (str, optional): Device to use. Defaults to "cpu".

    Returns:
        DrumMIDIFeature: Generated feature
        DrumMIDIFeature: Button feature
    """
    model.to(device)
    model.eval()
    with torch.no_grad():
        # Encode grid if provided
        encoded, button_repr = None, None
        if grid is not None:
            encoded, button_repr = model.encode(grid)

        # Make button_embed from either user-provided button_hits or learned button_repr
        # depending on the model's configuration
        # If both are available, button_hits is used
        if button_hits is None:
            if button_repr is None:
                raise ValueError("button_repr is None. Either provide button_hits or grid.")
            button_hits = model.make_button_hits(button_repr)
            
        button_embed = model.make_button_embed(button_hits)
        if encoded is None:
            z, _, _ = model.make_z_prior(button_embed)
        else:
            z, _, _ = model.make_z_post(button_embed, encoded)

        generated_grid, _ = model.generate(button_embed, z, max_steps=MAX_LENGTH, threshold=threshold)
        generated_grid = generated_grid[:, 1:, :, :] # Drop SOS token
        generated_sample = sample.from_fixed_grid(generated_grid.squeeze(0), steps_per_quarter=4)
        generated_feature = generated_sample.feature

        button_hvo = torch.cat(
                [
                    button_hits.unsqueeze(-1), 
                    torch.ones_like(button_hits).unsqueeze(-1).repeat(1, 1, 1,1) * 0.8,  # 0.8 : velocity
                    torch.zeros_like(button_hits).unsqueeze(-1).repeat(1, 1, 1, 1),      # 0 : offset
                ], dim=-1) # (1, T, num_buttons, M)
        button_hvo = button_hvo.squeeze(0) # (T, num_buttons, M)
        button_feature = generated_sample.feature.from_button_hvo(button_hvo, steps_per_quarter=4)

    return generated_feature, button_feature

## Random Playback
Ground Truth Drum Sequence -> Button Sequence -> Reconstructed Drum Sequence

In [None]:
# ======= User-defined gap (seconds) ========
gap_sec = 2.0  # <-- set your gap duration here
sample_rate = 44100  # or whatever your playback function uses
sample_z = False
combined_audio = []
num_samples = 4

for i in range(num_samples):
    # ======= Pick random test sample =======
    random_idx = np.random.randint(0, len(test_dataset))
    sample, grid, button_hvo, desc_label = test_dataset[random_idx]
    sample.feature.play()

    grid = grid.unsqueeze(0)
    if button_hvo is not None:
        button_hits = button_hvo[:, :, 0].unsqueeze(0) # (1, T, num_buttons)
    else:
        button_hits = None

    generated_feature, button_feature = inference(model, sample, button_hits, None, device="cpu", threshold=0.85)

    # ======= Play generated audio =======
    button_audio = button_feature.play_button_hvo(button_feature)
    generated_audio = generated_feature.play()

    # ======= Create silence gap =======
    gap_samples = int(sample_rate * gap_sec)
    silence_gap = np.zeros((gap_samples, 2), dtype=button_audio.dtype)

    # ======= Concatenate all =======
    if i == 0:
        combined_audio = np.concatenate([button_audio, silence_gap, generated_audio], axis=0)
    else:
        combined_audio = np.concatenate([combined_audio, button_audio, silence_gap, generated_audio], axis=0)

# ======= Save to file =======
output_path = os.path.join(audio_save_dir, f"combined_output_{random_idx}.wav")
sf.write(output_path, combined_audio, sample_rate)
print(f"Saved combined audio to {output_path}")

# ======= Play in notebook =======
Audio(output_path)

## Style Transfer

In [None]:
style_map          = {style: [] for style in test_dataset.data_stats.style_map.keys() if style != "unknown"}
samples_per_style  = 100
styles_complete    = set()
errors             = 0

pbar = tqdm(total=len(test_dataset), desc="Collecting Encoded Vectors")
for i in range(len(test_dataset)):
    if len(styles_complete) == len(style_map):
        break
    try:
        sample, grid, button_hvo, desc_label = test_dataset[i]

        if sample.style == "unknown":
            pbar.update(1)
            continue

        if len(style_map[sample.style]) >= samples_per_style:
            styles_complete.add(sample.style)
            pbar.update(1)
            continue

        style_map[sample.style].append(i)

        if len(style_map[sample.style]) >= samples_per_style:
            styles_complete.add(sample.style)

        pbar.update(1)
    except Exception as e:
        print(f"Error encoding grid {i}: {e}")
        errors += 1
        pbar.update(1)
        continue

pbar.close()
print(f"Total errors: {errors}")
for key in style_map.keys():
    print(f"Style: {key}, #Samples: {len(style_map[key])}")


In [None]:
# ======= User-defined gap (seconds) ========
gap_sec = 2.0  # <-- set your gap duration here
sample_rate = 44100  # or whatever your playback function uses
combined_audio = []
num_samples = 1
num_styles  = 4
to_plots  = []

for i in range(num_samples):
    # ======= Pick random test sample =======
    random_idx = np.random.randint(0, len(test_dataset))
    original_sample, _, button_hvo, _ = test_dataset[random_idx]
    # ======= Pick random style clusters =======
    for j in range(num_styles):
        # Pick random style
        random_style = np.random.choice(list(style_map.keys()))
        print(f"Random style: {random_style}")
        random_style_idx = np.random.choice(style_map[random_style])
        print(f"Random style index: {random_style_idx}")
        style_sample, style_grid, _, _ = test_dataset[random_style_idx]

        grid = grid.unsqueeze(0)
        if button_hvo is not None:
            control_hits = button_hvo[:, :, 0].unsqueeze(0) # (1, T, num_buttons)
        else:
            control_hits = None

        if j == 0:
            generated_feature, control_feature = inference(model, sample, control_hits, None, device="cpu", threshold=0.85)
        else:
            generated_feature, control_feature = inference(model, sample, control_hits, style_grid.unsqueeze(0), device="cpu", threshold=0.85)

        if j == 0:
            to_plots.append(control_feature)
        to_plots.append(generated_feature)

        # ======= Play generated audio =======
        if j == 0:
            control_audio = control_feature.play_button_hvo(control_feature)
            #style_audio = style_sample.feature.play()
        generated_audio = generated_feature.play()

        # ======= Create silence gap =======
        gap_samples = int(sample_rate * gap_sec)
        silence_gap = np.zeros((gap_samples, 2), dtype=control_audio.dtype)

        # ======= Concatenate all =======
        if j == 0:
            combined_audio.append(control_audio)
        combined_audio.append(generated_audio)

# ======= Save to file =======
for i, audio in enumerate(combined_audio):
    output_path = os.path.join(audio_save_dir, f"style_{i}_output_{random_idx}.wav")
    sf.write(output_path, audio, sample_rate)
    print(f"Saved combined audio to {output_path}")

import matplotlib.pyplot as plt

fig, axes = plt.subplots(len(to_plots), 1, figsize=(15, 15))
DrumMIDIFeature._grid_plot(
        to_plots[0].to_button_hvo(steps_per_quarter=4, num_buttons=2),
        ax=axes[0],
        title=f"Control Sequence",
        xlabel="Time Step",
        ylabel="Control Class"
)

for i in range(1, len(to_plots)):
    to_plots[i].fixed_grid_plot(ax=axes[i])

fig.tight_layout()

# ======= Play in notebook =======
#Audio(output_path)

## Progressive Masking

### Setup

In [None]:
import random
import torch

def mask_beats_structured(
    button_hvo_base: torch.Tensor,
    steps_per_beat: int = 4,
    keep_beats: list = [0, 2],
    total_beats: int = 4
) -> torch.Tensor:
    """
    Masks out full beats except those specified in `keep_beats`.

    Args:
        button_hvo_base (Tensor): shape (B, T, num_buttons, M)
        steps_per_beat (int): number of time steps per beat (default 4 = 16th-note grid at 4/4 time)
        keep_beats (list): indices of beats to retain (e.g., [0, 2] = keep beat 1 and 3)
        total_beats (int): total number of beats in the bar (default 4 for 4/4)

    Returns:
        Tensor: masked button_hvo_base of same shape
    """
    B, T, num_buttons, M = button_hvo_base.shape
    steps_per_bar = steps_per_beat * total_beats
    num_bars = T // steps_per_bar
    mask = torch.zeros((T,), device=button_hvo_base.device)

    for bar_idx in range(num_bars):
        bar_start = bar_idx * steps_per_bar
        for beat in keep_beats:
            start = bar_start + beat * steps_per_beat
            end = start + steps_per_beat
            if end <= T:
                mask[start:end] = 1.0

    # Handle leftover steps at the end if T is not a multiple of steps_per_bar
    leftover = T % steps_per_bar
    if leftover > 0:
        for beat in keep_beats:
            start = num_bars * steps_per_bar + beat * steps_per_beat
            end = start + steps_per_beat
            if start < T:
                mask[start:min(end, T)] = 1.0

    # Broadcast to match shape
    mask = mask.view(1, T, 1, 1)
    return button_hvo_base * mask


def shift_button_sequence(button_hvo_base, shift_range=(-1, 1)):
    """
    Randomly shifts each time step's button activation forward or backward within the shift_range.
    Wraps at boundaries (circular).
    
    Args:
        button_hvo_base: (1, T, B, M) tensor
        shift_range: (min_shift, max_shift), inclusive
    """
    B, T, B_, M = button_hvo_base.shape
    shifted = torch.zeros_like(button_hvo_base)

    for t in range(T):
        shift = random.randint(*shift_range)
        t_new = (t + shift) % T
        shifted[:, t_new] += button_hvo_base[:, t]
    
    return shifted.clamp(0, 1.0)  # ensure values stay valid


def mask_random_within_regions(button_hvo_base, steps_per_beat=4, beats_to_keep=[0, 2], total_beats=4, retain_ratio=0.5):
    """
    Keeps given beats but randomly drops hits within the beat window.
    
    Args:
        button_hvo_base: (1, T, B, M) tensor
        steps_per_beat: how many steps per beat (usually 4 for 16th grid)
        beats_to_keep: list of beat indices to retain
        total_beats: total number of beats in the bar (default 4 for 4/4)
        retain_ratio: proportion of hits to keep within kept beats
    """
    B, T, num_buttons, M = button_hvo_base.shape
    steps_per_bar = steps_per_beat * total_beats
    num_bars = T // steps_per_bar
    mask = torch.zeros((T,), device=button_hvo_base.device)

    for bar_idx in range(num_bars + 1):  # include partial last bar
        bar_start = bar_idx * steps_per_bar
        for beat in range(total_beats):
            beat_start = bar_start + beat * steps_per_beat
            beat_end = beat_start + steps_per_beat

            if beat_start >= T:
                break

            if torch.rand(1).item() < retain_ratio:
                mask[beat_start:min(beat_end, T)] = 1.0

    mask = mask.view(1, T, 1, 1)  # broadcast to match button_hvo shape
    return button_hvo_base * mask

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()

# ======= Pick random result =======
random_idx = np.random.randint(0, len(test_dataset))
sample_rate = 44100
gap_sec = 2.0
num_expts = 4

# ======= Masking =======
total_beats = 4
steps_per_beat = 5
mask_progression = [
    [0, 1, 2, 3], # No mask
    [0, 2],       # Strong beat only
    [1, 3],       # Off beat only
    [0, 1],       # Downbeat and upbeat
    [0],          # Downbeat only
]  # Number of beats to keep
threshold = 0.85

data = test_dataset[random_idx]
button_audios = []
generated_audios = []
button_plots = []
generated_plots = []

for i in range(num_expts):
    sample, grid, button_hvo, desc_label = data
    button_hvo_mask = mask_beats_structured(button_hvo.unsqueeze(0), steps_per_beat, mask_progression[i], total_beats)

    encoded, button_repr = model.encode(grid.unsqueeze(0))
    button_hits = button_hvo_mask[:, :, :, 0] # (B, T, num_buttons)
    button_embed = model.make_button_embed(button_hits)
    z_post = model.sample_z_from_button_embed(button_embed)
    generated_grids, hit_probs = model.generate(button_embed, z_post, max_steps=33, threshold=0.85)
    generated_grids = generated_grids[:, 1:, :, :] # Drop SOS token
    generated_grids = generated_grids.squeeze(0)

    generated_sample = sample.from_fixed_grid(generated_grids, steps_per_quarter=4)
    generated_feature = generated_sample.feature
    button_hvo_tmp = torch.cat(
                        [
                            button_hits.unsqueeze(-1), 
                            torch.ones_like(button_hits).unsqueeze(-1).repeat(1, 1, 1,1) * 0.8,  # 0.8 : velocity
                            torch.zeros_like(button_hits).unsqueeze(-1).repeat(1, 1, 1, 1),      # 0 : offset
                        ], dim=-1) # (1, T, num_buttons, M)
    button_hvo_tmp = button_hvo_tmp.squeeze(0) # (T, num_buttons, M)
    button_feature = generated_sample.feature.from_button_hvo(button_hvo_tmp, steps_per_quarter=4)

    button_plots.append(button_feature)
    generated_plots.append(generated_feature)

    button_audios.append(button_feature.play_button_hvo(button_feature))
    generated_audios.append(generated_feature.play())

fig, axes = plt.subplots(len(button_plots), 2, figsize=(15, 15))

for i, (button_audio, generated_audio) in enumerate(zip(button_audios, generated_audios)):
    button_output_path = os.path.join(audio_save_dir, f"mask_{i}_button_output_{random_idx}.wav")
    generated_output_path = os.path.join(audio_save_dir, f"mask_{i}_generated_output_{random_idx}.wav")

    sf.write(button_output_path, button_audio, sample_rate)
    sf.write(generated_output_path, generated_audio, sample_rate)

for i in range(len(button_plots)):
    DrumMIDIFeature._grid_plot(
        button_plots[i].to_button_hvo(steps_per_quarter=4, num_buttons=2),
        ax=axes[i, 0],
        title=f"Control Sequence",
        xlabel="Time Step",
        ylabel="Control Class"
    )
    generated_plots[i].fixed_grid_plot(ax=axes[i, 1])

fig.tight_layout()