# RLHF Pipeline for LSTMABAR Model (v5)

This notebook implements **Reinforcement Learning from Human Feedback (RLHF)** for fine-tuning the LSTMABAR model.

## Pipeline Overview:
1. Load the best model from `tuning_checkpoints/mids266final_bestmodel/best_model.pth`
2. Load train/val/test data
3. **Sample from training set** for human feedback (1-5 rating scale)
4. Fine-tune model using RLHF with optional val monitoring
5. Evaluate final model on **held-out test set**
6. Save results and updated model

## 1. Setup and Imports

In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import librosa
import soundfile as sf
from pathlib import Path
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import IPython.display as ipd
from IPython.display import display, HTML

# Import model components
from lstmabar_model import LSTMABAR, LSTMABARTrainer
from archetype_predictor import RLHFTrainer, ArchetypePredictionHead
from musiccaps_loader import MusicCapsLoader

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

  from .autonotebook import tqdm as notebook_tqdm


=== Testing Archetype Prediction ===
Predicted weights shape: torch.Size([8, 5])
Sample prediction: tensor([0.1012, 0.3700, 0.2950, 0.0840, 0.1498], grad_fn=<SelectBackward0>)
Sum of weights: 1.000000 (should be ~1.0)

Named predictions for sample 0:
  sine: 0.1187
  square: 0.0663
  sawtooth: 0.4614
  triangle: 0.1262
  noise: 0.2274

=== Testing Archetype Loss ===
MSE loss: 0.0164

=== Testing RLHF Trainer with Audio Playback ===

HUMAN FEEDBACK COLLECTION

Description: 'bright and cutting guitar tone'

Predicted Archetype Weights:
  sine      : ██ 0.101
  square    : ███████ 0.370
  sawtooth  : █████ 0.295
  triangle  : █ 0.084
  noise     : ██ 0.150

------------------------------------------------------------
AUDIO PLAYBACK
------------------------------------------------------------

▶️  ORIGINAL AUDIO:



▶️  TRANSFORMED AUDIO:



------------------------------------------------------------
RATING INSTRUCTIONS
------------------------------------------------------------
Rate how well the transformation matches the description:
  5 = Perfect match
  4 = Good match
  3 = Acceptable match
  2 = Poor match
  1 = Very poor match

✓ Feedback recorded: 5.0/5


=== Interactive RLHF Usage Example ===

# In Jupyter notebook, use this pattern:

# 1. Generate or load audio samples
original_audio = librosa.load('input.wav')[0]
transformed_audio = model.transform(original_audio, description)

# 2. Get embeddings and predictions
text_emb = text_encoder([description])
audio_emb = audio_encoder(torch.from_numpy(original_audio))
predicted_weights = predictor(text_emb, audio_emb)

# 3. Collect interactive feedback with audio playback
rating = rlhf_trainer.collect_feedback_with_audio(
description="bright and crunchy",
original_audio=original_audio,
transformed_audio=transformed_audio,
predicted_weights=predicted_weights[0].cpu().n

## 2. Configuration

In [2]:
# Paths
MODEL_PATH = 'tuning_checkpoints/mids266final_bestmodel/best_model.pth'
TRAIN_DATA_PATH = 'musiccaps_training_data_train.npz'  # For RLHF feedback
VAL_DATA_PATH = 'musiccaps_training_data_val.npz'      # For monitoring
TEST_DATA_PATH = 'musiccaps_training_data_test.npz'    # For final evaluation
AUDIO_DIR = 'musiccaps_audio'
OUTPUT_DIR = 'rlhf_results'

# Create output directory
Path(OUTPUT_DIR).mkdir(exist_ok=True)

# RLHF Configuration
RLHF_CONFIG = {
    'num_feedback_samples': 30,  # Number of TRAIN samples to collect feedback on
    'reward_threshold': 3.0,     # Rating threshold for positive reward
    'learning_rate': 1e-5,       # Lower learning rate for fine-tuning
    'batch_size': 8,             # Batch size for RLHF updates
    'num_rlhf_epochs': 5,        # Number of fine-tuning epochs
    'update_frequency': 5,       # Update model every N samples
    'use_val_monitoring': True,  # Monitor validation loss during RLHF
    'val_check_frequency': 10,   # Check val loss every N samples
    'early_stopping_patience': 3 # Stop if val loss increases N times
}

# Model Configuration (should match training config)
MODEL_CONFIG = {
    'embedding_dim': 768,
    'text_model': 'sentence-transformers/all-MiniLM-L6-v2',
    'audio_architecture': 'resnet',
    'num_archetypes': 5,
    'sample_rate': 44100,
    'use_quantum_attention': True,  # Set to True if best model used quantum
    'temperature': 0.07
}

print("Configuration loaded successfully!")
print(f"\n📁 Data Paths:")
print(f"  Model: {MODEL_PATH}")
print(f"  Train (RLHF feedback): {TRAIN_DATA_PATH}")
print(f"  Val (monitoring): {VAL_DATA_PATH}")
print(f"  Test (final eval): {TEST_DATA_PATH}")
print(f"  Output: {OUTPUT_DIR}/")

print(f"\n⚙️  RLHF Settings:")
print(f"  Feedback samples: {RLHF_CONFIG['num_feedback_samples']} (from train set)")
print(f"  Val monitoring: {RLHF_CONFIG['use_val_monitoring']}")
print(f"  Learning rate: {RLHF_CONFIG['learning_rate']}")

Configuration loaded successfully!

📁 Data Paths:
  Model: tuning_checkpoints/mids266final_bestmodel/best_model.pth
  Train (RLHF feedback): musiccaps_training_data_train.npz
  Val (monitoring): musiccaps_training_data_val.npz
  Test (final eval): musiccaps_training_data_test.npz
  Output: rlhf_results/

⚙️  RLHF Settings:
  Feedback samples: 30 (from train set)
  Val monitoring: True
  Learning rate: 1e-05


## 3. Load Best Model

In [3]:
print("=" * 60)
print("LOADING BEST MODEL")
print("=" * 60)

# Initialize model
model = LSTMABAR(
    embedding_dim=MODEL_CONFIG['embedding_dim'],
    text_model=MODEL_CONFIG['text_model'],
    audio_architecture=MODEL_CONFIG['audio_architecture'],
    num_archetypes=MODEL_CONFIG['num_archetypes'],
    sample_rate=MODEL_CONFIG['sample_rate'],
    use_quantum_attention=MODEL_CONFIG['use_quantum_attention'],
    temperature=MODEL_CONFIG['temperature'],
    device=device
)

# Load checkpoint
try:
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"✓ Model loaded successfully!")
    print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
    print(f"  Validation Loss: {checkpoint.get('val_loss', 'N/A')}")
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
except Exception as e:
    print(f"✗ Error loading model: {e}")
    raise

