In [18]:
"""
Training Pipeline for LSTMABAR
Includes data loading, training loop, and evaluation
"""

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import librosa
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import matplotlib.pyplot as plt

from lstmabar_model import LSTMABAR, LSTMABARTrainer
from musiccaps_loader import MusicCapsLoader

In [2]:
class MusicCapsDataset(Dataset):
    """
    PyTorch Dataset for MusicCaps with archetype annotations
    """
    
    def __init__(
        self,
        training_data_path: str,
        sample_rate: int = 44100,
        audio_duration: float = 2.0,
        augment: bool = False
    ):
        """
        Args:
            training_data_path: Path to .npz file with training data
            sample_rate: Audio sample rate
            audio_duration: Duration to load from each audio file
            augment: Whether to apply data augmentation
        """
        self.sample_rate = sample_rate
        self.audio_duration = audio_duration
        self.augment = augment
        self.target_samples = int(sample_rate * audio_duration)
        
        # Load training data
        print(f"Loading training data from {training_data_path}")
        data = np.load(training_data_path, allow_pickle=True)
        
        self.archetype_vectors = torch.from_numpy(data['archetype_vectors']).float()
        self.descriptions = data['descriptions'].tolist()
        self.audio_paths = data['audio_paths'].tolist()
        self.archetype_order = data['archetype_order'].tolist()
        
        print(f"Loaded {len(self.descriptions)} training examples")
        
        # Filter out samples with missing audio files
        self.valid_indices = self._find_valid_samples()
        print(f"Found {len(self.valid_indices)} samples with valid audio files")
    
    def _find_valid_samples(self) -> List[int]:
        """Find indices with existing audio files"""
        valid = []
        for i, audio_path in enumerate(self.audio_paths):
            if Path(audio_path).exists():
                valid.append(i)
        return valid
    
    def __len__(self) -> int:
        return len(self.valid_indices)
    
    def __getitem__(self, idx: int) -> Dict:
        """
        Get a single training sample
        
        Returns:
            Dict with 'audio', 'description', 'archetype_weights'
        """
        # Map to valid index
        actual_idx = self.valid_indices[idx]
        
        # Load audio
        audio_path = self.audio_paths[actual_idx]
        audio, sr = librosa.load(
            audio_path,
            sr=self.sample_rate,
            duration=self.audio_duration
        )
        
        # Pad or trim to exact length
        if len(audio) < self.target_samples:
            audio = np.pad(audio, (0, self.target_samples - len(audio)))
        else:
            audio = audio[:self.target_samples]
        
        # Apply augmentation if enabled
        if self.augment:
            audio = self._augment_audio(audio)
        
        # Convert to tensor
        audio_tensor = torch.from_numpy(audio).float()
        
        # Get description and archetype weights
        description = self.descriptions[actual_idx]
        archetype_weights = self.archetype_vectors[actual_idx]
        
        return {
            'audio': audio_tensor,
            'description': description,
            'archetype_weights': archetype_weights
        }
    
    def _augment_audio(self, audio: np.ndarray) -> np.ndarray:
        """Apply random audio augmentations"""
        # Random gain (±3dB)
        if np.random.random() > 0.5:
            gain_db = np.random.uniform(-3, 3)
            audio = audio * (10 ** (gain_db / 20))
        
        # Random time shift
        if np.random.random() > 0.5:
            shift = np.random.randint(-self.sample_rate // 10, self.sample_rate // 10)
            audio = np.roll(audio, shift)
        
        # Add slight noise
        if np.random.random() > 0.5:
            noise = np.random.randn(len(audio)) * 0.005
            audio = audio + noise
        
        return audio


def collate_fn(batch: List[Dict]) -> Dict:
    """
    Custom collate function for DataLoader
    Handles variable-length descriptions
    """
    audio = torch.stack([item['audio'] for item in batch])
    descriptions = [item['description'] for item in batch]
    archetype_weights = torch.stack([item['archetype_weights'] for item in batch])
    
    return {
        'audio': audio,
        'descriptions': descriptions,
        'archetype_weights': archetype_weights
    }

In [32]:
class TrainingPipeline:
    """
    Complete training pipeline for LSTMABAR
    """
    
    def __init__(
        self,
        model: LSTMABAR,
        train_dataset: MusicCapsDataset,
        val_dataset: Optional[MusicCapsDataset] = None,
        batch_size: int = 16,
        learning_rate: float = 1e-4,
        num_epochs: int = 50,
        checkpoint_dir: str = 'checkpoints',
        log_interval: int = 10
    ):
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.checkpoint_dir = Path(checkpoint_dir)
        self.log_interval = log_interval
        
        # Create checkpoint directory
        self.checkpoint_dir.mkdir(exist_ok=True)
        
        # Create data loaders
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=0,
            pin_memory=True
        )
        
        if val_dataset is not None:
            self.val_loader = DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
                collate_fn=collate_fn,
                num_workers=0,
                pin_memory=True
            )
        else:
            self.val_loader = None
        
        # Initialize trainer
        self.trainer = LSTMABARTrainer(
            model,
            learning_rate=learning_rate,
            loss_weights={
                'contrastive': 0.7,
                'archetype_prediction': 0.2,
                'audio_archetype_supervision': 0.1
            }
        )
        
        print(f"Training pipeline initialized:")
        print(f"  Training samples: {len(train_dataset)}")
        print(f"  Validation samples: {len(val_dataset) if val_dataset else 0}")
        print(f"  Batch size: {batch_size}")
        print(f"  Total epochs: {num_epochs}")
        print(f"  Steps per epoch: {len(self.train_loader)}")
    
    def train(self):
        """Run complete training loop"""
        best_val_loss = float('inf')
        
        for epoch in range(self.num_epochs):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch+1}/{self.num_epochs}")
            print(f"{'='*60}")
            
            # Training
            train_losses = self.trainer.train_epoch(self.train_loader, epoch)
            print(f"\nTrain Losses: {train_losses}")
            
            # Validation
            if self.val_loader is not None:
                val_losses = {'total': 0.0}
                for batch in self.val_loader:
                    batch_losses = self.trainer.validate(
                        batch['descriptions'],
                        batch['audio'].to(self.model.device),
                        batch['archetype_weights'].to(self.model.device)
                    )
                    for key in batch_losses:
                        val_losses[key] = val_losses.get(key, 0.0) + batch_losses[key]
                
                # Average validation losses
                for key in val_losses:
                    val_losses[key] /= len(self.val_loader)
                
                print(f"Val Losses: {val_losses}")
                
                # Save best model
                if val_losses['total'] < best_val_loss:
                    best_val_loss = val_losses['total']
                    save_path = self.checkpoint_dir / 'best_model.pth'
                    self.model.save_checkpoint(
                        str(save_path),
                        epoch,
                        self.trainer.optimizer.state_dict()
                    )
                    print(f"✓ Best model saved (val_loss: {best_val_loss:.4f})")
            
            # Save periodic checkpoint
            if (epoch + 1) % 10 == 0:
                save_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch+1}.pth'
                self.model.save_checkpoint(
                    str(save_path),
                    epoch,
                    self.trainer.optimizer.state_dict()
                )
        
        print(f"\n{'='*60}")
        print("Training complete!")
        print(f"{'='*60}")
        
        # Save final model
        final_path = self.checkpoint_dir / 'final_model.pth'
        self.model.save_checkpoint(
            str(final_path),
            self.num_epochs - 1,
            self.trainer.optimizer.state_dict()
        )
        
        # Plot training history
        self.plot_training_history()
    
    def plot_training_history(self):
        """Plot training curves"""
        history = self.trainer.history
        
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Total loss
        axes[0].plot(history['train_loss'], label='Train Loss')
        if history['val_loss']:
            axes[0].plot(history['val_loss'], label='Val Loss')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Total Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Component losses
        axes[1].plot(history['contrastive_loss'], label='Contrastive Loss')
        axes[1].plot(history['archetype_loss'], label='Archetype Loss')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Loss')
        axes[1].set_title('Component Losses')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(self.checkpoint_dir / 'training_history.png', dpi=300)
        print(f"Training history plot saved to {self.checkpoint_dir / 'training_history.png'}")
        plt.close()


