In [1]:
# Setup
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from dataset import DrumPatternDataset
from hierarchical_vae import HierarchicalDrumVAE
from training_utils import kl_annealing_schedule, temperature_annealing_schedule
from train import compute_hierarchical_elbo
from visualize import plot_drum_pattern
from analyze_latent import (
    visualize_latent_hierarchy, interpolate_styles,
    measure_disentanglement, controllable_generation
)

# Paths
results_dir = Path("results")
generated_dir = results_dir / "generated_patterns"
latent_dir = results_dir / "latent_analysis"

for d in [generated_dir, latent_dir]:
    d.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Training 

from train import main as train_main
train_main()

  from .autonotebook import tqdm as notebook_tqdm


Epoch 0, Loss=99.7768, Recon=99.7768, KL_low=676.1357, KL_high=1386.4526, Beta=0.000, Temp=2.00
Epoch 0 Validation - Loss: 36094.6738 KL_high: 34735.4806 KL_low: 1259.4785 Validity: 0.968 Diversity: 0.085
Epoch 10, Loss=98.9621, Recon=98.9553, KL_low=0.0206, KL_high=0.0131, Beta=0.200, Temp=1.85
Epoch 10 Validation - Loss: 98.9123 KL_high: 0.0171 KL_low: 0.0237 Validity: 0.881 Diversity: 0.294
Epoch 20, Loss=99.0737, Recon=98.9184, KL_low=0.3805, KL_high=0.0080, Beta=0.400, Temp=1.70
Epoch 20 Validation - Loss: 98.9863 KL_high: 0.0089 KL_low: 0.4154 Validity: 0.936 Diversity: 0.273
Epoch 30, Loss=98.9551, Recon=98.7375, KL_low=0.3578, KL_high=0.0048, Beta=0.600, Temp=1.55
Epoch 30 Validation - Loss: 98.8856 KL_high: 0.0052 KL_low: 0.3241 Validity: 0.976 Diversity: 0.172
Epoch 40, Loss=99.2280, Recon=99.0483, KL_low=0.2212, KL_high=0.0035, Beta=0.800, Temp=1.40
Epoch 40 Validation - Loss: 98.9007 KL_high: 0.0035 KL_low: 0.2190 Validity: 0.963 Diversity: 0.205
Epoch 50, Loss=98.5480, Rec

In [None]:
# Generate and Save 10 Samples Per Style

from pathlib import Path
import torch
from torch.utils.data import DataLoader

# load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HierarchicalDrumVAE(
    z_high_dim=4,
    z_low_dim=12
).to(device)
model.load_state_dict(torch.load("results/best_model.pth", map_location=device))
model.eval()

gen_dir = Path("results/generated_patterns")
gen_dir.mkdir(parents=True, exist_ok=True)

# Generate samples
n_samples_per_style = 10
all_generated = {s: [] for s in range(5)}  # 5 styles

# Val Loader
data_dir = "../data/drums"
val_dataset = DrumPatternDataset(data_dir, split="val")
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

with torch.no_grad():
    for patterns, styles, densities in val_loader:
        patterns = patterns.to(device)
        out = model(patterns)
        recon = out['recon']
        sampled = (torch.sigmoid(recon) > 0.5).float().cpu()

        for p, s in zip(sampled, styles):
            s = int(s.item())
            if len(all_generated[s]) < n_samples_per_style:
                all_generated[s].append(p)

        if all(len(v) >= n_samples_per_style for v in all_generated.values()):
            break

# Save results
style_names = ["Rock", "Jazz", "Hip-hop", "Electronic", "Latin"]

for s, patterns in all_generated.items():
    for idx, p in enumerate(patterns):
        fig = plot_drum_pattern(p, title=f"Style {style_names[s]} - Sample {idx}")
        fig.savefig(gen_dir / f"style{s}_{style_names[s]}_sample{idx}.png")
        plt.close(fig)

print(f"Saved 10 samples per style to {gen_dir}")


Saved 10 samples per style to results/generated_patterns


In [None]:
# Latent Space Interpolation

import numpy as np

interp_dir = Path("results/generated_patterns")
interp_dir.mkdir(parents=True, exist_ok=True)

model.eval()

# Take two random validation samples
patterns_a, style_a, _ = val_dataset[0]
patterns_b, style_b, _ = val_dataset[1]

patterns_a = patterns_a.unsqueeze(0).to(device)
patterns_b = patterns_b.unsqueeze(0).to(device)

