In [2]:
# Real-Time Streaming Text-to-Speech System - PART 1
# Foundation, Data Loading, and Text Processing
# Built from scratch for LibriSpeech train-clean-100

# ==============================================================================
# CELL 1: Installation and Imports
# ==============================================================================

"""
Installation commands (run if needed):
!pip install torch torchaudio librosa soundfile numpy matplotlib jupyter ipython tqdm
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import librosa
import re
import json
import os
from pathlib import Path
import threading
import queue
import time
from typing import List, Tuple, Optional, Generator
import soundfile as sf
from collections import defaultdict
import math
import matplotlib.pyplot as plt
import IPython.display as ipd
from IPython.display import Audio, display, HTML
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm

print("=" * 60)
print("🎉 REAL-TIME STREAMING TTS SYSTEM - PART 1")
print("   Foundation, Data Loading, and Text Processing")
print("=" * 60)
print("✅ All imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ==============================================================================
# CELL 2: System Configuration
# ==============================================================================

class TTSConfig:
    """Complete configuration for LibriSpeech TTS system"""

    # === DATASET CONFIGURATION ===
    LIBRISPEECH_PATH = "/content/drive/MyDrive/NAC/train-clean-100"  # UPDATE THIS PATH!
    MAX_SAMPLES = None  # None for full dataset, number for subset (e.g., 5000 for testing)
    MAX_AUDIO_LENGTH = 15.0  # seconds
    MIN_AUDIO_LENGTH = 1.0   # seconds
    TRAIN_RATIO = 0.95  # 95% train, 5% validation

    # === AUDIO PARAMETERS ===
    SAMPLE_RATE = 16000  # LibriSpeech native sample rate
    HOP_LENGTH = 200     # Optimized for 16kHz
    N_MELS = 80
    N_FFT = 1024
    WIN_LENGTH = 800
    F_MIN = 0
    F_MAX = 8000  # Nyquist frequency

    # === MODEL ARCHITECTURE ===
    VOCAB_SIZE = 2000  # Will be updated after vocabulary building
    HIDDEN_DIM = 512
    NUM_HEADS = 8
    NUM_LAYERS = 6
    MEL_DIM = 80
    MAX_SEQ_LEN = 500  # Maximum text sequence length

    # === TRAINING PARAMETERS ===
    BATCH_SIZE = 8  # Adjust based on GPU memory
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 100
    WARMUP_STEPS = 4000
    GRAD_CLIP_NORM = 1.0

    # === STREAMING PARAMETERS ===
    BUFFER_SIZE = 4      # Number of words to buffer before processing
    OVERLAP_FRAMES = 2   # Overlap between chunks for smooth streaming

    # === DEVICE CONFIGURATION ===
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    USE_MIXED_PRECISION = torch.cuda.is_available()

    # === LOSS WEIGHTS ===
    MEL_LOSS_WEIGHT = 1.0
    DURATION_LOSS_WEIGHT = 0.1
    ALIGNMENT_LOSS_WEIGHT = 0.05

    @classmethod
    def print_config(cls):
        """Print current configuration"""
        print("🔧 TTS System Configuration:")
        print(f"   LibriSpeech Path: {cls.LIBRISPEECH_PATH}")
        print(f"   Sample Rate: {cls.SAMPLE_RATE}Hz")
        print(f"   Batch Size: {cls.BATCH_SIZE}")
        print(f"   Hidden Dim: {cls.HIDDEN_DIM}")
        print(f"   Device: {cls.DEVICE}")
        print(f"   Mixed Precision: {cls.USE_MIXED_PRECISION}")
        print(f"   Buffer Size: {cls.BUFFER_SIZE} words")

config = TTSConfig()
config.print_config()

# Verify LibriSpeech path
librispeech_path = Path(config.LIBRISPEECH_PATH)
if librispeech_path.exists():
    print(f"✅ LibriSpeech found at: {librispeech_path}")
    subdirs = [d.name for d in librispeech_path.iterdir() if d.is_dir()][:5]
    print(f"   Sample directories: {subdirs}")
else:
    print(f"❌ LibriSpeech not found at: {librispeech_path}")
    print("   Please update config.LIBRISPEECH_PATH to your LibriSpeech directory")
    print("   You can download it from: https://www.openslr.org/12/")

# ==============================================================================
# CELL 3: LibriSpeech Data Discovery and Loading
# ==============================================================================

class LibriSpeechDatasetLoader:
    """Comprehensive LibriSpeech dataset loader with robust error handling"""

    def __init__(self, root_path: str):
        self.root_path = Path(root_path)
        self.audio_files = []
        self.transcripts = {}
        self.speakers = set()
        self.chapters = set()

    def discover_files(self):
        """Discover all LibriSpeech files with progress tracking"""
        print(f"🔍 Discovering LibriSpeech files in {self.root_path}")

        if not self.root_path.exists():
            raise FileNotFoundError(f"LibriSpeech path not found: {self.root_path}")

        # Find all relevant files
        print("   Scanning for audio and transcript files...")
        flac_files = list(self.root_path.rglob("*.flac"))
        trans_files = list(self.root_path.rglob("*.trans.txt"))

        print(f"   Found {len(flac_files)} audio files (.flac)")
        print(f"   Found {len(trans_files)} transcript files (.trans.txt)")

        if len(flac_files) == 0:
            print("❌ No .flac files found! Check your LibriSpeech path.")
            return

        if len(trans_files) == 0:
            print("❌ No .trans.txt files found! Check your LibriSpeech path.")
            return

        # Load transcripts
        self._load_transcripts(trans_files)

        # Match audio files with transcripts
        self._match_audio_transcripts(flac_files)

        # Generate statistics
        self._generate_statistics()

        print(f"✅ Successfully processed {len(self.audio_files)} audio-transcript pairs")

    def _load_transcripts(self, trans_files):
        """Load all transcript files with error handling"""
        print("📖 Loading transcript files...")

        total_transcripts = 0
        failed_files = 0

        for trans_file in tqdm(trans_files, desc="Loading transcripts"):
            try:
                with open(trans_file, 'r', encoding='utf-8') as f:
                    for line_num, line in enumerate(f, 1):
                        line = line.strip()
                        if line:
                            parts = line.split(' ', 1)
                            if len(parts) == 2:
                                file_id, transcript = parts
                                # LibriSpeech uses uppercase text
                                self.transcripts[file_id] = transcript.upper()
                                total_transcripts += 1
                            else:
                                print(f"⚠️ Malformed line {line_num} in {trans_file}")

            except Exception as e:
                failed_files += 1
                print(f"❌ Error reading {trans_file}: {e}")

        print(f"📝 Loaded {total_transcripts} transcripts from {len(trans_files)} files")
        if failed_files > 0:
            print(f"⚠️ Failed to read {failed_files} transcript files")

    def _match_audio_transcripts(self, flac_files):
        """Match audio files with their corresponding transcripts"""
        print("🔗 Matching audio files with transcripts...")

        matched = 0
        missing_transcripts = 0

        for flac_file in tqdm(flac_files, desc="Matching files"):
            file_id = flac_file.stem  # Filename without extension

            if file_id in self.transcripts:
                # Extract speaker and chapter information
                id_parts = file_id.split('-')
                if len(id_parts) >= 2:
                    speaker_id = id_parts[0]
                    chapter_id = id_parts[1]

                    self.audio_files.append({
                        'audio_path': str(flac_file),
                        'transcript': self.transcripts[file_id],
                        'file_id': file_id,
                        'speaker_id': speaker_id,
                        'chapter_id': chapter_id
                    })

                    self.speakers.add(speaker_id)
                    self.chapters.add(chapter_id)
                    matched += 1
                else:
                    print(f"⚠️ Unexpected file ID format: {file_id}")
            else:
                missing_transcripts += 1

        print(f"✅ Matched {matched} audio files with transcripts")
        if missing_transcripts > 0:
            print(f"⚠️ {missing_transcripts} audio files had no matching transcripts")

    def _generate_statistics(self):
        """Generate dataset statistics"""
        if not self.audio_files:
            return

        print(f"\n📊 Dataset Statistics:")
        print(f"   Total samples: {len(self.audio_files)}")
        print(f"   Unique speakers: {len(self.speakers)}")
        print(f"   Unique chapters: {len(self.chapters)}")

        # Sample transcript lengths
        transcript_lengths = [len(sample['transcript']) for sample in self.audio_files[:1000]]
        if transcript_lengths:
            avg_length = np.mean(transcript_lengths)
            print(f"   Avg transcript length: {avg_length:.1f} characters")

    def get_samples(self, max_samples=None):
        """Get processed samples with optional limit"""
        samples = self.audio_files[:max_samples] if max_samples else self.audio_files

        if max_samples and max_samples < len(self.audio_files):
            print(f"📋 Using subset: {len(samples)} out of {len(self.audio_files)} total samples")

        return samples

    def get_speaker_info(self):
        """Get speaker and chapter information"""
        return {
            'speakers': sorted(list(self.speakers)),
            'chapters': sorted(list(self.chapters)),
            'total_speakers': len(self.speakers),
            'total_chapters': len(self.chapters)
        }

# Load LibriSpeech dataset
librispeech_samples = []
dataset_info = {}

if librispeech_path.exists():
    print("\n🚀 Loading LibriSpeech dataset...")
    try:
        loader = LibriSpeechDatasetLoader(config.LIBRISPEECH_PATH)
        loader.discover_files()

        # Get samples (use subset for testing if specified)
        librispeech_samples = loader.get_samples(max_samples=config.MAX_SAMPLES)
        dataset_info = loader.get_speaker_info()

        # Show sample transcripts
        print(f"\n📝 Sample transcripts:")
        for i, sample in enumerate(librispeech_samples[:5]):
            transcript_preview = sample['transcript'][:80] + "..." if len(sample['transcript']) > 80 else sample['transcript']
            print(f"   {i+1}. {sample['file_id']}: '{transcript_preview}'")
            print(f"      Speaker: {sample['speaker_id']}, Chapter: {sample['chapter_id']}")

        print(f"\n✅ LibriSpeech loaded: {len(librispeech_samples)} samples ready")

    except Exception as e:
        print(f"❌ Failed to load LibriSpeech: {e}")
        librispeech_samples = []
else:
    print("❌ Cannot load LibriSpeech - path not found")
    print("   Please download LibriSpeech train-clean-100 and update the path in config")

# ==============================================================================
# CELL 4: Advanced Text Processing System
# ==============================================================================

class LibriSpeechTextProcessor:
    """Advanced text processor optimized for LibriSpeech with multiple tokenization strategies"""

    def __init__(self, use_char_level=True):
        # Special tokens for sequence modeling
        self.special_tokens = {
            '<PAD>': 0,    # Padding token
            '<UNK>': 1,    # Unknown token
            '<SOS>': 2,    # Start of sequence
            '<EOS>': 3,    # End of sequence
            '<SPACE>': 4   # Space token for char-level
        }

        self.vocab = self.special_tokens.copy()
        self.vocab_size = len(self.special_tokens)
        self.id_to_token = {v: k for k, v in self.vocab.items()}

        # Tokenization strategy
        self.use_char_level = use_char_level

        # Statistics
        self.vocab_stats = {}

    def build_vocabulary(self, samples: List[dict], min_freq: int = 5, max_vocab_size: int = 5000):
        """Build vocabulary from LibriSpeech samples with frequency filtering"""
        print(f"📝 Building vocabulary from {len(samples)} samples...")
        print(f"   Tokenization: {'Character-level' if self.use_char_level else 'Word-level'}")
        print(f"   Min frequency: {min_freq}")
        print(f"   Max vocab size: {max_vocab_size}")

        if self.use_char_level:
            self._build_char_vocabulary(samples, max_vocab_size)
        else:
            self._build_word_vocabulary(samples, min_freq, max_vocab_size)

        # Update config with actual vocab size
        config.VOCAB_SIZE = self.vocab_size

        print(f"✅ Vocabulary built: {self.vocab_size} tokens")
        self._print_vocab_stats()

    def _build_char_vocabulary(self, samples, max_vocab_size):
        """Build character-level vocabulary with frequency analysis"""
        print("🔤 Building character-level vocabulary...")

        char_freq = defaultdict(int)
        total_chars = 0

        # Count character frequencies
        for sample in tqdm(samples, desc="Analyzing characters"):
            text = sample['transcript']
            for char in text:
                char_freq[char] += 1
                total_chars += 1

        # Sort by frequency and add to vocabulary
        sorted_chars = sorted(char_freq.items(), key=lambda x: x[1], reverse=True)
        added_chars = 0

        for char, freq in sorted_chars:
            if char.isprintable() and char not in self.vocab:
                if self.vocab_size < max_vocab_size:
                    self.vocab[char] = self.vocab_size
                    self.vocab_size += 1
                    added_chars += 1
                else:
                    break

        # Update mappings
        self.id_to_token = {v: k for k, v in self.vocab.items()}

        # Store statistics
        self.vocab_stats = {
            'total_chars_seen': len(char_freq),
            'total_char_instances': total_chars,
            'chars_added': added_chars,
            'coverage': added_chars / len(char_freq) if char_freq else 0
        }

        print(f"📊 Character analysis:")
        print(f"   Unique characters found: {len(char_freq)}")
        print(f"   Characters added to vocab: {added_chars}")
        print(f"   Coverage: {self.vocab_stats['coverage']:.1%}")

    def _build_word_vocabulary(self, samples, min_freq, max_vocab_size):
        """Build word-level vocabulary with frequency filtering"""
        print(f"📚 Building word-level vocabulary...")

        word_freq = defaultdict(int)
        total_words = 0

        # Count word frequencies
        for sample in tqdm(samples, desc="Analyzing words"):
            words = self._tokenize_text(sample['transcript'])
            for word in words:
                word_freq[word] += 1
                total_words += 1

        # Sort by frequency and add to vocabulary
        sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
        added_words = 0

        for word, freq in sorted_words:
            if freq >= min_freq and word not in self.vocab:
                if self.vocab_size < max_vocab_size:
                    self.vocab[word] = self.vocab_size
                    self.vocab_size += 1
                    added_words += 1
                else:
                    break

        # Update mappings
        self.id_to_token = {v: k for k, v in self.vocab.items()}

        # Store statistics
        words_above_threshold = sum(1 for w, f in word_freq.items() if f >= min_freq)
        self.vocab_stats = {
            'total_words_seen': len(word_freq),
            'total_word_instances': total_words,
            'words_above_threshold': words_above_threshold,
            'words_added': added_words,
            'coverage': added_words / words_above_threshold if words_above_threshold else 0
        }

        print(f"📊 Word analysis:")
        print(f"   Unique words found: {len(word_freq)}")
        print(f"   Words above threshold ({min_freq}): {words_above_threshold}")
        print(f"   Words added to vocab: {added_words}")
        print(f"   Coverage: {self.vocab_stats['coverage']:.1%}")

    def _tokenize_text(self, text: str) -> List[str]:
        """Tokenize text into words with LibriSpeech-specific preprocessing"""
        # LibriSpeech preprocessing: keep letters, numbers, apostrophes, hyphens
        text = re.sub(r'[^\w\s\'\-]', '', text)
        text = re.sub(r'\s+', ' ', text.strip())
        return text.split()

    def text_to_ids(self, text: str) -> List[int]:
        """Convert text to token IDs with proper sequence markers"""
        if self.use_char_level:
            # Character-level encoding
            ids = []
            for char in text:
                token_id = self.vocab.get(char, self.vocab['<UNK>'])
                ids.append(token_id)
        else:
            # Word-level encoding
            words = self._tokenize_text(text)
            ids = []
            for word in words:
                token_id = self.vocab.get(word, self.vocab['<UNK>'])
                ids.append(token_id)

        # Add sequence markers
        return [self.vocab['<SOS>']] + ids + [self.vocab['<EOS>']]

    def ids_to_text(self, ids: List[int]) -> str:
        """Convert token IDs back to readable text"""
        tokens = []
        for id_val in ids:
            token = self.id_to_token.get(id_val, '<UNK>')
            # Skip special tokens for display
            if token not in ['<SOS>', '<EOS>', '<PAD>']:
                tokens.append(token)

        if self.use_char_level:
            return ''.join(tokens)
        else:
            return ' '.join(tokens)

    def _print_vocab_stats(self):
        """Print vocabulary statistics"""
        print(f"\n📈 Vocabulary Statistics:")
        print(f"   Final vocabulary size: {self.vocab_size}")
        print(f"   Special tokens: {len(self.special_tokens)}")
        print(f"   Content tokens: {self.vocab_size - len(self.special_tokens)}")

        if self.vocab_stats:
            for key, value in self.vocab_stats.items():
                if isinstance(value, float):
                    print(f"   {key.replace('_', ' ').title()}: {value:.3f}")
                else:
                    print(f"   {key.replace('_', ' ').title()}: {value:,}")

    def get_vocab_sample(self, n=20):
        """Get a sample of vocabulary items for inspection"""
        vocab_items = list(self.vocab.items())
        special_count = len(self.special_tokens)

        print(f"\n🔍 Vocabulary Sample:")
        print("Special tokens:")
        for token, id_val in vocab_items[:special_count]:
            print(f"   {id_val:3d}: '{token}'")

        print("Content tokens (sample):")
        sample_items = vocab_items[special_count:special_count + n]
        for token, id_val in sample_items:
            display_token = repr(token) if self.use_char_level else token
            print(f"   {id_val:3d}: {display_token}")

        if len(vocab_items) > special_count + n:
            print(f"   ... and {len(vocab_items) - special_count - n} more")

    def save_vocabulary(self, filepath: str):
        """Save vocabulary and configuration to file"""
        vocab_data = {
            'vocab': self.vocab,
            'vocab_size': self.vocab_size,
            'id_to_token': self.id_to_token,
            'use_char_level': self.use_char_level,
            'special_tokens': self.special_tokens,
            'vocab_stats': self.vocab_stats,
            'config_snapshot': {
                'sample_rate': config.SAMPLE_RATE,
                'max_seq_len': config.MAX_SEQ_LEN,
                'mel_dim': config.MEL_DIM
            }
        }

        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(vocab_data, f, indent=2, ensure_ascii=False)

        print(f"💾 Vocabulary saved to {filepath}")
        print(f"   Size: {os.path.getsize(filepath) / 1024:.1f} KB")

    def load_vocabulary(self, filepath: str):
        """Load vocabulary from file"""
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                vocab_data = json.load(f)

            self.vocab = vocab_data['vocab']
            self.vocab_size = vocab_data['vocab_size']
            self.id_to_token = {int(k): v for k, v in vocab_data['id_to_token'].items()}
            self.use_char_level = vocab_data.get('use_char_level', True)
            self.special_tokens = vocab_data.get('special_tokens', self.special_tokens)
            self.vocab_stats = vocab_data.get('vocab_stats', {})

            # Update config
            config.VOCAB_SIZE = self.vocab_size

            print(f"📖 Vocabulary loaded from {filepath}")
            print(f"   Size: {self.vocab_size}")
            print(f"   Type: {'Character-level' if self.use_char_level else 'Word-level'}")
            return True

        except Exception as e:
            print(f"❌ Failed to load vocabulary: {e}")
            return False

# Build text processor if LibriSpeech is available
text_processor = None

if librispeech_samples:
    print("\n🏗️ Building text processor for LibriSpeech...")

    # Choose tokenization strategy
    use_char_level = True  # Set to False for word-level tokenization

    text_processor = LibriSpeechTextProcessor(use_char_level=use_char_level)
    text_processor.build_vocabulary(
        librispeech_samples,
        min_freq=5 if not use_char_level else 1,
        max_vocab_size=5000
    )

    # Test the text processor
    if librispeech_samples:
        test_text = librispeech_samples[0]['transcript'][:100]
        test_ids = text_processor.text_to_ids(test_text)
        reconstructed = text_processor.ids_to_text(test_ids)

        print(f"\n🧪 Text Processor Test:")
        print(f"   Original: '{test_text}'")
        print(f"   Token IDs: {test_ids[:15]}... (showing first 15)")
        print(f"   Reconstructed: '{reconstructed}'")
        print(f"   Tokens match: {test_text.upper() == reconstructed.upper()}")

    # Show vocabulary sample
    text_processor.get_vocab_sample(n=15)

    # Save vocabulary
    text_processor.save_vocabulary("librispeech_vocab.json")

else:
    print("❌ Cannot build text processor - no LibriSpeech samples available")

# ==============================================================================
# CELL 5: Utility Functions for Audio and Visualization
# ==============================================================================

def plot_mel_spectrogram(mel_spec, title="Mel-Spectrogram", figsize=(12, 4), save_path=None):
    """Plot mel-spectrogram with proper formatting"""
    plt.figure(figsize=figsize)

    # Handle different input formats
    if len(mel_spec.shape) == 2:
        # [time, mel_dim] format
        plt.imshow(mel_spec.T, aspect='auto', origin='lower', cmap='viridis')
    else:
        # [mel_dim, time] format
        plt.imshow(mel_spec, aspect='auto', origin='lower', cmap='viridis')

    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.ylabel('Mel Frequency Bin')
    plt.xlabel('Time Frame')
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"📊 Plot saved to {save_path}")

    plt.show()

def play_audio(audio, sample_rate=16000, title="Audio", autoplay=False):
    """Play audio in notebook with enhanced interface"""
    print(f"🔊 {title}")
    print(f"   Duration: {len(audio)/sample_rate:.2f}s, Samples: {len(audio):,}")

    # Normalize audio to prevent clipping
    if np.max(np.abs(audio)) > 0:
        audio_normalized = audio / np.max(np.abs(audio)) * 0.9
    else:
        audio_normalized = audio

    display(Audio(audio_normalized, rate=sample_rate, autoplay=autoplay))

def plot_audio_waveform(audio, sample_rate=16000, title="Audio Waveform", figsize=(15, 4)):
    """Plot audio waveform with time axis"""
    plt.figure(figsize=figsize)
    time_axis = np.linspace(0, len(audio)/sample_rate, len(audio))
    plt.plot(time_axis, audio)
    plt.title(title)
    plt.xlabel("Time (seconds)")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def show_training_progress(losses, val_losses=None, title="Training Progress", figsize=(12, 5)):
    """Enhanced training progress visualization"""
    fig, axes = plt.subplots(1, 2 if val_losses else 1, figsize=figsize)

    if val_losses:
        # Loss plot
        axes[0].plot(losses, label='Training Loss', linewidth=2)
        axes[0].plot(val_losses, label='Validation Loss', linewidth=2)
        axes[0].set_title('Loss Progress')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

        # Learning curve
        if len(losses) > 5:
            recent_train = np.mean(losses[-5:])
            recent_val = np.mean(val_losses[-5:]) if val_losses else 0
            overfitting_indicator = recent_val / recent_train if recent_train > 0 else 1

            axes[1].text(0.1, 0.8, f"Recent Train Loss: {recent_train:.4f}", transform=axes[1].transAxes)
            axes[1].text(0.1, 0.7, f"Recent Val Loss: {recent_val:.4f}", transform=axes[1].transAxes)
            axes[1].text(0.1, 0.6, f"Overfitting Ratio: {overfitting_indicator:.3f}", transform=axes[1].transAxes)

            if overfitting_indicator > 1.2:
                axes[1].text(0.1, 0.5, "⚠️ Possible overfitting", transform=axes[1].transAxes, color='red')
            else:
                axes[1].text(0.1, 0.5, "✅ Training looks good", transform=axes[1].transAxes, color='green')

        axes[1].set_title('Training Status')
        axes[1].axis('off')
    else:
        # Single loss plot
        if isinstance(axes, np.ndarray):
            ax = axes[0]
        else:
            ax = axes
        ax.plot(losses, label='Training Loss', linewidth=2)
        ax.set_title(title)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

def demo_dataset_samples(dataset, num_samples=3):
    """Comprehensive demonstration of LibriSpeech dataset samples"""
    if dataset is None:
        print("❌ No dataset available for demo")
        return

    print(f"🎵 LibriSpeech Dataset Sample Demo")
    print("=" * 60)
    print(f"Showing {min(num_samples, len(dataset))} samples from dataset")

    for i in range(min(num_samples, len(dataset))):
        print(f"\n📝 Sample {i+1}:")
        print("-" * 30)

        try:
            # Get processed dataset item
            text_tensor, mel_tensor, duration_tensor, text_len = dataset[i]

            # Get original sample info
            info = dataset.get_sample_info(i)

            print(f"File ID: {info['file_id']}")
            print(f"Speaker: {info['speaker_id']}")
            print(f"Chapter: {info['chapter_id']}")
            print(f"Duration: {info['duration']:.2f}s")
            print(f"Text: '{info['transcript']}'")
            print(f"Text length: {len(info['transcript'])} chars")
            print(f"Tokenized length: {text_len} tokens")
            print(f"Mel shape: {mel_tensor.shape}")
            print(f"Duration tensor shape: {duration_tensor.shape}")

            # Load and play original audio
            try:
                waveform, sr = torchaudio.load(info['audio_path'])
                waveform = waveform.squeeze(0).numpy()

                # Resample if necessary
                if sr != config.SAMPLE_RATE:
                    waveform = librosa.resample(
                        waveform,
                        orig_sr=sr,
                        target_sr=config.SAMPLE_RATE
                    )

                print(f"Audio samples: {len(waveform):,}")
                play_audio(waveform, config.SAMPLE_RATE, f"Sample {i+1} - Original LibriSpeech Audio")

                # Show mel-spectrogram for first sample only
                if i == 0:
                    print(f"\n📊 Mel-spectrogram visualization:")
                    plot_mel_spectrogram(
                        mel_tensor.numpy(),
                        f"Sample {i+1} - Mel-Spectrogram ({mel_tensor.shape[0]} frames)"
                    )

                # Show waveform for first sample
                if i == 0:
                    plot_audio_waveform(
                        waveform[:config.SAMPLE_RATE*3],  # First 3 seconds
                        config.SAMPLE_RATE,
                        f"Sample {i+1} - Audio Waveform (first 3s)"
                    )

            except Exception as audio_error:
                print(f"❌ Could not load/process audio: {audio_error}")

        except Exception as e:
            print(f"❌ Error processing sample {i+1}: {e}")

    print(f"\n✅ Dataset demo completed")

def test_text_processor(text_processor, test_texts=None):
    """Comprehensive test of text processor functionality"""
    if text_processor is None:
        print("❌ No text processor available for testing")
        return

    print(f"🧪 Text Processor Testing")
    print("=" * 40)

    # Default test texts if none provided
    if test_texts is None:
        test_texts = [
            "HELLO WORLD",
            "THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG",
            "TESTING ONE TWO THREE",
            "MACHINE LEARNING IS AMAZING"
        ]

    for i, text in enumerate(test_texts, 1):
        print(f"\nTest {i}: '{text}'")
        print("-" * 20)

        try:
            # Convert to IDs and back
            token_ids = text_processor.text_to_ids(text)
            reconstructed = text_processor.ids_to_text(token_ids)

            print(f"Original: '{text}'")
            print(f"Token IDs: {token_ids}")
            print(f"ID count: {len(token_ids)}")
            print(f"Reconstructed: '{reconstructed}'")

            # Check reconstruction quality
            if text.upper() == reconstructed.upper():
                print("✅ Perfect reconstruction")
            else:
                print("⚠️ Reconstruction differs from original")

        except Exception as e:
            print(f"❌ Error processing text: {e}")

    # Vocabulary statistics
    print(f"\n📊 Vocabulary Info:")
    print(f"Size: {text_processor.vocab_size}")
    print(f"Type: {'Character-level' if text_processor.use_char_level else 'Word-level'}")

    # Show some vocabulary samples
    text_processor.get_vocab_sample(n=10)

def validate_system_setup():
    """Comprehensive system validation"""
    print(f"🔍 System Setup Validation")
    print("=" * 50)

    checks = []

    # LibriSpeech data check
    if librispeech_samples and len(librispeech_samples) > 0:
        checks.append(("LibriSpeech Data", True, f"{len(librispeech_samples)} samples loaded"))
    else:
        checks.append(("LibriSpeech Data", False, "No samples loaded - check path"))

    # Text processor check
    if text_processor and text_processor.vocab_size > len(text_processor.special_tokens):
        checks.append(("Text Processor", True, f"Vocabulary size: {text_processor.vocab_size}"))
    else:
        checks.append(("Text Processor", False, "Not properly initialized"))

    # CUDA availability
    if torch.cuda.is_available():
        memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        checks.append(("CUDA Support", True, f"GPU available: {memory_gb:.1f}GB"))
    else:
        checks.append(("CUDA Support", False, "GPU not available - will use CPU"))

    # Audio libraries
    try:
        import librosa
        import soundfile
        checks.append(("Audio Libraries", True, "librosa and soundfile available"))
    except ImportError as e:
        checks.append(("Audio Libraries", False, f"Missing libraries: {e}"))

    # Display results
    print("System Component Status:")
    for component, status, details in checks:
        status_icon = "✅" if status else "❌"
        print(f"   {status_icon} {component}: {details}")

    # Overall status
    all_critical_ok = all(check[1] for check in checks[:2])  # LibriSpeech and text processor are critical
    if all_critical_ok:
        print(f"\n🎉 System ready for TTS training and inference!")
    else:
        print(f"\n⚠️ System has issues that need to be resolved before proceeding")

    return all_critical_ok

def get_system_info():
    """Get comprehensive system information"""
    info = {
        'pytorch_version': torch.__version__,
        'cuda_available': torch.cuda.is_available(),
        'device': str(config.DEVICE),
        'librispeech_samples': len(librispeech_samples) if librispeech_samples else 0,
        'vocab_size': text_processor.vocab_size if text_processor else 0,
        'tokenization_type': 'char' if text_processor and text_processor.use_char_level else 'word',
        'sample_rate': config.SAMPLE_RATE,
        'batch_size': config.BATCH_SIZE,
        'hidden_dim': config.HIDDEN_DIM
    }

    if torch.cuda.is_available():
        info['gpu_name'] = torch.cuda.get_device_name()
        info['gpu_memory'] = torch.cuda.get_device_properties(0).total_memory / 1e9

    return info

# ==============================================================================
# CELL 6: Audio Processing and Dataset Implementation
# ==============================================================================

class LibriSpeechTTSDataset(Dataset):
    """Production-ready LibriSpeech dataset for TTS training with comprehensive preprocessing"""

    def __init__(self, samples: List[dict], text_processor,
                 max_audio_length: float = 15.0, min_audio_length: float = 1.0):
        self.samples = samples
        self.text_processor = text_processor
        self.max_audio_length = max_audio_length
        self.min_audio_length = min_audio_length

        # Audio processing parameters
        self.sample_rate = config.SAMPLE_RATE
        self.hop_length = config.HOP_LENGTH
        self.n_mels = config.MEL_DIM
        self.n_fft = config.N_FFT
        self.win_length = config.WIN_LENGTH
        self.f_min = config.F_MIN
        self.f_max = config.F_MAX
        self.max_seq_len = config.MAX_SEQ_LEN

        # Filter and process samples
        self.filtered_samples = self._filter_samples()

        # Statistics
        self.stats = self._compute_dataset_stats()

        print(f"📊 Dataset initialized:")
        print(f"   Total samples: {len(self.filtered_samples)}")
        print(f"   Filtered from: {len(samples)} original samples")
        print(f"   Filter rate: {len(self.filtered_samples)/len(samples)*100:.1f}%")
        self._print_stats()

    def _filter_samples(self):
        """Filter samples by audio duration and text length with detailed logging"""
        print("🔍 Filtering samples by duration and text length...")

        filtered = []
        failed_loads = 0
        duration_filtered = 0
        text_filtered = 0

        for sample in tqdm(self.samples, desc="Filtering samples"):
            try:
                # Quick duration check without loading full audio
                audio_info = torchaudio.info(sample['audio_path'])
                duration = audio_info.num_frames / audio_info.sample_rate

                # Duration filter
                if not (self.min_audio_length <= duration <= self.max_audio_length):
                    duration_filtered += 1
                    continue

                # Text length filter
                text_len = len(sample['transcript'])
                if text_len > self.max_seq_len * 2:  # Rough estimate for max tokens
                    text_filtered += 1
                    continue

                # Add duration info and keep sample
                sample['duration'] = duration
                filtered.append(sample)

            except Exception as e:
                failed_loads += 1
                if failed_loads <= 5:  # Show first few errors
                    print(f"⚠️ Failed to load {sample['audio_path']}: {e}")

        # Print filtering summary
        print(f"📋 Filtering Summary:")
        print(f"   Duration filtered: {duration_filtered}")
        print(f"   Text length filtered: {text_filtered}")
        print(f"   Failed loads: {failed_loads}")
        print(f"   Kept: {len(filtered)}")

        return filtered

    def _compute_dataset_stats(self):
        """Compute comprehensive dataset statistics"""
        if not self.filtered_samples:
            return {}

        durations = [s['duration'] for s in self.filtered_samples]
        text_lengths = [len(s['transcript']) for s in self.filtered_samples]
        speakers = set(s['speaker_id'] for s in self.filtered_samples)

        stats = {
            'num_samples': len(self.filtered_samples),
            'num_speakers': len(speakers),
            'duration_mean': np.mean(durations),
            'duration_std': np.std(durations),
            'duration_min': np.min(durations),
            'duration_max': np.max(durations),
            'text_length_mean': np.mean(text_lengths),
            'text_length_std': np.std(text_lengths),
            'total_duration_hours': sum(durations) / 3600
        }

        return stats

    def _print_stats(self):
        """Print dataset statistics"""
        if not self.stats:
            return

        print(f"\n📈 Dataset Statistics:")
        print(f"   Samples: {self.stats['num_samples']:,}")
        print(f"   Speakers: {self.stats['num_speakers']:,}")
        print(f"   Total duration: {self.stats['total_duration_hours']:.1f} hours")
        print(f"   Avg duration: {self.stats['duration_mean']:.1f}s ± {self.stats['duration_std']:.1f}s")
        print(f"   Duration range: {self.stats['duration_min']:.1f}s - {self.stats['duration_max']:.1f}s")
        print(f"   Avg text length: {self.stats['text_length_mean']:.0f} ± {self.stats['text_length_std']:.0f} chars")

    def __len__(self):
        return len(self.filtered_samples)

    def __getitem__(self, idx):
        """Get a processed sample with comprehensive error handling"""
        sample = self.filtered_samples[idx]

        try:
            # Load and preprocess audio
            waveform, sample_rate = torchaudio.load(sample['audio_path'])
            waveform = waveform.squeeze(0).numpy()

            # Resample if necessary
            if sample_rate != self.sample_rate:
                waveform = librosa.resample(
                    waveform,
                    orig_sr=sample_rate,
                    target_sr=self.sample_rate
                )

            # Convert to mel-spectrogram
            mel_spec = self._compute_mel_spectrogram(waveform)

            # Process text
            text = sample['transcript']
            text_ids = self.text_processor.text_to_ids(text)

            # Truncate if too long
            if len(text_ids) > self.max_seq_len:
                text_ids = text_ids[:self.max_seq_len-1] + [self.text_processor.vocab['<EOS>']]

            # Compute target durations (distribute mel frames across text tokens)
            mel_length = mel_spec.shape[0]
            text_length = len(text_ids)

            if text_length > 0:
                # Base duration per token
                base_duration = mel_length / text_length
                durations = torch.ones(text_length, dtype=torch.float) * base_duration

                # Add natural variation for more realistic durations
                if text_length > 2:
                    # Special tokens get shorter durations
                    durations[0] *= 0.5   # SOS token
                    durations[-1] *= 0.5  # EOS token

                    # Add realistic duration variation
                    variation = torch.randn(text_length) * 0.3 + 1.0
                    variation = torch.clamp(variation, 0.3, 3.0)
                    durations *= variation

                # Normalize to match total mel length
                durations = durations * (mel_length / durations.sum())
            else:
                durations = torch.tensor([mel_length], dtype=torch.float)

            # Convert to tensors
            text_tensor = torch.tensor(text_ids, dtype=torch.long)
            mel_tensor = torch.tensor(mel_spec, dtype=torch.float)  # [time, mel_dim]

            return text_tensor, mel_tensor, durations, len(text_ids)

        except Exception as e:
            # Return dummy data if sample processing fails
            print(f"❌ Error processing sample {idx} ({sample.get('file_id', 'unknown')}): {e}")

            # Create dummy tensors
            dummy_text = torch.tensor([
                self.text_processor.vocab['< SOS >'],
                self.text_processor.vocab['<UNK>'],
                self.text_processor.vocab['<EOS>']
            ], dtype=torch.long)
            dummy_mel = torch.zeros(100, self.n_mels, dtype=torch.float)
            dummy_duration = torch.ones(3, dtype=torch.float) * 33.33  # 100/3

            return dummy_text, dummy_mel, dummy_duration, 3

    def _compute_mel_spectrogram(self, waveform: np.ndarray) -> np.ndarray:
        """Compute mel-spectrogram with LibriSpeech-optimized parameters"""
        # Compute mel-spectrogram using librosa
        mel_spec = librosa.feature.melspectrogram(
            y=waveform,
            sr=self.sample_rate,
            hop_length=self.hop_length,
            n_fft=self.n_fft,
            win_length=self.win_length,
            n_mels=self.n_mels,
            fmin=self.f_min,
            fmax=self.f_max
        )

        # Convert to log scale (dB)
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        # Normalize to [-1, 1] range for training stability
        mel_min, mel_max = mel_spec.min(), mel_spec.max()
        if mel_max > mel_min:
            mel_spec = (mel_spec - mel_min) / (mel_max - mel_min) * 2 - 1
        else:
            mel_spec = np.zeros_like(mel_spec)

        return mel_spec.T  # Return as [time, mel_dim] for consistency

    def get_sample_info(self, idx):
        """Get detailed information about a sample"""
        if idx >= len(self.filtered_samples):
            return None

        sample = self.filtered_samples[idx]
        return {
            'file_id': sample['file_id'],
            'transcript': sample['transcript'],
            'audio_path': sample['audio_path'],
            'duration': sample.get('duration', 'unknown'),
            'speaker_id': sample['speaker_id'],
            'chapter_id': sample['chapter_id']
        }

    def get_random_samples(self, n=5):
        """Get information about n random samples"""
        if len(self.filtered_samples) == 0:
            return []

        indices = np.random.choice(len(self.filtered_samples), min(n, len(self.filtered_samples)), replace=False)
        return [self.get_sample_info(i) for i in indices]

def collate_fn(batch):
    """Advanced collate function for variable length sequences with proper padding"""
    text_tensors, mel_tensors, duration_tensors, text_lengths = zip(*batch)

    # Pad text sequences to same length
    max_text_len = max(len(t) for t in text_tensors)
    padded_texts = []
    padded_durations = []

    pad_token = text_processor.vocab['<PAD>'] if text_processor else 0

    for text, duration in zip(text_tensors, duration_tensors):
        pad_len = max_text_len - len(text)
        padded_text = F.pad(text, (0, pad_len), value=pad_token)
        padded_duration = F.pad(duration, (0, pad_len), value=0)
        padded_texts.append(padded_text)
        padded_durations.append(padded_duration)

    # Pad mel sequences to same length
    max_mel_len = max(mel.size(0) for mel in mel_tensors)
    padded_mels = []

    for mel in mel_tensors:
        pad_len = max_mel_len - mel.size(0)
        padded_mel = F.pad(mel, (0, 0, 0, pad_len), value=0)  # Pad time dimension
        padded_mels.append(padded_mel)

    return (
        torch.stack(padded_texts),      # [batch, max_text_len]
        torch.stack(padded_mels),       # [batch, max_mel_len, mel_dim]
        torch.stack(padded_durations),  # [batch, max_text_len]
        torch.tensor(text_lengths)      # [batch]
    )

def create_data_loaders(samples, text_processor, train_ratio=0.95, batch_size=None, num_workers=2):
    """Create train and validation data loaders with comprehensive setup"""
    if batch_size is None:
        batch_size = config.BATCH_SIZE

    print(f"📦 Creating data loaders...")
    print(f"   Batch size: {batch_size}")
    print(f"   Train ratio: {train_ratio}")
    print(f"   Num workers: {num_workers}")

    # Create dataset
    dataset = LibriSpeechTTSDataset(
        samples,
        text_processor,
        max_audio_length=config.MAX_AUDIO_LENGTH,
        min_audio_length=config.MIN_AUDIO_LENGTH
    )

    if len(dataset) == 0:
        print("❌ Dataset is empty - cannot create data loaders")
        return None, None, None

    # Split dataset into train and validation
    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = total_size - train_size

    print(f"📊 Dataset split:")
    print(f"   Total: {total_size}")
    print(f"   Training: {train_size}")
    print(f"   Validation: {val_size}")

    # Use fixed seed for reproducible splits
    generator = torch.Generator().manual_seed(42)
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=generator
    )

    # Create data loaders with optimized settings
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,  # Ensure consistent batch sizes
        persistent_workers=num_workers > 0
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
        persistent_workers=num_workers > 0
    )

    print(f"✅ Data loaders created:")
    print(f"   Training batches: {len(train_loader)}")
    print(f"   Validation batches: {len(val_loader)}")
    print(f"   Samples per epoch: {len(train_loader) * batch_size}")

    return train_loader, val_loader, dataset

# Create dataset and data loaders if everything is available
train_dataloader = None
val_dataloader = None
dataset = None

if librispeech_samples and text_processor:
    print("\n🚀 Creating LibriSpeech TTS dataset and data loaders...")

    try:
        train_dataloader, val_dataloader, dataset = create_data_loaders(
            librispeech_samples,
            text_processor,
            train_ratio=config.TRAIN_RATIO,
            batch_size=config.BATCH_SIZE,
            num_workers=2
        )

        if dataset and len(dataset) > 0:
            print(f"\n📋 Dataset ready for training!")

            # Show sample entries
            print(f"\nSample dataset entries:")
            random_samples = dataset.get_random_samples(3)
            for i, info in enumerate(random_samples, 1):
                transcript_preview = info['transcript'][:80] + "..." if len(info['transcript']) > 80 else info['transcript']
                print(f"   {i}. {info['file_id']}: '{transcript_preview}'")
                print(f"      Speaker: {info['speaker_id']}, Duration: {info['duration']:.2f}s")

    except Exception as e:
        print(f"❌ Failed to create dataset: {e}")
        import traceback
        traceback.print_exc()

else:
    print("❌ Cannot create dataset - missing LibriSpeech samples or text processor")
    print("   Please ensure LibriSpeech is loaded and text processor is initialized")

# ==============================================================================
# PART 1 SUMMARY AND VALIDATION
# ==============================================================================

print("\n" + "="*80)
print("🎉 PART 1 COMPLETION SUMMARY")
print("="*80)

# Validate all components
system_ready = validate_system_setup()

# Print system information
system_info = get_system_info()
print(f"\n📊 System Configuration:")
for key, value in system_info.items():
    print(f"   {key.replace('_', ' ').title()}: {value}")

# Demo dataset if available
if dataset and len(dataset) > 0:
    print(f"\n🎵 Running dataset demo...")
    demo_dataset_samples(dataset, num_samples=2)

# Test text processor if available
if text_processor:
    print(f"\n🧪 Testing text processor...")
    test_text_processor(text_processor)

# Next steps
print(f"\n🚀 PART 1 COMPLETE - Ready for Part 2!")
print(f"Part 1 provided:")
print(f"   ✅ Complete LibriSpeech data loading ({len(librispeech_samples) if librispeech_samples else 0} samples)")
print(f"   ✅ Advanced text processing (vocab size: {text_processor.vocab_size if text_processor else 0})")
print(f"   ✅ Dataset implementation with preprocessing")
print(f"   ✅ Data loaders for training")
print(f"   ✅ Comprehensive utilities and validation")
print(f"\nNext: Part 2 will implement the neural network architecture")
print(f"      (Streaming components, attention mechanisms, duration prediction)")

if not system_ready:
    print(f"\n⚠️ IMPORTANT: Please resolve the system issues above before proceeding to Part 2")

print("="*80)

🎉 REAL-TIME STREAMING TTS SYSTEM - PART 1
   Foundation, Data Loading, and Text Processing
✅ All imports successful
PyTorch version: 2.6.0+cpu
CUDA available: False


RuntimeError: Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, maia, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: tpu

In [None]:
# Real-Time Streaming Text-to-Speech System - PART 2
# Neural Architecture: Streaming Components, Attention, and Model Implementation
# Built from scratch for LibriSpeech train-clean-100

print("=" * 80)
print("🧠 REAL-TIME STREAMING TTS SYSTEM - PART 2")
print("   Neural Architecture: Streaming Components & Attention Mechanisms")
print("=" * 80)

# ==============================================================================
# CELL 7: Streaming Buffer and Core Components
# ==============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import queue
from typing import List, Optional, Tuple
import time

class StreamingBuffer:
    """Advanced streaming buffer for real-time text processing with overlap management"""

    def __init__(self, buffer_size: int = 4, overlap_frames: int = 2):
        self.buffer_size = buffer_size
        self.overlap_frames = overlap_frames
        self.text_buffer = []
        self.audio_buffer = queue.Queue()
        self.processed_count = 0
        self.total_words_processed = 0

        # Performance tracking
        self.processing_times = []
        self.buffer_states = []

        print(f"🔄 StreamingBuffer initialized:")
        print(f"   Buffer size: {buffer_size} words")
        print(f"   Overlap frames: {overlap_frames} words")

    def add_word(self, word: str) -> bool:
        """Add word to buffer and return True if ready for processing"""
        self.text_buffer.append(word.strip())
        ready = len(self.text_buffer) >= self.buffer_size

        # Track buffer state
        self.buffer_states.append({
            'buffer_length': len(self.text_buffer),
            'word_added': word,
            'ready_for_processing': ready
        })

        return ready

    def get_processing_chunk(self) -> List[str]:
        """Get chunk for processing with intelligent overlap management"""
        if len(self.text_buffer) < self.buffer_size:
            return []

        # Get chunk with overlap for smoother transitions
        chunk_size = min(self.buffer_size + self.overlap_frames, len(self.text_buffer))
        chunk = self.text_buffer[:chunk_size]

        # Remove processed words but keep overlap
        words_to_remove = max(1, self.buffer_size - self.overlap_frames)
        self.text_buffer = self.text_buffer[words_to_remove:]

        self.processed_count += 1
        self.total_words_processed += words_to_remove

        return chunk

    def get_remaining_words(self) -> List[str]:
        """Get any remaining words in buffer"""
        remaining = self.text_buffer.copy()
        self.text_buffer.clear()
        return remaining

    def reset(self):
        """Reset buffer state"""
        self.text_buffer.clear()
        self.processed_count = 0
        self.total_words_processed = 0
        self.processing_times.clear()
        self.buffer_states.clear()

        # Clear audio queue
        while not self.audio_buffer.empty():
            try:
                self.audio_buffer.get_nowait()
            except queue.Empty:
                break

    def get_stats(self):
        """Get buffer performance statistics"""
        return {
            'total_chunks_processed': self.processed_count,
            'total_words_processed': self.total_words_processed,
            'current_buffer_size': len(self.text_buffer),
            'avg_processing_time': np.mean(self.processing_times) if self.processing_times else 0,
            'buffer_efficiency': self.total_words_processed / max(1, self.processed_count * self.buffer_size)
        }

# ==============================================================================
# CELL 8: Causal Convolution and Attention Components
# ==============================================================================

class CausalConv1d(nn.Module):
    """1D causal convolution for streaming with no future information leakage"""

    def __init__(self, in_channels: int, out_channels: int,
                 kernel_size: int, dilation: int = 1, groups: int = 1):
        super().__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = (kernel_size - 1) * dilation

        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            dilation=dilation, padding=self.padding, groups=groups
        )

        # Initialize weights for stable training
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
        if self.conv.bias is not None:
            nn.init.zeros_(self.conv.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass ensuring causality"""
        # Apply convolution with padding
        x = self.conv(x)

        # Remove future information by cropping the end
        if self.padding > 0:
            x = x[:, :, :-self.padding]

        return x