model.eval()
print("\nModel ready for inference and RLHF!")

LOADING BEST MODEL
Loading text encoder: sentence-transformers/all-MiniLM-L6-v2
✓ Model loaded successfully!
  Epoch: 7
  Validation Loss: N/A
  Total parameters: 37,804,309
  Trainable parameters: 37,804,309

Model ready for inference and RLHF!


## 4. Load All Data Splits

In [16]:
print("=" * 60)
print("LOADING DATA SPLITS")
print("=" * 60)

# Helper function to load audio from files
def load_audio_from_folder(audio_paths, sample_rate=44100, max_duration=2.0):
    """Load audio files from paths (paths already include directory)"""
    import librosa
    from pathlib import Path
    from tqdm import tqdm  # Import inside function
    
    loaded_audio = []
    failed_files = []
    max_length = int(sample_rate * max_duration)
    
    print(f"Loading {len(audio_paths)} audio files...")
    for idx, audio_path in enumerate(tqdm(audio_paths, desc="Loading audio", ncols=80)):
        # audio_path already includes 'musiccaps_audio/' prefix
        audio_path = str(audio_path)
        
        try:
            # Load audio
            audio, sr = librosa.load(audio_path, sr=sample_rate, duration=max_duration)
            
            # Pad or trim to max_length
            if len(audio) < max_length:
                audio = np.pad(audio, (0, max_length - len(audio)))
            else:
                audio = audio[:max_length]
            
            loaded_audio.append(audio)
        except Exception as e:
            if len(failed_files) < 5:  # Only print first 5 errors
                print(f"\n⚠️  Failed to load {audio_path}: {e}")
            failed_files.append(audio_path)
    
    if failed_files:
        print(f"\n⚠️  Warning: {len(failed_files)}/{len(audio_paths)} files failed to load")
    
    return np.array(loaded_audio), failed_files

# Load training data
print("\n📚 Loading TRAINING data (for RLHF feedback)...")
train_data = np.load(TRAIN_DATA_PATH, allow_pickle=True)

print(f"Available keys: {list(train_data.keys())}")

# Use correct key names from your data
train_descriptions = train_data['descriptions']
train_archetypes = train_data['archetype_vectors']  # Note: 'archetype_vectors' not 'archetype_weights'
train_audio_paths = train_data['audio_paths']
archetype_order = train_data['archetype_order']

print(f"✓ Metadata loaded: {len(train_descriptions)} samples")
print(f"  Archetype order: {list(archetype_order)}")

# Load actual audio
print(f"\n🎵 Loading training audio...")
train_audio, train_failed = load_audio_from_folder(train_audio_paths, MODEL_CONFIG['sample_rate'])
print(f"✓ Train audio: {train_audio.shape}")

# Load validation data
if RLHF_CONFIG['use_val_monitoring']:
    print("\n📊 Loading VALIDATION data...")
    val_data = np.load(VAL_DATA_PATH, allow_pickle=True)
    
    val_descriptions = val_data['descriptions']
    val_archetypes = val_data['archetype_vectors']
    val_audio_paths = val_data['audio_paths']
    
    print(f"🎵 Loading validation audio...")
    val_audio, val_failed = load_audio_from_folder(val_audio_paths, MODEL_CONFIG['sample_rate'])
    print(f"✓ Val audio: {val_audio.shape}")

# Load test data
print("\n🎯 Loading TEST data (for final evaluation)...")
test_data = np.load(TEST_DATA_PATH, allow_pickle=True)

test_descriptions = test_data['descriptions']
test_archetypes = test_data['archetype_vectors']
test_audio_paths = test_data['audio_paths']

print(f"🎵 Loading test audio...")
test_audio, test_failed = load_audio_from_folder(test_audio_paths, MODEL_CONFIG['sample_rate'])
print(f"✓ Test audio: {test_audio.shape}")

# Summary
print("\n" + "=" * 60)
print("DATA SPLIT SUMMARY")
print("=" * 60)
print(f"Train: {len(train_descriptions)} samples (RLHF feedback collection)")
if RLHF_CONFIG['use_val_monitoring']:
    print(f"Val:   {len(val_descriptions)} samples (monitoring)")
print(f"Test:  {len(test_descriptions)} samples (held-out evaluation)")
print("\n✓ No data leakage - test set remains unseen!")

LOADING DATA SPLITS

📚 Loading TRAINING data (for RLHF feedback)...
Available keys: ['archetype_vectors', 'descriptions', 'audio_paths', 'archetype_order']
✓ Metadata loaded: 333 samples
  Archetype order: [np.str_('sine'), np.str_('square'), np.str_('sawtooth'), np.str_('triangle'), np.str_('noise')]

🎵 Loading training audio...
Loading 333 audio files...


Loading audio: 100%|█████████████████████████| 333/333 [00:00<00:00, 430.37it/s]


✓ Train audio: (333, 88200)

📊 Loading VALIDATION data...
🎵 Loading validation audio...
Loading 71 audio files...


Loading audio: 100%|███████████████████████████| 71/71 [00:00<00:00, 466.46it/s]


✓ Val audio: (71, 88200)