def prepare_musiccaps_data(
    csv_path: str,
    audio_dir: str = 'musiccaps_audio',
    output_path: str = 'musiccaps_training_data.npz',
    max_downloads: int = 500,
    train_split: float = 0.7,
    val_split: float = 0.15,
    test_split: float = 0.15,
    random_seed: int = 42
) -> Tuple[str, str]:
    """
    Prepare MusicCaps dataset for training

    Args:
        csv_path: Path to MusicCaps CSV
        audio_dir: Directory to save audio
        output_path: Base path for output files
        max_downloads: Max clips to download
        train_split: Fraction for training (default 0.7)
        val_split: Fraction for validation (default 0.15)
        test_split: Fraction for testing (default 0.15)
        random_seed: Random seed for reproducibility
    
    Returns:
        Tuple of (train_path, val_path, test_path) .npz files
    """

    assert abs(train_split + val_split + test_split - 1.0) < 1e-6, \
        "Splits must sum to 1.0"

    print("=== Preparing MusicCaps Dataset ===")
    
    # Load and process MusicCaps
    loader = MusicCapsLoader(csv_path, audio_dir)
    
    # Download audio clips
    print(f"\nDownloading up to {max_downloads} audio clips...")
    downloaded, failed = loader.download_audio_clips(
        max_clips=max_downloads,
        use_balanced_subset=True
    )
    
    print(f"Successfully downloaded: {len(downloaded)}")
    print(f"Failed: {len(failed)}")
    
    # Create archetype training data
    print("\nCreating archetype training data...")
    training_data = loader.create_archetype_training_data(use_tfidf_weighting=True)
    
    # Shuffle data for random splits
    np.random.seed(random_seed)
    indices = np.random.permutation(len(training_data))
    training_data = [training_data[i] for i in indices]
    
    # Calculate split points
    n_total = len(training_data)
    n_train = int(n_total * train_split)
    n_val = int(n_total * val_split)
    n_test = n_total - n_train - n_val
    
    # Split data
    train_data = training_data[:n_train]
    val_data = training_data[n_train:n_train + n_val]
    test_data = training_data[n_train + n_val:]
    
    print(f"\nDataset split:")
    print(f"  Train: {len(train_data)} samples ({train_split*100:.1f}%)")
    print(f"  Val:   {len(val_data)} samples ({val_split*100:.1f}%)")
    print(f"  Test:  {len(test_data)} samples ({test_split*100:.1f}%)")
    
    # Save train set
    train_path = output_path.replace('.npz', '_train.npz')
    vectors_train = np.array([item['archetype_vector'] for item in train_data])
    descriptions_train = [item['description'] for item in train_data]
    audio_paths_train = [item['audio_path'] for item in train_data]
    
    np.savez_compressed(
        train_path,
        archetype_vectors=vectors_train,
        descriptions=descriptions_train,
        audio_paths=audio_paths_train,
        archetype_order=train_data[0]['archetype_order']
    )
    print(f"\n✓ Saved train set: {train_path}")
    
    # Save val set
    val_path = output_path.replace('.npz', '_val.npz')
    vectors_val = np.array([item['archetype_vector'] for item in val_data])
    descriptions_val = [item['description'] for item in val_data]
    audio_paths_val = [item['audio_path'] for item in val_data]
    
    np.savez_compressed(
        val_path,
        archetype_vectors=vectors_val,
        descriptions=descriptions_val,
        audio_paths=audio_paths_val,
        archetype_order=val_data[0]['archetype_order']
    )
    print(f"✓ Saved val set: {val_path}")
    
    # Save test set
    test_path = output_path.replace('.npz', '_test.npz')
    vectors_test = np.array([item['archetype_vector'] for item in test_data])
    descriptions_test = [item['description'] for item in test_data]
    audio_paths_test = [item['audio_path'] for item in test_data]
    
    np.savez_compressed(
        test_path,
        archetype_vectors=vectors_test,
        descriptions=descriptions_test,
        audio_paths=audio_paths_test,
        archetype_order=test_data[0]['archetype_order']
    )
    print(f"✓ Saved test set: {test_path}")
    
    return train_path, val_path, test_path

