In [1]:
"""
Archetype Prediction Head - Mixture Weight Estimator
Predicts archetype mixture vectors from joint text-audio embeddings
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Optional, Tuple, List

In [2]:
class ArchetypePredictionHead(nn.Module):
    """
    MLP-based regression head for predicting archetype mixture weights

    Takes concatenated [text_embedding; audio_embedding] and outputs
    a 5-dimensional probability distribution over archetypes
    """

    def __init__(
        self,
        embedding_dim=768,
        num_archetypes=5,
        hidden_dims=[512, 256],
        dropout=0.3,
        device='cpu'
    ):
        super(ArchetypePredictionHead, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_archetypes = num_archetypes
        self.device = device

        # Build MLP layers
        layers = []
        input_dim = embedding_dim * 2  # Concatenated text + audio

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.BatchNorm1d(hidden_dim)
            ])
            input_dim = hidden_dim

        # Final output layer
        layers.append(nn.Linear(input_dim, num_archetypes))

        self.mlp = nn.Sequential(*layers)

        # Archetype names for reference
        self.archetype_names = ['sine', 'square', 'sawtooth', 'triangle', 'noise']

        self.to(device)

    def forward(
        self,
        text_embeddings: torch.Tensor,
        audio_embeddings: torch.Tensor
    ) -> torch.Tensor:
        """
        Predict archetype mixture weights

        Args:
            text_embeddings: (batch_size, embedding_dim)
            audio_embeddings: (batch_size, embedding_dim)

        Returns:
            Archetype weights (batch_size, num_archetypes) - normalized with softmax
        """
        # Concatenate embeddings
        joint_embedding = torch.cat([text_embeddings, audio_embeddings], dim=1)

        # Pass through MLP
        logits = self.mlp(joint_embedding)

        # Apply softmax to get probability distribution
        weights = F.softmax(logits, dim=1)

        return weights

    def predict_with_names(
        self,
        text_embeddings: torch.Tensor,
        audio_embeddings: torch.Tensor
    ) -> Dict[str, np.ndarray]:
        """
        Predict and return archetype weights as named dictionary

        Returns:
            Dict mapping archetype names to weight arrays
        """
        weights = self.forward(text_embeddings, audio_embeddings)
        weights_np = weights.detach().cpu().numpy()

        return {
            name: weights_np[:, i]
            for i, name in enumerate(self.archetype_names)
        }

In [3]:
class ArchetypeLoss(nn.Module):
    """
    Loss functions for training archetype predictor
    """

    def __init__(self, loss_type='mse'):
        super(ArchetypeLoss, self).__init__()
        self.loss_type = loss_type

    def forward(
        self,
        predicted_weights: torch.Tensor,
        target_weights: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute loss between predicted and target archetype weights

        Args:
            predicted_weights: (batch_size, num_archetypes)
            target_weights: (batch_size, num_archetypes)

        Returns:
            Loss scalar
        """
        if self.loss_type == 'mse':
            # Mean squared error
            loss = F.mse_loss(predicted_weights, target_weights)

        elif self.loss_type == 'kl':
            # KL divergence (treating as probability distributions)
            loss = F.kl_div(
                torch.log(predicted_weights + 1e-10),
                target_weights,
                reduction='batchmean'
            )

        elif self.loss_type == 'cosine':
            # Cosine distance
            loss = 1 - F.cosine_similarity(predicted_weights, target_weights).mean()

        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")

        return loss