with torch.no_grad():
    out_a = model.encode_hierarchy(patterns_a)
    z_low_a, _, _, z_high_a, _, _ = out_a

    out_b = model.encode_hierarchy(patterns_b)
    z_low_b, _, _, z_high_b, _, _ = out_b

    # Interpolate
    n_steps = 10
    z_low_interp = [z_low_a * (1 - t) + z_low_b * t for t in np.linspace(0, 1, n_steps)]
    z_high_interp = [z_high_a * (1 - t) + z_high_b * t for t in np.linspace(0, 1, n_steps)]

    # Decode interpolated latents
    for idx, (zl, zh) in enumerate(zip(z_low_interp, z_high_interp)):
        recon = model.decode_hierarchy(zl, zh)
        sampled = (torch.sigmoid(recon) > 0.5).float().cpu()[0]

        fig = plot_drum_pattern(
            sampled, 
            title=f"Interpolation Step {idx}"
        )
        fig.savefig(interp_dir / f"interpolation_step{idx}.png")
        plt.close(fig)

print(f"Saved interpolation sequence to {interp_dir}")

Saved interpolation sequence to results/generated_patterns


In [None]:
# Style Transfer Examples

transfer_dir = Path("results/generated_patterns")
transfer_dir.mkdir(parents=True, exist_ok=True)

model.eval()
style_names = ["Rock", "Jazz", "Hip-hop", "Electronic", "Latin"]

n_styles = len(style_names)
sources = [val_dataset[i][0].unsqueeze(0).to(device) for i in range(n_styles)]

with torch.no_grad():
    for src_idx, src_pattern in enumerate(sources):
        z_low, mu_low, logvar_low, z_high, mu_high, logvar_high = model.encode_hierarchy(src_pattern)

        for tgt_idx in range(n_styles):
            tgt_pattern = val_dataset[tgt_idx][0].unsqueeze(0).to(device)
            _, _, _, z_high_tgt, _, _ = model.encode_hierarchy(tgt_pattern)

            recon = model.decode_hierarchy(z_low, z_high_tgt)
            sampled = (torch.sigmoid(recon) > 0.5).float().cpu()[0]

            fig = plot_drum_pattern(
                sampled,
                title=f"{style_names[src_idx]} → {style_names[tgt_idx]}"
            )
            fig.savefig(
                transfer_dir / f"{style_names[src_idx]}_to_{style_names[tgt_idx]}.png"
            )
            plt.close(fig)

print(f"Saved style transfer results to {transfer_dir}")


Saved style transfer results to results/generated_patterns


In [None]:
# Latent Space Visualization with t-SNE

from pathlib import Path
from visualize import *

latent_dir = Path("results/latent_analysis")
latent_dir.mkdir(parents=True, exist_ok=True)

all_latents = []
all_labels = []

model.eval()

with torch.no_grad():
    for patterns, styles, _ in val_loader:
        patterns = patterns.to(device)
        z_low, mu_low, logvar_low, z_high, mu_high, logvar_high = model.encode_hierarchy(patterns)
        
        all_latents.append(z_high.cpu())
        all_labels.append(styles)

all_latents = torch.cat(all_latents, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy()

fig = plot_latent_space_2d(all_latents, labels=all_labels, title="Latent Space (z_high)")
fig.savefig(latent_dir / "latent_space_tsne.png")
plt.close(fig)

print(f"Saved latent space visualization to {latent_dir}")

Saved latent space visualization to results/latent_analysis


In [None]:
# Latent Analysis: t-SNE

from pathlib import Path
from analyze_latent import visualize_latent_hierarchy, plot_kl_trends

latent_dir = Path("results/latent_analysis")
latent_dir.mkdir(parents=True, exist_ok=True)

z_high, z_low, labels = visualize_latent_hierarchy(model, val_loader, device=device)
print(f"Saved latent t-SNE plots to {latent_dir}")

Saved latent t-SNE plots to results/latent_analysis


In [None]:
# Humanization (Variation Injection)

import torch
from pathlib import Path
import matplotlib.pyplot as plt

# Reload trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HierarchicalDrumVAE(
    z_high_dim=4,
    z_low_dim=12
).to(device)
model.load_state_dict(torch.load("results/best_model.pth", map_location=device))
model.eval()

humanize_dir = Path("results/generated_patterns")
humanize_dir.mkdir(parents=True, exist_ok=True)

# Style
style_names = ["Rock", "Jazz", "Hip-hop", "Electronic", "Latin"]

n_variations = 5  
with torch.no_grad():
    for style_idx, style_name in enumerate(style_names):
        # fix z_high
        z_high = torch.randn(1, model.z_high_dim).to(device)

        for v in range(n_variations):
            # inject random variation into z_low
            z_low = torch.randn(1, model.z_low_dim).to(device)
            logits = model.decode_hierarchy(z_high, z_low)
            pattern = (torch.sigmoid(logits) > 0.5).cpu().squeeze(0)

            # Save visualization
            fig = plot_drum_pattern(pattern, title=f"{style_name} - Humanized {v}")
            fig.savefig(humanize_dir / f"{style_name}_variation{v}.png")
            plt.close(fig)

print(f"Saved humanization results to {humanize_dir}")

Saved humanization results to results/generated_patterns