In [29]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}\n")

# Configuration
config = {
    'embedding_dim': 768,
    'audio_architecture': 'resnet',  # or 'ast'
    'sample_rate': 44100,
    'audio_duration': 2.0,
    'batch_size': 16,
    'learning_rate': 1e-4,
    'num_epochs': 20,
    'max_downloads': 500
}

Using device: cpu



In [6]:
# Step 0: Download data (run once)

import kagglehub

# Download latest version
path = kagglehub.dataset_download("googleai/musiccaps")

print("Path to dataset files:", path)

csv_path = f"{path}/musiccaps-public.csv"

Path to dataset files: /Users/shanthgopalswamy/.cache/kagglehub/datasets/googleai/musiccaps/versions/1


In [33]:
# Step 1: Prepare dataset (run once)
print("Step 1: Preparing MusicCaps dataset with train/val/test split...")
train_data_path, val_data_path, test_data_path = prepare_musiccaps_data(
    csv_path=csv_path,
    audio_dir='musiccaps_audio',
    max_downloads=config['max_downloads'],
    train_split=0.7,
    val_split=0.15,
    test_split=0.15,
    random_seed=42
)

Step 1: Preparing MusicCaps dataset with train/val/test split...
=== Preparing MusicCaps Dataset ===
Loaded 5521 MusicCaps examples

Downloading up to 500 audio clips...
Using balanced subset: 1000 examples
✓ Already exists: -bgHkxwoliw
✓ Already exists: -kpR93atgd8
✓ Already exists: -wymN80CiYU
✓ Already exists: 07xGXxIHOL4
✓ Already exists: 0PMFAO4TIU4
✓ Already exists: 0TV9zvfwFhs
✓ Already exists: 0fiOM---7QI


ERROR: [youtube] 0i8VM_EooCs: Video unavailable. This video is no longer available due to a copyright claim by Terrabyte Music Limited


Error downloading 0i8VM_EooCs: ERROR: [youtube] 0i8VM_EooCs: Video unavailable. This video is no longer available due to a copyright claim by Terrabyte Music Limited
✗ Failed: 0i8VM_EooCs
✓ Already exists: 0jFQ21A6GRA
✓ Already exists: 1ACn3u5UnBw
✓ Already exists: 1BVSYfNCcv0
✓ Already exists: 1JpeDWbgUO8
✓ Already exists: 1PKxdTlquCA
✓ Already exists: 1Q9DXhXMSFI
✓ Already exists: 1TyOPtg0Yfk
✓ Already exists: 1V7ReAk9k-4
✓ Already exists: 1j4rFfU5XKQ
✓ Already exists: 20Vh6z6Ie0E


ERROR: [youtube] 2G5bSYHcJSM: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.


Error downloading 2G5bSYHcJSM: ERROR: [youtube] 2G5bSYHcJSM: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.
✗ Failed: 2G5bSYHcJSM
✓ Already exists: 2GWkKVHxGRM
✓ Already exists: 2JnlmS1zzls
✓ Already exists: 2RU4CSDzS-g
✓ Already exists: 2U8Dvh7nwFI
✓ Already exists: 2ZfthfWQowE
✓ Already exists: 2bCuw7U_Rac
✓ Already exists: 2dyEnOo3yJ8
✓ Already exists: 2vQTq4QLP8U
✓ Already exists: 2xGRCsW6-Bk
✓ Already exists: 3JYQgXudiH8
✓ Already exists: 3TQmts_MxyQ
✓ Already exists: 40D4L5Ndi6k
✓ Already exists: 44sbWBFswUY
✓ Already exists: 4i11P4OCRfk
✓ Already exists: 5JQIsqc8HBc
✓ Already exists: 5XXAeSybGK0
✓ Already exists: 5ZpVhmhVYoI
✓ Already exists: 5_orEetudIA
✓ Already exists: 5gyMt0YzPQ0
✓ Already exists: 60OIHit4Q-M
✓ Already exists: 6N1LWG4aztA
✓ Already exists: 6k4lcF9IGUk
✓ Already exists: 7-mNJ4IUY5Q
✓ Already exists: 7WZwlOrRELI
✓ Already exists: 7_q36NyJtQY
✓ Already exists: 8BJljuSm2Aw
✓ Alread

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