class DepthwiseSeparableConv1d(nn.Module):
    """Efficient depthwise separable convolution for reduced parameters"""

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, dilation: int = 1):
        super().__init__()

        # Depthwise convolution
        self.depthwise = CausalConv1d(
            in_channels, in_channels, kernel_size,
            dilation=dilation, groups=in_channels
        )

        # Pointwise convolution
        self.pointwise = nn.Conv1d(in_channels, out_channels, 1)

        # Normalization and activation
        self.norm = nn.BatchNorm1d(out_channels)
        self.activation = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.norm(x)
        return self.activation(x)

class MultiHeadCausalAttention(nn.Module):
    """Multi-head causal attention optimized for streaming TTS"""

    def __init__(self, hidden_dim: int, num_heads: int, dropout: float = 0.1,
                 use_relative_position: bool = True):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.use_relative_position = use_relative_position

        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        # Relative positional encoding
        if use_relative_position:
            self.max_relative_position = 32
            self.relative_position_k = nn.Parameter(
                torch.randn(2 * self.max_relative_position + 1, self.head_dim)
            )
            self.relative_position_v = nn.Parameter(
                torch.randn(2 * self.max_relative_position + 1, self.head_dim)
            )

        # Dropout
        self.dropout = nn.Dropout(dropout)
        self.attn_dropout = nn.Dropout(dropout)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize attention weights"""
        for module in [self.q_proj, self.k_proj, self.v_proj]:
            nn.init.xavier_uniform_(module.weight)
        nn.init.xavier_uniform_(self.out_proj.weight)
        nn.init.zeros_(self.out_proj.bias)

        if self.use_relative_position:
            nn.init.xavier_uniform_(self.relative_position_k)
            nn.init.xavier_uniform_(self.relative_position_v)

    def _get_relative_positions(self, seq_len: int) -> torch.Tensor:
        """Get relative position indices"""
        positions = torch.arange(seq_len, dtype=torch.long)
        relative_positions = positions[:, None] - positions[None, :]
        relative_positions = torch.clamp(
            relative_positions,
            -self.max_relative_position,
            self.max_relative_position
        ) + self.max_relative_position
        return relative_positions

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
                context: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass with optional context for streaming
        Args:
            x: Input tensor [batch, seq_len, hidden_dim]
            mask: Attention mask [seq_len, seq_len]
            context: Previous context for streaming [batch, context_len, hidden_dim]
        """
        batch_size, seq_len, _ = x.shape

        # Combine with context if provided
        if context is not None:
            kv_input = torch.cat([context, x], dim=1)
            context_len = context.size(1)
        else:
            kv_input = x
            context_len = 0

        kv_seq_len = kv_input.size(1)

        # Project to Q, K, V
        q = self.q_proj(x)  # Only query from current input
        k = self.k_proj(kv_input)  # Key from input + context
        v = self.v_proj(kv_input)  # Value from input + context

        # Reshape for multi-head attention
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Add relative position bias if enabled
        if self.use_relative_position:
            rel_pos_indices = self._get_relative_positions(seq_len).to(x.device)
            rel_pos_k = self.relative_position_k[rel_pos_indices]  # [seq_len, seq_len, head_dim]
            rel_scores = torch.einsum('bhid,ijd->bhij', q, rel_pos_k)
            scores = scores + rel_scores

        # Apply causal mask
        if mask is not None:
            # Expand mask for context if needed
            if context is not None:
                expanded_mask = torch.zeros(seq_len, kv_seq_len, dtype=mask.dtype, device=mask.device)
                expanded_mask[:, context_len:] = mask
                mask = expanded_mask
            scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        # Apply attention
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)

        # Add relative position bias to values if enabled
        if self.use_relative_position:
            rel_pos_v = self.relative_position_v[rel_pos_indices]  # [seq_len, seq_len, head_dim]
            rel_attn = torch.einsum('bhij,ijd->bhid', attn_weights, rel_pos_v)
            attn_output = attn_output + rel_attn

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.hidden_dim
        )

        output = self.out_proj(attn_output)
        return self.dropout(output)

