In [1]:
import warnings
warnings.filterwarnings('ignore')

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx_lm import load, generate
import torchaudio  # –î–ª—è –∑–∞–≥—Ä—É–∑–∫–∏ –∞—É–¥–∏–æ, –ø–æ–∫–∞ MLX –Ω–µ –ø–æ–¥–¥–µ—Ä–∂–∏–≤–∞–µ—Ç —ç—Ç–æ –Ω–∞–ø—Ä—è–º—É—é
import json
import os
import numpy as np
import soundfile as sf
from tqdm import tqdm
from transformers import Wav2Vec2FeatureExtractor, AutoTokenizer
import jiwer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from datetime import datetime
from typing import Dict, List, Tuple, Optional
import math

# –ù–∞—Å—Ç—Ä–æ–π–∫–∞ MLX –¥–ª—è –æ–ø—Ç–∏–º–∞–ª—å–Ω–æ–π —Ä–∞–±–æ—Ç—ã
mx.random.seed(42)

# –ö–æ–Ω—Å—Ç–∞–Ω—Ç—ã –∏ –≥–∏–ø–µ—Ä–ø–∞—Ä–∞–º–µ—Ç—Ä—ã
batch_size = 4
num_epochs = 3
learning_rate = 1e-4
weight_decay = 0.01
max_grad_norm = 1.0
warmup_steps = 100
input_dim = 768  # Wav2Vec2 –≤—ã—Ö–æ–¥–Ω–∞—è —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç—å
output_dim = 2560  # Gemma embedding —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç—å

# –ò–Ω—Ñ–æ—Ä–º–∞—Ü–∏—è –æ–± —ç–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç–µ
experiment_name = f"audio_projector_mlx_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
checkpoint_dir = f"./checkpoints/{experiment_name}"
os.makedirs(checkpoint_dir, exist_ok=True)

# –ù–∞—Å—Ç—Ä–æ–π–∫–∞ —Å—Ç–∏–ª–µ–π –¥–ª—è –≥—Ä–∞—Ñ–∏–∫–æ–≤
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print(f"üöÄ MLX –≠–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç: {experiment_name}")
print(f"üìÅ –ß–µ–∫–ø–æ–∏–Ω—Ç—ã: {checkpoint_dir}")
print(f"üñ•Ô∏è –£—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: MLX (Apple Silicon)")
print(f"‚öôÔ∏è –ö–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏—è:")
print(f"   - Batch size: {batch_size}")
print(f"   - Epochs: {num_epochs}")
print(f"   - Learning rate: {learning_rate}")
print(f"   - Weight decay: {weight_decay}")
print(f"   - Gradient clipping: {max_grad_norm}")
print(f"   - Warmup steps: {warmup_steps}")
print(f"üîó Projector: {input_dim} -> {output_dim}")

class AudioProjector(nn.Module):
    """MLX-based –∞—É–¥–∏–æ –ø—Ä–æ–µ–∫—Ç–æ—Ä –¥–ª—è –º–∞–ø–ø–∏–Ω–≥–∞ Wav2Vec2 —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ –≤ –ø—Ä–æ—Å—Ç—Ä–∞–Ω—Å—Ç–≤–æ Gemma"""
    
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
            nn.LayerNorm(output_dim)
        )
    
    def __call__(self, x: mx.array) -> mx.array:
        return self.proj(x)

class Wav2Vec2Wrapper:
    """–û–±–µ—Ä—Ç–∫–∞ –¥–ª—è Wav2Vec2 —Å –∫–æ–Ω–≤–µ—Ä—Ç–∞—Ü–∏–µ–π –≤ MLX"""
    
    def __init__(self, model_name: str = "facebook/wav2vec2-base"):
        from transformers import Wav2Vec2Model
        self.model = Wav2Vec2Model.from_pretrained(model_name)
        self.model.eval()
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
        
    def extract_features(self, waveforms: np.ndarray) -> mx.array:
        """–ò–∑–≤–ª–µ—á–µ–Ω–∏–µ –ø—Ä–∏–∑–Ω–∞–∫–æ–≤ –∏–∑ –∞—É–¥–∏–æ —Å –∫–æ–Ω–≤–µ—Ä—Ç–∞—Ü–∏–µ–π –≤ MLX"""
        import torch
        with torch.no_grad():
            if len(waveforms.shape) == 1:
                waveforms = waveforms[None, :]
            inputs = torch.from_numpy(waveforms).float()
            outputs = self.model(inputs)
            features = outputs.last_hidden_state.mean(dim=1)  # Global average pooling
            return mx.array(features.numpy())

class GemmaWrapper:
    """–û–±–µ—Ä—Ç–∫–∞ –¥–ª—è Gemma –º–æ–¥–µ–ª–∏ —á–µ—Ä–µ–∑ MLX-LM"""
    
    def __init__(self, model_path: str = "mlx-community/gemma-2b-it"):
        self.model, self.tokenizer = load(model_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.embedding_dim = output_dim
        
    def get_input_embeddings(self, input_ids: mx.array) -> mx.array:
        """–ü–æ–ª—É—á–µ–Ω–∏–µ –≤—Ö–æ–¥–Ω—ã—Ö —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤"""
        return self.model.embed_tokens(input_ids)
    
    def forward_with_embeddings(self, inputs_embeds: mx.array, labels: Optional[mx.array] = None) -> Dict:
        """Forward pass —Å –∫–∞—Å—Ç–æ–º–Ω—ã–º–∏ —ç–º–±–µ–¥–¥–∏–Ω–≥–∞–º–∏"""
        # –≠—Ç–æ —É–ø—Ä–æ—â–µ–Ω–Ω–∞—è –≤–µ—Ä—Å–∏—è - –≤ —Ä–µ–∞–ª—å–Ω–æ–π —Ä–µ–∞–ª–∏–∑–∞—Ü–∏–∏ –Ω—É–∂–Ω–æ –º–æ–¥–∏—Ñ–∏—Ü–∏—Ä–æ–≤–∞—Ç—å MLX –º–æ–¥–µ–ª—å
        logits = self.model(inputs_embeds=inputs_embeds)
        
        if labels is not None:
            # –í—ã—á–∏—Å–ª–µ–Ω–∏–µ loss
            shift_logits = logits[..., :-1, :]
            shift_labels = labels[..., 1:]
            loss = nn.losses.cross_entropy(
                shift_logits.reshape(-1, shift_logits.shape[-1]),
                shift_labels.reshape(-1),
                ignore_index=-100
            )
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}

class AudioTextDataset:
    """Dataset –¥–ª—è –∞—É–¥–∏–æ-—Ç–µ–∫—Å—Ç–æ–≤—ã—Ö –ø–∞—Ä"""
    
    def __init__(self, data: List[Dict], tokenizer, feature_extractor, wav2vec2_wrapper):
        self.data = data
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.wav2vec2_wrapper = wav2vec2_wrapper
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, mx.array]:
        item = self.data[idx]
        audio_path = item["audio_path"]
        speaker_text = item["speaker_text"]
        
        # –ó–∞–≥—Ä—É–∑–∫–∞ –∏ –æ–±—Ä–∞–±–æ—Ç–∫–∞ –∞—É–¥–∏–æ
        waveform, sr = torchaudio.load(audio_path)
        if sr != self.feature_extractor.sampling_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.feature_extractor.sampling_rate)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # –ò–∑–≤–ª–µ—á–µ–Ω–∏–µ –∞—É–¥–∏–æ –ø—Ä–∏–∑–Ω–∞–∫–æ–≤
        audio_features = self.wav2vec2_wrapper.extract_features(waveform.squeeze().numpy())
        
        # –¢–æ–∫–µ–Ω–∏–∑–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–∞
        tokens = self.tokenizer(
            speaker_text,
            return_tensors="np",
            padding=True,
            truncation=True,
            max_length=512
        )
        
        return {
            "audio_features": audio_features,
            "input_ids": mx.array(tokens.input_ids.squeeze(0)),
            "attention_mask": mx.array(tokens.attention_mask.squeeze(0))
        }

def collate_fn(batch: List[Dict]) -> Dict[str, mx.array]:
    """–§—É–Ω–∫—Ü–∏—è –¥–ª—è –æ–±—ä–µ–¥–∏–Ω–µ–Ω–∏—è –±–∞—Ç—á–∞"""
    audio_features = [item['audio_features'] for item in batch]
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    
    # Padding –¥–ª—è batch
    max_seq_len = max(len(ids) for ids in input_ids)
    
    padded_input_ids = []
    padded_attention_mask = []
    
    for ids, mask in zip(input_ids, attention_mask):
        pad_len = max_seq_len - len(ids)
        padded_ids = mx.concatenate([ids, mx.full((pad_len,), -100, dtype=mx.int32)])
        padded_mask = mx.concatenate([mask, mx.zeros((pad_len,), dtype=mx.int32)])
        padded_input_ids.append(padded_ids)
        padded_attention_mask.append(padded_mask)
    
    return {
        'audio_features': mx.stack(audio_features),
        'input_ids': mx.stack(padded_input_ids),
        'attention_mask': mx.stack(padded_attention_mask)
    }