✗ Failed: BXo1Tr_oJds
✓ Already exists: Bl-lCgr5hGY
✓ Already exists: BnkDQXlrIX4
✓ Already exists: Byk9p21g51g
✓ Already exists: C5MhO2HM2Wg
✓ Already exists: C6roSYqchkk
✓ Already exists: C8VECv8kicU
✓ Already exists: CJjyrDGmxIY
✓ Already exists: CP3phqztym0
✓ Already exists: CRxIJ7YbcZA
✓ Already exists: CWQvCCRuU6k
✓ Already exists: CZuH43NPynA
✓ Already exists: Cchf2QH63bI
✓ Already exists: ChyayWIp_vU
✓ Already exists: CphwhKgYHaM
✓ Already exists: CzMNiypg1I8


ERROR: [youtube] Czbi1u-gwUU: Video unavailable. This video is no longer available due to a copyright claim by Cheb Hasni


Error downloading Czbi1u-gwUU: ERROR: [youtube] Czbi1u-gwUU: Video unavailable. This video is no longer available due to a copyright claim by Cheb Hasni
✗ Failed: Czbi1u-gwUU
✓ Already exists: D2w3qHmJrdU
✓ Already exists: D3FyfFIKLVc
✓ Already exists: D4ccFYk3bhU


ERROR: [youtube] D8-x1T8M4gk: Video unavailable. This video has been removed by the uploader


Error downloading D8-x1T8M4gk: ERROR: [youtube] D8-x1T8M4gk: Video unavailable. This video has been removed by the uploader
✗ Failed: D8-x1T8M4gk
✓ Already exists: DAPGvg8qOAU
✓ Already exists: DCFrCX4HPO8
✓ Already exists: DG5d4megH8g
✓ Already exists: DGbMEkQerYs
✓ Already exists: DKflAAykh6A
✓ Already exists: DP2vmsftZHY
✓ Already exists: DU5pD63Pv30
✓ Already exists: DaiVfxATCEE
✓ Already exists: DdxW_JziHTA


ERROR: [youtube] DysXetu2I0E: Video unavailable


Error downloading DysXetu2I0E: ERROR: [youtube] DysXetu2I0E: Video unavailable
✗ Failed: DysXetu2I0E
✓ Already exists: EKZvq0dUk50
✓ Already exists: EUNTykrvpok
✓ Already exists: EaGhKzpkNso


ERROR: [youtube] EfUUgsioXyU: Private video. Sign in if you've been granted access to this video. Use --cookies-from-browser or --cookies for the authentication. See  https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp  for how to manually pass cookies. Also see  https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies  for tips on effectively exporting YouTube cookies


Error downloading EfUUgsioXyU: ERROR: [youtube] EfUUgsioXyU: Private video. Sign in if you've been granted access to this video. Use --cookies-from-browser or --cookies for the authentication. See  https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp  for how to manually pass cookies. Also see  https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies  for tips on effectively exporting YouTube cookies
✗ Failed: EfUUgsioXyU
✓ Already exists: EmSZKb0LdVM
✓ Already exists: Es9FNjZ-SHI
✓ Already exists: FCzMqo8kh1o
✓ Already exists: FDO5BekX478
✓ Already exists: FENJIDecy5s
✓ Already exists: Fsm-xDmyFKg
✓ Already exists: FsnRM2irjvI
✓ Already exists: FteW_2gNtD4


ERROR: [youtube] Fv9swdLA-lo: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.


Error downloading Fv9swdLA-lo: ERROR: [youtube] Fv9swdLA-lo: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.
✗ Failed: Fv9swdLA-lo
✓ Already exists: G2uCAwYS6w0
✓ Already exists: GHQlBD-6rkA
✓ Already exists: GJYhDjThTHM
✓ Already exists: GLIXnXZEOxY
✓ Already exists: GPSqrciDLog
✓ Already exists: GQbUpJFArKI
✓ Already exists: GYCfrx0ruz4
✓ Already exists: GbjtSTTEFK4
✓ Already exists: Gc8xf7CJiFY
✓ Already exists: GkB_BkyVyPs
✓ Already exists: Guu30szkA-0
✓ Already exists: H6qzijVEqZQ
✓ Already exists: HFH9tcIK_PM
✓ Already exists: HFVM5pVTwkM
✓ Already exists: HHTgjmgTV6c
✓ Already exists: HNf9eHqDT1A
✓ Already exists: HS_ikHx4LIQ
✓ Already exists: HU7oqkJeItQ
✓ Already exists: HYjSrwSm0T4
✓ Already exists: HfzEa06vDLg
✓ Already exists: Hg4f2xt3oKA
✓ Already exists: HkXSX7Kdhms
✓ Already exists: Hnk45Z0EAxg
✓ Already exists: HzXWXYxXyYA
✓ Already exists: ID4AoAfHMVk
✓ Already exists: IKq2OF8jq1c
✓ Alread

ERROR: [youtube] LRfVQsnaVQE: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.


Error downloading LRfVQsnaVQE: ERROR: [youtube] LRfVQsnaVQE: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.
✗ Failed: LRfVQsnaVQE
✓ Already exists: L_nC2BvhRdQ
✓ Already exists: LfvdxSBCtFE
✓ Already exists: LjihfG0fit0
✓ Already exists: LybSS4amIS0
✓ Already exists: LzSWdj4izHM
✓ Already exists: MHkfPjW0aRg
✓ Already exists: MIexFfOsuJs
✓ Already exists: MKikHxKeodA
✓ Already exists: MM0seezR2F4
✓ Already exists: MVYSWTF11Nc
✓ Already exists: MY0PsDE3xHs
✓ Already exists: MdYXznF3Eac
✓ Already exists: MpWGx5odhh8
✓ Already exists: MsjeOXuUYG4
✓ Already exists: MvnC1TfNiPY
✓ Already exists: MzUgHy7SyS8
✓ Already exists: N-dzfI3L5ic
✓ Already exists: NHA1l_Czm38
✓ Already exists: N_Wx35sNqdM
✓ Already exists: NlCfScKw_Mk
✓ Already exists: NsYVaRI6rXg
✓ Already exists: Nt0U-CXK6O0
✓ Already exists: NwA9JSlK_lM
✓ Already exists: O1RmrE_HfpE
✓ Already exists: OB7GyVqufwQ
✓ Already exists: OEjgIDubFbg
✓ Alread