🎯 Loading TEST data (for final evaluation)...
🎵 Loading test audio...
Loading 73 audio files...


Loading audio: 100%|███████████████████████████| 73/73 [00:00<00:00, 452.28it/s]

✓ Test audio: (73, 88200)

DATA SPLIT SUMMARY
Train: 333 samples (RLHF feedback collection)
Val:   71 samples (monitoring)
Test:  73 samples (held-out evaluation)

✓ No data leakage - test set remains unseen!





## 5. Initialize RLHF Trainer

In [17]:
print("=" * 60)
print("INITIALIZING RLHF TRAINER")
print("=" * 60)

# Initialize RLHF trainer for the archetype predictor
rlhf_trainer = RLHFTrainer(
    predictor=model.archetype_predictor,
    learning_rate=RLHF_CONFIG['learning_rate'],
    reward_threshold=RLHF_CONFIG['reward_threshold']
)

print(f"✓ RLHF Trainer initialized")
print(f"  Learning rate: {RLHF_CONFIG['learning_rate']}")
print(f"  Reward threshold: {RLHF_CONFIG['reward_threshold']}")
print(f"  Batch size: {RLHF_CONFIG['batch_size']}")

# Storage for feedback collection
feedback_history = {
    'samples': [],
    'ratings': [],
    'timestamps': [],
    'descriptions': [],
    'predicted_weights': [],
    'data_split': []  # Track which split each sample came from
}

# Storage for validation monitoring
val_history = {
    'iteration': [],
    'val_mse': [],
    'val_cosine_sim': []
}

INITIALIZING RLHF TRAINER
✓ RLHF Trainer initialized
  Learning rate: 1e-05
  Reward threshold: 3.0
  Batch size: 8


## 6. Helper Functions

In [18]:
def prepare_audio_sample(audio_array, sample_rate=44100):
    """
    Prepare audio array for model input
    """
    audio_tensor = torch.from_numpy(audio_array).float().to(device)
    if audio_tensor.dim() == 1:
        audio_tensor = audio_tensor.unsqueeze(0)
    return audio_tensor


def generate_transformation(model, description, audio):
    """
    Generate audio transformation using the model
    """
    model.eval()
    
    with torch.no_grad():
        # Prepare inputs
        audio_tensor = prepare_audio_sample(audio)
        descriptions = [description]
        
        # Get embeddings
        text_emb = model.encode_text(descriptions)
        audio_emb, _ = model.encode_audio(audio_tensor)
        
        # Predict archetype weights
        predicted_weights = model.predict_archetypes(text_emb, audio_emb)
        
        # Transform audio
        transformed_audio = model.transform_audio(audio_tensor, predicted_weights)
    
    return {
        'transformed_audio': transformed_audio.cpu().numpy().squeeze(),
        'predicted_weights': predicted_weights.cpu().numpy().squeeze(),
        'text_embedding': text_emb,
        'audio_embedding': audio_emb
    }


def evaluate_on_val_set(model, val_audio, val_descriptions, val_archetypes, num_samples=50):
    """
    Evaluate model on validation set (for monitoring during RLHF)
    """
    model.eval()
    
    # Sample from val set
    indices = np.random.choice(len(val_descriptions), min(num_samples, len(val_descriptions)), replace=False)
    
    mse_list = []
    cosine_sim_list = []
    
    with torch.no_grad():
        for idx in indices:
            result = generate_transformation(model, val_descriptions[idx], val_audio[idx])
            predicted = result['predicted_weights']
            target = val_archetypes[idx]
            
            # MSE
            mse = np.mean((predicted - target) ** 2)
            mse_list.append(mse)
            
            # Cosine similarity
            cos_sim = np.dot(predicted, target) / (
                np.linalg.norm(predicted) * np.linalg.norm(target) + 1e-10
            )
            cosine_sim_list.append(cos_sim)
    
    return {
        'mse': np.mean(mse_list),
        'cosine_similarity': np.mean(cosine_sim_list)
    }


