# RLHF Pipeline - Hybrid Approach

**Strategy**: Train on clean synthetic data, evaluate on real-world audio

## Workflow
1. **RLHF Training**: Collect feedback on synthetic base tones with text descriptions
2. **Evaluation**: Test improved model on REAL MusicCaps test set
3. **Analysis**: Does synthetic training → real improvement?

## Why This Works
- Clean RLHF signal (synthetic)
- Real evaluation (MusicCaps)
- Tests generalization\

In [1]:
# Cell 1: Imports
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import json
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt
import IPython.display as ipd
from IPython.display import display
import librosa
import warnings
warnings.filterwarnings('ignore')

# Import model components
from lstmabar_model import LSTMABAR
from archetype_predictor import ArchetypePredictionHead, RLHFTrainer

print("✓ Imports loaded")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

=== Testing Archetype Prediction ===
Predicted weights shape: torch.Size([8, 5])
Sample prediction: tensor([0.4174, 0.1978, 0.0785, 0.1063, 0.2000], grad_fn=<SelectBackward0>)
Sum of weights: 1.000000 (should be ~1.0)

Named predictions for sample 0:
  sine: 0.5278
  square: 0.1097
  sawtooth: 0.0864
  triangle: 0.1374
  noise: 0.1387

=== Testing Archetype Loss ===
MSE loss: 0.0263

=== Testing RLHF Trainer with Audio Playback ===

HUMAN FEEDBACK COLLECTION

Description: 'bright and cutting guitar tone'

Predicted Archetype Weights:
  sine      : ████████ 0.417
  square    : ███ 0.198
  sawtooth  : █ 0.079
  triangle  : ██ 0.106
  noise     : ████ 0.200

------------------------------------------------------------
AUDIO PLAYBACK
------------------------------------------------------------

▶️  ORIGINAL AUDIO:



▶️  TRANSFORMED AUDIO:



------------------------------------------------------------
RATING INSTRUCTIONS
------------------------------------------------------------
Rate how well the transformation matches the description:
  5 = Perfect match
  4 = Good match
  3 = Acceptable match
  2 = Poor match
  1 = Very poor match

✓ Feedback recorded: 5.0/5


=== Interactive RLHF Usage Example ===

# In Jupyter notebook, use this pattern:

# 1. Generate or load audio samples
original_audio = librosa.load('input.wav')[0]
transformed_audio = model.transform(original_audio, description)

# 2. Get embeddings and predictions
text_emb = text_encoder([description])
audio_emb = audio_encoder(torch.from_numpy(original_audio))
predicted_weights = predictor(text_emb, audio_emb)