ERROR: [youtube] T6iv9GFIVyU: Video unavailable. This video is no longer available due to a copyright claim by Rishad Zahir


Error downloading T6iv9GFIVyU: ERROR: [youtube] T6iv9GFIVyU: Video unavailable. This video is no longer available due to a copyright claim by Rishad Zahir
✗ Failed: T6iv9GFIVyU
✓ Already exists: T7A0RejsZIo
✓ Already exists: T7ZSZhcsfjA
✓ Already exists: TN53jpjqAGI
✓ Already exists: TPYNIc_M1ng
✓ Already exists: Tp8PG2xae8c
✓ Already exists: Tsmx6Pb7CnU
✓ Already exists: TworrkXAPuI
✓ Already exists: TzPuAqjoL80
✓ Already exists: U4UtZeTl2DE
✓ Already exists: UDN11Q90Fa4
✓ Already exists: UFyOGqmITjM
✓ Already exists: UIOnnpaqBy8
✓ Already exists: UNJswfXKJ3s
✓ Already exists: UQKLBsZJsww
✓ Already exists: UcabTrKowlI
✓ Already exists: UnFEqUWTefM
✓ Already exists: UoxHwOl2gN0
✓ Already exists: UrgzGbGVV8I
✓ Already exists: UsdoUjuczY4
✓ Already exists: UtZofZjccBs
✓ Already exists: UvCY9FHpKC8
✓ Already exists: V3Vvp5HS90k
✓ Already exists: V9jIsOTC1lY
✓ Already exists: VCusyLPrFCo
✓ Already exists: VG6-MlmCgzI
✓ Already exists: VHYxygh1STA
✓ Already exists: VL6uF-XeE_A
✓ Already exi

ERROR: [youtube] W0aT3SdtnfY: Private video. Sign in if you've been granted access to this video. Use --cookies-from-browser or --cookies for the authentication. See  https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp  for how to manually pass cookies. Also see  https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies  for tips on effectively exporting YouTube cookies


Error downloading W0aT3SdtnfY: ERROR: [youtube] W0aT3SdtnfY: Private video. Sign in if you've been granted access to this video. Use --cookies-from-browser or --cookies for the authentication. See  https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp  for how to manually pass cookies. Also see  https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies  for tips on effectively exporting YouTube cookies
✗ Failed: W0aT3SdtnfY
✓ Already exists: W3lKc2hj4XU
✓ Already exists: W7U-glgu4GM
✓ Already exists: WCifI6rwOoM
✓ Already exists: WEVBqGarEIY
✓ Already exists: WMtztIW1f6k
✓ Already exists: WPguqXCBQCI
✓ Already exists: WTVC7ZI9WtY
✓ Already exists: WT_wvvEvkw4
✓ Already exists: WaddbqEQ1NE


ERROR: [youtube] We0WIPYrtRE: Video unavailable


Error downloading We0WIPYrtRE: ERROR: [youtube] We0WIPYrtRE: Video unavailable
✗ Failed: We0WIPYrtRE
✓ Already exists: WeDA1mDFSCo
✓ Already exists: WgZ8KAnnTb8
✓ Already exists: WsDb16qzA5Q
✓ Already exists: WtN6uiDikRM
✓ Already exists: X96v9LlsjJM
✓ Already exists: XE4NRSDLYG8
✓ Already exists: XUD-9HkQuTE
✓ Already exists: XXBVsNt2Qr8
✓ Already exists: XYOnq7ju7o0
✓ Already exists: XgOA5oRkL2A
✓ Already exists: XjUmXwVlDDo
✓ Already exists: XkBXsaSXDJ0


ERROR: [youtube] XvtL_TTLXHY: Video unavailable


Error downloading XvtL_TTLXHY: ERROR: [youtube] XvtL_TTLXHY: Video unavailable
✗ Failed: XvtL_TTLXHY
✓ Already exists: XwhAoMLNYWQ
✓ Already exists: XykUpCigu4w
✓ Already exists: Y7mTjfgcybQ
✓ Already exists: YZx0_GRtvJk
✓ Already exists: YcWJUHWt-64
✓ Already exists: YrGQKTbiG1g
✓ Already exists: YzpzKyzyL0Y
✓ Already exists: Z31gI08SMzI
✓ Already exists: Z7V7Curou7s
✓ Already exists: Z8L3jychP14
✓ Already exists: ZEuY5HnECuo
✓ Already exists: ZFimyfPWltk
✓ Already exists: ZJHlHb-VyDc
✓ Already exists: ZLXW4ewrVpQ
✓ Already exists: ZMd8mAKe-k8
✓ Already exists: ZNGvyFsCx4g
✓ Already exists: ZUcHBeueBww
✓ Already exists: ZUkh168Nyus
✓ Already exists: ZaUaqnLdg6k
✓ Already exists: Zhurw43-Y1g
✓ Already exists: ZkfKOLp5SxU
✓ Already exists: Zlbo8ygfPSM
✓ Already exists: ZmgkpmzvL6c
✓ Already exists: ZoAfkpmztww
✓ Already exists: ZsmfIMEzrQs
✓ Already exists: Zt8x7tvP9Qs
✓ Already exists: Zz1Bz1a7yPE


