In [None]:
# Best model location: checkpoints/quantum_8qubit_depth3_high_reg_best.pth

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import json
from pathlib import Path
import IPython.display as ipd
from ipywidgets import interact, IntSlider, Button, Output, VBox, HBox, Label
from torch.utils.data import Dataset, DataLoader

# Import your scripts
from lstmabar_model import LSTMABAR
from macro_archetype_predictor import MacroArchetypePredictor
from parameter_decoder import ParameterDecoder
from harmonic_ddsp_engine import HarmonicDDSPEngine

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [2]:
# --- CONFIGURATION ---
CHECKPOINT_PATH = 'checkpoints/quantum_8qubit_depth3_high_reg_best.pth'

def load_pretrained_model(checkpoint_path):
    print(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    config = checkpoint['config']
    
    # Initialize the Full Model
    model = LSTMABAR(
        embedding_dim=config['embedding_dim'],
        text_model='sentence-transformers/all-MiniLM-L6-v2',
        audio_architecture=config['audio_architecture'],
        sample_rate=44100,
        use_quantum_attention=config['use_quantum_attention'],
        n_qubits=config['n_qubits'],
        device=device
    )
    
    # Load Weights
    # Note: The checkpoint only has weights for Text/Audio/Aligner.
    # The Predictor/Decoder/Engine will remain random (initialized fresh).
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    
    # Freeze the Backbone (We only want to train the Generator now)
    for param in model.text_encoder.parameters(): param.requires_grad = False
    for param in model.audio_encoder.parameters(): param.requires_grad = False
    for param in model.contrastive_aligner.parameters(): param.requires_grad = False
    
    print("âœ“ Backbone loaded and frozen. Generator ready for training.")
    return model, config

model, config = load_pretrained_model(CHECKPOINT_PATH)

Loading checkpoint: checkpoints/quantum_8qubit_depth3_high_reg_best.pth
Loading text encoder: sentence-transformers/all-MiniLM-L6-v2
ðŸ§  Initializing Macro Archetype Predictor...
Translation Layer: Parameter Decoder...
ðŸŽ¹ Initializing Harmonic DDSP Engine...
âœ“ Backbone loaded and frozen. Generator ready for training.


In [3]:
# --- GENERATE WARM-START METADATA ---
import soundfile as sf

DATASET_DIR = Path("synthetic_dynamic_dataset")
DATASET_DIR.mkdir(exist_ok=True)
SR = 44100

# 1. Physics Rules (4 Timbre + 5 Dynamics)
ARCHETYPE_MAP = {
    "pure":      [1.0, 0.0, 0.0, 0.0],
    "bright":    [0.0, 0.0, 0.8, 0.2], 
    "harsh":     [0.0, 0.5, 0.5, 0.0], 
    "warm":      [0.6, 0.0, 0.0, 0.4],
    "hollow":    [0.0, 0.9, 0.0, 0.1]
}

def get_dynamics(text):
    # A, D, S, R, Cutoff
    params = [0.1, 0.1, 1.0, 0.1, 1.0] 
    if "pluck" in text: params = [0.01, 0.2, 0.0, 0.2, 1.0]
    elif "pad" in text: params = [0.5, 0.0, 1.0, 0.5, 0.6]
    return params

data_records = []
for archetype, weights in ARCHETYPE_MAP.items():
    prompts = [f"{archetype} tone", f"{archetype} pad", f"{archetype} pluck"]
    for text in prompts:
        dyn = get_dynamics(text)
        
        # Dummy audio file (needed for AudioEncoder input)
        # We generate silence because for Warm Start we map Text -> Params directly.
        # The Audio context is less relevant here.
        fname = f"dyn_{archetype}_{text.split()[-1]}.wav"
        path = DATASET_DIR / fname
        sf.write(path, np.zeros(SR), SR)
        
        data_records.append({
            "filepath": str(path),
            "text": text,
            "w_sine": weights[0], "w_sq": weights[1], "w_saw": weights[2], "w_tri": weights[3],
            "a": dyn[0], "d": dyn[1], "s": dyn[2], "r": dyn[3], "cutoff": dyn[4]
        })

df = pd.DataFrame(data_records)
df.to_csv("synthetic_dynamic_dataset.csv", index=False)
print(f"âœ“ Generated {len(df)} synthetic samples.")

âœ“ Generated 15 synthetic samples.


In [4]:
# --- SMART WARM START ---
# Maps simple CSV weights -> Complex DDSP Parameters

class SyntheticDDSPDataset(Dataset):
    def __init__(self, csv_path, backbone_model, n_harmonics=64, device='cpu'):
        self.df = pd.read_csv(csv_path)
        self.backbone = backbone_model
        self.n_harmonics = n_harmonics
        self.device = device
        
        # Pre-calculate Harmonic Series for basic waves
        k = torch.arange(1, n_harmonics + 1).float()
        self.saw_series = 1.0 / k
        self.square_series = (1.0 / k) * (k % 2 != 0).float() # Odd harmonics only
        self.tri_series = (1.0 / (k**2)) * (k % 2 != 0).float() # Odd, steeper dropoff
        self.sine_series = torch.zeros(n_harmonics); self.sine_series[0] = 1.0
        
    def __len__(self): return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # 1. Inputs
        text = row['text']
        # Dummy audio embedding (zeros)
        audio_emb = torch.zeros(self.backbone.embedding_dim)
        
        # Pre-compute text embedding
        with torch.no_grad():
            # Use forward() call instead of .encode()
            text_input = [text] 
            text_emb = self.backbone.text_encoder(text_input)
            
            if isinstance(text_emb, tuple): text_emb = text_emb[0]
            if not isinstance(text_emb, torch.Tensor): text_emb = torch.tensor(text_emb)
            
            text_emb = text_emb.cpu().squeeze()

        # 2. Build Targets (Complex Physics)
        # Force Python Floats to avoid numpy.float64 pollution
        w_sine = float(row['w_sine'])
        w_sq = float(row['w_sq'])
        w_saw = float(row['w_saw'])
        w_tri = float(row['w_tri'])
        
        target_harmonics = (
            w_sine * self.sine_series + 
            w_sq * self.square_series + 
            w_saw * self.saw_series + 
            w_tri * self.tri_series
        )
        # Normalize harmonics and FORCE FLOAT32
        target_harmonics = target_harmonics / (target_harmonics.sum() + 1e-5)
        target_harmonics = target_harmonics.float() 
        
        # ADSR Target - FORCE FLOAT32
        target_adsr = torch.tensor([row['a'], row['d'], row['s'], row['r']]).float()
        
        # Noise Target
        target_noise = torch.ones(32) * 0.01 
        if row['cutoff'] > 0.8: target_noise *= 5.0
        target_noise = target_noise.float() # Ensure float32
            
        return text_emb, audio_emb, target_harmonics, target_noise, target_adsr

# --- TRAINING LOOP ---
def run_smart_warm_start(model, epochs=50):
    print("ðŸ”¥ Warm-starting Generator (Predictor + Decoder)...")
    dataset = SyntheticDDSPDataset("synthetic_dynamic_dataset.csv", model, device='cpu')
    loader = DataLoader(dataset, batch_size=8, shuffle=True)
    
    # Train both Predictor and Decoder
    optimizer = optim.Adam([
        {'params': model.predictor.parameters()},
        {'params': model.decoder.parameters()}
    ], lr=1e-3)
    
    criterion = nn.MSELoss()
    
    model.predictor.train()
    model.decoder.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for text_emb, audio_emb, t_harm, t_noise, t_adsr in loader:
            # Move to device
            text_emb, audio_emb = text_emb.to(device), audio_emb.to(device)
            t_harm, t_noise, t_adsr = t_harm.to(device), t_noise.to(device), t_adsr.to(device)
            
            # Robust dimension check (squeeze 3D -> 2D)
            if text_emb.dim() == 3: text_emb = text_emb.squeeze(1)
            if audio_emb.dim() == 3: audio_emb = audio_emb.squeeze(1)
            
            optimizer.zero_grad()
            
            # Forward: Predict -> Decode
            macros = model.predictor(text_emb, audio_emb)
            p_harm, p_noise, p_adsr, p_gain = model.decoder(macros['combined'])
            
            # Loss: Match targets
            loss_harm = criterion(p_harm, t_harm)
            loss_noise = criterion(p_noise, t_noise)
            loss_adsr = criterion(p_adsr, t_adsr)
            
            loss = loss_harm + loss_noise + loss_adsr
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        if (epoch+1) % 10 == 0:
            print(f"Epoch {epoch+1}: Loss {total_loss/len(loader):.4f}")
            
    print("âœ“ Warm start complete. Generator knows physics!")

run_smart_warm_start(model)

ðŸ”¥ Warm-starting Generator (Predictor + Decoder)...
Epoch 10: Loss 0.0905
Epoch 20: Loss 0.0850
Epoch 30: Loss 0.0408
Epoch 40: Loss 0.0147
Epoch 50: Loss 0.0082
âœ“ Warm start complete. Generator knows physics!


In [7]:
class RLHFTrainerMacro:
    def __init__(self, model, learning_rate=1e-4):
        self.model = model
        self.optimizer = optim.Adam([
            {'params': model.predictor.parameters()},
            {'params': model.decoder.parameters()}
        ], lr=learning_rate)
        # Use a list as the buffer
        self.experience_buffer = [] 
        
    def _ensure_tensor(self, x):
        """Safely extract tensor from tuple/list if needed"""
        if isinstance(x, (tuple, list)):
            return x[0]
        return x
        
    def add_feedback(self, text_emb, audio_emb, rating):
        # 1. Unpack tuples if necessary (Robustness Fix)
        text_emb = self._ensure_tensor(text_emb)
        audio_emb = self._ensure_tensor(audio_emb)
        
        # 2. Convert 1-5 rating to reward (-1.0 to 1.0)
        reward = (rating - 3) / 2.0
        
        # 3. Store DETACHED tensors
        self.experience_buffer.append({
            'text': text_emb.detach(),   # Now safe to call .detach()
            'audio': audio_emb.detach(), # Now safe to call .detach()
            'reward': reward
        })
        
    def update(self, batch_size=4):
        if len(self.experience_buffer) < batch_size: return None
        
        # Sample batch
        batch = np.random.choice(self.experience_buffer, batch_size, replace=False)
        
        loss_accum = 0
        self.optimizer.zero_grad()
        
        for item in batch:
            # Re-run forward pass to get gradient graph
            # Note: We need to squeeze potential extra dims [1, 768] -> [1, 768]
            t_emb = item['text'].to(device)
            a_emb = item['audio'].to(device)
            
            # Ensure correct shape for predictor (Batch, Dim)
            if t_emb.dim() == 1: t_emb = t_emb.unsqueeze(0)
            if a_emb.dim() == 1: a_emb = a_emb.unsqueeze(0)
            
            # Predict Macros
            macros = self.model.predictor(t_emb, a_emb)
            latents = macros['combined']
            
            # Simple Reward-Weighted Regularization Loss
            # Pushes latents towards active state if reward > 0
            # Pushes latents towards zero if reward < 0
            loss = -1.0 * item['reward'] * torch.mean(latents**2) 
            
            loss.backward()
            loss_accum += loss.item()
            
        self.optimizer.step()
        return loss_accum / batch_size

# Re-Initialize the trainer
trainer = RLHFTrainerMacro(model)
print("âœ“ Trainer patched. Ready for feedback.")

âœ“ Trainer patched. Ready for feedback.


In [11]:
# --- WIDGET SETUP ---
SR = 44100
# Base audio for transformation (Sine Wave Canvas)
t = np.linspace(0, 2.0, int(SR*2.0))
base_audio = 0.5 * np.sin(2 * np.pi * 440 * t)
base_audio_tensor = torch.from_numpy(base_audio).float().unsqueeze(0).to(device)

# Tasks
tasks = [
    "bright, cutting tone",
    "warm, smooth melody with gentle sustain",
    "harsh, digital sound with buzzy edges",
    "sharp, metallic plucks with quick decay",
    "crunchy overdriven riffs with grit",
    "warm, mellow strumming texture",
    "glassy, chiming harmonics",
    "thick, muffled rhythm",
    "raw, gritty tone with expressive bends",
    "lush, chorus-soaked clean chords",
    "dark, smoky ensemble in low register",
    "warm, smooth chords with soft transients",
    "bright, percussive stabs",
    "mellow, emotional melody",
    "dark, muted tone with felt-like texture",
    "sparkly, bell-like arpeggios",
    "soulful chords with slow rotary effect",
    "punchy, percussive riff with sharp attack",
    "dreamy, detuned hybrid tone with soft transients",
    "crunchy, slightly distorted tone with bite",
    "aggressive, bright lead with sharp harmonics",
    "warm, analog pad with gentle movement",
    "grainy, lofi lead with texture",
    "soft, airy lead with gentle brightness",
    "detuned, wobbling tone with drifting pitch",
    "raspy, resonant filter-sweep lead",
    "sparkly, crystalline plucks with short decay",
    "thick, wide supersaw lead with stereo spread",
    "hollow, formant-shifted tone with vowel-like quality",
    "lush, wide pad with dreamy texture",
    "dark, evolving ambient pad with low rumble",
    "celestial, shimmering pad with high-frequency sparkle",
    "hollow, airy pad with subtle modulation",
    "moody drone with slow-moving harmonics",
    "detuned, warm pad with analog drift",
    "ethereal, floating texture with soft overtones",
    "deep, warm sub-bass with smooth sine texture",
    "gritty, distorted tone with heavy saturation",
    "rubbery, bouncy bass with fast transients",
    "thick, resonant low-end with movement",
    "clean, round tone with gentle harmonics",
    "fuzzy, aggressive tone with buzz-saw texture",
    "burst with bright edges",
    "grainy, textured sound with soft filtering",
    "distorted, chaotic bed with harsh peaks",
    "warm, analog bed with subtle movement",
    "glitchy, stuttering texture with digital artifacts",
    "rough, chaotic saw-like texture",
    "tight, punchy hit with sharp attack",
    "snappy, bright sound with crisp transient",
    "warm, rounded hits with soft decay",
    "bright pattern with metallic shimmer",
    "airy, crisp groove with stereo shimmer",
    "tight, dry beat with fast transients",
    "boomy, cinematic hits",
    "gritty, overcompressed loop with pumping artifacts",
    "deep, resonant tone with long sustain",
    "sharp, percussive hit with strong transient",
    "bright, resonant melody",
    "soft, expressive line with airy tone",
    "warm phrase with smooth transitions",
    "dark, breathy melody",
    "bold stabs with powerful attack",
    "soft, lush strings with gentle swelling",
    "warm, woody plucks",
    "smooth section with warm resonance",
    "bright strikes with clean attack",
    "glassy, shimmering notes",
    "warm chords with soft tremolo",
    "sharp, metallic hits with long decay",
    "hollow plucks with woody texture",
    "distant, echoing ambient chords",
    "soft, hazy reverb-washed tones",
    "crystalline ambient washes with airy diffusion",
    "deep, cavernous drone with subharmonics",
    "slow, swelling cinematic texture",
    "retro 80s lead with chorus",
    "dark industrial tone with metallic grit",
    "lofi, tape-warped acoustic texture",
    "robotic, vocoder-like synthetic tone",
    "electronic plucks with rapid transient snap",
    "warm, resonant plucks with rounded body",
    "metallic, atonal texture with shifting harmonics",
    "heavy, saturated resonant tone with compression pump"
]
current_task_idx = 0

# UI Components
lbl_task = Label(f"Task: {tasks[0]}")
btn_gen = Button(description="Generate")
slider = IntSlider(min=1, max=5, value=3, description="Rating")
btn_rate = Button(description="Submit")
out = Output()

# Global variables to store current step data
curr_text_emb = None
curr_audio_emb = None

def on_gen(b):
    with out:
        out.clear_output()
        desc = tasks[current_task_idx]
        print(f"Generating for: {desc}")
        
        with torch.no_grad():
            # Run Full Model
            # generate_audio=True enables the Predictor -> Decoder -> Engine flow
            outputs = model([desc], base_audio_tensor, generate_audio=True)
            
            # Store embeddings for training
            global curr_text_emb, curr_audio_emb
            curr_text_emb = outputs['text_emb']
            curr_audio_emb = outputs['audio_emb']
            
            # Visualize Macros (32-dim Latents)
            macros = outputs['latents']
            # Show just the first few dimensions of each group
            t_latents = macros['timbre'][0, :4].cpu().numpy().round(2)
            e_latents = macros['envelope'][0, :4].cpu().numpy().round(2)
            
            print(f"\nTimbre Latents (first 4):   {t_latents}")
            print(f"Envelope Latents (first 4): {e_latents}")
            
            # Play Audio
            audio = outputs['audio_output'].cpu().numpy()[0]
            ipd.display(ipd.Audio(audio, rate=SR))

def on_rate(b):
    global current_task_idx
    with out:
        rating = slider.value
        
        # Add to trainer memory
        trainer.add_feedback(curr_text_emb, curr_audio_emb, rating)
        print(f"Rated: {rating}. Updating...")
        
        # *** FIX: Use 'experience_buffer' instead of 'buffer' ***
        if len(trainer.experience_buffer) >= 4:
            loss = trainer.update(batch_size=4)
            if loss:
                print(f"ðŸ“‰ Model Updated (Loss: {loss:.4f})")
            else:
                print("Update skipped (insufficient buffer)")
        else:
            print(f"Stored ({len(trainer.experience_buffer)}/4 samples needed for update)")
            
        # Advance task
        current_task_idx = (current_task_idx + 1) % len(tasks)
        lbl_task.value = f"Task: {tasks[current_task_idx]}"

# Bind buttons
btn_gen.on_click(on_gen)
btn_rate.on_click(on_rate)

# Display UI
display(VBox([lbl_task, btn_gen, slider, btn_rate, out]))

VBox(children=(Label(value='Task: bright, cutting tone'), Button(description='Generate', style=ButtonStyle()),â€¦