def display_archetype_weights(weights, title="Archetype Weights"):
    """
    Display archetype weights as a bar chart
    """
    archetype_names = ['Sine', 'Square', 'Sawtooth', 'Triangle', 'Noise']
    
    fig, ax = plt.subplots(figsize=(10, 4))
    bars = ax.bar(archetype_names, weights, color='steelblue', alpha=0.7)
    
    # Add value labels on bars
    for bar, weight in zip(bars, weights):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{weight:.3f}',
                ha='center', va='bottom', fontsize=10)
    
    ax.set_ylabel('Weight', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_ylim([0, max(weights) * 1.2])
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("Helper functions loaded!")

Helper functions loaded!


## 7. Pre-RLHF Baseline Evaluation on Test Set

**Important**: Let's establish baseline performance on test set BEFORE RLHF tuning.
This gives us a fair comparison point.

In [19]:
print("=" * 80)
print("PRE-RLHF BASELINE EVALUATION ON TEST SET")
print("=" * 80)

baseline_results = {
    'mse': [],
    'cosine_similarity': [],
    'archetype_accuracy': []
}

model.eval()
print(f"\nEvaluating {len(test_descriptions)} test samples...\n")

with torch.no_grad():
    for i in tqdm(range(len(test_descriptions)), desc="Baseline Test"):
        result = generate_transformation(model, test_descriptions[i], test_audio[i])
        predicted = result['predicted_weights']
        target = test_archetypes[i]
        
        # MSE
        mse = np.mean((predicted - target) ** 2)
        baseline_results['mse'].append(float(mse))
        
        # Cosine similarity
        cos_sim = np.dot(predicted, target) / (
            np.linalg.norm(predicted) * np.linalg.norm(target) + 1e-10
        )
        baseline_results['cosine_similarity'].append(float(cos_sim))
        
        # Archetype accuracy
        pred_archetype = np.argmax(predicted)
        target_archetype = np.argmax(target)
        acc = 1.0 if pred_archetype == target_archetype else 0.0
        baseline_results['archetype_accuracy'].append(float(acc))

print("\n" + "=" * 80)
print("BASELINE TEST RESULTS (PRE-RLHF)")
print("=" * 80)
print(f"Mean Squared Error:      {np.mean(baseline_results['mse']):.6f}")
print(f"Cosine Similarity:       {np.mean(baseline_results['cosine_similarity']):.6f}")
print(f"Archetype Accuracy:      {np.mean(baseline_results['archetype_accuracy'])*100:.2f}%")
print("\n💾 Saved for comparison after RLHF")

PRE-RLHF BASELINE EVALUATION ON TEST SET

Evaluating 73 test samples...



Baseline Test: 100%|██████████| 73/73 [00:12<00:00,  5.91it/s]


BASELINE TEST RESULTS (PRE-RLHF)
Mean Squared Error:      0.099907
Cosine Similarity:       0.608475
Archetype Accuracy:      43.84%

💾 Saved for comparison after RLHF





## 8. Interactive Feedback Collection (FROM TRAINING SET)

### Important Notes:
- ✅ Samples are randomly selected **from the training set**
- ✅ Test set remains completely held-out
- ✅ Optional validation monitoring to detect overfitting
- ✅ This follows proper ML methodology

### Instructions:
For each sample, you will:
1. See the text description
2. Listen to the **original audio**
3. Listen to the **transformed audio**
4. See the predicted archetype weights
5. **Rate the transformation** from 1-5

In [25]:
# Select random samples FROM TRAINING SET for feedback
num_samples = min(RLHF_CONFIG['num_feedback_samples'], len(train_descriptions))
feedback_indices = np.random.choice(len(train_descriptions), num_samples, replace=False)

print(f"Selected {num_samples} samples from TRAINING SET for feedback")
print(f"Model will update every {RLHF_CONFIG['update_frequency']} samples")
if RLHF_CONFIG['use_val_monitoring']:
    print(f"Validation monitoring every {RLHF_CONFIG['val_check_frequency']} samples\n")

Selected 30 samples from TRAINING SET for feedback
Model will update every 5 samples
Validation monitoring every 10 samples



In [26]:
# MAIN FEEDBACK COLLECTION LOOP
print("=" * 80)
print("STARTING INTERACTIVE FEEDBACK COLLECTION (FROM TRAIN SET)")
print("=" * 80)
print("\nPlease rate each transformation on a scale of 1-5")
print("Press Enter after listening to both audio samples\n")

sample_rate = MODEL_CONFIG['sample_rate']
update_losses = []
early_stopping_counter = 0
best_val_mse = float('inf')

for idx, sample_idx in enumerate(feedback_indices):
    print("\n" + "=" * 80)
    print(f"SAMPLE {idx + 1}/{num_samples} (Train Index: {sample_idx})")
    print("=" * 80)
    
    # Get sample data FROM TRAINING SET
    description = train_descriptions[sample_idx]
    original_audio = train_audio[sample_idx]
    
    print(f"\n📝 Description:\n{description}\n")
    
    # Generate transformation
    print("⚙️  Generating transformation...")
    result = generate_transformation(model, description, original_audio)
    
    transformed_audio = result['transformed_audio']
    predicted_weights = result['predicted_weights']
    text_emb = result['text_embedding']
    audio_emb = result['audio_embedding']
    
    # Display predicted archetype weights
    print("\n🎯 Predicted Archetype Weights:")
    archetype_names = ['sine', 'square', 'sawtooth', 'triangle', 'noise']
    for name, weight in zip(archetype_names, predicted_weights):
        bar = "█" * int(weight * 30)
        print(f"  {name:10s}: {bar} {weight:.3f}")
    
    # # Visualize weights
    # display_archetype_weights(predicted_weights, 
    #                         title=f"Sample {idx+1}: Predicted Archetype Weights")
    
    # Audio playback
    print("\n" + "-" * 80)
    print("🔊 AUDIO PLAYBACK")
    print("-" * 80)
    
    print("\n▶️  ORIGINAL AUDIO:")
    display(ipd.Audio(original_audio, rate=sample_rate, autoplay=False))
    
    print("\n▶️  TRANSFORMED AUDIO (Model Output):")
    display(ipd.Audio(transformed_audio, rate=sample_rate, autoplay=False))
    
    # Get user rating
    print("\n" + "-" * 80)
    print("📊 RATING")
    print("-" * 80)
    print("How well does the transformation match the description?")
    print("  5 = Perfect match")
    print("  4 = Good match")
    print("  3 = Acceptable match")
    print("  2 = Poor match")
    print("  1 = Very poor match")
    
    while True:
        try:
            rating_input = input("\nYour rating (1-5): ")
            rating = float(rating_input)
            if 1 <= rating <= 5:
                break
            else:
                print("⚠️  Please enter a number between 1 and 5.")
        except ValueError:
            print("⚠️  Invalid input. Please enter a number between 1 and 5.")
    
    # Add feedback to RLHF trainer
    predicted_weights_tensor = torch.from_numpy(predicted_weights).float()
    rlhf_trainer.add_feedback(
        text_emb[0],
        audio_emb[0],
        predicted_weights_tensor,
        rating
    )
    
    # Store in history
    feedback_history['samples'].append(sample_idx)
    feedback_history['ratings'].append(rating)
    feedback_history['timestamps'].append(datetime.now().isoformat())
    feedback_history['descriptions'].append(description)
    feedback_history['predicted_weights'].append(predicted_weights.tolist())
    feedback_history['data_split'].append('train')
    
    print(f"\n✅ Feedback recorded: {rating}/5")
    
    # Update model periodically
    if (idx + 1) % RLHF_CONFIG['update_frequency'] == 0:
        print("\n" + "*" * 80)
        print("🔄 UPDATING MODEL WITH COLLECTED FEEDBACK")
        print("*" * 80)
        
        loss = rlhf_trainer.update_from_feedback(
            batch_size=RLHF_CONFIG['batch_size']
        )
        
        if loss is not None:
            update_losses.append(loss)
            print(f"✓ Model updated | RLHF Loss: {loss:.4f}")
            
            # Show feedback statistics
            stats = rlhf_trainer.get_feedback_statistics()
            print(f"  Samples collected: {stats['num_samples']}")
            print(f"  Mean rating: {stats['mean_rating']:.2f}")
            print(f"  Positive feedback: {stats['positive_ratio']*100:.1f}%")
        print("*" * 80 + "\n")
    
    # Validation monitoring
    if RLHF_CONFIG['use_val_monitoring'] and (idx + 1) % RLHF_CONFIG['val_check_frequency'] == 0:
        print("\n" + "~" * 80)
        print("📊 VALIDATION CHECK")
        print("~" * 80)
        
        val_metrics = evaluate_on_val_set(model, val_audio, val_descriptions, val_archetypes)
        
        val_history['iteration'].append(idx + 1)
        val_history['val_mse'].append(val_metrics['mse'])
        val_history['val_cosine_sim'].append(val_metrics['cosine_similarity'])
        
        print(f"Val MSE: {val_metrics['mse']:.6f}")
        print(f"Val Cosine Sim: {val_metrics['cosine_similarity']:.6f}")
        
        # Early stopping check
        if val_metrics['mse'] < best_val_mse:
            best_val_mse = val_metrics['mse']
            early_stopping_counter = 0
            print("✓ Val MSE improved!")
        else:
            early_stopping_counter += 1
            print(f"⚠️  Val MSE did not improve ({early_stopping_counter}/{RLHF_CONFIG['early_stopping_patience']})")
            
            if early_stopping_counter >= RLHF_CONFIG['early_stopping_patience']:
                print("\n🛑 Early stopping triggered - val performance degrading")
                print("Consider stopping feedback collection or adjusting learning rate.")
        
        print("~" * 80 + "\n")

print("\n" + "=" * 80)
print("✅ FEEDBACK COLLECTION COMPLETE")
print("=" * 80)

# Final model update
if len(rlhf_trainer.experience_buffer) > 0:
    print("\n🔄 Final model update...")
    final_loss = rlhf_trainer.update_from_feedback(
        batch_size=RLHF_CONFIG['batch_size']
    )
    if final_loss is not None:
        update_losses.append(final_loss)
        print(f"✓ Final update complete | Loss: {final_loss:.4f}")

# Show final statistics
final_stats = rlhf_trainer.get_feedback_statistics()
print("\n📊 Final Feedback Statistics:")
print(f"  Total samples: {final_stats['num_samples']} (from train set)")
print(f"  Mean rating: {final_stats['mean_rating']:.2f} ± {final_stats['std_rating']:.2f}")
print(f"  Positive feedback: {final_stats['positive_ratio']*100:.1f}%")
print(f"  Rating distribution: {final_stats['rating_distribution']}")

STARTING INTERACTIVE FEEDBACK COLLECTION (FROM TRAIN SET)

Please rate each transformation on a scale of 1-5
Press Enter after listening to both audio samples


SAMPLE 1/30 (Train Index: 143)

📝 Description:
Someone is playing a very fat drum with a slightly open hi hat along with other people playing percussion with a lot of snare hits and a bass sound playing on every beat along with the kick. A keyboard is playing string chords. A male voice is singing/shouting with delay and reverb. This song may be playing at a celebration event.

⚙️  Generating transformation...

🎯 Predicted Archetype Weights:
  sine      : ████████ 0.294
  square    : ████████ 0.290
  sawtooth  : ████ 0.162
  triangle  : ███ 0.119
  noise     : ████ 0.134

--------------------------------------------------------------------------------
🔊 AUDIO PLAYBACK
--------------------------------------------------------------------------------

▶️  ORIGINAL AUDIO:



▶️  TRANSFORMED AUDIO (Model Output):



--------------------------------------------------------------------------------
📊 RATING
--------------------------------------------------------------------------------
How well does the transformation match the description?
  5 = Perfect match
  4 = Good match
  3 = Acceptable match
  2 = Poor match
  1 = Very poor match

✅ Feedback recorded: 3.0/5

SAMPLE 2/30 (Train Index: 287)

📝 Description:
This music is a lively fiddle instrumental. The tempo is fast with a percussion accompaniment. The music is spirited, vibrant, enthusiastic,cheerful, happy , merry and sunny. This music is Country Folk.

⚙️  Generating transformation...

🎯 Predicted Archetype Weights:
  sine      : ███ 0.113
  square    : ███ 0.103
  sawtooth  : █████████ 0.330
  triangle  : ████████ 0.274
  noise     : █████ 0.180

--------------------------------------------------------------------------------
🔊 AUDIO PLAYBACK
--------------------------------------------------------------------------------

▶️  ORIGINAL A


▶️  TRANSFORMED AUDIO (Model Output):



--------------------------------------------------------------------------------
📊 RATING
--------------------------------------------------------------------------------
How well does the transformation match the description?
  5 = Perfect match
  4 = Good match
  3 = Acceptable match
  2 = Poor match
  1 = Very poor match
⚠️  Invalid input. Please enter a number between 1 and 5.
⚠️  Invalid input. Please enter a number between 1 and 5.


KeyboardInterrupt: Interrupted by user

## 9. Visualize Feedback and Training Progress

In [None]:
# Create comprehensive visualization
if RLHF_CONFIG['use_val_monitoring'] and len(val_history['iteration']) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
else:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    axes = axes.reshape(1, 2)

# Rating distribution
ratings_array = np.array(feedback_history['ratings'])
axes[0, 0].hist(ratings_array, bins=np.arange(0.5, 6.5, 1), 
             color='steelblue', alpha=0.7, edgecolor='black')
axes[0, 0].set_xlabel('Rating', fontsize=12)
axes[0, 0].set_ylabel('Count', fontsize=12)
axes[0, 0].set_title('Distribution of User Ratings (Train Set)', fontsize=14, fontweight='bold')
axes[0, 0].set_xticks([1, 2, 3, 4, 5])
axes[0, 0].grid(axis='y', alpha=0.3)

# RLHF update losses
if len(update_losses) > 0:
    axes[0, 1].plot(range(1, len(update_losses) + 1), update_losses, 
                marker='o', linewidth=2, markersize=8, color='coral')
    axes[0, 1].set_xlabel('Update Step', fontsize=12)
    axes[0, 1].set_ylabel('RLHF Loss', fontsize=12)
    axes[0, 1].set_title('RLHF Training Loss', fontsize=14, fontweight='bold')
    axes[0, 1].grid(alpha=0.3)
else:
    axes[0, 1].text(0.5, 0.5, 'No updates performed', 
                ha='center', va='center', fontsize=14)
    axes[0, 1].set_title('RLHF Training Loss', fontsize=14, fontweight='bold')

# Validation monitoring plots (if enabled)
if RLHF_CONFIG['use_val_monitoring'] and len(val_history['iteration']) > 0:
    # Val MSE over time
    axes[1, 0].plot(val_history['iteration'], val_history['val_mse'], 
                   marker='s', linewidth=2, markersize=8, color='green')
    axes[1, 0].set_xlabel('Samples Collected', fontsize=12)
    axes[1, 0].set_ylabel('Validation MSE', fontsize=12)
    axes[1, 0].set_title('Validation MSE During RLHF', fontsize=14, fontweight='bold')
    axes[1, 0].grid(alpha=0.3)
    
    # Val Cosine Similarity over time
    axes[1, 1].plot(val_history['iteration'], val_history['val_cosine_sim'], 
                   marker='D', linewidth=2, markersize=8, color='purple')
    axes[1, 1].set_xlabel('Samples Collected', fontsize=12)
    axes[1, 1].set_ylabel('Validation Cosine Similarity', fontsize=12)
    axes[1, 1].set_title('Validation Cosine Sim During RLHF', fontsize=14, fontweight='bold')
    axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/rlhf_training_analysis.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Training analysis saved to {OUTPUT_DIR}/rlhf_training_analysis.png")

## 10. POST-RLHF Evaluation on Test Set

**Critical**: This is the fair comparison. Test set was never seen during RLHF!

In [None]:
print("=" * 80)
print("POST-RLHF EVALUATION ON TEST SET (HELD-OUT DATA)")
print("=" * 80)

model.eval()

# Metrics storage
postrlhf_results = {
    'description': [],
    'sample_idx': [],
    'predicted_weights': [],
    'target_weights': [],
    'mse': [],
    'cosine_similarity': [],
    'archetype_accuracy': []
}

print(f"\nEvaluating {len(test_descriptions)} test samples...\n")

with torch.no_grad():
    for i in tqdm(range(len(test_descriptions)), desc="Post-RLHF Test"):
        description = test_descriptions[i]
        audio = test_audio[i]
        target_weights = test_archetypes[i]
        
        # Generate transformation
        result = generate_transformation(model, description, audio)
        predicted_weights = result['predicted_weights']
        
        # Compute metrics
        mse = np.mean((predicted_weights - target_weights) ** 2)
        
        # Cosine similarity
        cos_sim = np.dot(predicted_weights, target_weights) / (
            np.linalg.norm(predicted_weights) * np.linalg.norm(target_weights) + 1e-10
        )
        
        # Archetype accuracy
        pred_archetype = np.argmax(predicted_weights)
        target_archetype = np.argmax(target_weights)
        archetype_acc = 1.0 if pred_archetype == target_archetype else 0.0
        
        # Store results
        postrlhf_results['description'].append(description)
        postrlhf_results['sample_idx'].append(i)
        postrlhf_results['predicted_weights'].append(predicted_weights.tolist())
        postrlhf_results['target_weights'].append(target_weights.tolist())
        postrlhf_results['mse'].append(float(mse))
        postrlhf_results['cosine_similarity'].append(float(cos_sim))
        postrlhf_results['archetype_accuracy'].append(float(archetype_acc))

# Compute aggregate metrics
print("\n" + "=" * 80)
print("POST-RLHF TEST RESULTS (HELD-OUT DATA)")
print("=" * 80)
print(f"Mean Squared Error:      {np.mean(postrlhf_results['mse']):.6f}")
print(f"Cosine Similarity:       {np.mean(postrlhf_results['cosine_similarity']):.6f}")
print(f"Archetype Accuracy:      {np.mean(postrlhf_results['archetype_accuracy'])*100:.2f}%")

# Save results
results_df = pd.DataFrame(postrlhf_results)
results_df.to_csv(f"{OUTPUT_DIR}/postrlhf_test_results.csv", index=False)
print(f"\n✓ Test results saved to {OUTPUT_DIR}/postrlhf_test_results.csv")

## 11. Compare Pre-RLHF vs Post-RLHF Performance

In [None]:
print("=" * 80)
print("COMPARISON: PRE-RLHF vs POST-RLHF (ON HELD-OUT TEST SET)")
print("=" * 80)

comparison = {
    'Metric': ['MSE', 'Cosine Similarity', 'Archetype Accuracy (%)'],
    'Pre-RLHF': [
        np.mean(baseline_results['mse']),
        np.mean(baseline_results['cosine_similarity']),
        np.mean(baseline_results['archetype_accuracy']) * 100
    ],
    'Post-RLHF': [
        np.mean(postrlhf_results['mse']),
        np.mean(postrlhf_results['cosine_similarity']),
        np.mean(postrlhf_results['archetype_accuracy']) * 100
    ]
}

comparison['Improvement'] = [
    comparison['Pre-RLHF'][0] - comparison['Post-RLHF'][0],  # MSE (lower is better)
    comparison['Post-RLHF'][1] - comparison['Pre-RLHF'][1],  # Cosine (higher is better)
    comparison['Post-RLHF'][2] - comparison['Pre-RLHF'][2]   # Accuracy (higher is better)
]

comparison_df = pd.DataFrame(comparison)
print("\n")
print(comparison_df.to_string(index=False))
print("\n" + "=" * 80)

# Interpretation
print("\n📈 Interpretation:")
if comparison['Improvement'][0] > 0:
    print(f"  ✓ MSE improved by {comparison['Improvement'][0]:.6f} (lower is better)")
else:
    print(f"  ⚠️  MSE worsened by {abs(comparison['Improvement'][0]):.6f}")

if comparison['Improvement'][1] > 0:
    print(f"  ✓ Cosine similarity improved by {comparison['Improvement'][1]:.6f}")
else:
    print(f"  ⚠️  Cosine similarity worsened by {abs(comparison['Improvement'][1]):.6f}")

if comparison['Improvement'][2] > 0:
    print(f"  ✓ Accuracy improved by {comparison['Improvement'][2]:.2f}%")
else:
    print(f"  ⚠️  Accuracy worsened by {abs(comparison['Improvement'][2]):.2f}%")

# Save comparison
comparison_df.to_csv(f"{OUTPUT_DIR}/pre_post_rlhf_comparison.csv", index=False)
print(f"\n✓ Comparison saved to {OUTPUT_DIR}/pre_post_rlhf_comparison.csv")

## 12. Visualize Test Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# MSE comparison
mse_data = [baseline_results['mse'], postrlhf_results['mse']]
bp1 = axes[0, 0].boxplot(mse_data, labels=['Pre-RLHF', 'Post-RLHF'], patch_artist=True)
for patch, color in zip(bp1['boxes'], ['lightblue', 'lightcoral']):
    patch.set_facecolor(color)
axes[0, 0].set_ylabel('MSE', fontsize=11)
axes[0, 0].set_title('MSE: Pre vs Post RLHF', fontsize=12, fontweight='bold')
axes[0, 0].grid(alpha=0.3)

# Cosine similarity comparison
cos_data = [baseline_results['cosine_similarity'], postrlhf_results['cosine_similarity']]
bp2 = axes[0, 1].boxplot(cos_data, labels=['Pre-RLHF', 'Post-RLHF'], patch_artist=True)
for patch, color in zip(bp2['boxes'], ['lightblue', 'lightcoral']):
    patch.set_facecolor(color)
axes[0, 1].set_ylabel('Cosine Similarity', fontsize=11)
axes[0, 1].set_title('Cosine Similarity: Pre vs Post RLHF', fontsize=12, fontweight='bold')
axes[0, 1].grid(alpha=0.3)

# Archetype accuracy comparison
acc_pre = np.mean(baseline_results['archetype_accuracy']) * 100
acc_post = np.mean(postrlhf_results['archetype_accuracy']) * 100
axes[1, 0].bar(['Pre-RLHF', 'Post-RLHF'], [acc_pre, acc_post], 
              color=['steelblue', 'coral'], alpha=0.7)
axes[1, 0].set_ylabel('Accuracy (%)', fontsize=11)
axes[1, 0].set_title('Archetype Accuracy: Pre vs Post RLHF', fontsize=12, fontweight='bold')
axes[1, 0].set_ylim([0, 100])
axes[1, 0].grid(axis='y', alpha=0.3)
# Add value labels
for i, v in enumerate([acc_pre, acc_post]):
    axes[1, 0].text(i, v + 2, f'{v:.1f}%', ha='center', fontweight='bold')

# Confusion matrix (post-RLHF)
predicted_archetypes = [np.argmax(w) for w in postrlhf_results['predicted_weights']]
target_archetypes = [np.argmax(w) for w in postrlhf_results['target_weights']]

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(target_archetypes, predicted_archetypes, labels=[0, 1, 2, 3, 4])

im = axes[1, 1].imshow(cm, cmap='Blues', aspect='auto')
axes[1, 1].set_xticks([0, 1, 2, 3, 4])
axes[1, 1].set_yticks([0, 1, 2, 3, 4])
axes[1, 1].set_xticklabels(['Sin', 'Sqr', 'Saw', 'Tri', 'Noi'], fontsize=9)
axes[1, 1].set_yticklabels(['Sin', 'Sqr', 'Saw', 'Tri', 'Noi'], fontsize=9)
axes[1, 1].set_xlabel('Predicted Archetype', fontsize=11)
axes[1, 1].set_ylabel('Target Archetype', fontsize=11)
axes[1, 1].set_title('Post-RLHF Confusion Matrix', fontsize=12, fontweight='bold')

# Add text annotations
for i in range(5):
    for j in range(5):
        text = axes[1, 1].text(j, i, str(cm[i, j]),
                              ha="center", va="center", color="black", fontsize=10)

plt.colorbar(im, ax=axes[1, 1])
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/test_results_comparison.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Test comparison saved to {OUTPUT_DIR}/test_results_comparison.png")

## 13. Save RLHF-Tuned Model and Results

In [None]:
print("=" * 80)
print("SAVING RLHF-TUNED MODEL AND RESULTS")
print("=" * 80)

# Save model checkpoint
rlhf_model_path = f"{OUTPUT_DIR}/rlhf_tuned_model.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'rlhf_config': RLHF_CONFIG,
    'feedback_stats': final_stats,
    'baseline_metrics': {
        'mse': float(np.mean(baseline_results['mse'])),
        'cosine_similarity': float(np.mean(baseline_results['cosine_similarity'])),
        'archetype_accuracy': float(np.mean(baseline_results['archetype_accuracy'])),
    },
    'postrlhf_metrics': {
        'mse': float(np.mean(postrlhf_results['mse'])),
        'cosine_similarity': float(np.mean(postrlhf_results['cosine_similarity'])),
        'archetype_accuracy': float(np.mean(postrlhf_results['archetype_accuracy'])),
    },
    'improvement': {
        'mse': comparison['Improvement'][0],
        'cosine_similarity': comparison['Improvement'][1],
        'archetype_accuracy': comparison['Improvement'][2]
    },
    'timestamp': datetime.now().isoformat()
}, rlhf_model_path)