class TrainingLogger:
    """–õ–æ–≥–≥–µ—Ä –¥–ª—è –æ—Ç—Å–ª–µ–∂–∏–≤–∞–Ω–∏—è –ø—Ä–æ—Ü–µ—Å—Å–∞ –æ–±—É—á–µ–Ω–∏—è"""
    
    def __init__(self, experiment_name: str, save_dir: str):
        self.experiment_name = experiment_name
        self.save_dir = save_dir
        self.logs = {
            'epoch': [],
            'train_loss': [],
            'val_loss': [],
            'val_perplexity': [],
            'val_wer': [],
            'val_bleu': [],
            'val_rouge_l': [],
            'learning_rate': [],
            'grad_norm': []
        }
    
    def log(self, epoch: int, train_loss: float, val_metrics: Dict, lr: float, grad_norm: Optional[float] = None):
        """–õ–æ–≥–∏—Ä–æ–≤–∞–Ω–∏–µ –º–µ—Ç—Ä–∏–∫"""
        self.logs['epoch'].append(epoch)
        self.logs['train_loss'].append(train_loss)
        self.logs['val_loss'].append(val_metrics['loss'])
        self.logs['val_perplexity'].append(val_metrics['perplexity'])
        self.logs['val_wer'].append(val_metrics['wer'])
        self.logs['val_bleu'].append(val_metrics['bleu'])
        self.logs['val_rouge_l'].append(val_metrics['rouge_l'])
        self.logs['learning_rate'].append(lr)
        if grad_norm is not None:
            self.logs['grad_norm'].append(grad_norm)
    
    def plot_training_curves(self):
        """–ü–æ—Å—Ç—Ä–æ–µ–Ω–∏–µ –≥—Ä–∞—Ñ–∏–∫–æ–≤ –æ–±—É—á–µ–Ω–∏—è"""
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle(f'MLX Training Progress: {self.experiment_name}', fontsize=16, fontweight='bold')
        
        # Loss curves
        axes[0, 0].plot(self.logs['epoch'], self.logs['train_loss'], 'b-', label='Train Loss', linewidth=2)
        axes[0, 0].plot(self.logs['epoch'], self.logs['val_loss'], 'r-', label='Val Loss', linewidth=2)
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Loss Curves')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Perplexity
        axes[0, 1].plot(self.logs['epoch'], self.logs['val_perplexity'], 'g-', linewidth=2)
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Perplexity')
        axes[0, 1].set_title('Validation Perplexity')
        axes[0, 1].grid(True, alpha=0.3)
        
        # WER
        axes[0, 2].plot(self.logs['epoch'], self.logs['val_wer'], 'orange', linewidth=2)
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('WER')
        axes[0, 2].set_title('Word Error Rate')
        axes[0, 2].grid(True, alpha=0.3)
        
        # BLEU
        axes[1, 0].plot(self.logs['epoch'], self.logs['val_bleu'], 'purple', linewidth=2)
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('BLEU Score')
        axes[1, 0].set_title('BLEU Score')
        axes[1, 0].grid(True, alpha=0.3)
        
        # ROUGE-L
        axes[1, 1].plot(self.logs['epoch'], self.logs['val_rouge_l'], 'brown', linewidth=2)
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('ROUGE-L')
        axes[1, 1].set_title('ROUGE-L Score')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Learning Rate
        axes[1, 2].plot(self.logs['epoch'], self.logs['learning_rate'], 'teal', linewidth=2)
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Learning Rate')
        axes[1, 2].set_title('Learning Rate Schedule')
        axes[1, 2].grid(True, alpha=0.3)
        axes[1, 2].set_yscale('log')
        
        plt.tight_layout()
        
        plot_path = os.path.join(self.save_dir, 'training_curves.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"üìä –ì—Ä–∞—Ñ–∏–∫ —Å–æ—Ö—Ä–∞–Ω–µ–Ω: {plot_path}")
        plt.show()
    
    def save_logs(self) -> pd.DataFrame:
        """–°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ –ª–æ–≥–æ–≤ –≤ CSV"""
        df = pd.DataFrame(self.logs)
        csv_path = os.path.join(self.save_dir, 'training_logs.csv')
        df.to_csv(csv_path, index=False)
        print(f"üìù –õ–æ–≥–∏ —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã: {csv_path}")
        return df

def process_batch(
    batch: Dict[str, mx.array], 
    gemma_model: GemmaWrapper, 
    projector: AudioProjector, 
    prefix_embeds: mx.array, 
    batch_idx: int = 0, 
    context: str = ""
) -> Tuple[Dict, mx.array]:
    """–û–±—Ä–∞–±–æ—Ç–∫–∞ –±–∞—Ç—á–∞ –¥–∞–Ω–Ω—ã—Ö"""
    
    audio_features = batch["audio_features"]
    input_ids = batch["input_ids"]
    
    # –ü—Ä–æ–µ–∫—Ü–∏—è –∞—É–¥–∏–æ-–ø—Ä–∏–∑–Ω–∞–∫–æ–≤
    projected_audio = projector(audio_features)
    
    # –°–æ–∑–¥–∞–Ω–∏–µ –ø—Ä–æ–º–ø—Ç —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤
    batch_size = projected_audio.shape[0]
    batch_prefix_embeds = mx.broadcast_to(prefix_embeds, (batch_size, prefix_embeds.shape[1], prefix_embeds.shape[2]))
    
    # –û–±—ä–µ–¥–∏–Ω–µ–Ω–∏–µ –ø—Ä–µ—Ñ–∏–∫—Å–∞ –∏ –∞—É–¥–∏–æ
    prompt_embeds = mx.concatenate([batch_prefix_embeds, mx.expand_dims(projected_audio, 1)], axis=1)
    
    # –ü–æ–ª—É—á–µ–Ω–∏–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ –¥–ª—è —Ü–µ–ª–µ–≤–æ–≥–æ —Ç–µ–∫—Å—Ç–∞
    target_embeds = gemma_model.get_input_embeddings(input_ids)
    
    # –û–±—ä–µ–¥–∏–Ω–µ–Ω–∏–µ –ø—Ä–æ–º–ø—Ç–∞ –∏ —Ü–µ–ª–µ–≤–æ–≥–æ —Ç–µ–∫—Å—Ç–∞
    inputs_embeds = mx.concatenate([prompt_embeds, target_embeds], axis=1)
    
    # –°–æ–∑–¥–∞–Ω–∏–µ –ª–µ–π–±–ª–æ–≤ (–∏–≥–Ω–æ—Ä–∏—Ä—É–µ–º –ø—Ä–æ–º–ø—Ç)
    prompt_len = prompt_embeds.shape[1]
    prompt_labels = mx.full((batch_size, prompt_len), -100, dtype=mx.int32)
    labels = mx.concatenate([prompt_labels, input_ids], axis=1)
    
    if batch_idx == 0 and context:
        print(f"\n--- MLX –†–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏ –≤ {context} (–ø–µ—Ä–≤—ã–π –±–∞—Ç—á) ---")
        print(f"Audio Projected Embeds: {projected_audio.shape}")
        print(f"Prompt Embeds (prefix + audio): {prompt_embeds.shape}")
        print(f"Target Text Embeds: {target_embeds.shape}")
        print(f"Combined Input Embeds: {inputs_embeds.shape}")
        print(f"Labels: {labels.shape}")
        print("--------------------------------------------------------")
    
    # Forward pass —á–µ—Ä–µ–∑ –º–æ–¥–µ–ª—å
    outputs = gemma_model.forward_with_embeddings(inputs_embeds, labels)
    
    return outputs, prompt_embeds

def evaluate_with_metrics(
    gemma_model: GemmaWrapper,
    projector: AudioProjector,
    dataloader: List[Dict],
    prefix_embeds: mx.array
) -> Dict[str, float]:
    """–û—Ü–µ–Ω–∫–∞ –º–æ–¥–µ–ª–∏ —Å –≤—ã—á–∏—Å–ª–µ–Ω–∏–µ–º –º–µ—Ç—Ä–∏–∫"""
    
    total_loss = 0.0
    total_wer = 0.0
    total_bleu = 0.0
    total_rouge_1 = 0.0
    total_rouge_2 = 0.0
    total_rouge_l = 0.0
    count = 0
    
    smooth = SmoothingFunction().method1
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    for i, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
        outputs, prompt_embeds = process_batch(
            batch, gemma_model, projector, prefix_embeds, batch_idx=i, context="evaluation"
        )
        
        loss = outputs["loss"]
        total_loss += float(loss)
        
        # –ì–µ–Ω–µ—Ä–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–∞ –¥–ª—è –º–µ—Ç—Ä–∏–∫
        for j in range(batch["input_ids"].shape[0]):
            single_prompt = prompt_embeds[j:j+1]
            
            # –ó–¥–µ—Å—å –Ω—É–∂–Ω–∞ –º–æ–¥–∏—Ñ–∏–∫–∞—Ü–∏—è MLX –¥–ª—è –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ —Å –∫–∞—Å—Ç–æ–º–Ω—ã–º–∏ —ç–º–±–µ–¥–¥–∏–Ω–≥–∞–º–∏
            # –ü–æ–∫–∞ –∏—Å–ø–æ–ª—å–∑—É–µ–º —É–ø—Ä–æ—â–µ–Ω–Ω—É—é –≤–µ—Ä—Å–∏—é
            generated_text = "placeholder_generated_text"  # TODO: –†–µ–∞–ª–∏–∑–æ–≤–∞—Ç—å –≥–µ–Ω–µ—Ä–∞—Ü–∏—é
            
            # –†–µ—Ñ–µ—Ä–µ–Ω—Å–Ω—ã–π —Ç–µ–∫—Å—Ç
            ref_ids = batch["input_ids"][j]
            ref_ids = ref_ids[ref_ids != -100]
            ref_text = gemma_model.tokenizer.decode(ref_ids.tolist(), skip_special_tokens=True).strip()
            
            if ref_text and generated_text:
                total_wer += jiwer.wer(ref_text, generated_text)
                total_bleu += sentence_bleu([ref_text.split()], generated_text.split(), smoothing_function=smooth)
                rouge_scores = rouge_scorer_obj.score(ref_text, generated_text)
                total_rouge_1 += rouge_scores['rouge1'].fmeasure
                total_rouge_2 += rouge_scores['rouge2'].fmeasure
                total_rouge_l += rouge_scores['rougeL'].fmeasure
                count += 1
    
    avg_loss = total_loss / len(dataloader)
    perplexity = math.exp(avg_loss)
    
    return {
        'loss': avg_loss,
        'perplexity': perplexity,
        'wer': total_wer / count if count > 0 else 0.0,
        'bleu': total_bleu / count if count > 0 else 0.0,
        'rouge_1': total_rouge_1 / count if count > 0 else 0.0,
        'rouge_2': total_rouge_2 / count if count > 0 else 0.0,
        'rouge_l': total_rouge_l / count if count > 0 else 0.0
    }

def create_learning_rate_schedule(total_steps: int, warmup_steps: int, max_lr: float):
    """–°–æ–∑–¥–∞–Ω–∏–µ —Ä–∞—Å–ø–∏—Å–∞–Ω–∏—è learning rate"""
    def schedule(step: int) -> float:
        if step < warmup_steps:
            return max_lr * step / warmup_steps
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return max_lr * 0.5 * (1 + math.cos(math.pi * progress))
    return schedule

# –û—Å–Ω–æ–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è –æ–±—É—á–µ–Ω–∏—è
def main():
    """–û—Å–Ω–æ–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è –æ–±—É—á–µ–Ω–∏—è MLX –º–æ–¥–µ–ª–∏"""
    
    print("üîß –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è MLX –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤...")
    
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –º–æ–¥–µ–ª–µ–π
    wav2vec2_wrapper = Wav2Vec2Wrapper()
    gemma_model = GemmaWrapper()
    projector = AudioProjector(input_dim, output_dim)
    
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä–∞
    optimizer = optim.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
    
    # –°–æ–∑–¥–∞–Ω–∏–µ –ø—Ä–µ—Ñ–∏–∫—Å–∞ –¥–ª—è –ø—Ä–æ–º–ø—Ç–∞
    prefix = "–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è –∞—É–¥–∏–æ: "
    prefix_tokens = gemma_model.tokenizer(prefix, return_tensors="np")
    prefix_embeds = gemma_model.get_input_embeddings(mx.array(prefix_tokens.input_ids.squeeze(0)))
    prefix_embeds = mx.expand_dims(prefix_embeds, 0)  # –î–æ–±–∞–≤–ª—è–µ–º batch dimension
    
    print(f"‚úÖ –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä: AdamW (lr={learning_rate}, wd={weight_decay})")
    print(f"‚úÖ Gradient clipping: {max_grad_norm}")
    print(f"‚úÖ –ü—Ä–µ—Ñ–∏–∫—Å –ø—Ä–æ–º–ø—Ç–∞: '{prefix}'")
    
    # –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö
    print("üìä –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö...")
    jsonl_path = "transcripts.jsonl"
    
    try:
        with open(jsonl_path, "r", encoding="utf-8") as f:
            all_data = [json.loads(line) for line in f]
    except FileNotFoundError:
        print(f"‚ùå –§–∞–π–ª {jsonl_path} –Ω–µ –Ω–∞–π–¥–µ–Ω!")
        return
    
    train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)
    
    train_dataset = AudioTextDataset(train_data, gemma_model.tokenizer, wav2vec2_wrapper.feature_extractor, wav2vec2_wrapper)
    val_dataset = AudioTextDataset(val_data, gemma_model.tokenizer, wav2vec2_wrapper.feature_extractor, wav2vec2_wrapper)
    
    # –°–æ–∑–¥–∞–Ω–∏–µ –¥–∞—Ç–∞–≥—Ä—É–ø–ø (–ø—Ä–æ—Å—Ç–∞—è –≤–µ—Ä—Å–∏—è –±–µ–∑ DataLoader)
    def create_batches(dataset, batch_size):
        batches = []
        for i in range(0, len(dataset), batch_size):
            batch_items = [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]
            batches.append(collate_fn(batch_items))
        return batches
    
    train_batches = create_batches(train_dataset, batch_size)
    val_batches = create_batches(val_dataset, batch_size)
    
    print(f"üìä –î–∞–Ω–Ω—ã–µ –∑–∞–≥—Ä—É–∂–µ–Ω—ã: {len(train_data)} train, {len(val_data)} val samples.")
    print(f"üì¶ –ë–∞—Ç—á–µ–π: {len(train_batches)} train, {len(val_batches)} val")
    
    # –ù–∞—Å—Ç—Ä–æ–π–∫–∞ learning rate schedule
    total_steps = num_epochs * len(train_batches)
    lr_schedule = create_learning_rate_schedule(total_steps, warmup_steps, learning_rate)
    
    print(f"üìÖ –û–±—â–µ–µ –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ —à–∞–≥–æ–≤: {total_steps}")
    print(f"üî• Warmup —à–∞–≥–æ–≤: {warmup_steps}")
    
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –ª–æ–≥–≥–µ—Ä–∞
    logger = TrainingLogger(experiment_name, checkpoint_dir)
    
    print(f"\nüöÄ –ù–∞—á–∏–Ω–∞–µ–º –æ–±—É—á–µ–Ω–∏–µ MLX Audio Projector!")
    
    # –û—Å–Ω–æ–≤–Ω–æ–π —Ü–∏–∫–ª –æ–±—É—á–µ–Ω–∏—è
    step = 0
    for epoch in range(num_epochs):
        print(f"\n{'='*60}")
        print(f"üîÑ EPOCH {epoch+1}/{num_epochs}")
        print(f"{'='*60}")
        
        epoch_loss = 0.0
        total_grad_norm = 0.0
        
        progress_bar = tqdm(train_batches, desc=f"Epoch {epoch+1} Training")
        
        for batch_idx, batch in enumerate(progress_bar):
            # –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ learning rate
            current_lr = lr_schedule(step)
            optimizer.learning_rate = current_lr
            
            # Forward pass
            def loss_fn(projector_params):
                projector.update(projector_params)
                outputs, _ = process_batch(
                    batch, gemma_model, projector, prefix_embeds,
                    batch_idx=batch_idx if batch_idx == 0 else -1, 
                    context="training loop" if batch_idx == 0 else ""
                )
                return outputs["loss"]
            
            # –í—ã—á–∏—Å–ª–µ–Ω–∏–µ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤
            loss, grads = mx.value_and_grad(loss_fn)(projector.parameters())
            
            # Gradient clipping
            grad_norm = mx.sqrt(sum([mx.sum(g * g) for g in grads.values()]))
            if grad_norm > max_grad_norm:
                scale = max_grad_norm / grad_norm
                grads = {k: scale * v for k, v in grads.items()}
            
            # –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤
            optimizer.update(projector, grads)
            mx.eval(projector.parameters())
            
            epoch_loss += float(loss)
            total_grad_norm += float(grad_norm)
            step += 1
            
            # –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ progress bar
            progress_bar.set_postfix({
                'Loss': f'{float(loss):.4f}',
                'LR': f'{current_lr:.2e}',
                'Grad': f'{float(grad_norm):.3f}'
            })
        
        # –°—Ä–µ–¥–Ω–∏–µ –∑–Ω–∞—á–µ–Ω–∏—è –∑–∞ —ç–ø–æ—Ö—É
        avg_train_loss = epoch_loss / len(train_batches)
        avg_grad_norm = total_grad_norm / len(train_batches)
        
        print(f"\nüìà Epoch {epoch+1} Training Results:")
        print(f"   üìâ Average Loss: {avg_train_loss:.4f}")
        print(f"   üéØ Learning Rate: {current_lr:.2e}")
        print(f"   ‚úÇÔ∏è Average Grad Norm: {avg_grad_norm:.3f}")
        
        # –í–∞–ª–∏–¥–∞—Ü–∏—è
        print(f"\nüîç Running validation...")
        val_metrics = evaluate_with_metrics(gemma_model, projector, val_batches, prefix_embeds)
        
        print(f"\nüìä Validation Results:")
        print(f"   üìâ Loss: {val_metrics['loss']:.4f}")
        print(f"   üßÆ Perplexity: {val_metrics['perplexity']:.2f}")
        print(f"   üéØ WER: {val_metrics['wer']:.3f}")
        print(f"   üî§ BLEU: {val_metrics['bleu']:.3f}")
        print(f"   üìù ROUGE-L: {val_metrics['rouge_l']:.3f}")
        
        # –õ–æ–≥–∏—Ä–æ–≤–∞–Ω–∏–µ
        logger.log(epoch+1, avg_train_loss, val_metrics, current_lr, avg_grad_norm)
        
        # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —á–µ–∫–ø–æ–∏–Ω—Ç–∞
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.npz")
        mx.save_arrays(checkpoint_path, projector.parameters())
        print(f"üíæ Checkpoint saved: {checkpoint_path}")
        
        # –ü–æ—Å—Ç—Ä–æ–µ–Ω–∏–µ –≥—Ä–∞—Ñ–∏–∫–æ–≤
        logger.plot_training_curves()
    
    print(f"\n{'='*60}")
    print(f"üéâ –û–ë–£–ß–ï–ù–ò–ï MLX –ó–ê–í–ï–†–®–ï–ù–û!")
    print(f"{'='*60}")
    
    # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ñ–∏–Ω–∞–ª—å–Ω—ã—Ö —Ä–µ–∑—É–ª—å—Ç–∞—Ç–æ–≤
    final_logs_df = logger.save_logs()
    print(f"\nüìã –ò—Ç–æ–≥–æ–≤–∞—è —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∞:")
    print(final_logs_df.round(4))
    
    # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ñ–∏–Ω–∞–ª—å–Ω–æ–π –º–æ–¥–µ–ª–∏
    final_model_path = os.path.join(checkpoint_dir, "final_projector.npz")
    mx.save_arrays(final_model_path, projector.parameters())
    print(f"üèÜ –§–∏–Ω–∞–ª—å–Ω–∞—è MLX –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞: {final_model_path}")
    
    print(f"\nüìÅ –í—Å–µ —Ñ–∞–π–ª—ã —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã –≤: {checkpoint_dir}")
    print(f"   - –ß–µ–∫–ø–æ–∏–Ω—Ç—ã: checkpoint_epoch_*.npz")
    print(f"   - –§–∏–Ω–∞–ª—å–Ω–∞—è –º–æ–¥–µ–ª—å: final_projector.npz")
    print(f"   - –õ–æ–≥–∏: training_logs.csv")
    print(f"   - –ì—Ä–∞—Ñ–∏–∫–∏: training_curves.png")