ERROR: [youtube] _3OlK_1yQOk: Video unavailable. This video contains content from Storm Labels Inc., who has blocked it on copyright grounds


Error downloading _3OlK_1yQOk: ERROR: [youtube] _3OlK_1yQOk: Video unavailable. This video contains content from Storm Labels Inc., who has blocked it on copyright grounds
✗ Failed: _3OlK_1yQOk
✓ Already exists: _43OOP6UEw0
✓ Already exists: _78P-0zWJtg
✓ Already exists: _9OUh0uwDec
✓ Already exists: _R9Ma9rjEWg
✓ Already exists: _b5n-mny1lM
✓ Already exists: _fKntnlIYTQ
✓ Already exists: _gWEpDgPAho
✓ Already exists: _h2rFVPCSPE
✓ Already exists: _lq8nEXh064
✓ Already exists: _m-N4i-ge28
✓ Already exists: _mQ6KuA2p6k
✓ Already exists: _n3r2inlqBc
✓ Already exists: _n9boKzVRhs
✓ Already exists: _yXtw_z2xf4
✓ Already exists: a2Wuroc8DQU
✓ Already exists: aJHv6TV7JpY
✓ Already exists: aOGNUGgTQ8k
✓ Already exists: aPQTrv2B1sw
✓ Already exists: aUH12rRIVDw
✓ Already exists: aUXKK9AmrPU
✓ Already exists: aUvHaURNgY8
✓ Already exists: aW6greyYuO4
✓ Already exists: aWK9CcvOK9w
✓ Already exists: aY8-pXDdwiw
✓ Already exists: ad6UhYwTXXQ
✓ Already exists: adYFXYPqo2M
✓ Already exists: ajy9PM2S

ERROR: [youtube] cADT8fUucLQ: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.


Error downloading cADT8fUucLQ: ERROR: [youtube] cADT8fUucLQ: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.
✗ Failed: cADT8fUucLQ
✓ Already exists: cBd0yZ27dtA
✓ Already exists: cGUhG5PZp0A
✓ Already exists: cS2gRhH6it4
✓ Already exists: cXEJWtj2kT8
✓ Already exists: cYRsnYEPIiM
✓ Already exists: cbq6Q2htPRM
✓ Already exists: chw8sAKOM5k
✓ Already exists: clefr8E-iZQ
✓ Already exists: cnvmLwFZr28
✓ Already exists: cp8t27oT_ww
✓ Already exists: cs-zcTX2tRA
✓ Already exists: d1nz5tZckSA
✓ Already exists: dBAeAk7dXnU
✓ Already exists: dMAp3dvs3kE
✓ Already exists: dNOHIxD0j_Q
✓ Already exists: dSJpZQ8u_xY
✓ Already exists: dSs4xfvATjc
✓ Already exists: darQBSIlol8
✓ Already exists: deIj55UAxeo


ERROR: [youtube] doX8FjlNPf8: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.


Error downloading doX8FjlNPf8: ERROR: [youtube] doX8FjlNPf8: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.
✗ Failed: doX8FjlNPf8
✓ Already exists: dvDSgmqbrM0
✓ Already exists: dwAo0dKCyBI
✓ Already exists: dwFtlQLdbq0
✓ Already exists: dwSj0Rr3vFc
✓ Already exists: dy_yFZ6dL34
✓ Already exists: e1KHGfMekek
✓ Already exists: e2tZmQI8ICw
✓ Already exists: e7WPFeDPFB4
✓ Already exists: e8wnUU5pIWE


ERROR: [youtube] eHeUipPZHIc: Video unavailable. This video contains content from SVG Music, who has blocked it on copyright grounds


Error downloading eHeUipPZHIc: ERROR: [youtube] eHeUipPZHIc: Video unavailable. This video contains content from SVG Music, who has blocked it on copyright grounds
✗ Failed: eHeUipPZHIc
✓ Already exists: eI4PbSh6g_Y
✓ Already exists: eM0PkfqGmIE
✓ Already exists: eOmQbJljnqE
✓ Already exists: eQTK2fo3RoE
✓ Already exists: eSesh6vnek8
✓ Already exists: eStzDzEopDI
✓ Already exists: eW8se7t0s-U
✓ Already exists: eWwWwoQLtVg
✓ Already exists: eXWBC3XfiXY
✓ Already exists: eXrJL1VUQNE
✓ Already exists: eYngZ5It0b8
✓ Already exists: eZE0RmJESFU
✓ Already exists: eZNnuRvrZDU
✓ Already exists: e_W17jp40G4
✓ Already exists: efTVnvwI2PQ
✓ Already exists: eiFyXXqd9Rk
✓ Already exists: eiUjc4UPnSs
✓ Already exists: esIzFH7vYLY
✓ Already exists: euAQCWBX6ns
✓ Already exists: evscfdO-oSY
✓ Already exists: f3l6KnC8930
✓ Already exists: f8nysknTFUo
✓ Already exists: fEfe8jznp5Q
✓ Already exists: fH9CY48sfJY
✓ Already exists: fHNAxa0QaOM
✓ Already exists: fPYeqTFc3IQ
✓ Already exists: fWypK9RHJJI
✓ Al

ERROR: [youtube] fZyq2pM2-dI: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.