In [4]:
class RLHFTrainer:
    """
    Reinforcement Learning from Human Feedback trainer
    Fine-tunes archetype predictor based on user ratings
    """

    def __init__(
        self,
        predictor: ArchetypePredictionHead,
        learning_rate=1e-4,
        reward_threshold=3.0  # Minimum Likert scale rating to be positive
    ):
        self.predictor = predictor
        self.reward_threshold = reward_threshold
        self.optimizer = torch.optim.Adam(predictor.parameters(), lr=learning_rate)

        # Experience buffer for RLHF
        self.experience_buffer = []

    def add_feedback(
        self,
        text_embedding: torch.Tensor,
        audio_embedding: torch.Tensor,
        predicted_weights: torch.Tensor,
        rating: float  # Likert scale 1-5
    ):
        """
        Add human feedback to experience buffer

        Args:
            text_embedding: Text embedding that generated prediction
            audio_embedding: Audio embedding that generated prediction
            predicted_weights: Archetype weights that were predicted
            rating: User rating (1-5 scale)
        """
        # Convert rating to reward (-1 for bad, +1 for good)
        reward = 1.0 if rating >= self.reward_threshold else -1.0

        self.experience_buffer.append({
            'text_embedding': text_embedding.detach(),
            'audio_embedding': audio_embedding.detach(),
            'predicted_weights': predicted_weights.detach(),
            'reward': reward,
            'rating': rating
        })

    def collect_feedback_with_audio(
        self,
        description: str,
        original_audio: np.ndarray,
        transformed_audio: np.ndarray,
        predicted_weights: np.ndarray,
        text_embedding: torch.Tensor,
        audio_embedding: torch.Tensor,
        sample_rate: int = 44100,
        auto_play: bool = True
    ) -> float:
        """
        Interactive feedback collection with audio playback

        Args:
            description: Text description used
            original_audio: Original input audio
            transformed_audio: Model's transformed audio
            predicted_weights: Predicted archetype weights
            text_embedding: Text embedding used
            audio_embedding: Audio embedding used
            sample_rate: Audio sample rate
            auto_play: Whether to auto-play audio (for Jupyter)

        Returns:
            User rating (1-5)
        """
        import IPython.display as ipd
        from IPython.display import display, HTML

        print("\n" + "="*60)
        print("HUMAN FEEDBACK COLLECTION")
        print("="*60)
        print(f"\nDescription: '{description}'")
        print(f"\nPredicted Archetype Weights:")
        archetype_names = ['sine', 'square', 'sawtooth', 'triangle', 'noise']
        for name, weight in zip(archetype_names, predicted_weights):
            bar = "█" * int(weight * 20)
            print(f"  {name:10s}: {bar} {weight:.3f}")

        print("\n" + "-"*60)
        print("AUDIO PLAYBACK")
        print("-"*60)

        # Display original audio
        print("\n▶️  ORIGINAL AUDIO:")
        display(ipd.Audio(original_audio, rate=sample_rate, autoplay=False))

        # Display transformed audio
        print("\n▶️  TRANSFORMED AUDIO:")
        display(ipd.Audio(transformed_audio, rate=sample_rate, autoplay=False))

        # Get user rating
        print("\n" + "-"*60)
        print("RATING INSTRUCTIONS")
        print("-"*60)
        print("Rate how well the transformation matches the description:")
        print("  5 = Perfect match")
        print("  4 = Good match")
        print("  3 = Acceptable match")
        print("  2 = Poor match")
        print("  1 = Very poor match")

        while True:
            try:
                rating_input = input("\nYour rating (1-5): ")
                rating = float(rating_input)
                if 1 <= rating <= 5:
                    break
                else:
                    print("Please enter a number between 1 and 5.")
            except ValueError:
                print("Invalid input. Please enter a number between 1 and 5.")

        # Add feedback to buffer
        predicted_weights_tensor = torch.from_numpy(predicted_weights).float()
        self.add_feedback(
            text_embedding,
            audio_embedding,
            predicted_weights_tensor,
            rating
        )

        print(f"\n✓ Feedback recorded: {rating}/5")
        print("="*60 + "\n")

        return rating

    def batch_collect_feedback(
        self,
        samples: List[Dict],
        sample_rate: int = 44100,
        max_samples: Optional[int] = None
    ) -> List[float]:
        """
        Collect feedback for multiple samples

        Args:
            samples: List of dicts with 'description', 'original_audio',
                    'transformed_audio', 'predicted_weights', 'text_emb', 'audio_emb'
            sample_rate: Audio sample rate
            max_samples: Maximum number of samples to collect (None = all)

        Returns:
            List of ratings
        """
        ratings = []
        n_samples = min(len(samples), max_samples) if max_samples else len(samples)

        print(f"\n{'='*60}")
        print(f"BATCH FEEDBACK COLLECTION: {n_samples} samples")
        print(f"{'='*60}\n")

        for i, sample in enumerate(samples[:n_samples]):
            print(f"\n>>> Sample {i+1}/{n_samples}")

            rating = self.collect_feedback_with_audio(
                description=sample['description'],
                original_audio=sample['original_audio'],
                transformed_audio=sample['transformed_audio'],
                predicted_weights=sample['predicted_weights'],
                text_embedding=sample['text_embedding'],
                audio_embedding=sample['audio_embedding'],
                sample_rate=sample_rate
            )

            ratings.append(rating)

            # Option to stop early
            if i < n_samples - 1:
                continue_input = input("Continue to next sample? (y/n): ").lower()
                if continue_input != 'y':
                    print("Feedback collection stopped.")
                    break

        return ratings

    def save_audio_comparison(
        self,
        original_audio: np.ndarray,
        transformed_audio: np.ndarray,
        description: str,
        predicted_weights: np.ndarray,
        output_path: str,
        sample_rate: int = 44100
    ):
        """
        Save audio files for offline feedback collection

        Args:
            original_audio: Original audio
            transformed_audio: Transformed audio
            description: Text description
            predicted_weights: Predicted weights
            output_path: Base path for saving (without extension)
            sample_rate: Sample rate
        """
        import soundfile as sf
        import json

        # Save audio files
        sf.write(f"{output_path}_original.wav", original_audio, sample_rate)
        sf.write(f"{output_path}_transformed.wav", transformed_audio, sample_rate)

        # Save metadata
        metadata = {
            'description': description,
            'predicted_weights': {
                'sine': float(predicted_weights[0]),
                'square': float(predicted_weights[1]),
                'sawtooth': float(predicted_weights[2]),
                'triangle': float(predicted_weights[3]),
                'noise': float(predicted_weights[4])
            }
        }

        with open(f"{output_path}_metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)

        print(f"Audio comparison saved to {output_path}_*.wav")

    def update_from_feedback(self, batch_size=16):
        """
        Update predictor using collected feedback
        Uses policy gradient-style update
        """
        if len(self.experience_buffer) < batch_size:
            return None

        # Sample batch from buffer
        indices = np.random.choice(
            len(self.experience_buffer),
            size=min(batch_size, len(self.experience_buffer)),
            replace=False
        )

        batch_loss = 0
        for idx in indices:
            experience = self.experience_buffer[idx]

            # Re-predict with current model
            text_emb = experience['text_embedding'].unsqueeze(0)
            audio_emb = experience['audio_embedding'].unsqueeze(0)
            new_weights = self.predictor(text_emb, audio_emb)

            # Compute log probability of action (predicted weights)
            log_prob = torch.log(new_weights + 1e-10).sum()

            # Policy gradient: log_prob * reward
            reward = experience['reward']
            loss = -log_prob * reward

            batch_loss += loss

        # Average loss
        batch_loss = batch_loss / len(indices)

        # Update
        self.optimizer.zero_grad()
        batch_loss.backward()
        self.optimizer.step()

        return batch_loss.item()

    def clear_buffer(self):
        """Clear experience buffer"""
        self.experience_buffer = []

    def get_feedback_statistics(self) -> Dict:
        """Get statistics about collected feedback"""
        if not self.experience_buffer:
            return {}

        ratings = [exp['rating'] for exp in self.experience_buffer]
        rewards = [exp['reward'] for exp in self.experience_buffer]

        return {
            'num_samples': len(self.experience_buffer),
            'mean_rating': np.mean(ratings),
            'std_rating': np.std(ratings),
            'positive_ratio': (np.array(rewards) > 0).mean(),
            'rating_distribution': np.bincount(np.array(ratings).astype(int), minlength=6)[1:]
        }

In [5]:
class ArchetypeEnsemble(nn.Module):
    """
    Ensemble of multiple archetype predictors for robust predictions
    """

    def __init__(
        self,
        embedding_dim=768,
        num_archetypes=5,
        num_predictors=3,
        device='cpu'
    ):
        super(ArchetypeEnsemble, self).__init__()

        self.num_predictors = num_predictors

        # Create ensemble of predictors
        self.predictors = nn.ModuleList([
            ArchetypePredictionHead(
                embedding_dim=embedding_dim,
                num_archetypes=num_archetypes,
                device=device
            )
            for _ in range(num_predictors)
        ])

        self.to(device)

    def forward(
        self,
        text_embeddings: torch.Tensor,
        audio_embeddings: torch.Tensor,
        return_individual=False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Ensemble prediction

        Args:
            text_embeddings: Text embeddings
            audio_embeddings: Audio embeddings
            return_individual: Whether to return individual predictions

        Returns:
            - Average prediction across ensemble
            - Optional individual predictions
        """
        predictions = []

        for predictor in self.predictors:
            pred = predictor(text_embeddings, audio_embeddings)
            predictions.append(pred)

        # Stack predictions
        predictions = torch.stack(predictions)  # (num_predictors, batch, archetypes)

        # Average ensemble prediction
        ensemble_pred = predictions.mean(dim=0)

        if return_individual:
            return ensemble_pred, predictions
        else:
            return ensemble_pred, None

    def uncertainty_estimation(
        self,
        text_embeddings: torch.Tensor,
        audio_embeddings: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Estimate prediction uncertainty via ensemble variance

        Returns:
            - Mean prediction
            - Standard deviation across ensemble
        """
        ensemble_pred, individual_preds = self.forward(
            text_embeddings,
            audio_embeddings,
            return_individual=True
        )

        # Compute std across ensemble
        uncertainty = individual_preds.std(dim=0)

        return ensemble_pred, uncertainty

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize archetype predictor
predictor = ArchetypePredictionHead(
    embedding_dim=768,
    num_archetypes=5,
    device=device
)

In [7]:
# Create dummy embeddings
batch_size = 8
text_emb = torch.randn(batch_size, 768).to(device)
audio_emb = torch.randn(batch_size, 768).to(device)

# Predict archetype weights
print("=== Testing Archetype Prediction ===")
weights = predictor(text_emb, audio_emb)
print(f"Predicted weights shape: {weights.shape}")
print(f"Sample prediction: {weights[0]}")
print(f"Sum of weights: {weights[0].sum().item():.6f} (should be ~1.0)")

# Test with names
named_weights = predictor.predict_with_names(text_emb, audio_emb)
print(f"\nNamed predictions for sample 0:")
for name, weight in named_weights.items():
    print(f"  {name}: {weight[0]:.4f}")

# Test loss
print("\n=== Testing Archetype Loss ===")
target_weights = torch.rand(batch_size, 5).to(device)
target_weights = target_weights / target_weights.sum(dim=1, keepdim=True)

loss_fn = ArchetypeLoss(loss_type='mse')
loss = loss_fn(weights, target_weights)
print(f"MSE loss: {loss.item():.4f}")

=== Testing Archetype Prediction ===
Predicted weights shape: torch.Size([8, 5])
Sample prediction: tensor([0.1609, 0.1870, 0.2255, 0.1701, 0.2566], grad_fn=<SelectBackward0>)
Sum of weights: 1.000000 (should be ~1.0)

Named predictions for sample 0:
  sine: 0.2511
  square: 0.4593
  sawtooth: 0.0599
  triangle: 0.1269
  noise: 0.1028

=== Testing Archetype Loss ===
MSE loss: 0.0278


In [8]:
# Test RLHF trainer with interactive audio playback
print("\n=== Testing RLHF Trainer with Audio Playback ===")
rlhf_trainer = RLHFTrainer(predictor, learning_rate=1e-4)

# Example: Interactive feedback collection
# Uncomment to test interactively in Jupyter:

# Generate some example audio
import numpy as np
sample_rate = 44100
duration = 2.0
t = np.linspace(0, duration, int(sample_rate * duration))

# Original audio (sine wave)
original = np.sin(2 * np.pi * 440 * t) * 0.5

# Transformed audio (sawtooth-ish)
transformed = 2 * (t * 440 - np.floor(0.5 + t * 440)) * 0.5

# Collect feedback
rating = rlhf_trainer.collect_feedback_with_audio(
    description="bright and cutting guitar tone",
    original_audio=original,
    transformed_audio=transformed,
    predicted_weights=weights[0].detach().cpu().numpy(),
    text_embedding=text_emb[0],
    audio_embedding=audio_emb[0],
    sample_rate=sample_rate
)


=== Testing RLHF Trainer with Audio Playback ===

HUMAN FEEDBACK COLLECTION

Description: 'bright and cutting guitar tone'

Predicted Archetype Weights:
  sine      : ███ 0.161
  square    : ███ 0.187
  sawtooth  : ████ 0.225
  triangle  : ███ 0.170
  noise     : █████ 0.257

------------------------------------------------------------
AUDIO PLAYBACK
------------------------------------------------------------

▶️  ORIGINAL AUDIO:



▶️  TRANSFORMED AUDIO:



------------------------------------------------------------
RATING INSTRUCTIONS
------------------------------------------------------------
Rate how well the transformation matches the description:
  5 = Perfect match
  4 = Good match
  3 = Acceptable match
  2 = Poor match
  1 = Very poor match
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5.
Invalid input. Please enter a number between 1 and 5

In [None]:
print("\n=== Interactive RLHF Usage Example ===")
print("""
# In Jupyter notebook, use this pattern:

# 1. Generate or load audio samples
original_audio = librosa.load('input.wav')[0]
transformed_audio = model.transform(original_audio, description)

# 2. Get embeddings and predictions
text_emb = text_encoder([description])
audio_emb = audio_encoder(torch.from_numpy(original_audio))
predicted_weights = predictor(text_emb, audio_emb)

# 3. Collect interactive feedback with audio playback
rating = rlhf_trainer.collect_feedback_with_audio(
description="bright and crunchy",
original_audio=original_audio,
transformed_audio=transformed_audio,
predicted_weights=predicted_weights[0].cpu().numpy(),
text_embedding=text_emb[0],
audio_embedding=audio_emb[0],
sample_rate=44100
)

# 4. Update model from feedback
loss = rlhf_trainer.update_from_feedback(batch_size=16)
print(f"Updated model with loss: {loss}")

# 5. View feedback statistics
stats = rlhf_trainer.get_feedback_statistics()
print(f"Collected {stats['num_samples']} ratings, mean: {stats['mean_rating']:.2f}")
""")