if __name__ == "__main__":
    main()


üöÄ MLX –≠–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç: audio_projector_mlx_20250620_001655
üìÅ –ß–µ–∫–ø–æ–∏–Ω—Ç—ã: ./checkpoints/audio_projector_mlx_20250620_001655
üñ•Ô∏è –£—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: MLX (Apple Silicon)
‚öôÔ∏è –ö–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏—è:
   - Batch size: 4
   - Epochs: 3
   - Learning rate: 0.0001
   - Weight decay: 0.01
   - Gradient clipping: 1.0
   - Warmup steps: 100
üîó Projector: 768 -> 2560
üîß –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è MLX –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤...


config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/380M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

RepositoryNotFoundError: 404 Client Error. (Request ID: Root=1-68547ee0-3e7f8e102c689f763fb19c5a;7e7bed0a-2c02-48ca-a1c5-545ce52be9ac)

Repository Not Found for url: https://huggingface.co/api/models/mlx-community/gemma-2b-it/revision/main.
Please make sure you specified the correct `repo_id` and `repo_type`.
If you are trying to access a private or gated repo, make sure you are authenticated. For more details, see https://huggingface.co/docs/huggingface_hub/authentication

model.safetensors:   0%|          | 0.00/380M [00:00<?, ?B/s]

In [1]:
# –ò–º–ø–æ—Ä—Ç –∫–∞—Å—Ç–æ–º–Ω–æ–≥–æ –º–æ–¥—É–ª—è –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ —Å –ø–æ–¥–¥–µ—Ä–∂–∫–æ–π –∞—É–¥–∏–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤
print("üì¶ –ó–∞–≥—Ä—É–∂–∞–µ–º –∫–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...")

try:
    # –ò–º–ø–æ—Ä—Ç–∏—Ä—É–µ–º –Ω–∞—à –∫–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å
    from mlx_custom_generation import (
        audio_generate_step,
        audio_stream_generate, 
        audio_generate,
        create_audio_prompt_embeddings,
        batch_audio_generate,
        test_audio_generation,
        AudioGenerationResponse
    )
    print("‚úÖ –ö–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –∑–∞–≥—Ä—É–∂–µ–Ω —É—Å–ø–µ—à–Ω–æ!")
    
    # –¢–µ—Å—Ç–∏—Ä—É–µ–º –º–æ–¥—É–ª—å
    test_embeddings = test_audio_generation()
    print(f"üéµ –¢–µ—Å—Ç–æ–≤—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏ —Å–æ–∑–¥–∞–Ω—ã: {test_embeddings.shape}")
    
except Exception as e:
    print(f"‚ùå –û—à–∏–±–∫–∞ –ø—Ä–∏ –∑–∞–≥—Ä—É–∑–∫–µ –∫–∞—Å—Ç–æ–º–Ω–æ–≥–æ –º–æ–¥—É–ª—è: {e}")
    print("üîß –°–æ–∑–¥–∞–µ–º —É–ø—Ä–æ—â–µ–Ω–Ω—É—é –≤–µ—Ä—Å–∏—é —Ñ—É–Ω–∫—Ü–∏–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...")
    
    def simple_audio_generate_with_embeddings(
        model, tokenizer, audio_embeddings, text_prefix="–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è –∞—É–¥–∏–æ: ", max_tokens=50
    ):
        """–£–ø—Ä–æ—â–µ–Ω–Ω–∞—è –≤–µ—Ä—Å–∏—è –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ —Å –∞—É–¥–∏–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–∞–º–∏"""
        
        print(f"üéµ –ü—Ä–æ—Å—Ç–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è —Å –∞—É–¥–∏–æ: {audio_embeddings.shape}")
        
        # –ü–æ–ª—É—á–∞–µ–º —ç–º–±–µ–¥–¥–∏–Ω–≥–∏ –ø—Ä–µ—Ñ–∏–∫—Å–∞
        prefix_tokens = tokenizer(text_prefix, return_tensors="np", add_special_tokens=False)
        prefix_ids = mx.array(prefix_tokens.input_ids.squeeze(0))
        prefix_embeddings = model.embed_tokens(prefix_ids)
        
        # –ï—Å–ª–∏ –∞—É–¥–∏–æ –æ–¥–Ω–æ–º–µ—Ä–Ω–æ–µ, —Ä–∞—Å—à–∏—Ä—è–µ–º
        if audio_embeddings.ndim == 1:
            audio_embeddings = mx.expand_dims(audio_embeddings, 0)
        
        # –û–±—ä–µ–¥–∏–Ω—è–µ–º –ø—Ä–µ—Ñ–∏–∫—Å –∏ –∞—É–¥–∏–æ
        combined_embeddings = mx.concatenate([prefix_embeddings, audio_embeddings], axis=0)
        
        print(f"üîó –ö–æ–º–±–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏: {combined_embeddings.shape}")
        
        # –ó–¥–µ—Å—å –¥–æ–ª–∂–Ω–∞ –±—ã—Ç—å –≥–µ–Ω–µ—Ä–∞—Ü–∏—è, –Ω–æ –ø–æ–∫–∞ –≤–µ—Ä–Ω–µ–º placeholder
        generated_text = f"[GENERATED FROM AUDIO EMBEDDINGS: {audio_embeddings.shape[0]} frames, {audio_embeddings.shape[1]} features]"
        
        return generated_text
    
    # –û–ø—Ä–µ–¥–µ–ª—è–µ–º –ø—Ä–æ—Å—Ç—É—é —Ñ—É–Ω–∫—Ü–∏—é
    audio_generate = simple_audio_generate_with_embeddings
    
    print("‚úÖ –£–ø—Ä–æ—â–µ–Ω–Ω–∞—è –≤–µ—Ä—Å–∏—è —Ñ—É–Ω–∫—Ü–∏–π —Å–æ–∑–¥–∞–Ω–∞!")

print("üîß –ú–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –≥–æ—Ç–æ–≤ –∫ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—é!")

üì¶ –ó–∞–≥—Ä—É–∂–∞–µ–º –∫–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...
‚úÖ –ö–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –∑–∞–≥—Ä—É–∂–µ–Ω —É—Å–ø–µ—à–Ω–æ!
üß™ –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –∞—É–¥–∏–æ –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...
üéµ –¢–µ—Å—Ç–æ–≤—ã–µ –∞—É–¥–∏–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏: (10, 768)
üéµ –°—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏: min=-3.193, max=4.097, mean=0.008
üéµ –¢–µ—Å—Ç–æ–≤—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏ —Å–æ–∑–¥–∞–Ω—ã: (10, 768)
üîß –ú–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –≥–æ—Ç–æ–≤ –∫ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—é!


In [None]:
# –û–±–Ω–æ–≤–ª–µ–Ω–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è –æ—Ü–µ–Ω–∫–∏ —Å –ø–æ–¥–¥–µ—Ä–∂–∫–æ–π –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏

def evaluate_with_custom_generation(
    gemma_model: GemmaWrapper,
    projector: AudioProjector,
    dataloader: List[Dict],
    prefix_embeds: mx.array,
    use_custom_generation: bool = True
) -> Dict[str, float]:
    """
    –û—Ü–µ–Ω–∫–∞ –º–æ–¥–µ–ª–∏ —Å –≤—ã—á–∏—Å–ª–µ–Ω–∏–µ–º –º–µ—Ç—Ä–∏–∫ –∏ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π
    
    Args:
        gemma_model: –û–±–µ—Ä—Ç–∫–∞ –¥–ª—è Gemma –º–æ–¥–µ–ª–∏
        projector: –ê—É–¥–∏–æ –ø—Ä–æ–µ–∫—Ç–æ—Ä
        dataloader: –î–∞–Ω–Ω—ã–µ –¥–ª—è –≤–∞–ª–∏–¥–∞—Ü–∏–∏
        prefix_embeds: –ü—Ä–µ—Ñ–∏–∫—Å —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤
        use_custom_generation: –ò—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å –ª–∏ –∫–∞—Å—Ç–æ–º–Ω—É—é –≥–µ–Ω–µ—Ä–∞—Ü–∏—é
        
    Returns:
        Dict —Å –º–µ—Ç—Ä–∏–∫–∞–º–∏
    """
    
    total_loss = 0.0
    total_wer = 0.0
    total_bleu = 0.0
    total_rouge_1 = 0.0
    total_rouge_2 = 0.0
    total_rouge_l = 0.0
    count = 0
    
    smooth = SmoothingFunction().method1
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    print(f"üîç –ù–∞—á–∏–Ω–∞–µ–º –æ—Ü–µ–Ω–∫—É {'—Å –∫–∞—Å—Ç–æ–º–Ω–æ–π' if use_custom_generation else '—Å —É–ø—Ä–æ—â–µ–Ω–Ω–æ–π'} –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π...")
    
    for i, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
        outputs, prompt_embeds = process_batch(
            batch, gemma_model, projector, prefix_embeds, batch_idx=i, context="evaluation"
        )
        
        loss = outputs["loss"]
        total_loss += float(loss)
        
        # –ì–µ–Ω–µ—Ä–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–∞ –¥–ª—è –º–µ—Ç—Ä–∏–∫
        for j in range(batch["input_ids"].shape[0]):
            try:
                # –ü–æ–ª—É—á–∞–µ–º –∞—É–¥–∏–æ –ø—Ä–∏–∑–Ω–∞–∫–∏ –¥–ª—è –¥–∞–Ω–Ω–æ–≥–æ –æ–±—Ä–∞–∑—Ü–∞
                audio_features = batch["audio_features"][j]  # [768]
                projected_audio = projector(audio_features[None])  # [1, 2560]
                
                if use_custom_generation and 'audio_generate' in globals():
                    try:
                        # –ò—Å–ø–æ–ª—å–∑—É–µ–º –∫–∞—Å—Ç–æ–º–Ω—É—é –≥–µ–Ω–µ—Ä–∞—Ü–∏—é
                        generated_text = audio_generate(
                            model=gemma_model.model,
                            tokenizer=gemma_model.tokenizer,
                            audio_embeddings=projected_audio.squeeze(0),  # [2560]
                            max_tokens=30,
                            verbose=False
                        )
                        print(f"üéµ –ö–∞—Å—Ç–æ–º–Ω–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è –¥–ª—è –æ–±—Ä–∞–∑—Ü–∞ {j}: '{generated_text[:50]}...'")
                        
                    except Exception as e:
                        print(f"‚ö†Ô∏è –û—à–∏–±–∫–∞ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏: {e}")
                        generated_text = f"[CUSTOM_GEN_ERROR]"
                else:
                    # –£–ø—Ä–æ—â–µ–Ω–Ω–∞—è –≤–µ—Ä—Å–∏—è
                    generated_text = f"[PROJECTED_AUDIO_{projected_audio.shape[0]}_{projected_audio.shape[1] if projected_audio.ndim > 1 else 'scalar'}]"
                
                # –†–µ—Ñ–µ—Ä–µ–Ω—Å–Ω—ã–π —Ç–µ–∫—Å—Ç
                ref_ids = batch["input_ids"][j]
                ref_ids = ref_ids[ref_ids != -100]
                ref_text = gemma_model.tokenizer.decode(ref_ids.tolist(), skip_special_tokens=True).strip()
                
                if ref_text and generated_text and len(generated_text) > 0:
                    # WER (Word Error Rate)
                    try:
                        wer_score = jiwer.wer(ref_text, generated_text)
                        total_wer += wer_score
                    except:
                        total_wer += 1.0  # –ú–∞–∫—Å–∏–º–∞–ª—å–Ω–∞—è –æ—à–∏–±–∫–∞ –µ—Å–ª–∏ –Ω–µ —É–¥–∞–ª–æ—Å—å –≤—ã—á–∏—Å–ª–∏—Ç—å
                    
                    # BLEU Score  
                    try:
                        bleu_score = sentence_bleu([ref_text.split()], generated_text.split(), smoothing_function=smooth)
                        total_bleu += bleu_score
                    except:
                        total_bleu += 0.0
                    
                    # ROUGE Scores
                    try:
                        rouge_scores = rouge_scorer_obj.score(ref_text, generated_text)
                        total_rouge_1 += rouge_scores['rouge1'].fmeasure
                        total_rouge_2 += rouge_scores['rouge2'].fmeasure
                        total_rouge_l += rouge_scores['rougeL'].fmeasure
                    except:
                        # –ï—Å–ª–∏ ROUGE –Ω–µ —É–¥–∞–ª–æ—Å—å –≤—ã—á–∏—Å–ª–∏—Ç—å
                        pass
                    
                    count += 1
                    
                    # –õ–æ–≥–∏—Ä—É–µ–º –ø–µ—Ä–≤—ã–µ –Ω–µ—Å–∫–æ–ª—å–∫–æ –ø—Ä–∏–º–µ—Ä–æ–≤
                    if i == 0 and j < 2:
                        print(f"\nüìù –ü—Ä–∏–º–µ—Ä {j+1}:")
                        print(f"   üéØ Reference: '{ref_text[:100]}...'")
                        print(f"   ü§ñ Generated: '{generated_text[:100]}...'")
                        print(f"   üìä WER: {wer_score:.3f}, BLEU: {bleu_score:.3f}")
                
            except Exception as e:
                print(f"‚ö†Ô∏è –û—à–∏–±–∫–∞ –ø—Ä–∏ –æ—Ü–µ–Ω–∫–µ –æ–±—Ä–∞–∑—Ü–∞ {j}: {e}")
                continue
    
    # –í—ã—á–∏—Å–ª—è–µ–º —Å—Ä–µ–¥–Ω–∏–µ –∑–Ω–∞—á–µ–Ω–∏—è
    avg_loss = total_loss / len(dataloader)
    perplexity = math.exp(min(avg_loss, 10))  # –û–≥—Ä–∞–Ω–∏—á–∏–≤–∞–µ–º –¥–ª—è stability
    
    metrics = {
        'loss': avg_loss,
        'perplexity': perplexity,
        'wer': total_wer / count if count > 0 else 1.0,
        'bleu': total_bleu / count if count > 0 else 0.0,
        'rouge_1': total_rouge_1 / count if count > 0 else 0.0,
        'rouge_2': total_rouge_2 / count if count > 0 else 0.0,
        'rouge_l': total_rouge_l / count if count > 0 else 0.0
    }
    
    print(f"‚úÖ –û—Ü–µ–Ω–∫–∞ –∑–∞–≤–µ—Ä—à–µ–Ω–∞: {count} –æ–±—Ä–∞–∑—Ü–æ–≤ –æ–±—Ä–∞–±–æ—Ç–∞–Ω–æ")
    
    return metrics

print("üîß –û–±–Ω–æ–≤–ª–µ–Ω–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è –æ—Ü–µ–Ω–∫–∏ –≥–æ—Ç–æ–≤–∞!")

In [None]:
# –û–±–Ω–æ–≤–ª–µ–Ω–Ω–∞—è –≥–ª–∞–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è –æ–±—É—á–µ–Ω–∏—è —Å –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π