Error downloading fZyq2pM2-dI: ERROR: [youtube] fZyq2pM2-dI: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.
✗ Failed: fZyq2pM2-dI
✓ Already exists: feC0L9MtghM
✓ Already exists: fer_4HvG3aY
✓ Already exists: fgCTFyzKQtk
✓ Already exists: fhWzjWZqzvs
✓ Already exists: fow1TC_MpHs
✓ Already exists: fsTVRca31nI
✓ Already exists: fsXfBoNcLeM
✓ Already exists: ftaHv79hRoY
✓ Already exists: fvw3Bi0GONA
✓ Already exists: g0scnRzoo9M
✓ Already exists: g4xhZgKwiNo


ERROR: [youtube] g8USMvt9np0: We're processing this video. Check back later.


Error downloading g8USMvt9np0: ERROR: [youtube] g8USMvt9np0: We're processing this video. Check back later.
✗ Failed: g8USMvt9np0
✓ Already exists: gAURHUoIK0M
✓ Already exists: gBuLpP4klvI
✓ Already exists: gDm4IphrlYg
✓ Already exists: gDnJoHpSL4M
✓ Already exists: gDzi8N3BYMw
✓ Already exists: gEvCUcZ6w88
✓ Already exists: gFxLnprPgv4
✓ Already exists: gRn6OjQf2ZQ
✓ Already exists: gWRfk8nCcPs
✓ Already exists: gXOyw8a4_Xs
✓ Already exists: g_bgmnJ1b_g


ERROR: [youtube] gdtw54I8soM: Video unavailable


Error downloading gdtw54I8soM: ERROR: [youtube] gdtw54I8soM: Video unavailable
✗ Failed: gdtw54I8soM
✓ Already exists: giPa2vVEyVc
✓ Already exists: gjJWbtCShqo
✓ Already exists: gsBXngKgy-Q
✓ Already exists: guRyU4B5LlA
✓ Already exists: guYWKdxrtIg
✓ Already exists: gxzU5EqNL14
✓ Already exists: h0-6U948u7Y
✓ Already exists: h8JS_FEF_fY
✓ Already exists: hDsA_ky9Hfw
✓ Already exists: hDzmNYd_eaA
✓ Already exists: hFj0KUzofNg
✓ Already exists: hFqZZrj0rnM
✓ Already exists: hQ5OBio4Cy0
✓ Already exists: hRbukCd6N68


ERROR: [youtube] hTAWbHXCJ2A: Sign in to confirm your age. This video may be inappropriate for some users. Use --cookies-from-browser or --cookies for the authentication. See  https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp  for how to manually pass cookies. Also see  https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies  for tips on effectively exporting YouTube cookies


Error downloading hTAWbHXCJ2A: ERROR: [youtube] hTAWbHXCJ2A: Sign in to confirm your age. This video may be inappropriate for some users. Use --cookies-from-browser or --cookies for the authentication. See  https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp  for how to manually pass cookies. Also see  https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies  for tips on effectively exporting YouTube cookies
✗ Failed: hTAWbHXCJ2A
✓ Already exists: hTNKYJ6suII
✓ Already exists: hUcuXIvDN2E
✓ Already exists: hVPQu1UJ2N8
✓ Already exists: hgitRq_0410
✓ Already exists: hlquKjPgxmY
✓ Already exists: hpiFoinUgvY
✓ Already exists: hqQvatf1RUY
✓ Already exists: hu6sChY-Yps
✓ Already exists: i6WtNBpRll0
✓ Already exists: i6k1yiyO5jQ
✓ Already exists: iBH5X5SKirU
✓ Already exists: iBezxlI_f_c
✓ Already exists: iCIa_pmLDqs
✓ Already exists: iEMTTKA7NxU


ERROR: [youtube] iEQwupwwp0s: The uploader has not made this video available in your country
This video is available in United Arab Emirates, Afghanistan, Albania, Armenia, Angola, Antarctica, Argentina, American Samoa, Australia, Aruba, Åland Islands, Azerbaijan, Bosnia and Herzegovina, Bangladesh, Burkina Faso, Bulgaria, Bahrain, Benin, Bermuda, Brunei Darussalam, Bolivia, Plurinational State of, Bonaire, Sint Eustatius and Saba, Brazil, Bhutan, Bouvet Island, Botswana, Belarus, Belize, Cocos (Keeling) Islands, Congo, the Democratic Republic of the, Central African Republic, Congo, Côte d'Ivoire, Cook Islands, Chile, Cameroon, China, Colombia, Costa Rica, Cape Verde, Curaçao, Christmas Island, Cyprus, Czech Republic, Denmark, Ecuador, Estonia, Western Sahara, Spain, Finland, Fiji, Falkland Islands (Malvinas), Micronesia, Federated States of, Faroe Islands, Gabon, Georgia, French Guiana, Guernsey, Ghana, Gibraltar, Greenland, Gambia, Guinea, Guadeloupe, Equatorial Guinea, Greece, Sout

Error downloading iEQwupwwp0s: ERROR: [youtube] iEQwupwwp0s: The uploader has not made this video available in your country
This video is available in United Arab Emirates, Afghanistan, Albania, Armenia, Angola, Antarctica, Argentina, American Samoa, Australia, Aruba, Åland Islands, Azerbaijan, Bosnia and Herzegovina, Bangladesh, Burkina Faso, Bulgaria, Bahrain, Benin, Bermuda, Brunei Darussalam, Bolivia, Plurinational State of, Bonaire, Sint Eustatius and Saba, Brazil, Bhutan, Bouvet Island, Botswana, Belarus, Belize, Cocos (Keeling) Islands, Congo, the Democratic Republic of the, Central African Republic, Congo, Côte d'Ivoire, Cook Islands, Chile, Cameroon, China, Colombia, Costa Rica, Cape Verde, Curaçao, Christmas Island, Cyprus, Czech Republic, Denmark, Ecuador, Estonia, Western Sahara, Spain, Finland, Fiji, Falkland Islands (Malvinas), Micronesia, Federated States of, Faroe Islands, Gabon, Georgia, French Guiana, Guernsey, Ghana, Gibraltar, Greenland, Gambia, Guinea, Guadeloupe, 