print(f"✓ Model saved to {rlhf_model_path}")

# Save feedback history
feedback_path = f"{OUTPUT_DIR}/feedback_history.json"
with open(feedback_path, 'w') as f:
    json.dump(feedback_history, f, indent=2)

print(f"✓ Feedback history saved to {feedback_path}")

# Save validation history
if RLHF_CONFIG['use_val_monitoring'] and len(val_history['iteration']) > 0:
    val_path = f"{OUTPUT_DIR}/validation_history.json"
    with open(val_path, 'w') as f:
        json.dump(val_history, f, indent=2)
    print(f"✓ Validation history saved to {val_path}")

# Save comprehensive report
report = {
    'experiment': 'RLHF Fine-tuning v2 (Proper Train/Val/Test Split)',
    'timestamp': datetime.now().isoformat(),
    'data_methodology': {
        'feedback_source': 'Training set',
        'monitoring_source': 'Validation set' if RLHF_CONFIG['use_val_monitoring'] else None,
        'evaluation_source': 'Test set (held-out)',
        'note': 'Test set never seen during RLHF - proper ML methodology'
    },
    'config': RLHF_CONFIG,
    'model_config': MODEL_CONFIG,
    'feedback_statistics': final_stats,
    'baseline_test_metrics': {
        'mean_mse': float(np.mean(baseline_results['mse'])),
        'std_mse': float(np.std(baseline_results['mse'])),
        'mean_cosine_similarity': float(np.mean(baseline_results['cosine_similarity'])),
        'std_cosine_similarity': float(np.std(baseline_results['cosine_similarity'])),
        'archetype_accuracy': float(np.mean(baseline_results['archetype_accuracy'])),
    },
    'postrlhf_test_metrics': {
        'mean_mse': float(np.mean(postrlhf_results['mse'])),
        'std_mse': float(np.std(postrlhf_results['mse'])),
        'mean_cosine_similarity': float(np.mean(postrlhf_results['cosine_similarity'])),
        'std_cosine_similarity': float(np.std(postrlhf_results['cosine_similarity'])),
        'archetype_accuracy': float(np.mean(postrlhf_results['archetype_accuracy'])),
    },
    'improvement': {
        'mse': float(comparison['Improvement'][0]),
        'cosine_similarity': float(comparison['Improvement'][1]),
        'archetype_accuracy_pct': float(comparison['Improvement'][2])
    },
    'rlhf_training': {
        'num_updates': len(update_losses),
        'final_loss': float(update_losses[-1]) if len(update_losses) > 0 else None,
        'mean_loss': float(np.mean(update_losses)) if len(update_losses) > 0 else None
    },
    'validation_monitoring': val_history if RLHF_CONFIG['use_val_monitoring'] else None
}