# ==============================================================================
# CELL 9: Streaming Transformer Layers
# ==============================================================================

class StreamingTransformerLayer(nn.Module):
    """Transformer layer optimized for streaming with context management"""

    def __init__(self, hidden_dim: int, num_heads: int, ff_dim: Optional[int] = None,
                 dropout: float = 0.1, use_relative_position: bool = True):
        super().__init__()

        self.hidden_dim = hidden_dim
        ff_dim = ff_dim or hidden_dim * 4

        # Multi-head causal attention
        self.self_attention = MultiHeadCausalAttention(
            hidden_dim, num_heads, dropout, use_relative_position
        )

        # Feed-forward network with GELU activation
        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, hidden_dim),
            nn.Dropout(dropout)
        )

        # Layer normalization (Pre-LN for better training stability)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize layer weights"""
        for module in self.feed_forward:
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x: torch.Tensor, causal: bool = True,
                context: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass with optional streaming context
        Args:
            x: Input tensor [batch, seq_len, hidden_dim]
            causal: Whether to apply causal masking
            context: Previous context for streaming
        """
        seq_len = x.size(1)

        # Create causal mask if needed
        mask = None
        if causal:
            # Create lower triangular mask for causality
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            mask = mask.to(x.device)

        # Pre-LN: Layer norm before attention
        normed_x = self.norm1(x)

        # Self-attention with residual connection
        attn_output = self.self_attention(normed_x, mask=mask, context=context)
        x = x + attn_output

        # Pre-LN: Layer norm before feed-forward
        normed_x = self.norm2(x)

        # Feed-forward with residual connection
        ff_output = self.feed_forward(normed_x)
        x = x + ff_output

        return x

class StreamingTransformerEncoder(nn.Module):
    """Multi-layer streaming transformer encoder with context management"""

    def __init__(self, num_layers: int, hidden_dim: int, num_heads: int,
                 ff_dim: Optional[int] = None, dropout: float = 0.1,
                 use_relative_position: bool = True):
        super().__init__()

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.context_length = 2  # Number of tokens to keep as context

        # Transformer layers
        self.layers = nn.ModuleList([
            StreamingTransformerLayer(
                hidden_dim, num_heads, ff_dim, dropout, use_relative_position
            )
            for _ in range(num_layers)
        ])

        # Final layer norm
        self.final_norm = nn.LayerNorm(hidden_dim)

        print(f"🏗️ StreamingTransformerEncoder:")
        print(f"   Layers: {num_layers}")
        print(f"   Hidden dim: {hidden_dim}")
        print(f"   Attention heads: {num_heads}")
        print(f"   Context length: {self.context_length}")

    def forward(self, x: torch.Tensor, causal: bool = True,
                context: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through all transformer layers
        Returns:
            output: Transformed representation
            new_context: Context for next streaming chunk
        """
        # Apply all transformer layers
        for layer in self.layers:
            x = layer(x, causal=causal, context=context)
            # Update context for next layer (optional improvement)
            if context is not None and x.size(1) >= self.context_length:
                context = x[:, -self.context_length:, :].detach()

        # Apply final normalization
        x = self.final_norm(x)

        # Extract new context for streaming
        new_context = None
        if x.size(1) >= self.context_length:
            new_context = x[:, -self.context_length:, :].detach()

        return x, new_context