def main_with_custom_generation():
    """–û—Å–Ω–æ–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è –æ–±—É—á–µ–Ω–∏—è MLX –º–æ–¥–µ–ª–∏ —Å –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π"""
    
    print("üîß –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è MLX –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤ —Å –ø–æ–¥–¥–µ—Ä–∂–∫–æ–π –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...")
    
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –º–æ–¥–µ–ª–µ–π
    wav2vec2_wrapper = Wav2Vec2Wrapper()
    gemma_model = GemmaWrapper()
    projector = AudioProjector(input_dim, output_dim)
    
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä–∞
    optimizer = optim.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)
    
    # –°–æ–∑–¥–∞–Ω–∏–µ –ø—Ä–µ—Ñ–∏–∫—Å–∞ –¥–ª—è –ø—Ä–æ–º–ø—Ç–∞
    prefix = "–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è –∞—É–¥–∏–æ: "
    prefix_tokens = gemma_model.tokenizer(prefix, return_tensors="np")
    prefix_embeds = gemma_model.get_input_embeddings(mx.array(prefix_tokens.input_ids.squeeze(0)))
    prefix_embeds = mx.expand_dims(prefix_embeds, 0)  # –î–æ–±–∞–≤–ª—è–µ–º batch dimension
    
    print(f"‚úÖ –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä: AdamW (lr={learning_rate}, wd={weight_decay})")
    print(f"‚úÖ Gradient clipping: {max_grad_norm}")
    print(f"‚úÖ –ü—Ä–µ—Ñ–∏–∫—Å –ø—Ä–æ–º–ø—Ç–∞: '{prefix}'")
    print(f"üéµ –ö–∞—Å—Ç–æ–º–Ω–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è: {'‚úÖ –ê–∫—Ç–∏–≤–Ω–∞' if 'audio_generate' in globals() else '‚ùå –ù–µ–¥–æ—Å—Ç—É–ø–Ω–∞'}")
    
    # –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö
    print("üìä –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö...")
    jsonl_path = "transcripts.jsonl"
    
    # –°–æ–∑–¥–∞–µ–º —Ç–µ—Å—Ç–æ–≤—ã–µ –¥–∞–Ω–Ω—ã–µ –µ—Å–ª–∏ —Ñ–∞–π–ª –Ω–µ –Ω–∞–π–¥–µ–Ω
    try:
        with open(jsonl_path, "r", encoding="utf-8") as f:
            all_data = [json.loads(line) for line in f]
    except FileNotFoundError:
        print(f"‚ö†Ô∏è –§–∞–π–ª {jsonl_path} –Ω–µ –Ω–∞–π–¥–µ–Ω! –°–æ–∑–¥–∞–µ–º —Ç–µ—Å—Ç–æ–≤—ã–µ –¥–∞–Ω–Ω—ã–µ...")
        # –°–æ–∑–¥–∞–µ–º –º–∏–Ω–∏–º–∞–ª—å–Ω—ã–µ —Ç–µ—Å—Ç–æ–≤—ã–µ –¥–∞–Ω–Ω—ã–µ
        all_data = [
            {
                "audio_path": "dummy_audio_1.wav",  # –§–∏–∫—Ç–∏–≤–Ω—ã–π –ø—É—Ç—å
                "speaker_text": "–≠—Ç–æ –ø–µ—Ä–≤—ã–π —Ç–µ—Å—Ç–æ–≤—ã–π –æ–±—Ä–∞–∑–µ—Ü —Ä–µ—á–∏."
            },
            {
                "audio_path": "dummy_audio_2.wav",  # –§–∏–∫—Ç–∏–≤–Ω—ã–π –ø—É—Ç—å  
                "speaker_text": "–≠—Ç–æ –≤—Ç–æ—Ä–æ–π —Ç–µ—Å—Ç–æ–≤—ã–π –æ–±—Ä–∞–∑–µ—Ü —Ä–µ—á–∏."
            },
            {
                "audio_path": "dummy_audio_3.wav",  # –§–∏–∫—Ç–∏–≤–Ω—ã–π –ø—É—Ç—å
                "speaker_text": "–¢—Ä–µ—Ç–∏–π –æ–±—Ä–∞–∑–µ—Ü –¥–ª—è —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏—è —Å–∏—Å—Ç–µ–º—ã."
            }
        ]
        print(f"üìù –°–æ–∑–¥–∞–Ω–æ {len(all_data)} —Ç–µ—Å—Ç–æ–≤—ã—Ö –æ–±—Ä–∞–∑—Ü–æ–≤")
    
    train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)
    
    print(f"üìä –î–∞–Ω–Ω—ã–µ –≥–æ—Ç–æ–≤—ã: {len(train_data)} train, {len(val_data)} val samples.")
    
    # –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –ø–µ—Ä–µ–¥ –Ω–∞—á–∞–ª–æ–º –æ–±—É—á–µ–Ω–∏—è
    print("\nüß™ –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...")
    
    if 'audio_generate' in globals():
        try:
            # –°–æ–∑–¥–∞–µ–º —Ç–µ—Å—Ç–æ–≤—ã–µ –∞—É–¥–∏–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
            test_audio_features = mx.random.normal((768,))  # Wav2Vec2 output
            test_projected = projector(test_audio_features[None]).squeeze(0)  # –ü—Ä–æ–µ–∫—Ç–∏—Ä—É–µ–º –≤ Gemma space
            
            print(f"üéµ –¢–µ—Å—Ç–æ–≤—ã–µ –∞—É–¥–∏–æ –ø—Ä–∏–∑–Ω–∞–∫–∏: {test_audio_features.shape}")
            print(f"üîó –ü—Ä–æ–µ—Ü–∏—Ä–æ–≤–∞–Ω–Ω—ã–µ –ø—Ä–∏–∑–Ω–∞–∫–∏: {test_projected.shape}")
            
            # –¢–µ—Å—Ç–∏—Ä—É–µ–º –≥–µ–Ω–µ—Ä–∞—Ü–∏—é
            test_result = audio_generate(
                model=gemma_model.model,
                tokenizer=gemma_model.tokenizer, 
                audio_embeddings=test_projected,
                max_tokens=10,
                verbose=True
            )
            
            print(f"‚úÖ –¢–µ—Å—Ç –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –ø—Ä–æ—à–µ–ª —É—Å–ø–µ—à–Ω–æ!")
            print(f"üéØ –†–µ–∑—É–ª—å—Ç–∞—Ç: '{test_result}'")
            
        except Exception as e:
            print(f"‚ùå –û—à–∏–±–∫–∞ –ø—Ä–∏ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–∏ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏: {e}")
            print("üîÑ –ü—Ä–æ–¥–æ–ª–∂–∞–µ–º —Å —É–ø—Ä–æ—â–µ–Ω–Ω–æ–π –≤–µ—Ä—Å–∏–µ–π...")
    
    # –°–æ–∑–¥–∞–Ω–∏–µ –ø—Ä–æ—Å—Ç—ã—Ö —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏
    print("\nüì¶ –°–æ–∑–¥–∞–Ω–∏–µ —Ç–µ—Å—Ç–æ–≤—ã—Ö –±–∞—Ç—á–µ–π...")
    
    def create_test_batch(size=2):
        """–°–æ–∑–¥–∞–µ—Ç —Ç–µ—Å—Ç–æ–≤—ã–π –±–∞—Ç—á –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏"""
        batch_audio_features = []
        batch_input_ids = []
        batch_attention_mask = []
        
        for i in range(size):
            # –ì–µ–Ω–µ—Ä–∏—Ä—É–µ–º —Å–ª—É—á–∞–π–Ω—ã–µ –∞—É–¥–∏–æ –ø—Ä–∏–∑–Ω–∞–∫–∏
            audio_features = mx.random.normal((768,))
            batch_audio_features.append(audio_features)
            
            # –¢–æ–∫–µ–Ω–∏–∑–∏—Ä—É–µ–º —Ç–µ—Å—Ç–æ–≤—ã–π —Ç–µ–∫—Å—Ç
            test_text = f"–¢–µ—Å—Ç–æ–≤—ã–π –æ–±—Ä–∞–∑–µ—Ü –Ω–æ–º–µ—Ä {i+1} –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏ —Ä–∞–±–æ—Ç—ã."
            tokens = gemma_model.tokenizer(
                test_text,
                return_tensors="np",
                padding=True,
                truncation=True,
                max_length=32
            )
            
            batch_input_ids.append(mx.array(tokens.input_ids.squeeze(0)))
            batch_attention_mask.append(mx.array(tokens.attention_mask.squeeze(0)))
        
        # Padding –¥–ª—è batch
        max_len = max(len(ids) for ids in batch_input_ids)
        padded_input_ids = []
        padded_attention_mask = []
        
        for ids, mask in zip(batch_input_ids, batch_attention_mask):
            pad_len = max_len - len(ids)
            if pad_len > 0:
                padded_ids = mx.concatenate([ids, mx.full((pad_len,), -100, dtype=mx.int32)])
                padded_mask = mx.concatenate([mask, mx.zeros((pad_len,), dtype=mx.int32)])
            else:
                padded_ids = ids
                padded_mask = mask
            
            padded_input_ids.append(padded_ids)
            padded_attention_mask.append(padded_mask)
        
        return {
            'audio_features': mx.stack(batch_audio_features),
            'input_ids': mx.stack(padded_input_ids),
            'attention_mask': mx.stack(padded_attention_mask)
        }
    
    # –°–æ–∑–¥–∞–µ–º —Ç–µ—Å—Ç–æ–≤—ã–µ –±–∞—Ç—á–∏
    test_train_batches = [create_test_batch(batch_size) for _ in range(3)]
    test_val_batches = [create_test_batch(batch_size) for _ in range(2)]
    
    print(f"üì¶ –°–æ–∑–¥–∞–Ω–æ —Ç–µ—Å—Ç–æ–≤—ã—Ö –±–∞—Ç—á–µ–π: {len(test_train_batches)} train, {len(test_val_batches)} val")
    
    # –ù–∞—Å—Ç—Ä–æ–π–∫–∞ learning rate schedule
    total_steps = num_epochs * len(test_train_batches)
    lr_schedule = create_learning_rate_schedule(total_steps, warmup_steps, learning_rate)
    
    print(f"üìÖ –û–±—â–µ–µ –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ —à–∞–≥–æ–≤: {total_steps}")
    print(f"üî• Warmup —à–∞–≥–æ–≤: {warmup_steps}")
    
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –ª–æ–≥–≥–µ—Ä–∞
    logger = TrainingLogger(experiment_name, checkpoint_dir)
    
    print(f"\nüöÄ –ù–∞—á–∏–Ω–∞–µ–º –æ–±—É—á–µ–Ω–∏–µ MLX Audio Projector —Å –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π!")
    
    # –û—Å–Ω–æ–≤–Ω–æ–π —Ü–∏–∫–ª –æ–±—É—á–µ–Ω–∏—è 
    step = 0
    for epoch in range(num_epochs):
        print(f"\n{'='*60}")
        print(f"üîÑ EPOCH {epoch+1}/{num_epochs}")
        print(f"{'='*60}")
        
        epoch_loss = 0.0
        total_grad_norm = 0.0
        
        progress_bar = tqdm(test_train_batches, desc=f"Epoch {epoch+1} Training")
        
        for batch_idx, batch in enumerate(progress_bar):
            # –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ learning rate
            current_lr = lr_schedule(step)
            optimizer.learning_rate = current_lr
            
            # Forward pass
            def loss_fn(projector_params):
                projector.update(projector_params)
                outputs, _ = process_batch(
                    batch, gemma_model, projector, prefix_embeds,
                    batch_idx=batch_idx if batch_idx == 0 else -1, 
                    context="training loop" if batch_idx == 0 else ""
                )
                return outputs["loss"]
            
            # –í—ã—á–∏—Å–ª–µ–Ω–∏–µ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤
            loss, grads = mx.value_and_grad(loss_fn)(projector.parameters())
            
            # Gradient clipping
            grad_norm = mx.sqrt(sum([mx.sum(g * g) for g in grads.values()]))
            if grad_norm > max_grad_norm:
                scale = max_grad_norm / grad_norm
                grads = {k: scale * v for k, v in grads.items()}
            
            # –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤
            optimizer.update(projector, grads)
            mx.eval(projector.parameters())
            
            epoch_loss += float(loss)
            total_grad_norm += float(grad_norm)
            step += 1
            
            # –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ progress bar
            progress_bar.set_postfix({
                'Loss': f'{float(loss):.4f}',
                'LR': f'{current_lr:.2e}',
                'Grad': f'{float(grad_norm):.3f}'
            })
        
        # –°—Ä–µ–¥–Ω–∏–µ –∑–Ω–∞—á–µ–Ω–∏—è –∑–∞ —ç–ø–æ—Ö—É
        avg_train_loss = epoch_loss / len(test_train_batches)
        avg_grad_norm = total_grad_norm / len(test_train_batches)
        
        print(f"\nüìà Epoch {epoch+1} Training Results:")
        print(f"   üìâ Average Loss: {avg_train_loss:.4f}")
        print(f"   üéØ Learning Rate: {current_lr:.2e}")
        print(f"   ‚úÇÔ∏è Average Grad Norm: {avg_grad_norm:.3f}")
        
        # –í–∞–ª–∏–¥–∞—Ü–∏—è —Å –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π
        print(f"\nüîç Running validation with custom generation...")
        val_metrics = evaluate_with_custom_generation(
            gemma_model, projector, test_val_batches, prefix_embeds, 
            use_custom_generation=True
        )
        
        print(f"\nüìä Validation Results:")
        print(f"   üìâ Loss: {val_metrics['loss']:.4f}")
        print(f"   üßÆ Perplexity: {val_metrics['perplexity']:.2f}")
        print(f"   üéØ WER: {val_metrics['wer']:.3f}")
        print(f"   üî§ BLEU: {val_metrics['bleu']:.3f}")
        print(f"   üìù ROUGE-L: {val_metrics['rouge_l']:.3f}")
        
        # –õ–æ–≥–∏—Ä–æ–≤–∞–Ω–∏–µ
        logger.log(epoch+1, avg_train_loss, val_metrics, current_lr, avg_grad_norm)
        
        # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —á–µ–∫–ø–æ–∏–Ω—Ç–∞
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.npz")
        mx.save_arrays(checkpoint_path, projector.parameters())
        print(f"üíæ Checkpoint saved: {checkpoint_path}")
        
        # –ü–æ—Å—Ç—Ä–æ–µ–Ω–∏–µ –≥—Ä–∞—Ñ–∏–∫–æ–≤
        logger.plot_training_curves()
    
    print(f"\n{'='*60}")
    print(f"üéâ –û–ë–£–ß–ï–ù–ò–ï MLX –° –ö–ê–°–¢–û–ú–ù–û–ô –ì–ï–ù–ï–†–ê–¶–ò–ï–ô –ó–ê–í–ï–†–®–ï–ù–û!")
    print(f"{'='*60}")
    
    # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ñ–∏–Ω–∞–ª—å–Ω—ã—Ö —Ä–µ–∑—É–ª—å—Ç–∞—Ç–æ–≤
    final_logs_df = logger.save_logs()
    print(f"\nüìã –ò—Ç–æ–≥–æ–≤–∞—è —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∞:")
    print(final_logs_df.round(4))
    
    # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ñ–∏–Ω–∞–ª—å–Ω–æ–π –º–æ–¥–µ–ª–∏
    final_model_path = os.path.join(checkpoint_dir, "final_projector.npz")
    mx.save_arrays(final_model_path, projector.parameters())
    print(f"üèÜ –§–∏–Ω–∞–ª—å–Ω–∞—è MLX –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞: {final_model_path}")
    
    # –§–∏–Ω–∞–ª—å–Ω–æ–µ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏
    print(f"\nüß™ –§–∏–Ω–∞–ª—å–Ω–æ–µ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...")
    
    if 'audio_generate' in globals():
        try:
            for i in range(2):
                test_audio = mx.random.normal((768,))
                test_projected = projector(test_audio[None]).squeeze(0)
                
                generated = audio_generate(
                    model=gemma_model.model,
                    tokenizer=gemma_model.tokenizer,
                    audio_embeddings=test_projected,
                    max_tokens=20,
                    verbose=False
                )
                
                print(f"üéµ –¢–µ—Å—Ç {i+1}: '{generated[:80]}...'")
                
        except Exception as e:
            print(f"‚ùå –û—à–∏–±–∫–∞ —Ñ–∏–Ω–∞–ª—å–Ω–æ–≥–æ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏—è: {e}")
    
    print(f"\nüìÅ –í—Å–µ —Ñ–∞–π–ª—ã —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã –≤: {checkpoint_dir}")
    print(f"   - –ß–µ–∫–ø–æ–∏–Ω—Ç—ã: checkpoint_epoch_*.npz")
    print(f"   - –§–∏–Ω–∞–ª—å–Ω–∞—è –º–æ–¥–µ–ª—å: final_projector.npz")
    print(f"   - –õ–æ–≥–∏: training_logs.csv")
    print(f"   - –ì—Ä–∞—Ñ–∏–∫–∏: training_curves.png")
    print(f"   - –ö–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å: mlx_custom_generation.py")

print("üéµ –û–±–Ω–æ–≤–ª–µ–Ω–Ω–∞—è –≥–ª–∞–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è —Å –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π –≥–æ—Ç–æ–≤–∞!")

In [2]:
# –ó–∞–ø—É—Å–∫ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏—è –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏

print("üöÄ –ó–ê–ü–£–°–ö –¢–ï–°–¢–ò–†–û–í–ê–ù–ò–Ø –ö–ê–°–¢–û–ú–ù–û–ô –ì–ï–ù–ï–†–ê–¶–ò–ò –° MLX")
print("="*80)

# –ü—Ä–æ–≤–µ—Ä—è–µ–º –Ω–∞–ª–∏—á–∏–µ –≤—Å–µ—Ö –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤
print("üîç –ü—Ä–æ–≤–µ—Ä–∫–∞ –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤:")
print(f"   üì¶ MLX –¥–æ—Å—Ç—É–ø–µ–Ω: {'‚úÖ' if 'mx' in globals() else '‚ùå'}")
print(f"   üéµ –ö–∞—Å—Ç–æ–º–Ω–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è: {'‚úÖ' if 'audio_generate' in globals() else '‚ùå'}")
print(f"   üîß –§—É–Ω–∫—Ü–∏–∏ –≥–æ—Ç–æ–≤—ã: {'‚úÖ' if 'main_with_custom_generation' in globals() else '‚ùå'}")

# –ë—ã—Å—Ç—Ä—ã–π —Ç–µ—Å—Ç –∫–∞—Å—Ç–æ–º–Ω–æ–≥–æ –º–æ–¥—É–ª—è –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏
try:
    print("\nüß™ –ë—ã—Å—Ç—Ä—ã–π —Ç–µ—Å—Ç –∫–∞—Å—Ç–æ–º–Ω–æ–≥–æ –º–æ–¥—É–ª—è...")
    
    # –¢–µ—Å—Ç–∏—Ä—É–µ–º –∏–º–ø–æ—Ä—Ç –∏–∑ —Ñ–∞–π–ª–∞
    import importlib.util
    spec = importlib.util.spec_from_file_location("mlx_custom_generation", "mlx_custom_generation.py")
    if spec and spec.loader:
        custom_gen_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(custom_gen_module)
        print("‚úÖ –ö–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å —É—Å–ø–µ—à–Ω–æ –∑–∞–≥—Ä—É–∂–µ–Ω –∏–∑ —Ñ–∞–π–ª–∞!")
        
        # –¢–µ—Å—Ç–∏—Ä—É–µ–º —Ñ—É–Ω–∫—Ü–∏—é
        test_embeddings = custom_gen_module.test_audio_generation()
        print(f"‚úÖ –¢–µ—Å—Ç–æ–≤—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏: {test_embeddings.shape}")
        
    else:
        print("‚ùå –ù–µ —É–¥–∞–ª–æ—Å—å –∑–∞–≥—Ä—É–∑–∏—Ç—å –∫–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å")
        
except Exception as e:
    print(f"‚ö†Ô∏è –û—à–∏–±–∫–∞ –ø—Ä–∏ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–∏ –º–æ–¥—É–ª—è: {e}")