report_path = f"{OUTPUT_DIR}/rlhf_report_v2.json"
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print(f"✓ Comprehensive report saved to {report_path}")

print("\n✅ All files saved successfully!")

## 14. Summary

In [None]:
print("\n" + "=" * 80)
print("RLHF PIPELINE COMPLETE (v2 - Proper Methodology)")
print("=" * 80)

print("\n📊 Experiment Summary:")
print(f"  • RLHF feedback collected from: TRAINING SET ({len(feedback_history['ratings'])} samples)")
if RLHF_CONFIG['use_val_monitoring']:
    print(f"  • Validation monitoring: ENABLED")
print(f"  • Final evaluation on: TEST SET (held-out)")
print(f"  • Mean user rating: {np.mean(feedback_history['ratings']):.2f}/5.0")
print(f"  • Model updates: {len(update_losses)}")

print("\n🎯 Performance on Test Set (Held-Out):")
print("\nPRE-RLHF (Baseline):")
print(f"  • MSE: {np.mean(baseline_results['mse']):.6f}")
print(f"  • Cosine Similarity: {np.mean(baseline_results['cosine_similarity']):.6f}")
print(f"  • Archetype Accuracy: {np.mean(baseline_results['archetype_accuracy'])*100:.2f}%")

print("\nPOST-RLHF:")
print(f"  • MSE: {np.mean(postrlhf_results['mse']):.6f}")
print(f"  • Cosine Similarity: {np.mean(postrlhf_results['cosine_similarity']):.6f}")
print(f"  • Archetype Accuracy: {np.mean(postrlhf_results['archetype_accuracy'])*100:.2f}%")