# ==============================================================================
# CELL 10: Duration Prediction and Alignment
# ==============================================================================

class DurationPredictor(nn.Module):
    """Advanced duration predictor for text-to-speech alignment"""

    def __init__(self, hidden_dim: int, filter_size: int = 256, kernel_size: int = 3):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.filter_size = filter_size

        # Convolutional layers for duration prediction
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(hidden_dim, filter_size, kernel_size, padding=kernel_size//2),
                nn.GELU(),
                nn.BatchNorm1d(filter_size),
                nn.Dropout(0.1)
            ),
            nn.Sequential(
                nn.Conv1d(filter_size, filter_size, kernel_size, padding=kernel_size//2),
                nn.GELU(),
                nn.BatchNorm1d(filter_size),
                nn.Dropout(0.1)
            ),
            nn.Sequential(
                nn.Conv1d(filter_size, filter_size // 2, kernel_size, padding=kernel_size//2),
                nn.GELU(),
                nn.BatchNorm1d(filter_size // 2)
            )
        ])

        # Output projection
        self.output_projection = nn.Sequential(
            nn.Linear(filter_size // 2, filter_size // 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(filter_size // 4, 1)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize weights for stable training"""
        for module in self.modules():
            if isinstance(module, nn.Conv1d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Predict duration for each token
        Args:
            x: Hidden representation [batch, seq_len, hidden_dim]
        Returns:
            durations: Predicted durations [batch, seq_len]
        """
        # Transpose for conv1d: [batch, hidden_dim, seq_len]
        x = x.transpose(1, 2)

        # Apply convolutional layers
        for conv_layer in self.conv_layers:
            residual = x
            x = conv_layer(x)

            # Residual connection if dimensions match
            if residual.size(1) == x.size(1):
                x = x + residual

        # Transpose back: [batch, seq_len, filter_size//2]
        x = x.transpose(1, 2)

        # Project to duration
        durations = self.output_projection(x).squeeze(-1)

        # Ensure positive durations with reasonable range for LibriSpeech
        # Range: 1 to 25 frames (suitable for 16kHz, 200 hop_length)
        durations = torch.sigmoid(durations) * 24 + 1

        return durations

class LengthRegulator(nn.Module):
    """Advanced length regulator for duration-based alignment"""

    def __init__(self, max_duration: int = 100):
        super().__init__()
        self.max_duration = max_duration

    def forward(self, hidden: torch.Tensor, durations: torch.Tensor,
                target_length: Optional[int] = None) -> torch.Tensor:
        """
        Regulate sequence length based on predicted durations
        Args:
            hidden: Hidden representations [batch, seq_len, hidden_dim]
            durations: Duration for each token [batch, seq_len]
            target_length: Optional target length for training
        Returns:
            expanded: Length-regulated sequence
        """
        batch_size, seq_len, hidden_dim = hidden.shape

        # Clamp durations to reasonable bounds
        durations = torch.clamp(durations, min=0.5, max=self.max_duration)

        # If target length is provided (training), scale durations
        if target_length is not None:
            total_duration = durations.sum(dim=1, keepdim=True)
            scale_factor = target_length / (total_duration + 1e-8)
            durations = durations * scale_factor

        # Round to integers
        durations = torch.round(durations).long()
        durations = torch.clamp(durations, min=1, max=self.max_duration)

        # Calculate output length
        max_len = min(durations.sum(dim=1).max().item(), 2000)  # Cap for memory

        # Expand sequences
        expanded = torch.zeros(batch_size, max_len, hidden_dim,
                              device=hidden.device, dtype=hidden.dtype)

        for b in range(batch_size):
            pos = 0
            for i in range(seq_len):
                if pos >= max_len:
                    break

                dur = min(durations[b, i].item(), max_len - pos)
                if dur > 0:
                    # Repeat hidden state for duration
                    expanded[b, pos:pos+dur, :] = hidden[b, i, :].unsqueeze(0).repeat(dur, 1)
                    pos += dur

        return expanded

# ==============================================================================
# CELL 11: Mel-Spectrogram Decoder
# ==============================================================================

class StreamingMelDecoder(nn.Module):
    """Advanced mel-spectrogram decoder with streaming capability"""

    def __init__(self, hidden_dim: int, mel_dim: int, prenet_dim: int = 256):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.mel_dim = mel_dim
        self.prenet_dim = prenet_dim

        # Pre-net for input conditioning
        self.prenet = nn.Sequential(
            nn.Linear(hidden_dim, prenet_dim),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(prenet_dim, prenet_dim),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(prenet_dim, hidden_dim)
        )

        # Causal convolution stack for streaming
        self.causal_convs = nn.ModuleList([
            CausalConv1d(hidden_dim, hidden_dim, kernel_size=5, dilation=1),
            CausalConv1d(hidden_dim, hidden_dim, kernel_size=5, dilation=2),
            CausalConv1d(hidden_dim, hidden_dim, kernel_size=5, dilation=4),
            CausalConv1d(hidden_dim, hidden_dim, kernel_size=5, dilation=8),
        ])

        # Normalization for each conv layer
        self.conv_norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(len(self.causal_convs))
        ])

        # Post-processing network
        self.postnet = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.1),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.1),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim)
        )

        # Output projection to mel-spectrogram
        self.mel_projection = nn.Linear(hidden_dim, mel_dim)

        # Length regulator
        self.length_regulator = LengthRegulator()

        # Initialize weights
        self._init_weights()

        print(f"🎼 StreamingMelDecoder:")
        print(f"   Hidden dim: {hidden_dim}")
        print(f"   Mel dim: {mel_dim}")
        print(f"   Prenet dim: {prenet_dim}")
        print(f"   Causal conv layers: {len(self.causal_convs)}")

    def _init_weights(self):
        """Initialize decoder weights"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Conv1d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, hidden: torch.Tensor, durations: torch.Tensor,
                target_length: Optional[int] = None) -> torch.Tensor:
        """
        Generate mel-spectrogram from hidden representations
        Args:
            hidden: Text encoder output [batch, seq_len, hidden_dim]
            durations: Duration predictions [batch, seq_len]
            target_length: Target mel length for training
        Returns:
            mel_output: Generated mel-spectrogram [batch, mel_len, mel_dim]
        """
        # Apply pre-net
        hidden = self.prenet(hidden)

        # Length regulation based on durations
        hidden = self.length_regulator(hidden, durations, target_length)

        # Apply causal convolutions with residual connections
        hidden_conv = hidden.transpose(1, 2)  # [batch, hidden_dim, seq_len]

        for conv, norm in zip(self.causal_convs, self.conv_norms):
            residual = hidden_conv
            hidden_conv = conv(hidden_conv)
            hidden_conv = hidden_conv + residual  # Residual connection

            # Apply layer normalization
            hidden_conv = norm(hidden_conv.transpose(1, 2)).transpose(1, 2)

        # Apply post-processing with residual
        post_processed = self.postnet(hidden_conv)
        hidden_conv = hidden_conv + post_processed

        # Convert back to [batch, seq_len, hidden_dim]
        hidden = hidden_conv.transpose(1, 2)

        # Project to mel-spectrogram
        mel_output = self.mel_projection(hidden)

        return mel_output

# ==============================================================================
# CELL 12: Advanced Vocoder Architecture
# ==============================================================================

class ResidualBlock(nn.Module):
    """Enhanced residual block with multiple kernel sizes"""

    def __init__(self, channels: int, kernel_sizes: List[int] = [3, 5, 7]):
        super().__init__()

        self.convs = nn.ModuleList()
        for kernel_size in kernel_sizes:
            self.convs.append(nn.Sequential(
                nn.Conv1d(channels, channels, kernel_size, padding=kernel_size//2),
                nn.BatchNorm1d(channels),
                nn.LeakyReLU(0.2)
            ))

        self.fusion_conv = nn.Conv1d(channels * len(kernel_sizes), channels, 1)
        self.norm = nn.BatchNorm1d(channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x

        # Apply multiple kernel sizes
        conv_outputs = []
        for conv in self.convs:
            conv_outputs.append(conv(x))

        # Concatenate and fuse
        fused = torch.cat(conv_outputs, dim=1)
        fused = self.fusion_conv(fused)
        fused = self.norm(fused)

        return F.leaky_relu(fused + residual, 0.2)

class LibriSpeechVocoder(nn.Module):
    """Advanced vocoder optimized for LibriSpeech 16kHz with high quality output"""

    def __init__(self, mel_dim: int, hop_length: int = 200,
                 upsample_rates: List[int] = None):
        super().__init__()

        self.mel_dim = mel_dim
        self.hop_length = hop_length

        # Default upsampling rates for 16kHz, 200 hop_length
        if upsample_rates is None:
            upsample_rates = [2, 2, 5, 5, 2]  # Total: 200x upsampling

        self.upsample_rates = upsample_rates

        # Input projection
        self.input_proj = nn.Conv1d(mel_dim, 512, 7, padding=3)

        # Upsampling layers with decreasing channels
        channels = [512, 256, 128, 64, 32]
        self.upsample_layers = nn.ModuleList()

        for i, rate in enumerate(upsample_rates):
            in_ch = channels[i] if i < len(channels) else 32
            out_ch = channels[i+1] if i+1 < len(channels) else 32

            self.upsample_layers.append(nn.Sequential(
                nn.ConvTranspose1d(in_ch, out_ch, rate*2, rate, rate//2),
                nn.BatchNorm1d(out_ch),
                nn.LeakyReLU(0.2)
            ))

        # Residual blocks for each upsampling stage
        self.res_blocks = nn.ModuleList([
            ResidualBlock(ch) for ch in channels[1:]
        ])

        # Multi-receptive field fusion
        self.mrf_blocks = nn.ModuleList([
            nn.Conv1d(32, 32, 3, padding=1),
            nn.Conv1d(32, 32, 5, padding=2),
            nn.Conv1d(32, 32, 7, padding=3),
        ])

        # Output layers
        self.output_conv = nn.Sequential(
            nn.Conv1d(32, 16, 7, padding=3),
            nn.LeakyReLU(0.2),
            nn.Conv1d(16, 8, 7, padding=3),
            nn.LeakyReLU(0.2),
            nn.Conv1d(8, 1, 7, padding=3),
            nn.Tanh()
        )

        # Initialize weights
        self._init_weights()

        print(f"🔊 LibriSpeechVocoder:")
        print(f"   Mel dim: {mel_dim}")
        print(f"   Hop length: {hop_length}")
        print(f"   Upsample rates: {upsample_rates}")
        print(f"   Total upsampling: {np.prod(upsample_rates)}x")

    def _init_weights(self):
        """Initialize vocoder weights"""
        for module in self.modules():
            if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='leaky_relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        """
        Convert mel-spectrogram to waveform
        Args:
            mel: Mel-spectrogram [batch, mel_len, mel_dim]
        Returns:
            waveform: Generated audio [batch, audio_len]
        """
        # Transpose to [batch, mel_dim, mel_len] for convolution
        x = mel.transpose(1, 2)

        # Input projection
        x = self.input_proj(x)

        # Upsampling with residual blocks
        for upsample, res_block in zip(self.upsample_layers, self.res_blocks):
            x = upsample(x)
            x = res_block(x)

        # Multi-receptive field fusion
        mrf_outputs = []
        for mrf_conv in self.mrf_blocks:
            mrf_outputs.append(mrf_conv(x))

        # Average fusion of different receptive fields
        x = sum(mrf_outputs) / len(mrf_outputs)

        # Generate final waveform
        waveform = self.output_conv(x)

        return waveform.squeeze(1)  # [batch, audio_len]

# ==============================================================================
# CELL 13: Complete Streaming TTS Model
# ==============================================================================

class StreamingTTSModel(nn.Module):
    """Complete streaming TTS model with all components integrated"""

    def __init__(self, vocab_size: int, hidden_dim: int = 512,
                 num_heads: int = 8, num_layers: int = 6,
                 mel_dim: int = 80, max_seq_len: int = 500,
                 dropout: float = 0.1):
        super().__init__()

        # Model configuration
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.mel_dim = mel_dim
        self.max_seq_len = max_seq_len
        self.num_layers = num_layers

        # Text embedding layer
        self.text_embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
        self.embedding_dropout = nn.Dropout(dropout)

        # Positional encoding
        self.pos_encoding = self._create_positional_encoding(max_seq_len, hidden_dim)
        self.register_buffer('pos_encoding_buffer', self.pos_encoding)

        # Streaming transformer encoder
        self.encoder = StreamingTransformerEncoder(
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            use_relative_position=True
        )

        # Duration predictor
        self.duration_predictor = DurationPredictor(hidden_dim)

        # Mel-spectrogram decoder
        self.mel_decoder = StreamingMelDecoder(hidden_dim, mel_dim)

        # Vocoder for audio synthesis
        self.vocoder = LibriSpeechVocoder(mel_dim, hop_length=200)  # From config

        # Streaming state management
        self.streaming_context = None
        self.streaming_position = 0

        # Initialize model weights
        self._initialize_weights()

        # Calculate model size
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        print(f"🧠 StreamingTTSModel Complete:")
        print(f"   Vocabulary size: {vocab_size:,}")
        print(f"   Hidden dimension: {hidden_dim}")
        print(f"   Transformer layers: {num_layers}")
        print(f"   Attention heads: {num_heads}")
        print(f"   Mel dimensions: {mel_dim}")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")
        print(f"   Model size: {total_params * 4 / 1024**2:.1f} MB")

    def _create_positional_encoding(self, max_len: int, d_model: int) -> torch.Tensor:
        """Create sinusoidal positional encoding"""
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        return pe.unsqueeze(0)  # [1, max_len, d_model]

    def _initialize_weights(self):
        """Initialize all model weights"""
        # Text embedding
        nn.init.normal_(self.text_embedding.weight, 0, 0.1)
        if hasattr(self.text_embedding, 'padding_idx') and self.text_embedding.padding_idx is not None:
            with torch.no_grad():
                self.text_embedding.weight[self.text_embedding.padding_idx].fill_(0)

        # Apply initialization to all modules
        for module in self.modules():
            if isinstance(module, nn.Linear):
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, (nn.BatchNorm1d, nn.LayerNorm)):
                if hasattr(module, 'weight') and module.weight is not None:
                    nn.init.ones_(module.weight)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.zeros_(module.bias)

    def reset_streaming_state(self):
        """Reset streaming state for new session"""
        self.streaming_context = None
        self.streaming_position = 0

    def forward_streaming(self, text_ids: torch.Tensor,
                         return_alignments: bool = False) -> torch.Tensor:
        """
        Streaming forward pass for real-time synthesis
        Args:
            text_ids: Token IDs [batch, seq_len]
            return_alignments: Whether to return duration predictions
        Returns:
            mel_output: Generated mel-spectrogram
            durations: Duration predictions (if return_alignments=True)
        """
        batch_size, seq_len = text_ids.shape

        # Text embedding
        text_emb = self.text_embedding(text_ids)
        text_emb = self.embedding_dropout(text_emb)

        # Add positional encoding with streaming position
        pos_start = min(self.streaming_position, self.max_seq_len - seq_len)
        pos_end = min(pos_start + seq_len, self.max_seq_len)

        if pos_end > pos_start:
            pos_enc = self.pos_encoding_buffer[:, pos_start:pos_end, :]
            # Handle case where we need fewer positions than requested
            if pos_enc.size(1) < seq_len:
                # Pad with the last positional encoding
                last_pos = self.pos_encoding_buffer[:, -1:, :]
                padding_needed = seq_len - pos_enc.size(1)
                padding = last_pos.repeat(1, padding_needed, 1)
                pos_enc = torch.cat([pos_enc, padding], dim=1)

            text_emb = text_emb + pos_enc.to(text_emb.device)

        # Transformer encoding with streaming context
        hidden, new_context = self.encoder(
            text_emb,
            causal=True,
            context=self.streaming_context
        )

        # Update streaming state
        self.streaming_context = new_context
        self.streaming_position += max(1, seq_len - self.encoder.context_length)

        # Duration prediction
        durations = self.duration_predictor(hidden)

        # Mel-spectrogram generation
        mel_output = self.mel_decoder(hidden, durations)

        if return_alignments:
            return mel_output, durations
        return mel_output

    def synthesize_streaming(self, mel_spec: torch.Tensor) -> torch.Tensor:
        """Convert mel-spectrogram to audio waveform"""
        return self.vocoder(mel_spec)

    def forward(self, text_ids: torch.Tensor,
                target_mel_length: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Standard forward pass for training
        Args:
            text_ids: Token IDs [batch, seq_len]
            target_mel_length: Target mel length for training alignment
        Returns:
            mel_output: Generated mel-spectrogram
            durations: Predicted durations
        """
        # Reset streaming state for non-streaming forward pass
        self.reset_streaming_state()
        return self.forward_streaming(text_ids, return_alignments=True)

    def inference(self, text_ids: torch.Tensor,
                  return_intermediate: bool = False) -> dict:
        """
        Complete inference pipeline
        Args:
            text_ids: Token IDs [batch, seq_len]
            return_intermediate: Whether to return intermediate outputs
        Returns:
            Dictionary with outputs and optional intermediate results
        """
        self.eval()

        with torch.no_grad():
            # Generate mel-spectrogram
            mel_output, durations = self.forward(text_ids)

            # Generate audio
            audio_output = self.synthesize_streaming(mel_output)

            results = {
                'audio': audio_output,
                'mel_spectrogram': mel_output,
                'durations': durations
            }

            if return_intermediate:
                # Add intermediate representations for analysis
                text_emb = self.text_embedding(text_ids)
                hidden, _ = self.encoder(text_emb, causal=False)

                results.update({
                    'text_embedding': text_emb,
                    'encoder_output': hidden,
                    'text_length': text_ids.size(1),
                    'mel_length': mel_output.size(1),
                    'audio_length': audio_output.size(-1)
                })

        return results

    def get_model_stats(self) -> dict:
        """Get comprehensive model statistics"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        component_params = {}
        for name, module in self.named_children():
            component_params[name] = sum(p.numel() for p in module.parameters())

        return {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'model_size_mb': total_params * 4 / 1024**2,
            'component_parameters': component_params,
            'vocab_size': self.vocab_size,
            'hidden_dim': self.hidden_dim,
            'num_layers': self.num_layers,
            'mel_dim': self.mel_dim
        }

# ==============================================================================
# CELL 14: Model Testing and Validation
# ==============================================================================

def test_model_components():
    """Test individual model components"""
    print("🧪 Testing Model Components")
    print("=" * 50)

    # Test parameters (assuming Part 1 components are available)
    try:
        from config import config  # From Part 1
        vocab_size = getattr(config, 'VOCAB_SIZE', 1000)
        hidden_dim = getattr(config, 'HIDDEN_DIM', 512)
        mel_dim = getattr(config, 'MEL_DIM', 80)
        device = getattr(config, 'DEVICE', torch.device('cpu'))
    except:
        # Fallback values if config not available
        vocab_size = 1000
        hidden_dim = 512
        mel_dim = 80
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    batch_size = 2
    seq_len = 10

    print(f"Using device: {device}")
    print(f"Test parameters: batch={batch_size}, seq_len={seq_len}")

    # Test StreamingBuffer
    print("\n🔄 Testing StreamingBuffer...")
    buffer = StreamingBuffer(buffer_size=4, overlap_frames=2)
    test_words = ["hello", "world", "this", "is", "a", "test"]

    for word in test_words:
        ready = buffer.add_word(word)
        if ready:
            chunk = buffer.get_processing_chunk()
            print(f"   Processed chunk: {chunk}")

    print(f"   Buffer stats: {buffer.get_stats()}")

    # Test MultiHeadCausalAttention
    print("\n🔍 Testing MultiHeadCausalAttention...")
    attention = MultiHeadCausalAttention(hidden_dim, num_heads=8).to(device)

    test_input = torch.randn(batch_size, seq_len, hidden_dim).to(device)

    try:
        attn_output = attention(test_input)
        print(f"   Input shape: {test_input.shape}")
        print(f"   Output shape: {attn_output.shape}")
        print(f"   ✅ Attention test passed")
    except Exception as e:
        print(f"   ❌ Attention test failed: {e}")

    # Test StreamingTransformerLayer
    print("\n🏗️ Testing StreamingTransformerLayer...")
    transformer_layer = StreamingTransformerLayer(hidden_dim, num_heads=8).to(device)

    try:
        layer_output = transformer_layer(test_input)
        print(f"   Input shape: {test_input.shape}")
        print(f"   Output shape: {layer_output.shape}")
        print(f"   ✅ Transformer layer test passed")
    except Exception as e:
        print(f"   ❌ Transformer layer test failed: {e}")

    # Test DurationPredictor
    print("\n⏱️ Testing DurationPredictor...")
    duration_predictor = DurationPredictor(hidden_dim).to(device)

    try:
        duration_output = duration_predictor(test_input)
        print(f"   Input shape: {test_input.shape}")
        print(f"   Duration output shape: {duration_output.shape}")
        print(f"   Duration range: {duration_output.min().item():.2f} - {duration_output.max().item():.2f}")
        print(f"   ✅ Duration predictor test passed")
    except Exception as e:
        print(f"   ❌ Duration predictor test failed: {e}")

    # Test StreamingMelDecoder
    print("\n🎼 Testing StreamingMelDecoder...")
    mel_decoder = StreamingMelDecoder(hidden_dim, mel_dim).to(device)

    try:
        durations = torch.ones(batch_size, seq_len).to(device) * 10  # 10 frames per token
        mel_output = mel_decoder(test_input, durations)
        print(f"   Input shape: {test_input.shape}")
        print(f"   Duration shape: {durations.shape}")
        print(f"   Mel output shape: {mel_output.shape}")
        print(f"   ✅ Mel decoder test passed")
    except Exception as e:
        print(f"   ❌ Mel decoder test failed: {e}")

    # Test LibriSpeechVocoder
    print("\n🔊 Testing LibriSpeechVocoder...")
    vocoder = LibriSpeechVocoder(mel_dim).to(device)

    try:
        # Create test mel-spectrogram
        mel_len = 100
        test_mel = torch.randn(batch_size, mel_len, mel_dim).to(device)

        audio_output = vocoder(test_mel)
        expected_audio_len = mel_len * 200  # hop_length = 200

        print(f"   Mel input shape: {test_mel.shape}")
        print(f"   Audio output shape: {audio_output.shape}")
        print(f"   Expected audio length: ~{expected_audio_len}")
        print(f"   Actual audio length: {audio_output.size(-1)}")
        print(f"   ✅ Vocoder test passed")
    except Exception as e:
        print(f"   ❌ Vocoder test failed: {e}")

def test_complete_model():
    """Test the complete streaming TTS model"""
    print("\n🚀 Testing Complete StreamingTTSModel")
    print("=" * 50)

    # Test parameters
    try:
        from config import config
        vocab_size = getattr(config, 'VOCAB_SIZE', 1000)
        hidden_dim = getattr(config, 'HIDDEN_DIM', 512)
        device = getattr(config, 'DEVICE', torch.device('cpu'))
    except:
        vocab_size = 1000
        hidden_dim = 512
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create model
    model = StreamingTTSModel(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        num_heads=8,
        num_layers=6,
        mel_dim=80
    ).to(device)

    # Test data
    batch_size = 2
    seq_len = 15
    test_text_ids = torch.randint(1, vocab_size, (batch_size, seq_len)).to(device)

    print(f"Model device: {next(model.parameters()).device}")
    print(f"Test input shape: {test_text_ids.shape}")

    # Test standard forward pass
    print("\n📝 Testing standard forward pass...")
    try:
        model.eval()
        with torch.no_grad():
            mel_output, durations = model.forward(test_text_ids)

        print(f"   ✅ Forward pass successful")
        print(f"   Mel output shape: {mel_output.shape}")
        print(f"   Duration shape: {durations.shape}")
        print(f"   Duration stats: min={durations.min():.2f}, max={durations.max():.2f}, mean={durations.mean():.2f}")

    except Exception as e:
        print(f"   ❌ Forward pass failed: {e}")
        import traceback
        traceback.print_exc()
        return

    # Test streaming forward pass
    print("\n🔄 Testing streaming forward pass...")
    try:
        model.reset_streaming_state()

        # Process in chunks
        chunk_size = 5
        all_mels = []

        for i in range(0, seq_len, chunk_size):
            chunk = test_text_ids[:, i:i+chunk_size]
            if chunk.size(1) > 0:
                mel_chunk = model.forward_streaming(chunk)
                all_mels.append(mel_chunk)
                print(f"   Chunk {i//chunk_size + 1}: input {chunk.shape} -> mel {mel_chunk.shape}")

        print(f"   ✅ Streaming forward pass successful")

    except Exception as e:
        print(f"   ❌ Streaming forward pass failed: {e}")
        import traceback
        traceback.print_exc()

    # Test audio synthesis
    print("\n🔊 Testing audio synthesis...")
    try:
        audio_output = model.synthesize_streaming(mel_output)

        expected_length = mel_output.size(1) * 200  # hop_length
        print(f"   ✅ Audio synthesis successful")
        print(f"   Audio output shape: {audio_output.shape}")
        print(f"   Expected length: ~{expected_length}")
        print(f"   Actual length: {audio_output.size(-1)}")
        print(f"   Audio stats: min={audio_output.min():.3f}, max={audio_output.max():.3f}")

    except Exception as e:
        print(f"   ❌ Audio synthesis failed: {e}")

    # Test complete inference
    print("\n🎯 Testing complete inference pipeline...")
    try:
        results = model.inference(test_text_ids, return_intermediate=True)

        print(f"   ✅ Complete inference successful")
        print(f"   Results keys: {list(results.keys())}")
        print(f"   Audio shape: {results['audio'].shape}")
        print(f"   Mel shape: {results['mel_spectrogram'].shape}")
        print(f"   Text length: {results['text_length']}")
        print(f"   Mel length: {results['mel_length']}")
        print(f"   Audio length: {results['audio_length']}")

    except Exception as e:
        print(f"   ❌ Complete inference failed: {e}")

    # Model statistics
    print("\n📊 Model Statistics:")
    stats = model.get_model_stats()
    for key, value in stats.items():
        if key == 'component_parameters':
            print(f"   Component parameters:")
            for comp, params in value.items():
                print(f"     {comp}: {params:,}")
        else:
            if isinstance(value, float):
                print(f"   {key}: {value:.2f}")
            else:
                print(f"   {key}: {value:,}")

def validate_model_architecture():
    """Validate the model architecture meets requirements"""
    print("\n✅ Model Architecture Validation")
    print("=" * 50)

    requirements = {
        'streaming_capability': False,
        'causal_attention': False,
        'duration_prediction': False,
        'mel_generation': False,
        'audio_synthesis': False,
        'context_management': False
    }

    try:
        # Create minimal model for testing
        model = StreamingTTSModel(vocab_size=100, hidden_dim=64, num_layers=2)

        # Test streaming capability
        if hasattr(model, 'reset_streaming_state') and hasattr(model, 'forward_streaming'):
            requirements['streaming_capability'] = True

        # Test causal attention
        if hasattr(model.encoder.layers[0], 'self_attention'):
            requirements['causal_attention'] = True

        # Test duration prediction
        if hasattr(model, 'duration_predictor'):
            requirements['duration_prediction'] = True

        # Test mel generation
        if hasattr(model, 'mel_decoder'):
            requirements['mel_generation'] = True

        # Test audio synthesis
        if hasattr(model, 'vocoder') and hasattr(model, 'synthesize_streaming'):
            requirements['audio_synthesis'] = True

        # Test context management
        if hasattr(model, 'streaming_context'):
            requirements['context_management'] = True

    except Exception as e:
        print(f"❌ Architecture validation failed: {e}")

    print("Architecture Requirements:")
    for req, status in requirements.items():
        status_icon = "✅" if status else "❌"
        print(f"   {status_icon} {req.replace('_', ' ').title()}")

    all_passed = all(requirements.values())
    if all_passed:
        print(f"\n🎉 All architecture requirements met!")
    else:
        print(f"\n⚠️ Some requirements not met - check implementation")

    return all_passed

# ==============================================================================
# PART 2 SUMMARY AND NEXT STEPS
# ==============================================================================

print("\n" + "="*80)
print("🎉 PART 2 COMPLETION SUMMARY")
print("="*80)

# Run component tests
test_model_components()

# Run complete model test
test_complete_model()

# Validate architecture
architecture_valid = validate_model_architecture()

print(f"\n🎯 PART 2 COMPLETE - Neural Architecture Ready!")
print(f"Part 2 delivered:")
print(f"   ✅ Streaming buffer with overlap management")
print(f"   ✅ Causal convolutions for streaming")
print(f"   ✅ Multi-head causal attention with relative positioning")
print(f"   ✅ Streaming transformer encoder with context")
print(f"   ✅ Advanced duration predictor")
print(f"   ✅ Streaming mel-spectrogram decoder")
print(f"   ✅ High-quality vocoder for LibriSpeech")
print(f"   ✅ Complete integrated streaming TTS model")
print(f"   ✅ Comprehensive testing and validation")

print(f"\nModel Capabilities:")
print(f"   🔄 Real-time word-by-word processing")
print(f"   🧠 Context-aware streaming inference")
print(f"   ⏱️  Duration-based alignment")
print(f"   🎼 High-quality mel-spectrogram generation")
print(f"   🔊 16kHz audio synthesis optimized for LibriSpeech")

if architecture_valid:
    print(f"\n✅ Architecture validation: PASSED")
    print(f"🚀 Ready for Part 3: Training Pipeline & Optimization")
else:
    print(f"\n⚠️ Architecture validation: ISSUES FOUND")
    print(f"   Please resolve issues before proceeding to Part 3")

print(f"\nNext: Part 3 will implement:")
print(f"   - Advanced training pipeline with mixed precision")
print(f"   - Loss functions and optimization strategies")
print(f"   - Validation and checkpointing")
print(f"   - Performance monitoring and debugging tools")

print("="*80)

In [None]:
# Real-Time Streaming Text-to-Speech System - PART 3
# Training Pipeline, Optimization, and Performance Monitoring
# Built from scratch for LibriSpeech train-clean-100

print("=" * 80)
print("🚀 REAL-TIME STREAMING TTS SYSTEM - PART 3")
print("   Training Pipeline, Optimization & Performance Monitoring")
print("=" * 80)

# ==============================================================================
# CELL 15: Advanced Loss Functions and Metrics
# ==============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import time
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional, Any
from collections import defaultdict, deque
import json
import os
from pathlib import Path
from tqdm import tqdm
import math

class TTSLossComputer:
    """Advanced loss computation for TTS training with multiple objectives"""

    def __init__(self, mel_loss_weight: float = 1.0, duration_loss_weight: float = 0.1,
                 alignment_loss_weight: float = 0.05, consistency_loss_weight: float = 0.02):
        self.mel_loss_weight = mel_loss_weight
        self.duration_loss_weight = duration_loss_weight
        self.alignment_loss_weight = alignment_loss_weight
        self.consistency_loss_weight = consistency_loss_weight

        # Loss functions
        self.mel_loss_fn = nn.L1Loss(reduction='mean')
        self.mel_loss_mse = nn.MSELoss(reduction='mean')
        self.duration_loss_fn = nn.MSELoss(reduction='mean')
        self.alignment_loss_fn = nn.L1Loss(reduction='mean')

        # Loss history for monitoring
        self.loss_history = defaultdict(list)

        print(f"🎯 TTSLossComputer initialized:")
        print(f"   Mel loss weight: {mel_loss_weight}")
        print(f"   Duration loss weight: {duration_loss_weight}")
        print(f"   Alignment loss weight: {alignment_loss_weight}")
        print(f"   Consistency loss weight: {consistency_loss_weight}")

    def compute_mel_loss(self, predicted_mel: torch.Tensor, target_mel: torch.Tensor,
                        mel_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Compute mel-spectrogram loss with optional masking"""
        # Align sequences to same length
        min_len = min(predicted_mel.size(1), target_mel.size(1))
        pred_aligned = predicted_mel[:, :min_len, :]
        target_aligned = target_mel[:, :min_len, :]

        # Apply length masking if provided
        if mel_lengths is not None:
            batch_size = pred_aligned.size(0)
            mask = torch.arange(min_len).unsqueeze(0).expand(batch_size, -1)
            mask = mask < mel_lengths.unsqueeze(1)
            mask = mask.unsqueeze(-1).to(pred_aligned.device)

            pred_aligned = pred_aligned * mask
            target_aligned = target_aligned * mask

        # Combine L1 and L2 losses for better convergence
        l1_loss = self.mel_loss_fn(pred_aligned, target_aligned)
        l2_loss = self.mel_loss_mse(pred_aligned, target_aligned)

        return 0.7 * l1_loss + 0.3 * l2_loss

    def compute_duration_loss(self, predicted_durations: torch.Tensor,
                            target_durations: torch.Tensor,
                            text_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Compute duration prediction loss with masking"""
        if text_lengths is not None:
            # Create mask for valid tokens
            batch_size, max_len = predicted_durations.shape
            mask = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1)
            mask = mask < text_lengths.unsqueeze(1)
            mask = mask.to(predicted_durations.device)

            # Apply mask
            predicted_masked = predicted_durations * mask
            target_masked = target_durations * mask

            # Compute loss only on valid tokens
            valid_elements = mask.sum()
            if valid_elements > 0:
                loss = F.mse_loss(predicted_masked, target_masked, reduction='sum') / valid_elements
            else:
                loss = torch.tensor(0.0, device=predicted_durations.device)
        else:
            loss = self.duration_loss_fn(predicted_durations, target_durations)

        return loss

    def compute_alignment_loss(self, predicted_durations: torch.Tensor,
                             target_mel_length: torch.Tensor) -> torch.Tensor:
        """Compute alignment consistency loss"""
        predicted_total_length = predicted_durations.sum(dim=1)
        target_total_length = target_mel_length.float()

        return self.alignment_loss_fn(predicted_total_length, target_total_length)

    def compute_consistency_loss(self, predicted_mel: torch.Tensor) -> torch.Tensor:
        """Compute temporal consistency loss for smoother outputs"""
        if predicted_mel.size(1) < 2:
            return torch.tensor(0.0, device=predicted_mel.device)

        # Compute differences between consecutive frames
        mel_diff = predicted_mel[:, 1:, :] - predicted_mel[:, :-1, :]

        # Penalize large differences (encourage smoothness)
        consistency_loss = torch.mean(torch.abs(mel_diff))

        return consistency_loss

    def compute_total_loss(self, predicted_mel: torch.Tensor, target_mel: torch.Tensor,
                          predicted_durations: torch.Tensor, target_durations: torch.Tensor,
                          text_lengths: Optional[torch.Tensor] = None,
                          mel_lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Compute all losses and return detailed breakdown"""

        # Individual loss components
        mel_loss = self.compute_mel_loss(predicted_mel, target_mel, mel_lengths)
        duration_loss = self.compute_duration_loss(predicted_durations, target_durations, text_lengths)

        # Alignment loss
        target_mel_length = torch.tensor([target_mel.size(1)] * target_mel.size(0),
                                       device=target_mel.device)
        alignment_loss = self.compute_alignment_loss(predicted_durations, target_mel_length)

        # Consistency loss
        consistency_loss = self.compute_consistency_loss(predicted_mel)

        # Weighted total loss
        total_loss = (self.mel_loss_weight * mel_loss +
                     self.duration_loss_weight * duration_loss +
                     self.alignment_loss_weight * alignment_loss +
                     self.consistency_loss_weight * consistency_loss)

        # Store in history
        losses = {
            'total_loss': total_loss,
            'mel_loss': mel_loss,
            'duration_loss': duration_loss,
            'alignment_loss': alignment_loss,
            'consistency_loss': consistency_loss
        }

        for name, loss in losses.items():
            self.loss_history[name].append(loss.item())

        return losses

    def get_loss_summary(self, last_n: int = 100) -> Dict[str, float]:
        """Get summary statistics for recent losses"""
        summary = {}
        for name, history in self.loss_history.items():
            if history:
                recent = history[-last_n:]
                summary[name] = {
                    'mean': np.mean(recent),
                    'std': np.std(recent),
                    'min': np.min(recent),
                    'max': np.max(recent),
                    'latest': recent[-1] if recent else 0.0
                }
        return summary

class PerformanceMetrics:
    """Comprehensive performance metrics for TTS training"""

    def __init__(self, window_size: int = 100):
        self.window_size = window_size
        self.metrics = defaultdict(lambda: deque(maxlen=window_size))
        self.epoch_metrics = defaultdict(list)

    def update(self, **kwargs):
        """Update metrics with new values"""
        for name, value in kwargs.items():
            self.metrics[name].append(value)

    def get_current_stats(self) -> Dict[str, Dict[str, float]]:
        """Get current statistics for all metrics"""
        stats = {}
        for name, values in self.metrics.items():
            if values:
                values_list = list(values)
                stats[name] = {
                    'mean': np.mean(values_list),
                    'std': np.std(values_list),
                    'min': np.min(values_list),
                    'max': np.max(values_list),
                    'latest': values_list[-1]
                }
        return stats

    def end_epoch(self):
        """Store epoch-level statistics"""
        current_stats = self.get_current_stats()
        for name, stats in current_stats.items():
            self.epoch_metrics[name].append(stats['mean'])

    def plot_metrics(self, metrics_to_plot: List[str] = None, figsize: Tuple[int, int] = (15, 10)):
        """Plot training metrics"""
        if metrics_to_plot is None:
            metrics_to_plot = list(self.epoch_metrics.keys())

        n_metrics = len(metrics_to_plot)
        if n_metrics == 0:
            print("No metrics to plot")
            return

        n_cols = min(3, n_metrics)
        n_rows = (n_metrics + n_cols - 1) // n_cols

        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
        if n_rows == 1 and n_cols == 1:
            axes = [axes]
        elif n_rows == 1:
            axes = axes
        else:
            axes = axes.flatten()

        for i, metric_name in enumerate(metrics_to_plot):
            if i < len(axes) and metric_name in self.epoch_metrics:
                ax = axes[i]
                values = self.epoch_metrics[metric_name]
                ax.plot(values, linewidth=2)
                ax.set_title(metric_name.replace('_', ' ').title())
                ax.set_xlabel('Epoch')
                ax.set_ylabel('Value')
                ax.grid(True, alpha=0.3)

        # Hide unused subplots
        for i in range(len(metrics_to_plot), len(axes)):
            axes[i].set_visible(False)

        plt.tight_layout()
        plt.show()

# ==============================================================================
# CELL 16: Advanced Training Pipeline
# ==============================================================================

class LibriSpeechTTSTrainer:
    """Production-ready training pipeline for LibriSpeech TTS"""

    def __init__(self, model, train_dataloader, val_dataloader, config):
        self.model = model
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.config = config
        self.device = config.DEVICE

        # Move model to device
        self.model = self.model.to(self.device)

        # Loss computer
        self.loss_computer = TTSLossComputer(
            mel_loss_weight=config.MEL_LOSS_WEIGHT,
            duration_loss_weight=config.DURATION_LOSS_WEIGHT,
            alignment_loss_weight=config.ALIGNMENT_LOSS_WEIGHT
        )

        # Performance metrics
        self.metrics = PerformanceMetrics(window_size=100)

        # Optimizer with weight decay
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.LEARNING_RATE,
            betas=(0.9, 0.98),
            eps=1e-9,
            weight_decay=1e-6
        )

        # Advanced learning rate scheduler
        total_steps = len(train_dataloader) * config.NUM_EPOCHS
        self.scheduler = optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=config.LEARNING_RATE * 10,
            total_steps=total_steps,
            pct_start=0.1,
            div_factor=25,
            final_div_factor=10000,
            anneal_strategy='cos'
        )

        # Mixed precision training
        self.use_amp = config.USE_MIXED_PRECISION and torch.cuda.is_available()
        if self.use_amp:
            self.scaler = GradScaler()
            print("⚡ Using mixed precision training (AMP)")

        # Training state
        self.current_epoch = 0
        self.global_step = 0
        self.best_val_loss = float('inf')
        self.patience_counter = 0
        self.early_stopping_patience = 20

        # Training history
        self.train_loss_history = []
        self.val_loss_history = []

        # Gradient clipping
        self.max_grad_norm = config.GRAD_CLIP_NORM

        print(f"🏃‍♂️ LibriSpeechTTSTrainer initialized:")
        print(f"   Device: {self.device}")
        print(f"   Mixed precision: {self.use_amp}")
        print(f"   Training batches: {len(train_dataloader)}")
        print(f"   Validation batches: {len(val_dataloader) if val_dataloader else 0}")
        print(f"   Total training steps: {total_steps}")
        print(f"   Max learning rate: {config.LEARNING_RATE * 10}")

    def train_step(self, batch) -> Dict[str, float]:
        """Single training step with mixed precision"""
        texts, mels, durations, text_lengths = batch
        texts = texts.to(self.device)
        mels = mels.to(self.device)
        durations = durations.to(self.device)
        text_lengths = text_lengths.to(self.device)

        # Forward pass
        self.optimizer.zero_grad()

        if self.use_amp:
            with autocast():
                predicted_mels, predicted_durations = self.model.forward(texts)
                losses = self.loss_computer.compute_total_loss(
                    predicted_mels, mels, predicted_durations, durations, text_lengths
                )
                total_loss = losses['total_loss']

            # Backward pass with scaling
            self.scaler.scale(total_loss).backward()

            # Gradient clipping
            self.scaler.unscale_(self.optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

            # Optimizer step
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            predicted_mels, predicted_durations = self.model.forward(texts)
            losses = self.loss_computer.compute_total_loss(
                predicted_mels, mels, predicted_durations, durations, text_lengths
            )
            total_loss = losses['total_loss']

            # Backward pass
            total_loss.backward()

            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

            # Optimizer step
            self.optimizer.step()

        # Scheduler step
        self.scheduler.step()

        # Prepare metrics
        step_metrics = {
            'total_loss': losses['total_loss'].item(),
            'mel_loss': losses['mel_loss'].item(),
            'duration_loss': losses['duration_loss'].item(),
            'alignment_loss': losses['alignment_loss'].item(),
            'consistency_loss': losses['consistency_loss'].item(),
            'learning_rate': self.scheduler.get_last_lr()[0],
            'grad_norm': grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
        }

        self.global_step += 1
        return step_metrics

    def validation_step(self, batch) -> Dict[str, float]:
        """Single validation step"""
        texts, mels, durations, text_lengths = batch
        texts = texts.to(self.device)
        mels = mels.to(self.device)
        durations = durations.to(self.device)
        text_lengths = text_lengths.to(self.device)

        with torch.no_grad():
            if self.use_amp:
                with autocast():
                    predicted_mels, predicted_durations = self.model.forward(texts)
                    losses = self.loss_computer.compute_total_loss(
                        predicted_mels, mels, predicted_durations, durations, text_lengths
                    )
            else:
                predicted_mels, predicted_durations = self.model.forward(texts)
                losses = self.loss_computer.compute_total_loss(
                    predicted_mels, mels, predicted_durations, durations, text_lengths
                )

        return {f"val_{name}": loss.item() for name, loss in losses.items()}

    def train_epoch(self) -> Dict[str, float]:
        """Train for one epoch"""
        self.model.train()
        epoch_metrics = defaultdict(list)

        # Progress bar
        pbar = tqdm(self.train_dataloader, desc=f"Epoch {self.current_epoch + 1}")

        for batch_idx, batch in enumerate(pbar):
            # Training step
            step_metrics = self.train_step(batch)

            # Update metrics
            for name, value in step_metrics.items():
                epoch_metrics[name].append(value)

            # Update progress bar
            if batch_idx % 10 == 0:
                current_lr = step_metrics['learning_rate']
                pbar.set_postfix({
                    'Loss': f"{step_metrics['total_loss']:.4f}",
                    'Mel': f"{step_metrics['mel_loss']:.4f}",
                    'Dur': f"{step_metrics['duration_loss']:.4f}",
                    'LR': f"{current_lr:.2e}",
                    'GradNorm': f"{step_metrics['grad_norm']:.2f}"
                })

            # Update live metrics
            self.metrics.update(**step_metrics)

            # Memory cleanup
            if batch_idx % 100 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Compute epoch averages
        epoch_avg = {name: np.mean(values) for name, values in epoch_metrics.items()}
        return epoch_avg

    def validate_epoch(self) -> Dict[str, float]:
        """Validate for one epoch"""
        if self.val_dataloader is None:
            return {}

        self.model.eval()
        epoch_metrics = defaultdict(list)

        with torch.no_grad():
            for batch in tqdm(self.val_dataloader, desc="Validation"):
                step_metrics = self.validation_step(batch)

                for name, value in step_metrics.items():
                    epoch_metrics[name].append(value)

        # Compute averages
        epoch_avg = {name: np.mean(values) for name, values in epoch_metrics.items()}
        return epoch_avg

    def save_checkpoint(self, filepath: str, is_best: bool = False):
        """Save training checkpoint"""
        checkpoint = {
            'epoch': self.current_epoch,
            'global_step': self.global_step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_loss': self.best_val_loss,
            'train_loss_history': self.train_loss_history,
            'val_loss_history': self.val_loss_history,
            'config': {
                'vocab_size': self.config.VOCAB_SIZE,
                'hidden_dim': self.config.HIDDEN_DIM,
                'num_layers': self.config.NUM_LAYERS,
                'sample_rate': self.config.SAMPLE_RATE,
                'mel_dim': self.config.MEL_DIM
            }
        }

        if self.use_amp:
            checkpoint['scaler_state_dict'] = self.scaler.state_dict()

        torch.save(checkpoint, filepath)

        if is_best:
            best_path = str(Path(filepath).parent / 'best_model.pth')
            torch.save(checkpoint, best_path)
            print(f"💾 Best model saved: {best_path}")

        print(f"💾 Checkpoint saved: {filepath}")

    def load_checkpoint(self, filepath: str) -> bool:
        """Load training checkpoint"""
        try:
            checkpoint = torch.load(filepath, map_location=self.device)

            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            self.current_epoch = checkpoint['epoch']
            self.global_step = checkpoint['global_step']
            self.best_val_loss = checkpoint['best_val_loss']
            self.train_loss_history = checkpoint.get('train_loss_history', [])
            self.val_loss_history = checkpoint.get('val_loss_history', [])

            if self.use_amp and 'scaler_state_dict' in checkpoint:
                self.scaler.load_state_dict(checkpoint['scaler_state_dict'])

            print(f"📖 Checkpoint loaded: {filepath}")
            print(f"   Resuming from epoch {self.current_epoch}")
            print(f"   Global step: {self.global_step}")
            print(f"   Best validation loss: {self.best_val_loss:.4f}")

            return True

        except Exception as e:
            print(f"❌ Failed to load checkpoint: {e}")
            return False

    def train(self, num_epochs: int, save_dir: str = "./checkpoints",
              save_freq: int = 5, validate_freq: int = 1) -> Dict[str, List[float]]:
        """Main training loop"""
        save_dir = Path(save_dir)
        save_dir.mkdir(exist_ok=True, parents=True)

        print(f"🚀 Starting training for {num_epochs} epochs")
        print(f"   Save directory: {save_dir}")
        print(f"   Save frequency: every {save_freq} epochs")
        print(f"   Validation frequency: every {validate_freq} epochs")

        try:
            for epoch in range(self.current_epoch, self.current_epoch + num_epochs):
                self.current_epoch = epoch
                epoch_start_time = time.time()

                print(f"\n📈 Epoch {epoch + 1}/{self.current_epoch + num_epochs}")
                print("-" * 60)

                # Training
                train_metrics = self.train_epoch()
                self.train_loss_history.append(train_metrics['total_loss'])

                # Validation
                val_metrics = {}
                if epoch % validate_freq == 0:
                    val_metrics = self.validate_epoch()
                    if 'val_total_loss' in val_metrics:
                        self.val_loss_history.append(val_metrics['val_total_loss'])

                # Epoch timing
                epoch_time = time.time() - epoch_start_time

                # Print epoch summary
                print(f"✅ Epoch {epoch + 1} completed in {epoch_time:.1f}s")
                print(f"   Train Loss: {train_metrics['total_loss']:.4f}")
                if val_metrics:
                    print(f"   Val Loss: {val_metrics.get('val_total_loss', 0):.4f}")
                print(f"   Learning Rate: {train_metrics['learning_rate']:.2e}")

                # Early stopping and best model saving
                if val_metrics and 'val_total_loss' in val_metrics:
                    val_loss = val_metrics['val_total_loss']
                    if val_loss < self.best_val_loss:
                        self.best_val_loss = val_loss
                        self.patience_counter = 0
                        is_best = True
                        print(f"   🏆 New best validation loss: {val_loss:.4f}")
                    else:
                        self.patience_counter += 1
                        is_best = False
                        if self.patience_counter >= self.early_stopping_patience:
                            print(f"   ⏹️ Early stopping triggered (patience: {self.early_stopping_patience})")
                            break
                else:
                    is_best = False

                # Save checkpoint
                if epoch % save_freq == 0 or is_best:
                    checkpoint_path = save_dir / f"checkpoint_epoch_{epoch + 1}.pth"
                    self.save_checkpoint(checkpoint_path, is_best=is_best)

                # Update epoch metrics
                self.metrics.end_epoch()

                # Plot progress periodically
                if epoch % (save_freq * 2) == 0 and epoch > 0:
                    self.plot_training_progress()

        except KeyboardInterrupt:
            print(f"\n⏹️ Training interrupted at epoch {self.current_epoch + 1}")
            checkpoint_path = save_dir / "interrupted_checkpoint.pth"
            self.save_checkpoint(checkpoint_path)

        except Exception as e:
            print(f"\n❌ Training failed with error: {e}")
            import traceback
            traceback.print_exc()

        finally:
            # Final checkpoint
            final_path = save_dir / "final_checkpoint.pth"
            self.save_checkpoint(final_path)

            print(f"\n🎉 Training completed!")
            print(f"   Total epochs: {self.current_epoch + 1}")
            print(f"   Best validation loss: {self.best_val_loss:.4f}")
            print(f"   Final checkpoint: {final_path}")

        return {
            'train_loss': self.train_loss_history,
            'val_loss': self.val_loss_history
        }

    def plot_training_progress(self):
        """Plot training progress"""
        if len(self.train_loss_history) < 2:
            return

        plt.figure(figsize=(15, 5))

        # Loss plot
        plt.subplot(1, 3, 1)
        epochs = range(1, len(self.train_loss_history) + 1)
        plt.plot(epochs, self.train_loss_history, label='Training Loss', linewidth=2)
        if self.val_loss_history:
            val_epochs = range(1, len(self.val_loss_history) + 1)
            plt.plot(val_epochs, self.val_loss_history, label='Validation Loss', linewidth=2)
        plt.title('Training Progress')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Learning rate plot
        plt.subplot(1, 3, 2)
        if hasattr(self.metrics, 'epoch_metrics') and 'learning_rate' in self.metrics.epoch_metrics:
            lr_history = self.metrics.epoch_metrics['learning_rate']
            plt.plot(range(1, len(lr_history) + 1), lr_history, linewidth=2, color='orange')
            plt.title('Learning Rate Schedule')
            plt.xlabel('Epoch')
            plt.ylabel('Learning Rate')
            plt.yscale('log')
            plt.grid(True, alpha=0.3)

        # Gradient norm plot
        plt.subplot(1, 3, 3)
        if hasattr(self.metrics, 'epoch_metrics') and 'grad_norm' in self.metrics.epoch_metrics:
            grad_history = self.metrics.epoch_metrics['grad_norm']
            plt.plot(range(1, len(grad_history) + 1), grad_history, linewidth=2, color='green')
            plt.title('Gradient Norm')
            plt.xlabel('Epoch')
            plt.ylabel('Gradient Norm')
            plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

# ==============================================================================
# CELL 17: Model Evaluation and Testing
# ==============================================================================

class ModelEvaluator:
    """Comprehensive model evaluation and testing"""

    def __init__(self, model, text_processor, config):
        self.model = model
        self.text_processor = text_processor
        self.config = config
        self.device = config.DEVICE

        self.model.eval()

    def evaluate_single_sample(self, text: str, return_intermediate: bool = True) -> Dict[str, Any]:
        """Evaluate model on a single text sample"""
        # Tokenize text
        text_ids = self.text_processor.text_to_ids(text)
        text_tensor = torch.tensor([text_ids], dtype=torch.long).to(self.device)

        # Time the inference
        start_time = time.time()

        with torch.no_grad():
            results = self.model.inference(text_tensor, return_intermediate=return_intermediate)

        inference_time = time.time() - start_time

        # Add metadata
        results.update({
            'input_text': text,
            'input_tokens': text_ids,
            'inference_time': inference_time,
            'real_time_factor': results['audio_length'] / self.config.SAMPLE_RATE / inference_time
        })

        return results

    def evaluate_streaming_performance(self, text: str, buffer_size: int = 4) -> Dict[str, Any]:
        """Evaluate streaming performance with latency measurements"""
        words = text.split()

        # Initialize streaming
        self.model.reset_streaming_state()
        buffer = []
        chunk_times = []
        chunk_audio_lengths = []
        total_audio = []