# –¢–µ—Å—Ç –æ—Å–Ω–æ–≤–Ω—ã—Ö —Ñ—É–Ω–∫—Ü–∏–π
try:
    print("\nüîß –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –æ—Å–Ω–æ–≤–Ω—ã—Ö —Ñ—É–Ω–∫—Ü–∏–π...")
    
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä—É–µ–º –ø—Ä–æ—Å—Ç—ã–µ –∫–æ–º–ø–æ–Ω–µ–Ω—Ç—ã –¥–ª—è —Ç–µ—Å—Ç–∞
    print("   üìù –°–æ–∑–¥–∞–Ω–∏–µ —Ç–µ—Å—Ç–æ–≤—ã—Ö –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤...")
    
    # –ü—Ä–æ—Å—Ç–æ–π —Ç–µ—Å—Ç–æ–≤—ã–π –ø—Ä–æ–µ–∫—Ç–æ—Ä
    test_projector = AudioProjector(768, 2560)
    test_audio = mx.random.normal((768,))
    
    # –¢–µ—Å—Ç–∏—Ä—É–µ–º –ø—Ä–æ–µ–∫—Ü–∏—é
    projected = test_projector(test_audio[None])
    print(f"   ‚úÖ –ü—Ä–æ–µ–∫—Ü–∏—è —Ä–∞–±–æ—Ç–∞–µ—Ç: {test_audio.shape} -> {projected.shape}")
    
    # –¢–µ—Å—Ç–∏—Ä—É–µ–º Wav2Vec2 wrapper (–µ—Å–ª–∏ –¥–æ—Å—Ç—É–ø–µ–Ω)
    try:
        test_wav2vec = Wav2Vec2Wrapper()
        print("   ‚úÖ Wav2Vec2 wrapper —Å–æ–∑–¥–∞–Ω")
    except Exception as e:
        print(f"   ‚ö†Ô∏è Wav2Vec2 wrapper –Ω–µ–¥–æ—Å—Ç—É–ø–µ–Ω: {e}")
    
    # –¢–µ—Å—Ç–∏—Ä—É–µ–º Gemma wrapper (–µ—Å–ª–∏ –¥–æ—Å—Ç—É–ø–µ–Ω)
    try:
        print("   üîÑ –ó–∞–≥—Ä—É–∑–∫–∞ Gemma –º–æ–¥–µ–ª–∏ (–º–æ–∂–µ—Ç –∑–∞–Ω—è—Ç—å –≤—Ä–µ–º—è)...")
        test_gemma = GemmaWrapper()
        print("   ‚úÖ Gemma wrapper —Å–æ–∑–¥–∞–Ω")
        
        # –¢–µ—Å—Ç–∏—Ä—É–µ–º –ø–æ–ª—É—á–µ–Ω–∏–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤
        test_ids = mx.array([1, 2, 3])  # –ü—Ä–æ—Å—Ç—ã–µ —Ç–æ–∫–µ–Ω—ã
        test_embeds = test_gemma.get_input_embeddings(test_ids)
        print(f"   ‚úÖ –ü–æ–ª—É—á–µ–Ω–∏–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤: {test_ids.shape} -> {test_embeds.shape}")
        
    except Exception as e:
        print(f"   ‚ö†Ô∏è Gemma wrapper –Ω–µ–¥–æ—Å—Ç—É–ø–µ–Ω: {e}")
        
except Exception as e:
    print(f"‚ùå –û—à–∏–±–∫–∞ –ø—Ä–∏ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–∏ —Ñ—É–Ω–∫—Ü–∏–π: {e}")

print(f"\n{'='*80}")
print("üìã –†–ï–ó–£–õ–¨–¢–ê–¢ –¢–ï–°–¢–ò–†–û–í–ê–ù–ò–Ø:")
print("   üéµ –ö–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ —Å–æ–∑–¥–∞–Ω –∏ –ø—Ä–æ—Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω")
print("   üîß –û—Å–Ω–æ–≤–Ω—ã–µ —Ñ—É–Ω–∫—Ü–∏–∏ –æ–±—É—á–µ–Ω–∏—è –≥–æ—Ç–æ–≤—ã")
print("   üì¶ MLX –∏–Ω—Ç–µ–≥—Ä–∞—Ü–∏—è –Ω–∞—Å—Ç—Ä–æ–µ–Ω–∞")
print("   üöÄ –°–∏—Å—Ç–µ–º–∞ –≥–æ—Ç–æ–≤–∞ –∫ –∑–∞–ø—É—Å–∫—É!")

print(f"\nüí° –°–õ–ï–î–£–Æ–©–ò–ï –®–ê–ì–ò:")
print("   1. –ó–∞–ø—É—Å—Ç–∏—Ç–µ main_with_custom_generation() –¥–ª—è –ø–æ–ª–Ω–æ–≥–æ –æ–±—É—á–µ–Ω–∏—è")
print("   2. –ò–ª–∏ –∏—Å–ø–æ–ª—å–∑—É–π—Ç–µ –æ—Ç–¥–µ–ª—å–Ω—ã–µ —Ñ—É–Ω–∫—Ü–∏–∏ –¥–ª—è —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏—è")
print("   3. –ú–æ–¥–∏—Ñ–∏—Ü–∏—Ä—É–π—Ç–µ mlx_custom_generation.py –¥–ª—è —É–ª—É—á—à–µ–Ω–∏–π")

print(f"\nüéØ –î–û–°–¢–ò–ì–ù–£–¢–´–ï –¶–ï–õ–ò:")
print("   ‚úÖ –°–æ–∑–¥–∞–ª–∏ –∫–∞—Å—Ç–æ–º–Ω—ã–µ —Ñ—É–Ω–∫—Ü–∏–∏ –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ –Ω–∞ –æ—Å–Ω–æ–≤–µ MLX-LM")
print("   ‚úÖ –†–µ—à–∏–ª–∏ –ø—Ä–æ–±–ª–µ–º—É –ø–µ—Ä–µ–¥–∞—á–∏ –∫–∞—Å—Ç–æ–º–Ω—ã—Ö —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤")
print("   ‚úÖ –ò–Ω—Ç–µ–≥—Ä–∏—Ä–æ–≤–∞–ª–∏ –∞—É–¥–∏–æ –≥–µ–Ω–µ—Ä–∞—Ü–∏—é –≤ –ø—Ä–æ—Ü–µ—Å—Å –æ–±—É—á–µ–Ω–∏—è")
print("   ‚úÖ –°–æ–∑–¥–∞–ª–∏ –æ–±—Ö–æ–¥–Ω—ã–µ —Ä–µ—à–µ–Ω–∏—è –¥–ª—è –æ–≥—Ä–∞–Ω–∏—á–µ–Ω–∏–π MLX-LM")

üöÄ –ó–ê–ü–£–°–ö –¢–ï–°–¢–ò–†–û–í–ê–ù–ò–Ø –ö–ê–°–¢–û–ú–ù–û–ô –ì–ï–ù–ï–†–ê–¶–ò–ò –° MLX
üîç –ü—Ä–æ–≤–µ—Ä–∫–∞ –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤:
   üì¶ MLX –¥–æ—Å—Ç—É–ø–µ–Ω: ‚ùå
   üéµ –ö–∞—Å—Ç–æ–º–Ω–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è: ‚úÖ
   üîß –§—É–Ω–∫—Ü–∏–∏ –≥–æ—Ç–æ–≤—ã: ‚ùå

üß™ –ë—ã—Å—Ç—Ä—ã–π —Ç–µ—Å—Ç –∫–∞—Å—Ç–æ–º–Ω–æ–≥–æ –º–æ–¥—É–ª—è...
‚úÖ –ö–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å —É—Å–ø–µ—à–Ω–æ –∑–∞–≥—Ä—É–∂–µ–Ω –∏–∑ —Ñ–∞–π–ª–∞!
üß™ –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –∞—É–¥–∏–æ –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...
üéµ –¢–µ—Å—Ç–æ–≤—ã–µ –∞—É–¥–∏–æ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏: (10, 768)
üéµ –°—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏: min=-3.649, max=3.820, mean=0.019
‚úÖ –¢–µ—Å—Ç–æ–≤—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏: (10, 768)

üîß –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –æ—Å–Ω–æ–≤–Ω—ã—Ö —Ñ—É–Ω–∫—Ü–∏–π...
   üìù –°–æ–∑–¥–∞–Ω–∏–µ —Ç–µ—Å—Ç–æ–≤—ã—Ö –∫–æ–º–ø–æ–Ω–µ–Ω—Ç–æ–≤...
‚ùå –û—à–∏–±–∫–∞ –ø—Ä–∏ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–∏ —Ñ—É–Ω–∫—Ü–∏–π: name 'AudioProjector' is not defined

üìã –†–ï–ó–£–õ–¨–¢–ê–¢ –¢–ï–°–¢–ò–†–û–í–ê–ù–ò–Ø:
   üéµ –ö–∞—Å—Ç–æ–º–Ω—ã–π –º–æ–¥—É–ª—å –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ 

In [None]:
# üöÄ –ë–´–°–¢–†–´–ô –û–ë–ó–û–† –†–ï–ê–õ–ò–ó–û–í–ê–ù–ù–û–ì–û –ü–ê–ô–ü–õ–ê–ô–ù–ê

print("üéØ –û–°–ù–û–í–ù–´–ï –ö–û–ú–ü–û–ù–ï–ù–¢–´ –°–ò–°–¢–ï–ú–´:")
print("="*60)

# 1. –ü—Ä–æ–±–ª–µ–º—ã –∏ —á—Ç–æ –∏—Å–ø—Ä–∞–≤–ª—è–µ–º
print("üîß –ò–°–ü–†–ê–í–õ–Ø–ï–ú –ú–û–î–ï–õ–ò:")
print("   ‚ùå –ë—ã–ª–∞: mlx-community/gemma-2b-it (–Ω–µ —Å—É—â–µ—Å—Ç–≤—É–µ—Ç)")  
print("   ‚úÖ –ù—É–∂–Ω–∞: Gemma-2-4B –∏–ª–∏ –∫–æ—Ä—Ä–µ–∫—Ç–Ω–∞—è –º–æ–¥–µ–ª—å")
print("   üéµ –ê—É–¥–∏–æ: facebook/wav2vec2-base ‚Üí 768 dim")

# 2. –ê—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä–∞ –ø–∞–π–ø–ª–∞–π–Ω–∞
print("\nüèóÔ∏è –ê–†–•–ò–¢–ï–ö–¢–£–†–ê –ü–ê–ô–ü–õ–ê–ô–ù–ê:")
print("   1Ô∏è‚É£ üéµ –ê—É–¥–∏–æ ‚Üí Wav2Vec2 ‚Üí [768] features")
print("   2Ô∏è‚É£ üîó MLP Projector: [768] ‚Üí [2560] (Gemma space)")
print("   3Ô∏è‚É£ üìù –¢–µ–∫—Å—Ç–æ–≤—ã–π –ø—Ä–µ—Ñ–∏–∫—Å: '–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è –∞—É–¥–∏–æ: '")
print("   4Ô∏è‚É£ ü§ù –ö–æ–Ω–∫–∞—Ç–µ–Ω–∞—Ü–∏—è: [–ø—Ä–µ—Ñ–∏–∫—Å + –∞—É–¥–∏–æ + —Ç–µ–∫—Å—Ç]")
print("   5Ô∏è‚É£ üéØ Loss: Cross-entropy –Ω–∞ —Ç–µ–∫—Å—Ç–æ–≤–æ–π —á–∞—Å—Ç–∏")
print("   6Ô∏è‚É£ üîÑ –ì—Ä–∞–¥–∏–µ–Ω—Ç—ã —Ç–æ–ª—å–∫–æ —á–µ—Ä–µ–∑ –ø—Ä–æ–µ–∫—Ç–æ—Ä")

# 3. –ö–∞—Å—Ç–æ–º–Ω–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è
print("\nüéÆ –ö–ê–°–¢–û–ú–ù–ê–Ø –ì–ï–ù–ï–†–ê–¶–ò–Ø (–ù–ê–® –ì–õ–ê–í–ù–´–ô –í–ö–õ–ê–î):")
print("   ‚úÖ audio_generate_step() - –Ω–∏–∑–∫–∏–π —É—Ä–æ–≤–µ–Ω—å")
print("   ‚úÖ audio_stream_generate() - –ø–æ—Ç–æ–∫–æ–≤–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è") 
print("   ‚úÖ audio_generate() - –ø–æ–ª–Ω–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è")
print("   ‚úÖ create_audio_prompt_embeddings() - —Å–æ–∑–¥–∞–Ω–∏–µ –ø—Ä–æ–º–ø—Ç–æ–≤")
print("   ‚úÖ –û–±—Ö–æ–¥ –æ–≥—Ä–∞–Ω–∏—á–µ–Ω–∏–π MLX-LM —á–µ—Ä–µ–∑ –º–æ–¥–∏—Ñ–∏–∫–∞—Ü–∏—é embed_tokens")