ERROR: [youtube] iXgEQj1Fs7g: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.


Error downloading iXgEQj1Fs7g: ERROR: [youtube] iXgEQj1Fs7g: Video unavailable. This video is no longer available because the YouTube account associated with this video has been terminated.
✗ Failed: iXgEQj1Fs7g
✓ Already exists: iZjIuV_cTe8

=== Download Summary ===
Successfully downloaded: 477
Failed: 23
Successfully downloaded: 477
Failed: 23

Creating archetype training data...

=== Archetype Distribution Statistics ===

SINE:
  Mean weight: 0.431
  Std dev: 0.392
  Max weight: 1.000
  % samples with weight > 0.3: 48.6%

SQUARE:
  Mean weight: 0.138
  Std dev: 0.253
  Max weight: 1.000
  % samples with weight > 0.3: 15.1%

SAWTOOTH:
  Mean weight: 0.198
  Std dev: 0.311
  Max weight: 1.000
  % samples with weight > 0.3: 21.6%

TRIANGLE:
  Mean weight: 0.144
  Std dev: 0.248
  Max weight: 1.000
  % samples with weight > 0.3: 15.1%

NOISE:
  Mean weight: 0.088
  Std dev: 0.187
  Max weight: 1.000
  % samples with weight > 0.3: 8.8%

=== Examples with Dominant Archetypes ===

SINE dom

In [34]:
# Step 2: Create datasets
print("\nStep 2: Creating PyTorch datasets...")
train_dataset = MusicCapsDataset(
    train_data_path,
    sample_rate=config['sample_rate'],
    audio_duration=config['audio_duration'],
    augment=True
)

val_dataset = MusicCapsDataset(
    val_data_path,
    sample_rate=config['sample_rate'],
    audio_duration=config['audio_duration'],
    augment=False
)

test_dataset = MusicCapsDataset(
    test_data_path,
    sample_rate=config['sample_rate'],
    audio_duration=config['audio_duration'],
    augment=False
)


Step 2: Creating PyTorch datasets...
Loading training data from musiccaps_training_data_train.npz
Loaded 333 training examples
Found 333 samples with valid audio files
Loading training data from musiccaps_training_data_val.npz
Loaded 71 training examples
Found 71 samples with valid audio files
Loading training data from musiccaps_training_data_test.npz
Loaded 73 training examples
Found 73 samples with valid audio files


In [22]:
# Cell 1: Reload the modules to pick up changes
import importlib
import text_tower
import lstmabar_model

importlib.reload(text_tower)
importlib.reload(lstmabar_model)

# Re-import the classes
from text_tower import TextEncoder
from lstmabar_model import LSTMABAR, LSTMABARTrainer

In [35]:
# Step 3: Initialize model
print("\nStep 3: Initializing LSTMABAR model...")
model = LSTMABAR(
    embedding_dim=config['embedding_dim'],
    audio_architecture=config['audio_architecture'],
    sample_rate=config['sample_rate'],
    use_quantum_attention=False,
    device=device
)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")


Step 3: Initializing LSTMABAR model...
Loading text encoder: sentence-transformers/all-MiniLM-L6-v2
Model has 37,203,724 parameters


In [36]:
# Step 4: Create training pipeline
print("\nStep 4: Setting up training pipeline...")
pipeline = TrainingPipeline(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=config['batch_size'],
    learning_rate=config['learning_rate'],
    num_epochs=config['num_epochs'],
    checkpoint_dir='checkpoints'
)


Step 4: Setting up training pipeline...
Training pipeline initialized:
  Training samples: 333
  Validation samples: 71
  Batch size: 16
  Total epochs: 20
  Steps per epoch: 21


In [37]:
# Step 5: Train!
print("\nStep 5: Starting training...")
pipeline.train()

print("\n✓ Training complete! Model checkpoints saved to checkpoints/")


Step 5: Starting training...

Epoch 1/20




Epoch 0, Batch 0/21: Loss=1.9718, Contrastive=2.7729, Archetype=0.1024
Epoch 0, Batch 10/21: Loss=1.9569, Contrastive=2.7576, Archetype=0.0877
Epoch 0, Batch 20/21: Loss=1.8003, Contrastive=2.5298, Archetype=0.0973

Train Losses: {'total': np.float64(1.9567998023260207), 'contrastive': np.float64(2.751888933635893), 'archetype': np.float64(0.10551349180085319)}
Val Losses: {'total': 1.8490639925003052, 'contrastive': 2.6003254413604737, 'archetype': 0.09782975912094116}
Checkpoint saved to checkpoints/best_model.pth
✓ Best model saved (val_loss: 1.8491)

Epoch 2/20
Epoch 1, Batch 0/21: Loss=1.8529, Contrastive=2.6029, Archetype=0.1075
Epoch 1, Batch 10/21: Loss=1.8889, Contrastive=2.6589, Archetype=0.0963
Epoch 1, Batch 20/21: Loss=1.7524, Contrastive=2.4618, Archetype=0.0992

Train Losses: {'total': np.float64(1.8320271628243583), 'contrastive': np.float64(2.5745021729242232), 'archetype': np.float64(0.10390859984216236)}
Val Losses: {'total': 1.7285382747650146, 'contrastive': 2.4283