print("\n📈 Improvement:")
print(f"  • MSE: {comparison['Improvement'][0]:.6f} {'✓' if comparison['Improvement'][0] > 0 else '⚠️'}")
print(f"  • Cosine Sim: {comparison['Improvement'][1]:.6f} {'✓' if comparison['Improvement'][1] > 0 else '⚠️'}")
print(f"  • Accuracy: {comparison['Improvement'][2]:.2f}% {'✓' if comparison['Improvement'][2] > 0 else '⚠️'}")

print("\n💾 Saved Files:")
print(f"  • RLHF-tuned model: {OUTPUT_DIR}/rlhf_tuned_model.pth")
print(f"  • Test results: {OUTPUT_DIR}/postrlhf_test_results.csv")
print(f"  • Comparison: {OUTPUT_DIR}/pre_post_rlhf_comparison.csv")
print(f"  • Feedback history: {OUTPUT_DIR}/feedback_history.json")
if RLHF_CONFIG['use_val_monitoring']:
    print(f"  • Validation history: {OUTPUT_DIR}/validation_history.json")
print(f"  • Comprehensive report: {OUTPUT_DIR}/rlhf_report_v2.json")
print(f"  • Visualizations: {OUTPUT_DIR}/*.png")

print("\n✅ Key Methodology Points:")
print("  ✓ RLHF feedback collected from TRAINING SET")
if RLHF_CONFIG['use_val_monitoring']:
    print("  ✓ Validation monitoring for overfitting detection")
print("  ✓ Final evaluation on HELD-OUT TEST SET")
print("  ✓ No data leakage - proper ML methodology")
print("  ✓ Results are trustworthy and publishable")

print("\n" + "=" * 80)
print("Thank you for your feedback! 🙏")
print("=" * 80)