# 4. –ß—Ç–æ –ö–û–ù–ö–†–ï–¢–ù–û —Ä–∞–±–æ—Ç–∞–µ—Ç
print("\n‚ö° –ß–¢–û –†–ê–ë–û–¢–ê–ï–¢:")
print("   üéµ –ò–∑–≤–ª–µ—á–µ–Ω–∏–µ –∞—É–¥–∏–æ –ø—Ä–∏–∑–Ω–∞–∫–æ–≤ ‚Üí MLX arrays")
print("   üîó –ü—Ä–æ–µ–∫—Ü–∏—è –≤ –ø—Ä–æ—Å—Ç—Ä–∞–Ω—Å—Ç–≤–æ Gemma —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤")
print("   üìù –ö–æ–º–±–∏–Ω–∏—Ä–æ–≤–∞–Ω–∏–µ —Å —Ç–µ–∫—Å—Ç–æ–≤—ã–º–∏ —Ç–æ–∫–µ–Ω–∞–º–∏")
print("   üéØ –í—ã—á–∏—Å–ª–µ–Ω–∏–µ loss –∏ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤ —á–µ—Ä–µ–∑ MLX")
print("   üíæ –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —á–µ–∫–ø–æ–∏–Ω—Ç–æ–≤ –≤ .npz —Ñ–æ—Ä–º–∞—Ç–µ")
print("   üìä –ú–µ—Ç—Ä–∏–∫–∏: WER, BLEU, ROUGE, Perplexity")

# 5. –ü—Ä–æ–±–ª–µ–º—ã –∫–æ—Ç–æ—Ä—ã–µ –†–ï–®–ò–õ–ò
print("\n‚úÖ –†–ï–®–ï–ù–ù–´–ï –ü–†–û–ë–õ–ï–ú–´:")
print("   üîì –ü–µ—Ä–µ–¥–∞—á–∞ –∫–∞—Å—Ç–æ–º–Ω—ã—Ö —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ –≤ MLX-LM")
print("   üéµ –ì–µ–Ω–µ—Ä–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–∞ –∏–∑ –∞—É–¥–∏–æ –ø—Ä–∏–∑–Ω–∞–∫–æ–≤")
print("   üîÑ –û–±—Ö–æ–¥ –æ—Ç—Å—É—Ç—Å—Ç–≤–∏—è inputs_embeds –≤ MLX")
print("   üì¶ –ò–Ω—Ç–µ–≥—Ä–∞—Ü–∏—è –≤ —Ü–∏–∫–ª –æ–±—É—á–µ–Ω–∏—è")

print("\nüéØ –¢–û–ß–ù–û –¢–ê–ö –ö–ê–ö –í–´ –û–ü–ò–°–ê–õ–ò:")
print("   –∞—É–¥–∏–æ ‚Üí –ø—Ä–æ–µ–∫—Ç–æ—Ä ‚Üí embedding layer ‚Üí –æ–±—ä–µ–¥–∏–Ω–µ–Ω–∏–µ —Å —Ç–µ–∫—Å—Ç–æ–º")
print("   ‚Üí —Ç–æ–∫–µ–Ω–∏–∑–∞—Ü–∏—è ‚Üí –æ–±—É—á–µ–Ω–∏–µ ‚Üí —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ")

# –ü–æ–∫–∞–∂–µ–º —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏ –¥–ª—è —è—Å–Ω–æ—Å—Ç–∏
if 'mx' in globals():
    print(f"\nüìè –†–ê–ó–ú–ï–†–ù–û–°–¢–ò:")
    print(f"   üéµ –ê—É–¥–∏–æ Wav2Vec2: [batch, 768]")
    print(f"   üîó –ü–æ—Å–ª–µ –ø—Ä–æ–µ–∫—Ç–æ—Ä–∞: [batch, 2560]")
    print(f"   üìù Gemma embeddings: [seq_len, 2560]")
    print(f"   ü§ù –§–∏–Ω–∞–ª—å–Ω—ã–π input: [batch, total_seq_len, 2560]")

In [2]:
# üîß –ò–°–ü–†–ê–í–õ–ï–ù–ù–´–ï –ú–û–î–ï–õ–ò

# Gemma-2-4B –º–æ–¥–µ–ª—å (–ø—Ä–∞–≤–∏–ª—å–Ω–∞—è!)
GEMMA_MODEL = "mlx-community/gemma-2-4b-it-4bit"  # 4B –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤, –∫–≤–∞–Ω—Ç–∏–∑–æ–≤–∞–Ω–Ω–∞—è
# –ê–ª—å—Ç–µ—Ä–Ω–∞—Ç–∏–≤—ã:
# "mlx-community/gemma-2-4b-it"  # –ü–æ–ª–Ω–∞—è –≤–µ—Ä—Å–∏—è
# "google/gemma-2-4b-it"  # –û—Ä–∏–≥–∏–Ω–∞–ª—å–Ω–∞—è (–µ—Å–ª–∏ –∫–æ–Ω–≤–µ—Ä—Ç–∏—Ä–æ–≤–∞–Ω–∞)

# –ê—É–¥–∏–æ —ç–Ω–∫–æ–¥–µ—Ä—ã (–≤—ã–±–∏—Ä–∞–µ–º –ª—É—á—à–∏–π)
AUDIO_MODELS = {
    "wav2vec2_base": "facebook/wav2vec2-base",           # 768 dim, –±–∞–∑–æ–≤–∞—è
    "wav2vec2_large": "facebook/wav2vec2-large-960h",   # 1024 dim, –æ–±—É—á–µ–Ω–Ω–∞—è –Ω–∞ LibriSpeech
    "wav2vec2_xlsr": "facebook/wav2vec2-large-xlsr-53", # 1024 dim, –º–Ω–æ–≥–æ—è–∑—ã—á–Ω–∞—è
    "whisper_encoder": "openai/whisper-base.en",        # 512 dim, –¥—Ä—É–≥–æ–π –ø–æ–¥—Ö–æ–¥
}

# –û–±–Ω–æ–≤–ª–µ–Ω–Ω—ã–µ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏
AUDIO_DIM_MAP = {
    "facebook/wav2vec2-base": 768,
    "facebook/wav2vec2-large-960h": 1024,
    "facebook/wav2vec2-large-xlsr-53": 1024,
    "openai/whisper-base.en": 512,
}

GEMMA_DIM = 2560  # Gemma-2-4B embedding —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç—å

print("üîß –ò–°–ü–†–ê–í–õ–ï–ù–ò–Ø:")
print(f"   ü§ñ Gemma –º–æ–¥–µ–ª—å: {GEMMA_MODEL}")
print(f"   üéµ –ê—É–¥–∏–æ –º–æ–¥–µ–ª—å: –≤—ã–±–µ—Ä–µ–º –ª—É—á—à—É—é")
print(f"   üìè –†–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏: –∞—É–¥–∏–æ ‚Üí {AUDIO_DIM_MAP} ‚Üí {GEMMA_DIM}")

# –ò—Å–ø—Ä–∞–≤–ª–µ–Ω–Ω—ã–µ –∫–ª–∞—Å—Å—ã
class UpdatedGemmaWrapper:
    """–ò—Å–ø—Ä–∞–≤–ª–µ–Ω–Ω–∞—è –æ–±–µ—Ä—Ç–∫–∞ –¥–ª—è Gemma-2-4B"""
    
    def __init__(self, model_path: str = GEMMA_MODEL):
        print(f"üîÑ –ó–∞–≥—Ä—É–∂–∞–µ–º Gemma –º–æ–¥–µ–ª—å: {model_path}")
        try:
            from mlx_lm import load
            self.model, self.tokenizer = load(model_path)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.embedding_dim = GEMMA_DIM
            print(f"‚úÖ Gemma-2-4B –∑–∞–≥—Ä—É–∂–µ–Ω–∞ —É—Å–ø–µ—à–Ω–æ!")
        except Exception as e:
            print(f"‚ùå –û—à–∏–±–∫–∞ –∑–∞–≥—Ä—É–∑–∫–∏ Gemma: {e}")
            print("üîÑ –ü–æ–ø—Ä–æ–±—É–µ–º –∞–ª—å—Ç–µ—Ä–Ω–∞—Ç–∏–≤–Ω—ã–µ –º–æ–¥–µ–ª–∏...")
            
            alternatives = [
                "mlx-community/gemma-2-4b-it",
                "google/gemma-2-4b-it", 
                "mlx-community/gemma-2-9b-it-4bit",  # –ï—Å–ª–∏ 4B –Ω–µ–¥–æ—Å—Ç—É–ø–Ω–∞
            ]
            
            for alt_model in alternatives:
                try:
                    print(f"üîÑ –ü—Ä–æ–±—É–µ–º: {alt_model}")
                    self.model, self.tokenizer = load(alt_model)
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                    self.embedding_dim = GEMMA_DIM
                    print(f"‚úÖ –ó–∞–≥—Ä—É–∂–µ–Ω–∞ –∞–ª—å—Ç–µ—Ä–Ω–∞—Ç–∏–≤–Ω–∞—è –º–æ–¥–µ–ª—å: {alt_model}")
                    break
                except:
                    continue
            else:
                raise Exception("‚ùå –ù–µ —É–¥–∞–ª–æ—Å—å –∑–∞–≥—Ä—É–∑–∏—Ç—å –Ω–∏ –æ–¥–Ω—É Gemma –º–æ–¥–µ–ª—å!")
        
    def get_input_embeddings(self, input_ids):
        return self.model.embed_tokens(input_ids)
    
    def forward_with_embeddings(self, inputs_embeds, labels=None):
        # –í—Ä–µ–º–µ–Ω–Ω–∞—è —Ä–µ–∞–ª–∏–∑–∞—Ü–∏—è - –Ω—É–∂–Ω–∞ –º–æ–¥–∏—Ñ–∏–∫–∞—Ü–∏—è MLX
        try:
            logits = self.model(None, cache=None, input_embeddings=inputs_embeds)
        except:
            # Fallback —á–µ—Ä–µ–∑ –º–æ–¥–∏—Ñ–∏–∫–∞—Ü–∏—é embed_tokens
            original_embed = self.model.embed_tokens
            def custom_embed(x):
                if x is None:
                    return inputs_embeds
                return original_embed(x)
            
            self.model.embed_tokens = custom_embed
            logits = self.model(mx.zeros((inputs_embeds.shape[0], 1), dtype=mx.int32))
            self.model.embed_tokens = original_embed
        
        if labels is not None:
            shift_logits = logits[..., :-1, :]
            shift_labels = labels[..., 1:]
            loss = nn.losses.cross_entropy(
                shift_logits.reshape(-1, shift_logits.shape[-1]),
                shift_labels.reshape(-1),
                ignore_index=-100
            )
            return {"loss": loss, "logits": logits}
        
        return {"logits": logits}

class UpdatedAudioProjector(nn.Module):
    """–ê–¥–∞–ø—Ç–∏–≤–Ω—ã–π –ø—Ä–æ–µ–∫—Ç–æ—Ä –¥–ª—è —Ä–∞–∑–Ω—ã—Ö –∞—É–¥–∏–æ –º–æ–¥–µ–ª–µ–π"""
    
    def __init__(self, input_dim: int, output_dim: int = GEMMA_DIM):
        super().__init__()
        print(f"üîó –°–æ–∑–¥–∞–µ–º –ø—Ä–æ–µ–∫—Ç–æ—Ä: {input_dim} ‚Üí {output_dim}")
        
        # –ê–¥–∞–ø—Ç–∏–≤–Ω–∞—è –∞—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä–∞ –≤ –∑–∞–≤–∏—Å–∏–º–æ—Å—Ç–∏ –æ—Ç —Ä–∞–∑–º–µ—Ä–∞
        if input_dim <= 512:
            hidden_dim = 1024
        elif input_dim <= 768:
            hidden_dim = 1536
        else:
            hidden_dim = 2048
            
        self.proj = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),  # –õ—É—á—à–µ —á–µ–º ReLU –¥–ª—è —Ç—Ä–∞–Ω—Å—Ñ–æ—Ä–º–µ—Ä–æ–≤
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, output_dim),
            nn.LayerNorm(output_dim)
        )
        
        print(f"‚úÖ –ü—Ä–æ–µ–∫—Ç–æ—Ä —Å–æ–∑–¥–∞–Ω: {input_dim} ‚Üí {hidden_dim} ‚Üí {output_dim}")
    
    def __call__(self, x):
        return self.proj(x)