# 3. Collect interactive feedback with audio playback
rating = rlhf_trainer.collect_feedback_with_audio(
description="bright and crunchy",
original_audio=original_audio,
transformed_audio=transformed_audio,
predicted_weights=predicted_weights[0].cpu().n

In [None]:
# Cell 2: Configuration

# Model config
MODEL_CONFIG = {
    'use_quantum_attention': True,
    'checkpoint_path': 'tuning_checkpoints/mids266final_bestmodel/best_model.pth',
    'sample_rate': 44100,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# Data paths
DATA_CONFIG = {
    'test_data_path': 'musiccaps_training_data_test.npz',
}

# RLHF config
RLHF_CONFIG = {
    'num_feedback_samples': 30,
    'reward_threshold': 3.0,
    'learning_rate': 1e-5,
    'num_rlhf_epochs': 5,
    'update_frequency': 5,
    'amplify_archetypes': True,      # NEW: Amplify archetype differences
    'amplification_temperature': 2.0  # NEW: 2.0 = moderate, 3.0+ = more extreme
}

# Base tone config
BASE_TONE_CONFIG = {
    'duration': 2.0,
    'sample_rate': 44100,
    'fundamental_freq': 220.0,
}

# Output
OUTPUT_DIR = Path('rlhf_hybrid_results')
OUTPUT_DIR.mkdir(exist_ok=True)

print("✓ Configuration loaded")
print(f"  RLHF training: {RLHF_CONFIG['num_feedback_samples']} synthetic samples")
print(f"  Amplification: {RLHF_CONFIG['amplify_archetypes']} (temp={RLHF_CONFIG['amplification_temperature']})")
print(f"  Evaluation: Real test data from {DATA_CONFIG['test_data_path']}")
print(f"  Output dir: {OUTPUT_DIR}")

✓ Configuration loaded
  RLHF training: 30 synthetic samples
  Amplification: True (temp=2.0)
  Evaluation: Real test data from musiccaps_training_data_test.npz
  Output dir: rlhf_hybrid_results


In [45]:
# Cell 3: Load Model

print("="*60)
print("LOADING MODEL")
print("="*60)

device = torch.device(MODEL_CONFIG['device'])
print(f"\nUsing device: {device}")

# Load checkpoint
checkpoint_path = Path(MODEL_CONFIG['checkpoint_path'])
if not checkpoint_path.exists():
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

print(f"Loading from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)

# Initialize model with Quantum attention
model = LSTMABAR(
    sample_rate=MODEL_CONFIG['sample_rate'],
    use_quantum_attention=MODEL_CONFIG['use_quantum_attention'],
    device=device
)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"✓ Model loaded successfully")
print(f"  Epoch: {checkpoint.get('epoch', 'unknown')}")
print(f"  Loss: {checkpoint.get('loss', 'unknown')}")
print(f"\n✓ Model has archetype predictor: {hasattr(model, 'archetype_predictor')}")
print(f"✓ Model has transform_audio: {hasattr(model, 'transform_audio')}")
print(f"✓ Model has inference: {hasattr(model, 'inference')}")

LOADING MODEL

Using device: cpu
Loading from: tuning_checkpoints/mids266final_bestmodel/best_model.pth
Loading text encoder: sentence-transformers/all-MiniLM-L6-v2
✓ Model loaded successfully
  Epoch: 7
  Loss: unknown

✓ Model has archetype predictor: True
✓ Model has transform_audio: True
✓ Model has inference: True


In [46]:
# Cell 4: Load Real Test Data

print("="*60)
print("LOADING REAL TEST DATA")
print("="*60)

print(f"\nLoading from: {DATA_CONFIG['test_data_path']}")
test_data = np.load(DATA_CONFIG['test_data_path'], allow_pickle=True)

print(f"Available keys: {list(test_data.keys())}")

# Extract data
test_descriptions = test_data['descriptions']
test_archetypes = test_data['archetype_vectors']
test_audio_paths = test_data['audio_paths']
archetype_order = test_data['archetype_order']

print(f"\n✓ Metadata loaded: {len(test_descriptions)} test samples")
print(f"  Archetype order: {list(archetype_order)}")

# Load audio
def load_audio_from_paths(audio_paths, sample_rate=44100, max_duration=10.0):
    """Load audio from file paths"""
    loaded_audio = []
    failed = []
    max_length = int(sample_rate * max_duration)
    
    for audio_path in tqdm(audio_paths, desc="Loading test audio"):
        audio_path = str(audio_path)
        try:
            audio, sr = librosa.load(audio_path, sr=sample_rate, duration=max_duration)
            if len(audio) < max_length:
                audio = np.pad(audio, (0, max_length - len(audio)))
            else:
                audio = audio[:max_length]
            loaded_audio.append(audio)
        except Exception as e:
            if len(failed) < 5:
                print(f"\n  Failed: {audio_path}: {e}")
            loaded_audio.append(np.zeros(max_length))
            failed.append(audio_path)
    
    if failed:
        print(f"\n  {len(failed)}/{len(audio_paths)} files failed")
    
    return np.array(loaded_audio), failed

print("\n🎵 Loading test audio files...")
test_audio, test_failed = load_audio_from_paths(
    test_audio_paths, 
    MODEL_CONFIG['sample_rate']
)

print(f"\n✓ Test data loaded")
print(f"  Audio shape: {test_audio.shape}")
print(f"  Descriptions: {len(test_descriptions)}")
print(f"  Target archetypes: {test_archetypes.shape}")

LOADING REAL TEST DATA

Loading from: musiccaps_training_data_test.npz
Available keys: ['archetype_vectors', 'descriptions', 'audio_paths', 'archetype_order']

✓ Metadata loaded: 73 test samples
  Archetype order: [np.str_('sine'), np.str_('square'), np.str_('sawtooth'), np.str_('triangle'), np.str_('noise')]

🎵 Loading test audio files...


Loading test audio: 100%|██████████| 73/73 [00:00<00:00, 93.57it/s]



✓ Test data loaded
  Audio shape: (73, 441000)
  Descriptions: 73
  Target archetypes: (73, 5)


In [None]:
# Cell 5: Generate Synthetic Base Tones

print("="*60)
print("GENERATING SIMPLE SINE WAVE BASE TONE")
print("="*60)

# Simple sine wave - just like archetype_predictor.py example
duration = BASE_TONE_CONFIG['duration']
sample_rate = BASE_TONE_CONFIG['sample_rate']
frequency = BASE_TONE_CONFIG['fundamental_freq']  # 220 Hz (A3)

num_samples = int(duration * sample_rate)
t = np.linspace(0, duration, num_samples)

# Pure sine wave at 0.5 amplitude (mellow, like the example)
default_base_tone = (np.sin(2 * np.pi * frequency * t) * 0.5).astype(np.float32)

print(f"\n✓ Generated simple sine wave base tone: {default_base_tone.shape}")
print(f"  Frequency: {frequency} Hz")
print(f"  Duration: {duration} seconds")
print(f"  Amplitude: 0.5 (mellow)")
print(f"  Mean: {default_base_tone.mean():.6f}")
print(f"  Std: {default_base_tone.std():.6f}")

# Play it so you can hear the base
print("\n🎵 Base tone (will be transformed):")
display(ipd.Audio(default_base_tone, rate=sample_rate))

GENERATING SIMPLE SINE WAVE BASE TONE

✓ Generated simple sine wave base tone: (88200,)
  Frequency: 220.0 Hz
  Duration: 2.0 seconds
  Amplitude: 0.5 (mellow)
  Mean: -0.000000
  Std: 0.353551

🎵 Base tone (will be transformed):


In [48]:
# Cell 6: Load Training Descriptions

TRAINING_DESCRIPTIONS = [
    "bright, cutting guitar tone",
    "warm, smooth guitar melody with gentle sustain",
    "harsh, digital synth-like guitar with buzzy edges",
    "sharp, metallic guitar plucks with quick decay",
    "crunchy overdriven guitar riffs with grit",
    "warm, mellow acoustic guitar strumming",
    "glassy, chiming harmonics on electric guitar",
    "thick, muffled palm-muted guitar rhythm",
    "raw, gritty blues guitar with expressive bends",
    "lush, chorus-soaked clean guitar chords",
    "dark, smoky cello ensemble in low register",
    "warm, smooth piano chords with soft transients",
    "bright, percussive piano stabs",
    "mellow, emotional grand piano melody",
    "dark, muted piano with felt-like tone",
    "sparkly, bell-like piano arpeggios",
    "soulful Hammond organ chords with slow rotary effect",
    "punchy, percussive clavinet riff with sharp attack",
    "dreamy, detuned synth-piano hybrid with soft transients",
    "crunchy, slightly distorted Rhodes with bite",
    "aggressive, bright synth lead with sharp harmonics",
    "warm, analog synth pad with gentle movement",
    "grainy, lofi synth lead with noise texture",
    "soft, airy synth lead with gentle brightness",
    "detuned, wobbling synth tone with drifting pitch",
    "raspy, resonant filter-sweep synth lead",
    "sparkly, crystalline synth plucks with short decay",
    "thick, wide supersaw lead with stereo spread",
    "hollow, formant-shifted synth tone with vowel-like quality",
    "lush, wide pad with dreamy texture",
    "dark, evolving ambient pad with low rumble",
    "celestial, shimmering pad with high-frequency sparkle",
    "hollow, airy pad with subtle modulation",
    "moody drone with slow-moving harmonics",
    "detuned, warm synth pad with analog drift",
    "ethereal, floating texture with soft overtones",
    "deep, warm sub-bass with smooth sine texture",
    "gritty, distorted bass with heavy saturation",
    "rubbery, bouncy synth bass with fast transients",
    "thick, resonant low-end bass with movement",
    "clean, round bass with gentle harmonics",
    "fuzzy, aggressive bass with buzz-saw texture",
    "white-noise burst with bright edges",
    "grainy, textured noise with soft filtering",
    "distorted, chaotic noise bed with harsh peaks",
    "warm, analog noise bed with subtle movement",
    "glitchy, stuttering texture with digital artifacts",
    "noisy, rough, chaotic saw-like texture",
    "tight, punchy kick drum with sharp attack",
    "snappy, bright snare with crisp transient",
    "warm, rounded tom hits with soft decay",
    "bright hi-hat pattern with metallic shimmer",
    "airy, crisp percussion groove with stereo shimmer",
    "tight, dry breakbeat with fast transients",
    "boomy, cinematic taiko-style drum hits",
    "gritty, overcompressed drum loop with pumping artifacts",
    "deep, resonant floor tom with long sustain",
    "sharp, percussive rimshot with strong transient",
    "bright, resonant violin melody",
    "soft, expressive flute line with airy tone",
    "warm clarinet phrase with smooth transitions",
    "dark, breathy saxophone melody",
    "bold brass stabs with powerful attack",
    "soft, lush orchestral strings with gentle swelling",
    "warm, woody double-bass plucks",
    "smooth horn section with warm resonance",
    "bright marimba strikes with clean attack",
    "glassy, shimmering celeste notes",
    "warm vibraphone chords with soft tremolo",
    "sharp, metallic bell hits with long decay",
    "hollow kalimba plucks with woody texture",
    "distant, echoing ambient chords",
    "soft, hazy reverb-washed tones",
    "crystalline ambient washes with airy diffusion",
    "deep, cavernous drone with subharmonics",
    "slow, swelling cinematic texture",
    "retro 80s synthwave lead with chorus",
    "dark industrial tone with metallic grit",
    "lofi, tape-warped acoustic texture",
    "robotic, vocoder-like synthetic tone",
    "electronic plucks with rapid transient snap",
    "warm, resonant plucks with rounded body",
    "metallic, atonal texture with shifting harmonics",
    "heavy, saturated resonant tone with compression pump"
]

# Randomly sample descriptions for this RLHF run
print("="*60)
print("TRAINING DESCRIPTIONS FOR RLHF")
print("="*60)
print(f"Total available: {len(TRAINING_DESCRIPTIONS)}")
print(f"Sampling: {RLHF_CONFIG['num_feedback_samples']} descriptions")
print()

# Set seed for reproducibility (change seed for different random samples)
import random
random_seed = 42  # Change this to get different samples each run
np.random.seed(random_seed)
random.seed(random_seed)

# Randomly sample from all descriptions
train_descriptions = random.sample(TRAINING_DESCRIPTIONS, RLHF_CONFIG['num_feedback_samples'])

print(f"✓ Randomly selected {len(train_descriptions)} descriptions")
print(f"  (To get different samples, change random_seed in this cell)")
print(f"\nFirst 5 selected:")
for i, desc in enumerate(train_descriptions[:5]):
    print(f"  {i+1}. {desc}")


TRAINING DESCRIPTIONS FOR RLHF
Total available: 84
Sampling: 30 descriptions

✓ Randomly selected 30 descriptions
  (To get different samples, change random_seed in this cell)

First 5 selected:
  1. warm, resonant plucks with rounded body
  2. dark, muted piano with felt-like tone
  3. sharp, metallic guitar plucks with quick decay
  4. ethereal, floating texture with soft overtones
  5. celestial, shimmering pad with high-frequency sparkle


---
## Helper Functions
---

In [54]:
# Cell 7: Helper Functions (with AMPLIFICATION)

def amplify_archetypes(archetypes, temperature=2.0):
    """
    Amplify archetype differences by applying temperature scaling
    
    temperature > 1.0: Makes distribution more extreme (sharper)
    temperature < 1.0: Makes distribution more uniform (softer)
    """
    # Apply temperature scaling
    logits = np.log(archetypes + 1e-8) / temperature
    exp_logits = np.exp(logits - np.max(logits))  # Subtract max for numerical stability
    amplified = exp_logits / exp_logits.sum()
    return amplified


def transform_with_text(model, base_audio, description, device, amplify=True, temperature=2.0):
    """
    Transform base audio using text description
    
    CRITICAL FIX: Uses TEXT-ONLY archetype prediction!
    amplify: If True, amplify archetype differences for more noticeable transformations
    temperature: Higher = more extreme differences (try 2.0-5.0)
    """
    # Convert audio to tensor
    audio_tensor = torch.from_numpy(base_audio).unsqueeze(0).float().to(device)
    
    with torch.no_grad():
        # Get text embedding
        text_emb = model.encode_text([description])
        
        # CRITICAL FIX: Use ZERO audio embedding to force text-only prediction
        zero_audio_emb = torch.zeros_like(text_emb).to(device)
        
        # Predict archetypes using text-only
        archetype_weights = model.predict_archetypes(text_emb, zero_audio_emb)
        archetype_weights_np = archetype_weights.squeeze(0).cpu().detach().numpy()
        
        # Amplify differences if requested
        if amplify:
            original_weights = archetype_weights_np.copy()
            archetype_weights_np = amplify_archetypes(archetype_weights_np, temperature)
            archetype_weights = torch.from_numpy(archetype_weights_np).unsqueeze(0).float().to(device)
        
        # Transform audio using DDSP
        transformed = model.transform_audio(audio_tensor, archetype_weights)
    
    # Convert back
    transformed_audio = transformed.squeeze(0).cpu().numpy()
    
    return transformed_audio, archetype_weights_np


def evaluate_on_real_test_set(model, test_audio, test_descriptions, test_archetypes, device):
    """Evaluate model on real test data"""
    model.eval()
    results = []
    
    with torch.no_grad():
        for i in tqdm(range(len(test_audio)), desc="Evaluating"):
            audio = test_audio[i]
            description = test_descriptions[i]
            target_archetypes = test_archetypes[i]
            
            # Model prediction - use REAL audio embeddings for evaluation
            audio_tensor = torch.from_numpy(audio).unsqueeze(0).float().to(device)
            _, metadata = model.inference([description], audio_tensor)
            pred_arch = metadata['predicted_weights'][0]
            
            # Metrics
            arch_mse = np.mean((pred_arch - target_archetypes) ** 2)
            cos_sim = np.dot(pred_arch, target_archetypes) / (
                np.linalg.norm(pred_arch) * np.linalg.norm(target_archetypes) + 1e-8
            )
            pred_max = np.argmax(pred_arch)
            target_max = np.argmax(target_archetypes)
            arch_correct = int(pred_max == target_max)
            
            results.append({
                'description': description,
                'predicted_archetypes': pred_arch.tolist(),
                'target_archetypes': target_archetypes.tolist(),
                'archetype_mse': float(arch_mse),
                'cosine_similarity': float(cos_sim),
                'archetype_correct': arch_correct
            })
    
    # Aggregate
    avg_mse = np.mean([r['archetype_mse'] for r in results])
    avg_cos_sim = np.mean([r['cosine_similarity'] for r in results])
    arch_accuracy = np.mean([r['archetype_correct'] for r in results]) * 100
    
    return results, {
        'avg_mse': avg_mse,
        'avg_cosine_similarity': avg_cos_sim,
        'archetype_accuracy': arch_accuracy
    }


def play_comparison(original, transformed, sample_rate=44100):
    """Play audio comparison"""
    print("🎵 ORIGINAL:")
    display(ipd.Audio(original, rate=sample_rate))
    print("\n🎵 TRANSFORMED:")
    display(ipd.Audio(transformed, rate=sample_rate))


def get_rating_from_user(description, attempt=0, max_attempts=3):
    """Get rating from user"""
    print("\n" + "="*60)
    print("RATE THE TRANSFORMATION")
    print("="*60)
    print(f"Description: '{description}'")
    print("\nHow well does transformation match description?")
    print("  5 = Perfect")
    print("  4 = Good")
    print("  3 = Acceptable")
    print("  2 = Poor")
    print("  1 = Very poor")
    
    try:
        rating = input("\nYour rating (1-5): ").strip()
        rating = int(rating)
        if rating not in [1, 2, 3, 4, 5]:
            print("⚠️  Rating must be 1-5")
            if attempt < max_attempts:
                return get_rating_from_user(description, attempt + 1, max_attempts)
            return 3
        return rating
    except ValueError:
        print("⚠️  Please enter 1-5")
        if attempt < max_attempts:
            return get_rating_from_user(description, attempt + 1, max_attempts)
        return 3
    except KeyboardInterrupt:
        print("\n\n⚠️  Interrupted")
        raise


def compute_reward(rating, threshold=3.0):
    """Convert rating to reward"""
    return (rating - 3.0) / 2.0


print("✓ Helper functions loaded")
print("✓ Using TEXT-ONLY archetype prediction with AMPLIFICATION!")
print("  Temperature should make archetype differences more extreme")

✓ Helper functions loaded
✓ Using TEXT-ONLY archetype prediction with AMPLIFICATION!
  Temperature should make archetype differences more extreme


---
## Pre-RLHF Baseline
---

In [8]:
# Cell 8: Pre-RLHF Baseline

print("="*60)
print("PRE-RLHF BASELINE EVALUATION")
print("="*60)
print(f"\nEvaluating on {len(test_audio)} REAL test samples")
print("Baseline before RLHF...\n")

baseline_results, baseline_metrics = evaluate_on_real_test_set(
    model, test_audio, test_descriptions, test_archetypes, device
)

print("\n" + "="*60)
print("BASELINE RESULTS")
print("="*60)
print(f"MSE:              {baseline_metrics['avg_mse']:.6f}")
print(f"Cosine sim:       {baseline_metrics['avg_cosine_similarity']:.4f}")
print(f"Accuracy:         {baseline_metrics['archetype_accuracy']:.2f}%")

with open(OUTPUT_DIR / 'baseline_results.json', 'w') as f:
    json.dump({
        'metrics': baseline_metrics,
        'detailed_results': baseline_results[:10]
    }, f, indent=2)

print(f"\n✓ Saved: {OUTPUT_DIR / 'baseline_results.json'}")

PRE-RLHF BASELINE EVALUATION

Evaluating on 73 REAL test samples
Baseline before RLHF...



Evaluating: 100%|██████████| 73/73 [00:52<00:00,  1.39it/s]


BASELINE RESULTS
MSE:              0.105025
Cosine sim:       0.5820
Accuracy:         31.51%

✓ Saved: rlhf_hybrid_results/baseline_results.json





In [41]:
# Diagnotic Test 1: Testing out that ddsp is working

print("="*60)
print("TESTING TEXT-ONLY PREDICTIONS")
print("="*60)

test_descs = [
    "bright, cutting guitar tone",
    "warm, smooth piano",
    "deep sub-bass"
]

model.eval()
for desc in test_descs:
    text_emb = model.encode_text([desc])
    zero_audio_emb = torch.zeros_like(text_emb).to(device)
    archetypes = model.predict_archetypes(text_emb, zero_audio_emb)
    arch_np = archetypes.squeeze(0).cpu().detach().numpy()
    
    print(f"\n'{desc}':")
    for name, weight in zip(['sine', 'square', 'sawtooth', 'triangle', 'noise'], arch_np):
        print(f"  {name:10s}: {weight:.4f}")

print("\n" + "="*60)

TESTING TEXT-ONLY PREDICTIONS

'bright, cutting guitar tone':
  sine      : 0.2680
  square    : 0.1669
  sawtooth  : 0.1923
  triangle  : 0.2100
  noise     : 0.1628

'warm, smooth piano':
  sine      : 0.2201
  square    : 0.1185
  sawtooth  : 0.3346
  triangle  : 0.1641
  noise     : 0.1627

'deep sub-bass':
  sine      : 0.2201
  square    : 0.1877
  sawtooth  : 0.2522
  triangle  : 0.1838
  noise     : 0.1562



In [51]:
# Diagnostic Test 2

print("="*60)
print("TESTING DDSP WITH TEXT-ONLY PREDICTIONS")
print("="*60)

# Test with actual text-only predictions
test_descs = [
    "bright, cutting guitar tone",
    "warm, smooth piano",
    "deep sub-bass"
]

model.eval()
base_audio_tensor = torch.from_numpy(default_base_tone).unsqueeze(0).float().to(device)

print(f"\nBase tone stats:")
print(f"  Mean: {default_base_tone.mean():.6f}")
print(f"  Std:  {default_base_tone.std():.6f}")
print(f"  Max:  {np.abs(default_base_tone).max():.6f}")

transformed_audios = []

for desc in test_descs:
    with torch.no_grad():
        text_emb = model.encode_text([desc])
        zero_audio_emb = torch.zeros_like(text_emb).to(device)
        archetypes = model.predict_archetypes(text_emb, zero_audio_emb)
        
        # Transform with DDSP
        transformed = model.transform_audio(base_audio_tensor, archetypes)
        transformed_np = transformed.squeeze(0).cpu().numpy()
        transformed_audios.append(transformed_np)
        
        arch_np = archetypes.squeeze(0).cpu().detach().numpy()
        
        print(f"\n'{desc}':")
        print(f"  Archetypes: {arch_np}")
        print(f"  Transformed mean: {transformed_np.mean():.6f}")
        print(f"  Transformed std:  {transformed_np.std():.6f}")
        print(f"  Transformed max:  {np.abs(transformed_np).max():.6f}")
        
        # Check difference from base
        min_len = min(len(default_base_tone), len(transformed_np))
        diff = np.abs(default_base_tone[:min_len] - transformed_np[:min_len])
        print(f"  Diff from base: mean={diff.mean():.6f}, max={diff.max():.6f}")

# Check if transformed audios differ from each other
print("\n" + "="*60)
print("COMPARING TRANSFORMED AUDIOS")
print("="*60)

for i in range(len(transformed_audios)):
    for j in range(i+1, len(transformed_audios)):
        min_len = min(len(transformed_audios[i]), len(transformed_audios[j]))
        diff = np.abs(transformed_audios[i][:min_len] - transformed_audios[j][:min_len])
        same = np.allclose(transformed_audios[i][:min_len], transformed_audios[j][:min_len], atol=1e-4)
        
        print(f"\n'{test_descs[i]}' vs '{test_descs[j]}':")
        print(f"  Diff: mean={diff.mean():.6f}, max={diff.max():.6f}")
        print(f"  Same? {same}")

print("\n" + "="*60)
print("CONCLUSION")
print("="*60)

all_same = True
for i in range(len(transformed_audios)):
    for j in range(i+1, len(transformed_audios)):
        min_len = min(len(transformed_audios[i]), len(transformed_audios[j]))
        if not np.allclose(transformed_audios[i][:min_len], transformed_audios[j][:min_len], atol=1e-4):
            all_same = False
            break

if all_same:
    print("   PROBLEM: All transformed audios are IDENTICAL!")
    print("   DDSP is not responding to different archetype weights.")
    print("   Possible causes:")
    print("   1. DDSP weights weren't trained")
    print("   2. Archetype differences too subtle")
    print("   3. DDSP needs more extreme weight differences")
    print("\n   SOLUTION: Use manual transforms with amplified differences")
else:
    print("✓ GOOD: Transformed audios ARE different!")
    print("   DDSP is working with TEXT-ONLY predictions.")
    print("   Differences might be subtle - listen carefully!")

TESTING DDSP WITH TEXT-ONLY PREDICTIONS

Base tone stats:
  Mean: -0.000000
  Std:  0.353551
  Max:  0.500000

'bright, cutting guitar tone':
  Archetypes: [0.26796505 0.166875   0.19232759 0.21002379 0.16280863]
  Transformed mean: 0.000067
  Transformed std:  0.591250
  Transformed max:  0.950000
  Diff from base: mean=0.217201, max=0.451453

'warm, smooth piano':
  Archetypes: [0.22011507 0.11852196 0.33460677 0.164053   0.16270319]
  Transformed mean: 0.000101
  Transformed std:  0.588581
  Transformed max:  0.950000
  Diff from base: mean=0.213869, max=0.454205

'deep sub-bass':
  Archetypes: [0.2201132  0.18768564 0.25220728 0.1838188  0.15617506]
  Transformed mean: -0.000068
  Transformed std:  0.579212
  Transformed max:  0.950000
  Diff from base: mean=0.206497, max=0.450023

COMPARING TRANSFORMED AUDIOS

'bright, cutting guitar tone' vs 'warm, smooth piano':
  Diff: mean=0.040918, max=0.223468
  Same? False

'bright, cutting guitar tone' vs 'deep sub-bass':
  Diff: mean=0.04

---
## RLHF Training (Synthetic)
---

In [None]:
# Cell 9: RLHF Training (with debugging)

print("="*60)
print("RLHF FEEDBACK COLLECTION")
print("="*60)
print(f"\nRating {len(train_descriptions)} transformations")
print("Training on SYNTHETIC, testing on REAL!\n")
print(f"Time: ~{len(train_descriptions) * 1.5:.0f}-{len(train_descriptions) * 2:.0f} min\n")

input("Press Enter to start...")

# Setup
feedback_history = []
optimizer = optim.Adam(model.parameters(), lr=RLHF_CONFIG['learning_rate'])
archetype_names = ['sine', 'square', 'sawtooth', 'triangle', 'noise']

for sample_idx, description in enumerate(train_descriptions):
    print(f"\n{'='*60}")
    print(f"Sample {sample_idx + 1}/{len(train_descriptions)}")
    print(f"{'='*60}")
    
    # Keep model in EVAL mode for transformation
    model.eval()
    
    # 🔍 DEBUG: Check base tone before transformation
    print(f"\n🔍 DEBUG - Before transform:")
    print(f"  default_base_tone id: {id(default_base_tone)}")
    print(f"  default_base_tone stats: mean={default_base_tone.mean():.6f}, std={default_base_tone.std():.6f}")
    
    # Transform
    transformed, archetypes = transform_with_text(
        model, default_base_tone, description, device,
        amplify=RLHF_CONFIG['amplify_archetypes'],
        temperature=RLHF_CONFIG['amplification_temperature']
)
    
    # 🔍 DEBUG: Check after transformation
    print(f"\n🔍 DEBUG - After transform:")
    print(f"  transformed id: {id(transformed)}")
    print(f"  transformed stats: mean={transformed.mean():.6f}, std={transformed.std():.6f}")
    print(f"  Same object? {id(default_base_tone) == id(transformed)}")
    
    min_len = min(len(default_base_tone), len(transformed))
    diff_check = np.abs(default_base_tone[:min_len] - transformed[:min_len])
    print(f"  Diff: mean={diff_check.mean():.6f}, max={diff_check.max():.6f}")
    print(f"  Are they identical? {np.allclose(default_base_tone[:min_len], transformed[:min_len], atol=1e-6)}")
    
    print(f"\nDescription: '{description}'")
    print(f"\nPredicted archetypes:")
    for name, weight in zip(archetype_names, archetypes):
        bar = '█' * int(weight * 30)
        print(f"  {name:10s}: {weight:.3f} {bar}")
    
    # Detailed audio diagnostics
    print(f"\n🔍 Transformation diagnostics:")
    orig_mean = np.mean(default_base_tone)
    orig_std = np.std(default_base_tone)
    orig_max = np.max(np.abs(default_base_tone))
    trans_mean = np.mean(transformed)
    trans_std = np.std(transformed)
    trans_max = np.max(np.abs(transformed))
    
    print(f"  Original:    mean={orig_mean:.6f}, std={orig_std:.6f}, max={orig_max:.6f}")
    print(f"  Transformed: mean={trans_mean:.6f}, std={trans_std:.6f}, max={trans_max:.6f}")
    
    # Check if actually different
    if np.allclose(default_base_tone[:min_len], transformed[:min_len], atol=1e-4):
        print(f"  ⚠️  WARNING: Audio appears UNCHANGED!")
        print(f"  This means DDSP didn't transform it.")
    else:
        # Calculate difference
        diff = np.abs(default_base_tone[:min_len] - transformed[:min_len])
        print(f"  ✓ Audio IS different (avg diff: {np.mean(diff):.6f})")
        print(f"  Transformation strength: {(np.mean(diff) / orig_std * 100):.1f}% of original std")
    
    # 🔍 DEBUG: What are we about to play?
    print(f"\n🔍 DEBUG - About to play:")
    print(f"  Original audio: shape={default_base_tone.shape}, first 5 values: {default_base_tone[:5]}")
    print(f"  Transformed audio: shape={transformed.shape}, first 5 values: {transformed[:5]}")
    
    # Play
    print("\n" + "-"*60)
    print("🎵 Playing ORIGINAL:")
    display(ipd.Audio(default_base_tone, rate=MODEL_CONFIG['sample_rate']))
    print("\n🎵 Playing TRANSFORMED:")
    display(ipd.Audio(transformed, rate=MODEL_CONFIG['sample_rate']))
    print("-"*60)
    
    # Rate
    rating = get_rating_from_user(description)
    reward = compute_reward(rating, RLHF_CONFIG['reward_threshold'])
    
    print(f"\n✓ Rating: {rating}/5, Reward: {reward:+.2f}")
    
    # Store
    feedback_history.append({
        'sample_idx': sample_idx,
        'description': description,
        'archetypes': archetypes.tolist(),
        'rating': int(rating),
        'reward': float(reward),
        'timestamp': datetime.now().isoformat()
    })
    
    # Update model
    if (sample_idx + 1) % RLHF_CONFIG['update_frequency'] == 0:
        print(f"\n🔄 Updating model...")
        
        # Switch to TRAIN mode for gradient updates
        model.train()
        
        recent = feedback_history[-RLHF_CONFIG['update_frequency']:]
        avg_reward = np.mean([f['reward'] for f in recent])
        print(f"  Recent avg reward: {avg_reward:+.3f}")
        
        # Policy gradient
        for epoch in range(RLHF_CONFIG['num_rlhf_epochs']):
            epoch_loss = 0.0
            
            # Create batch from recent feedback for BatchNorm
            batch_texts = [f['description'] for f in recent]
            batch_rewards = torch.tensor([f['reward'] for f in recent]).float().to(device)
            
            # Get text embeddings for batch
            text_emb_batch = model.encode_text(batch_texts)
            
            # CRITICAL FIX: Use ZERO audio embeddings (TEXT-ONLY prediction)
            zero_audio_emb_batch = torch.zeros_like(text_emb_batch).to(device)
            
            # Predict archetypes using text-only for entire batch
            predicted_archetypes_batch = model.predict_archetypes(text_emb_batch, zero_audio_emb_batch)
            
            # Policy gradient loss for batch
            log_probs = torch.log(predicted_archetypes_batch + 1e-8).sum(dim=1)
            loss = -(log_probs * batch_rewards).mean()
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
        
        print(f"  ✓ Updated (loss: {epoch_loss/RLHF_CONFIG['num_rlhf_epochs']:.4f})")
        
        # Switch back to EVAL mode
        model.eval()

print(f"\n\n{'='*60}")
print("RLHF COMPLETE")
print(f"{'='*60}")
print(f"Total: {len(feedback_history)}")
print(f"Avg rating: {np.mean([f['rating'] for f in feedback_history]):.2f}")
print(f"Avg reward: {np.mean([f['reward'] for f in feedback_history]):+.3f}")
print(f"Positive (≥3): {sum(1 for f in feedback_history if f['rating'] >= 3)}/{len(feedback_history)}")

with open(OUTPUT_DIR / 'feedback_history.json', 'w') as f:
    json.dump(feedback_history, f, indent=2)
print(f"\n✓ Saved: {OUTPUT_DIR / 'feedback_history.json'}")

RLHF FEEDBACK COLLECTION

Rating 30 transformations
Training on SYNTHETIC, testing on REAL!

Time: ~45-60 min


Sample 1/30

🔍 DEBUG - Before transform:
  default_base_tone id: 136505601351984
  default_base_tone stats: mean=-0.000000, std=0.353551

🔍 DEBUG - After transform:
  transformed id: 136505138577008
  transformed stats: mean=-0.000185, std=0.579759
  Same object? False
  Diff: mean=0.206665, max=0.458053
  Are they identical? False

Description: 'warm, resonant plucks with rounded body'

Predicted archetypes:
  sine      : 0.370 ███████████
  square    : 0.157 ████
  sawtooth  : 0.157 ████
  triangle  : 0.157 ████
  noise     : 0.157 ████

🔍 Transformation diagnostics:
  Original:    mean=-0.000000, std=0.353551, max=0.500000
  Transformed: mean=-0.000185, std=0.579759, max=0.950000
  ✓ Audio IS different (avg diff: 0.206665)
  Transformation strength: 58.5% of original std

🔍 DEBUG - About to play:
  Original audio: shape=(88200,), first 5 values: [0.         0.01566995 0.03


🎵 Playing TRANSFORMED:


------------------------------------------------------------

RATE THE TRANSFORMATION
Description: 'warm, resonant plucks with rounded body'

How well does transformation match description?
  5 = Perfect
  4 = Good
  3 = Acceptable
  2 = Poor
  1 = Very poor

✓ Rating: 4/5, Reward: +0.50

Sample 2/30

🔍 DEBUG - Before transform:
  default_base_tone id: 136505601351984
  default_base_tone stats: mean=-0.000000, std=0.353551

🔍 DEBUG - After transform:
  transformed id: 136505044501200
  transformed stats: mean=0.000136, std=0.574320
  Same object? False
  Diff: mean=0.201660, max=0.450710
  Are they identical? False

Description: 'dark, muted piano with felt-like tone'

Predicted archetypes:
  sine      : 0.157 ████
  square    : 0.157 ████
  sawtooth  : 0.157 ████
  triangle  : 0.370 ███████████
  noise     : 0.157 ████

🔍 Transformation diagnostics:
  Original:    mean=-0.000000, std=0.353551, max=0.500000
  Transformed: mean=0.000136, std=0.574320, max=0.950000
  ✓ Audio IS different


🎵 Playing TRANSFORMED:


------------------------------------------------------------

RATE THE TRANSFORMATION
Description: 'dark, muted piano with felt-like tone'

How well does transformation match description?
  5 = Perfect
  4 = Good
  3 = Acceptable
  2 = Poor
  1 = Very poor

✓ Rating: 2/5, Reward: -0.50

Sample 3/30

🔍 DEBUG - Before transform:
  default_base_tone id: 136505601351984
  default_base_tone stats: mean=-0.000000, std=0.353551

🔍 DEBUG - After transform:
  transformed id: 136505598031024
  transformed stats: mean=-0.000055, std=0.598653
  Same object? False
  Diff: mean=0.223477, max=0.465019
  Are they identical? False

Description: 'sharp, metallic guitar plucks with quick decay'

Predicted archetypes:
  sine      : 0.157 ████
  square    : 0.157 ████
  sawtooth  : 0.370 ███████████
  triangle  : 0.157 ████
  noise     : 0.157 ████

🔍 Transformation diagnostics:
  Original:    mean=-0.000000, std=0.353551, max=0.500000
  Transformed: mean=-0.000055, std=0.598653, max=0.950000
  ✓ Audio IS 


🎵 Playing TRANSFORMED:


------------------------------------------------------------

RATE THE TRANSFORMATION
Description: 'sharp, metallic guitar plucks with quick decay'

How well does transformation match description?
  5 = Perfect
  4 = Good
  3 = Acceptable
  2 = Poor
  1 = Very poor
⚠️  Please enter 1-5

RATE THE TRANSFORMATION
Description: 'sharp, metallic guitar plucks with quick decay'

How well does transformation match description?
  5 = Perfect
  4 = Good
  3 = Acceptable
  2 = Poor
  1 = Very poor
⚠️  Please enter 1-5

RATE THE TRANSFORMATION
Description: 'sharp, metallic guitar plucks with quick decay'

How well does transformation match description?
  5 = Perfect
  4 = Good
  3 = Acceptable
  2 = Poor
  1 = Very poor
⚠️  Please enter 1-5

RATE THE TRANSFORMATION
Description: 'sharp, metallic guitar plucks with quick decay'

How well does transformation match description?
  5 = Perfect
  4 = Good
  3 = Acceptable
  2 = Poor
  1 = Very poor


⚠️  Interrupted


KeyboardInterrupt: Interrupted by user

---
## Post-RLHF Evaluation (Real)
---

In [None]:
# Cell 10: Post-RLHF Evaluation

print("="*60)
print("POST-RLHF EVALUATION")
print("="*60)
print(f"\nEvaluating on {len(test_audio)} REAL test samples")
print("Did synthetic training help?\n")

postrlhf_results, postrlhf_metrics = evaluate_on_real_test_set(
    model, test_audio, test_descriptions, test_archetypes, device
)

print("\n" + "="*60)
print("POST-RLHF RESULTS")
print("="*60)
print(f"MSE:              {postrlhf_metrics['avg_mse']:.6f}")
print(f"Cosine sim:       {postrlhf_metrics['avg_cosine_similarity']:.4f}")
print(f"Accuracy:         {postrlhf_metrics['archetype_accuracy']:.2f}%")

with open(OUTPUT_DIR / 'postrlhf_results.json', 'w') as f:
    json.dump({
        'metrics': postrlhf_metrics,
        'detailed_results': postrlhf_results[:10]
    }, f, indent=2)

print(f"\n✓ Saved: {OUTPUT_DIR / 'postrlhf_results.json'}")

In [None]:
# Cell 11: Compare Results

print("="*60)
print("PRE vs POST RLHF COMPARISON")
print("="*60)
print("\nTrained: Synthetic (30 samples)")
print("Tested:  Real MusicCaps (73 samples)\n")

print("Metric                  Before         After          Change")
print("-" * 65)

mse_before = baseline_metrics['avg_mse']
mse_after = postrlhf_metrics['avg_mse']
mse_change = mse_after - mse_before
mse_pct = (mse_change / mse_before) * 100 if mse_before > 0 else 0
print(f"MSE:                    {mse_before:.6f}     {mse_after:.6f}     {mse_change:+.6f} ({mse_pct:+.1f}%)")

cos_before = baseline_metrics['avg_cosine_similarity']
cos_after = postrlhf_metrics['avg_cosine_similarity']
cos_change = cos_after - cos_before
cos_pct = (cos_change / cos_before) * 100 if cos_before > 0 else 0
print(f"Cosine Sim:             {cos_before:.6f}     {cos_after:.6f}     {cos_change:+.6f} ({cos_pct:+.1f}%)")

acc_before = baseline_metrics['archetype_accuracy']
acc_after = postrlhf_metrics['archetype_accuracy']
acc_change = acc_after - acc_before
print(f"Accuracy:               {acc_before:.2f}%        {acc_after:.2f}%        {acc_change:+.2f}%")

print("\n" + "="*60)
print("INTERPRETATION")
print("="*60)

if mse_change < 0:
    print("✓ MSE decreased (better predictions)")
else:
    print("⚠ MSE increased")

if cos_change > 0:
    print("✓ Cosine similarity increased (better alignment)")
else:
    print("⚠ Cosine similarity decreased")

if acc_change > 0:
    print("✓ Accuracy increased")
else:
    print("⚠ Accuracy decreased")

print("\nSynthetic → Real transfer: ", end="")
if mse_change < 0 and cos_change > 0:
    print("✓ SUCCESS!")
elif abs(mse_change/mse_before) < 0.05:
    print("✓ STABLE (didn't overfit)")
else:
    print("⚠ MIXED RESULTS")

# Save comparison
comparison = {
    'training': {
        'data_type': 'synthetic',
        'num_samples': len(feedback_history),
        'avg_rating': float(np.mean([f['rating'] for f in feedback_history])),
        'positive_pct': float(100 * sum(1 for f in feedback_history if f['rating'] >= 3) / len(feedback_history))
    },
    'evaluation': {
        'data_type': 'real',
        'num_samples': len(test_audio)
    },
    'metrics': {
        'before': baseline_metrics,
        'after': postrlhf_metrics,
        'changes': {
            'mse_change': float(mse_change),
            'mse_pct': float(mse_pct),
            'cos_change': float(cos_change),
            'cos_pct': float(cos_pct),
            'acc_change': float(acc_change)
        }
    }
}

with open(OUTPUT_DIR / 'comparison.json', 'w') as f:
    json.dump(comparison, f, indent=2)

print(f"\n✓ Saved: {OUTPUT_DIR / 'comparison.json'}")

---
## Save & Visualize
---

In [None]:
# Cell 12: Save Model

print("="*60)
print("SAVING MODEL")
print("="*60)

final_model_path = OUTPUT_DIR / 'rlhf_final_model.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': MODEL_CONFIG,
    'rlhf_config': RLHF_CONFIG,
    'training_summary': comparison['training'],
    'evaluation_summary': comparison['evaluation'],
    'timestamp': datetime.now().isoformat()
}, final_model_path)

print(f"✓ Saved: {final_model_path}")
print(f"  Size: {final_model_path.stat().st_size / 1024 / 1024:.1f} MB")

In [None]:
# Cell 13: Visualizations

print("="*60)
print("VISUALIZATIONS")
print("="*60)

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Hybrid RLHF: Synthetic Training → Real Evaluation', fontsize=16, fontweight='bold')

# Training plots
ratings = [f['rating'] for f in feedback_history]
rewards = [f['reward'] for f in feedback_history]

axes[0, 0].hist(ratings, bins=5, range=(0.5, 5.5), edgecolor='black', alpha=0.7, color='skyblue')
axes[0, 0].set_title('Training: Ratings (Synthetic)')
axes[0, 0].set_xlabel('Rating')
axes[0, 0].set_xticks([1, 2, 3, 4, 5])
axes[0, 0].grid(axis='y', alpha=0.3)

axes[0, 1].plot(rewards, 'o-', alpha=0.6, color='green')
axes[0, 1].axhline(y=0, color='r', linestyle='--', alpha=0.5)
axes[0, 1].set_title('Training: Rewards')
axes[0, 1].set_xlabel('Sample')
axes[0, 1].grid(alpha=0.3)

cumulative = np.cumsum([1 if r >= 3 else 0 for r in ratings])
axes[0, 2].plot(cumulative, 'o-', color='green', alpha=0.6)
axes[0, 2].set_title('Training: Cumulative Positive')
axes[0, 2].set_xlabel('Sample')
axes[0, 2].grid(alpha=0.3)

# Evaluation plots
axes[1, 0].bar(['Before', 'After'], [mse_before, mse_after], color=['coral', 'lightgreen'], edgecolor='black')
axes[1, 0].set_title('Eval: MSE (Real Test)')
axes[1, 0].grid(axis='y', alpha=0.3)

axes[1, 1].bar(['Before', 'After'], [cos_before, cos_after], color=['coral', 'lightgreen'], edgecolor='black')
axes[1, 1].set_title('Eval: Cosine Sim (Real Test)')
axes[1, 1].grid(axis='y', alpha=0.3)

axes[1, 2].bar(['Before', 'After'], [acc_before, acc_after], color=['coral', 'lightgreen'], edgecolor='black')
axes[1, 2].set_title('Eval: Accuracy (Real Test)')
axes[1, 2].grid(axis='y', alpha=0.3)

plt.tight_layout()
viz_path = OUTPUT_DIR / 'rlhf_hybrid_analysis.png'
plt.savefig(viz_path, dpi=150, bbox_inches='tight')
print(f"\n✓ Saved: {viz_path}")
plt.show()

In [None]:
# Cell 14: Final Report

report = f"""
HYBRID RLHF - FINAL REPORT
{'='*70}

APPROACH
--------
Training:   Synthetic base tones + text descriptions ({len(feedback_history)} samples)
Evaluation: Real MusicCaps test audio ({len(test_audio)} samples)
Goal:       Test if synthetic RLHF → real improvement

TRAINING (Synthetic)
--------------------
Samples:    {len(feedback_history)}
Avg rating: {np.mean([f['rating'] for f in feedback_history]):.2f}/5
Positive:   {sum(1 for f in feedback_history if f['rating'] >= 3)}/{len(feedback_history)} ({100*sum(1 for f in feedback_history if f['rating'] >= 3)/len(feedback_history):.1f}%)

EVALUATION (Real Test Data)
---------------------------
Metric              Before         After          Change
──────────────────────────────────────────────────────────────
MSE:                {mse_before:.6f}     {mse_after:.6f}     {mse_change:+.6f} ({mse_pct:+.1f}%)
Cosine Sim:         {cos_before:.6f}     {cos_after:.6f}     {cos_change:+.6f} ({cos_pct:+.1f}%)
Accuracy:           {acc_before:.2f}%        {acc_after:.2f}%        {acc_change:+.2f}%

CONCLUSION
----------
"""

if mse_change < 0 and cos_change > 0:
    report += "✓ SUCCESS: Synthetic training improved real performance!\n"
elif abs(mse_change / mse_before) < 0.05:
    report += "✓ STABLE: Model maintained performance (didn't overfit)\n"
else:
    report += "⚠ MIXED: Some metrics improved, others declined\n"

report += f"""
This demonstrates {'successful' if mse_change < 0 else 'the challenge of'} transfer learning
from synthetic to real audio in RLHF settings.

FILES
-----
Model:       {OUTPUT_DIR / 'rlhf_final_model.pth'}
Feedback:    {OUTPUT_DIR / 'feedback_history.json'}
Baseline:    {OUTPUT_DIR / 'baseline_results.json'}
Post-RLHF:   {OUTPUT_DIR / 'postrlhf_results.json'}
Comparison:  {OUTPUT_DIR / 'comparison.json'}
Viz:         {OUTPUT_DIR / 'rlhf_hybrid_analysis.png'}

{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""

print(report)

with open(OUTPUT_DIR / 'final_report.txt', 'w') as f:
    f.write(report)

print(f"✓ Saved: {OUTPUT_DIR / 'final_report.txt'}")
print("\n" + "="*60)
print("✅ COMPLETE!")
print("="*60)
print(f"\nAll outputs: {OUTPUT_DIR}/")
print("\n🎉 Hybrid RLHF done! Trained synthetic, tested real.")