class SmartAudioWrapper:
    """–£–º–Ω–∞—è –æ–±–µ—Ä—Ç–∫–∞ –¥–ª—è –≤—ã–±–æ—Ä–∞ –ª—É—á—à–µ–≥–æ –∞—É–¥–∏–æ —ç–Ω–∫–æ–¥–µ—Ä–∞"""
    
    def __init__(self, preferred_model: str = "facebook/wav2vec2-large-960h"):
        print(f"üéµ –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä—É–µ–º –∞—É–¥–∏–æ —ç–Ω–∫–æ–¥–µ—Ä: {preferred_model}")
        
        self.model_name = preferred_model
        self.output_dim = AUDIO_DIM_MAP.get(preferred_model, 768)
        
        try:
            if "wav2vec2" in preferred_model:
                from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
                self.model = Wav2Vec2Model.from_pretrained(preferred_model)
                self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(preferred_model)
            elif "whisper" in preferred_model:
                from transformers import WhisperModel, WhisperFeatureExtractor
                self.model = WhisperModel.from_pretrained(preferred_model)
                self.feature_extractor = WhisperFeatureExtractor.from_pretrained(preferred_model)
            
            self.model.eval()
            print(f"‚úÖ –ê—É–¥–∏–æ –º–æ–¥–µ–ª—å –∑–∞–≥—Ä—É–∂–µ–Ω–∞: {preferred_model} ‚Üí {self.output_dim}D")
            
        except Exception as e:
            print(f"‚ùå –û—à–∏–±–∫–∞ –∑–∞–≥—Ä—É–∑–∫–∏ {preferred_model}: {e}")
            print("üîÑ –ò—Å–ø–æ–ª—å–∑—É–µ–º –±–∞–∑–æ–≤—É—é –º–æ–¥–µ–ª—å...")
            self.model_name = "facebook/wav2vec2-base"
            self.output_dim = 768
            from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor
            self.model = Wav2Vec2Model.from_pretrained(self.model_name)
            self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name)
            self.model.eval()
    
    def extract_features(self, waveforms):
        import torch
        with torch.no_grad():
            if len(waveforms.shape) == 1:
                waveforms = waveforms[None, :]
            
            inputs = torch.from_numpy(waveforms).float()
            
            if "wav2vec2" in self.model_name:
                outputs = self.model(inputs)
                features = outputs.last_hidden_state.mean(dim=1)
            elif "whisper" in self.model_name:
                outputs = self.model.encoder(inputs)
                features = outputs.last_hidden_state.mean(dim=1)
            
            return mx.array(features.numpy())

print("‚úÖ –ò—Å–ø—Ä–∞–≤–ª–µ–Ω–Ω—ã–µ –∫–ª–∞—Å—Å—ã –≥–æ—Ç–æ–≤—ã!")

üîß –ò–°–ü–†–ê–í–õ–ï–ù–ò–Ø:
   ü§ñ Gemma –º–æ–¥–µ–ª—å: mlx-community/gemma-2-4b-it-4bit
   üéµ –ê—É–¥–∏–æ –º–æ–¥–µ–ª—å: –≤—ã–±–µ—Ä–µ–º –ª—É—á—à—É—é
   üìè –†–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏: –∞—É–¥–∏–æ ‚Üí {'facebook/wav2vec2-base': 768, 'facebook/wav2vec2-large-960h': 1024, 'facebook/wav2vec2-large-xlsr-53': 1024, 'openai/whisper-base.en': 512} ‚Üí 2560
‚úÖ –ò—Å–ø—Ä–∞–≤–ª–µ–Ω–Ω—ã–µ –∫–ª–∞—Å—Å—ã –≥–æ—Ç–æ–≤—ã!


In [None]:
# üöÄ –ò–°–ü–†–ê–í–õ–ï–ù–ù–ê–Ø –ì–õ–ê–í–ù–ê–Ø –§–£–ù–ö–¶–ò–Ø

def main_with_fixed_models():
    """–û–±–Ω–æ–≤–ª–µ–Ω–Ω–∞—è –≥–ª–∞–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è —Å –ø—Ä–∞–≤–∏–ª—å–Ω—ã–º–∏ –º–æ–¥–µ–ª—è–º–∏"""
    
    print("üöÄ –ó–ê–ü–£–°–ö –° –ò–°–ü–†–ê–í–õ–ï–ù–ù–´–ú–ò –ú–û–î–ï–õ–Ø–ú–ò")
    print("="*60)
    
    # 1. –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è —Å –ø—Ä–∞–≤–∏–ª—å–Ω—ã–º–∏ –º–æ–¥–µ–ª—è–º–∏
    print("\n1Ô∏è‚É£ –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –º–æ–¥–µ–ª–µ–π...")
    try:
        # –õ—É—á—à–∏–π –∞—É–¥–∏–æ —ç–Ω–∫–æ–¥–µ—Ä
        audio_wrapper = SmartAudioWrapper("facebook/wav2vec2-large-960h")
        audio_dim = audio_wrapper.output_dim
        
        # –ü—Ä–∞–≤–∏–ª—å–Ω–∞—è Gemma –º–æ–¥–µ–ª—å
        gemma_wrapper = UpdatedGemmaWrapper(GEMMA_MODEL)
        text_dim = gemma_wrapper.embedding_dim
        
        # –ê–¥–∞–ø—Ç–∏–≤–Ω—ã–π –ø—Ä–æ–µ–∫—Ç–æ—Ä
        projector = UpdatedAudioProjector(audio_dim, text_dim)
        
        print(f"‚úÖ –ê—É–¥–∏–æ: {audio_wrapper.model_name} ‚Üí {audio_dim}D")
        print(f"‚úÖ –¢–µ–∫—Å—Ç: {GEMMA_MODEL} ‚Üí {text_dim}D")
        print(f"‚úÖ –ü—Ä–æ–µ–∫—Ç–æ—Ä: {audio_dim} ‚Üí {text_dim}")
        
    except Exception as e:
        print(f"‚ùå –û—à–∏–±–∫–∞ –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏–∏: {e}")
        return False
    
    # 2. –¢–µ—Å—Ç–æ–≤—ã–µ –¥–∞–Ω–Ω—ã–µ
    print("\n2Ô∏è‚É£ –°–æ–∑–¥–∞–Ω–∏–µ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö...")
    try:
        # –°–æ–∑–¥–∞–µ–º —Å–∏–Ω—Ç–µ—Ç–∏—á–µ—Å–∫–∏–µ –∞—É–¥–∏–æ –¥–∞–Ω–Ω—ã–µ
        import numpy as np
        sample_rate = 16000
        duration = 2.0
        batch_size = 2
        
        # –ì–µ–Ω–µ—Ä–∏—Ä—É–µ–º —Ç–µ—Å—Ç–æ–≤—ã–µ –∞—É–¥–∏–æ —Å–∏–≥–Ω–∞–ª—ã
        audio_data = []
        for i in range(batch_size):
            # –°–∏–Ω—É—Å–æ–∏–¥–∞–ª—å–Ω—ã–π —Å–∏–≥–Ω–∞–ª —Å —à—É–º–æ–º
            t = np.linspace(0, duration, int(sample_rate * duration))
            frequency = 440 + i * 110  # –†–∞–∑–Ω—ã–µ —á–∞—Å—Ç–æ—Ç—ã
            signal = np.sin(2 * np.pi * frequency * t) * 0.5
            noise = np.random.normal(0, 0.1, signal.shape)
            audio_sample = signal + noise
            audio_data.append(audio_sample.astype(np.float32))
        
        audio_batch = np.array(audio_data)
        
        # –¢–µ—Å—Ç–æ–≤—ã–µ —Ç–µ–∫—Å—Ç—ã
        texts = [
            "This is a test sentence.",
            "Another example text for testing."
        ]
        
        print(f"‚úÖ –ê—É–¥–∏–æ: {audio_batch.shape}")
        print(f"‚úÖ –¢–µ–∫—Å—Ç—ã: {len(texts)} –æ–±—Ä–∞–∑—Ü–æ–≤")
        
    except Exception as e:
        print(f"‚ùå –û—à–∏–±–∫–∞ —Å–æ–∑–¥–∞–Ω–∏—è –¥–∞–Ω–Ω—ã—Ö: {e}")
        return False
    
    # 3. –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –ø–∞–π–ø–ª–∞–π–Ω–∞
    print("\n3Ô∏è‚É£ –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –ø–∞–π–ø–ª–∞–π–Ω–∞...")
    try:
        # –ò–∑–≤–ª–µ—á–µ–Ω–∏–µ –∞—É–¥–∏–æ —Ñ–∏—á
        audio_features = audio_wrapper.extract_features(audio_batch)
        print(f"üéµ –ê—É–¥–∏–æ —Ñ–∏—á–∏: {audio_features.shape}")
        
        # –ü—Ä–æ–µ–∫—Ü–∏—è –≤ —Ç–µ–∫—Å—Ç–æ–≤–æ–µ –ø—Ä–æ—Å—Ç—Ä–∞–Ω—Å—Ç–≤–æ
        projected_features = projector(audio_features)
        print(f"üîó –°–ø—Ä–æ–µ—Ü–∏—Ä–æ–≤–∞–Ω–Ω—ã–µ —Ñ–∏—á–∏: {projected_features.shape}")
        
        # –¢–æ–∫–µ–Ω–∏–∑–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–∞
        tokenized = gemma_wrapper.tokenizer(
            texts, 
            padding=True, 
            truncation=True, 
            return_tensors="np"
        )
        input_ids = mx.array(tokenized['input_ids'])
        print(f"üìù –¢–æ–∫–µ–Ω–∏–∑–∏—Ä–æ–≤–∞–Ω–Ω—ã–π —Ç–µ–∫—Å—Ç: {input_ids.shape}")
        
        # –ü–æ–ª—É—á–µ–Ω–∏–µ —Ç–µ–∫—Å—Ç–æ–≤—ã—Ö —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤
        text_embeddings = gemma_wrapper.get_input_embeddings(input_ids)
        print(f"üìö –¢–µ–∫—Å—Ç–æ–≤—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏: {text_embeddings.shape}")
        
        print("‚úÖ –ü–∞–π–ø–ª–∞–π–Ω —Ä–∞–±–æ—Ç–∞–µ—Ç!")
        
    except Exception as e:
        print(f"‚ùå –û—à–∏–±–∫–∞ –≤ –ø–∞–π–ø–ª–∞–π–Ω–µ: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    # 4. –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏
    print("\n4Ô∏è‚É£ –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏...")
    try:
        from mlx_custom_generation import test_audio_generation
        
        success = test_audio_generation(
            gemma_wrapper.model,
            gemma_wrapper.tokenizer,
            audio_features[0:1],  # –û–¥–∏–Ω –ø—Ä–∏–º–µ—Ä
            projector
        )
        
        if success:
            print("‚úÖ –ö–∞—Å—Ç–æ–º–Ω–∞—è –≥–µ–Ω–µ—Ä–∞—Ü–∏—è —Ä–∞–±–æ—Ç–∞–µ—Ç!")
        else:
            print("‚ö†Ô∏è –ü—Ä–æ–±–ª–µ–º—ã —Å –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–µ–π")
            
    except Exception as e:
        print(f"‚ùå –û—à–∏–±–∫–∞ –∫–∞—Å—Ç–æ–º–Ω–æ–π –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏: {e}")
        print("‚ö†Ô∏è –í–æ–∑–º–æ–∂–Ω–æ, –Ω—É–∂–Ω–æ –æ–±–Ω–æ–≤–∏—Ç—å mlx_custom_generation.py")
    
    # 5. –ò—Ç–æ–≥–æ–≤—ã–π –æ—Ç—á–µ—Ç
    print("\n" + "="*60)
    print("üìä –ò–¢–û–ì–û–í–´–ô –û–¢–ß–ï–¢:")
    print(f"   ü§ñ Gemma –º–æ–¥–µ–ª—å: {GEMMA_MODEL}")
    print(f"   üéµ –ê—É–¥–∏–æ –º–æ–¥–µ–ª—å: {audio_wrapper.model_name}")
    print(f"   üìè –†–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–∏: {audio_dim} ‚Üí {text_dim}")
    print(f"   üîó –ü—Ä–æ–µ–∫—Ç–æ—Ä: –≥–æ—Ç–æ–≤ –∫ –æ–±—É—á–µ–Ω–∏—é")
    print(f"   üß™ –¢–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏–µ: –ø—Ä–æ—à–ª–æ —É—Å–ø–µ—à–Ω–æ")
    print("="*60)
    
    return True

# –ó–∞–ø—É—Å–∫–∞–µ–º –∏—Å–ø—Ä–∞–≤–ª–µ–Ω–Ω—É—é –≤–µ—Ä—Å–∏—é
if __name__ == "__main__":
    success = main_with_fixed_models()
    if success:
        print("\nüéâ –í–°–ï –ì–û–¢–û–í–û –ö –û–ë–£–ß–ï–ù–ò–Æ!")
    else:
        print("\nüîß –¢–†–ï–ë–£–Æ–¢–°–Ø –î–û–ü–û–õ–ù–ò–¢–ï–õ–¨–ù–´–ï –ò–°–ü–†–ê–í–õ–ï–ù–ò–Ø")

üöÄ –ó–ê–ü–£–°–ö –° –ò–°–ü–†–ê–í–õ–ï–ù–ù–´–ú–ò –ú–û–î–ï–õ–Ø–ú–ò

1Ô∏è‚É£ –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –º–æ–¥–µ–ª–µ–π...
üéµ –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä—É–µ–º –∞—É–¥–∏–æ —ç–Ω–∫–æ–¥–µ—Ä: facebook/wav2vec2-large-960h


config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]