In [None]:
import sys
from tqdm import tqdm
import subprocess

def install_deps(package_file):
    try:
        with open(package_file, 'r') as f:
            packages = [line.strip() for line in f if line.strip() and not line.startswith('#')]
    except FileNotFoundError:
        return
    
    for package in tqdm(packages, desc="üì• –û–±–Ω–æ–≤–ª–µ–Ω–∏–µ –ø–∞–∫–µ—Ç–æ–≤"):
        try:
            subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", package, "-q"], check=True, capture_output=True)
        except subprocess.CalledProcessError:
            pass
    
    print("‚úÖ Dependencies updated")

install_deps("requirements.txt")

In [None]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import json
import os
import gc
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.amp import autocast, GradScaler
from transformers import AutoConfig, AutoTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Model
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import jiwer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from huggingface_hub import notebook_login, login
from sklearn.model_selection import train_test_split
import pandas as pd
from datetime import datetime
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
from IPython.display import Audio, display
import zipfile
import io
import wandb
import glob
import random
import bitsandbytes as bnb

def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
def collate_fn(batch):
    input_values = [item['input_values'] for item in batch]
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    input_values = pad_sequence(input_values, batch_first=True)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=-100)
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    return {
        'input_values': input_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

def process_batch(batch, model, projector, wav2vec2, tokenizer, prefix_embeds, device, compression_rate_k):
    input_values = batch["input_values"].to(device, dtype=torch.bfloat16)
    input_ids = batch["input_ids"].to(device)

    with autocast('cuda' if torch.cuda.is_available() else 'cpu'):
        audio_embeds = wav2vec2(input_values).last_hidden_state
        
        compressed_audio = compress_audio_features(audio_embeds, compression_rate_k)
        del audio_embeds  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å
        
        projected_audio = projector(compressed_audio)
        del compressed_audio  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å
        # Ensure projected audio is in bfloat16 for consistency
        projected_audio = projected_audio.to(dtype=torch.bfloat16)
        
        batch_prefix_embeds = prefix_embeds.expand(projected_audio.size(0), -1, -1)
        
        prompt_embeds = torch.cat([batch_prefix_embeds, projected_audio], dim=1)
        
        embedding_input_ids = input_ids.clone()
        embedding_input_ids[embedding_input_ids == -100] = tokenizer.pad_token_id
        target_embeds = model.get_input_embeddings()(embedding_input_ids)
        del embedding_input_ids  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å

        inputs_embeds = torch.cat([prompt_embeds, target_embeds], dim=1)
        del target_embeds  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å
        
        prompt_len = prompt_embeds.shape[1]
        prompt_labels = torch.full((projected_audio.size(0), prompt_len), -100, device=device, dtype=torch.long)
        del projected_audio  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å
        
        labels = torch.cat([prompt_labels, input_ids], dim=1)
        del prompt_labels  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å

        outputs = model(inputs_embeds=inputs_embeds, labels=labels)
        del inputs_embeds, labels  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å
        
    return outputs, prompt_embeds

In [None]:
def force_gpu_cleanup():
    """–ü—Ä–∏–Ω—É–¥–∏—Ç–µ–ª—å–Ω–∞—è –æ—á–∏—Å—Ç–∫–∞ GPU –ø–∞–º—è—Ç–∏"""
    import gc
    import torch
    
    # 1. –£–¥–∞–ª—è–µ–º –≤—Å–µ –∏–∑–≤–µ—Å—Ç–Ω—ã–µ –ø–µ—Ä–µ–º–µ–Ω–Ω—ã–µ
    variables_to_delete = [
        'gemma_model', 'wav2vec2', 'projector', 'train_loader', 'val_loader',
        'optimizer', 'scheduler', 'scaler', 'train_dataset', 'val_dataset',
        'prefix_embeds', 'tokenizer', 'feature_extractor', 'logger',
        'quantization_config', 'lora_config'
    ]
    
    deleted_count = 0
    for var_name in variables_to_delete:
        if var_name in globals():
            try:
                # –î–ª—è –º–æ–¥–µ–ª–µ–π PyTorch - –ø–µ—Ä–µ–º–µ—â–∞–µ–º –Ω–∞ CPU –ø–µ—Ä–µ–¥ —É–¥–∞–ª–µ–Ω–∏–µ–º
                obj = globals()[var_name]
                if hasattr(obj, 'cpu'):
                    obj.cpu()
                if hasattr(obj, 'to'):
                    obj.to('cpu')
                del globals()[var_name]
                deleted_count += 1
            except:
                pass
    
    # 2. –£–¥–∞–ª—è–µ–º –≤—Å–µ —Ç–µ–Ω–∑–æ—Ä—ã –∏–∑ –∫—ç—à–∞
    torch._C._cuda_clearCublasWorkspaces()
    
    # 3. –û—á–∏—â–∞–µ–º –∫—ç—à –∞–≤—Ç–æ–≥—Ä–∞–¥–æ–≤
    if hasattr(torch.autograd, 'set_grad_enabled'):
        torch.autograd.set_grad_enabled(False)
        torch.autograd.set_grad_enabled(True)
    
    # 4. –ü—Ä–∏–Ω—É–¥–∏—Ç–µ–ª—å–Ω—ã–π —Å–±–æ—Ä –º—É—Å–æ—Ä–∞ (–Ω–µ—Å–∫–æ–ª—å–∫–æ —Ä–∞–∑)
    for _ in range(5):
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # 5. –û—á–∏—Å—Ç–∫–∞ GPU –ø–∞–º—è—Ç–∏
    if torch.cuda.is_available():
        # –°–∏–Ω—Ö—Ä–æ–Ω–∏–∑–∏—Ä—É–µ–º –≤—Å–µ –æ–ø–µ—Ä–∞—Ü–∏–∏
        torch.cuda.synchronize()
        
        # –û—á–∏—â–∞–µ–º –≤—Å–µ –∫—ç—à–∏
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        
        # –°–±—Ä–∞—Å—ã–≤–∞–µ–º —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        
        # –§–∏–Ω–∞–ª—å–Ω—ã–π —Å–±–æ—Ä –º—É—Å–æ—Ä–∞
        gc.collect()
        
        # –ü–æ–ª—É—á–∞–µ–º —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫—É –ø–∞–º—è—Ç–∏
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        
        print(f"üßπ –ü—Ä–∏–Ω—É–¥–∏—Ç–µ–ª—å–Ω–∞—è –æ—á–∏—Å—Ç–∫–∞ GPU –ø–∞–º—è—Ç–∏:")
        print(f"   üìä –£–¥–∞–ª–µ–Ω–æ –ø–µ—Ä–µ–º–µ–Ω–Ω—ã—Ö: {deleted_count}")
        print(f"   üíæ –í—ã–¥–µ–ª–µ–Ω–æ —Å–µ–π—á–∞—Å: {allocated:.2f} GB")
        print(f"   üîí –ó–∞—Ä–µ–∑–µ—Ä–≤–∏—Ä–æ–≤–∞–Ω–æ: {reserved:.2f} GB")
        
        return allocated, reserved
    else:
        print(f"üßπ CPU –æ—á–∏—Å—Ç–∫–∞ –∑–∞–≤–µ—Ä—à–µ–Ω–∞ (—É–¥–∞–ª–µ–Ω–æ –ø–µ—Ä–µ–º–µ–Ω–Ω—ã—Ö: {deleted_count})")
        return 0, 0

# –ó–∞–ø—É—Å–∫–∞–µ–º –ø—Ä–∏–Ω—É–¥–∏—Ç–µ–ª—å–Ω—É—é –æ—á–∏—Å—Ç–∫—É
force_gpu_cleanup()

In [None]:
def load_checkpoint(path, projector, gemma_model, optimizer, scheduler, device, batch_size):
    global best_val_loss
    checkpoint = torch.load(path, map_location=device)
    
    try:
        projector.load_state_dict(checkpoint['projector_state_dict'])
    except RuntimeError as e:
        print(f"‚ö†Ô∏è –ù–µ—Å–æ–≤–º–µ—Å—Ç–∏–º–æ—Å—Ç—å –ø—Ä–æ–µ–∫—Ç–æ—Ä–∞: {e}")
        print("üîÑ –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –ø—Ä–æ–µ–∫—Ç–æ—Ä–∞ –Ω–æ–≤—ã–º–∏ –≤–µ—Å–∞–º–∏, —Ç.–∫. –∫–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏—è (–Ω–∞–ø—Ä. compression_rate_k) –∏–∑–º–µ–Ω–∏–ª–∞—Å—å.")
        wandb.log({"checkpoint/projector_reinitialized": True})
    
    if 'lora_state_dict' in checkpoint:
        gemma_model.load_state_dict(checkpoint['lora_state_dict'], strict=False)
    
    print("üîÑ –ü—Ä–æ–ø—É—Å–∫–∞–µ–º –∑–∞–≥—Ä—É–∑–∫—É —Å–æ—Å—Ç–æ—è–Ω–∏—è –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä–∞ –¥–ª—è —ç–∫–æ–Ω–æ–º–∏–∏ –ø–∞–º—è—Ç–∏. –û–Ω –±—É–¥–µ—Ç –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω –∑–∞–Ω–æ–≤–æ.")
    wandb.log({"checkpoint/optimizer_reset_manual": True})
    
    try:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    except Exception as e:
        print(f"‚ö†Ô∏è Scheduler error: {e}")
    
    start_epoch = checkpoint['epoch']
    saved_step = checkpoint['step']
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    
    # –û–ë–†–ê–¢–ù–ê–Ø –°–û–í–ú–ï–°–¢–ò–ú–û–°–¢–¨: config –º–æ–∂–µ—Ç –Ω–µ –±—ã—Ç—å –≤ —Å—Ç–∞—Ä—ã—Ö —á–µ–∫–ø–æ–∏–Ω—Ç–∞—Ö
    config = checkpoint.get('config', {})
    prev_batch_size = config.get('batch_size', batch_size)
    
    # –ê–í–¢–û–ú–ê–¢–ò–ß–ï–°–ö–ò–ô –†–ê–°–ß–ï–¢ batch_idx –¥–ª—è —Å—Ç–∞—Ä—ã—Ö —á–µ–∫–ø–æ–∏–Ω—Ç–æ–≤
    if 'batch_idx' in checkpoint:
        # –ù–æ–≤—ã–π —á–µ–∫–ø–æ–∏–Ω—Ç - –±–µ—Ä–µ–º —Å–æ—Ö—Ä–∞–Ω–µ–Ω–Ω—ã–π batch_idx
        batch_idx = checkpoint['batch_idx']
        checkpoint_version = "new"
    else:
        # –°—Ç–∞—Ä—ã–π —á–µ–∫–ø–æ–∏–Ω—Ç - –≤—ã—á–∏—Å–ª—è–µ–º batch_idx –∞–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–∏
        checkpoint_version = "legacy"
        
        # –ú–µ—Ç–æ–¥ 1: –ü—ã—Ç–∞–µ–º—Å—è –∏–∑–≤–ª–µ—á—å –∏–∑ –∏–º–µ–Ω–∏ —Ñ–∞–π–ª–∞ (latest_checkpoint_bs4_epoch_4_step_5500.pt)
        import re
        filename = os.path.basename(path)
        match = re.search(r'epoch_(\d+)_step_(\d+)', filename)
        
        if match:
            file_epoch = int(match.group(1))
            file_step = int(match.group(2))
            
            # –ò–°–ü–†–ê–í–õ–ï–ù–ù–´–ô –†–ê–°–ß–ï–¢: –ø—Ä–∞–≤–∏–ª—å–Ω–æ –≤—ã—á–∏—Å–ª—è–µ–º –ø–æ–∑–∏—Ü–∏—é –≤ —Ç–µ–∫—É—â–µ–π —ç–ø–æ—Ö–µ
            steps_per_epoch_estimate = len(train_data) // batch_size if 'train_data' in globals() else 2000
            
            # –°–∫–æ–ª—å–∫–æ –ø–æ–ª–Ω—ã—Ö —ç–ø–æ—Ö –ø—Ä–æ—à–ª–æ (—ç–ø–æ—Ö–∏ —Å—á–∏—Ç–∞—é—Ç—Å—è —Å 1, –ø–æ—ç—Ç–æ–º—É file_epoch - 1)
            completed_epochs = file_epoch - 1
            steps_in_completed_epochs = completed_epochs * steps_per_epoch_estimate
            
            # batch_idx = –ø–æ–∑–∏—Ü–∏—è –≤ —Ç–µ–∫—É—â–µ–π —ç–ø–æ—Ö–µ
            batch_idx = file_step - steps_in_completed_epochs
            
            # –ü—Ä–æ–≤–µ—Ä—è–µ–º, —á—Ç–æ batch_idx –≤ —Ä–∞–∑—É–º–Ω—ã—Ö –ø—Ä–µ–¥–µ–ª–∞—Ö
            if batch_idx < 0:
                batch_idx = 0
            elif batch_idx >= steps_per_epoch_estimate:
                batch_idx = steps_per_epoch_estimate - 1
            
            print(f"üì¶ Legacy —á–µ–∫–ø–æ–∏–Ω—Ç: –∏–∑–≤–ª–µ—á–µ–Ω–æ –∏–∑ –∏–º–µ–Ω–∏ —Ñ–∞–π–ª–∞ epoch={file_epoch}, step={file_step}")
            print(f"üìä –†–∞—Å—á–µ—Ç: {completed_epochs} –ø–æ–ª–Ω—ã—Ö —ç–ø–æ—Ö √ó {steps_per_epoch_estimate} = {steps_in_completed_epochs} —à–∞–≥–æ–≤")
            print(f"üìä –ü–æ–∑–∏—Ü–∏—è –≤ —ç–ø–æ—Ö–µ {file_epoch}: batch_idx = {file_step} - {steps_in_completed_epochs} = {batch_idx}")
        else:
            # –ú–µ—Ç–æ–¥ 2: –ò—Å–ø–æ–ª—å–∑—É–µ–º start_epoch –∏ saved_step –¥–ª—è –ø—Ä–∞–≤–∏–ª—å–Ω–æ–≥–æ —Ä–∞—Å—á–µ—Ç–∞
            steps_per_epoch_estimate = len(train_data) // batch_size if 'train_data' in globals() else 2000
            
            # –°–∫–æ–ª—å–∫–æ –ø–æ–ª–Ω—ã—Ö —ç–ø–æ—Ö –ø—Ä–æ—à–ª–æ
            completed_epochs = start_epoch
            steps_in_completed_epochs = completed_epochs * steps_per_epoch_estimate
            
            # batch_idx = –ø–æ–∑–∏—Ü–∏—è –≤ —Ç–µ–∫—É—â–µ–π —ç–ø–æ—Ö–µ
            batch_idx = saved_step - steps_in_completed_epochs
            
            # –ü—Ä–æ–≤–µ—Ä—è–µ–º, —á—Ç–æ batch_idx –≤ —Ä–∞–∑—É–º–Ω—ã—Ö –ø—Ä–µ–¥–µ–ª–∞—Ö
            if batch_idx < 0:
                batch_idx = 0
            elif batch_idx >= steps_per_epoch_estimate:
                batch_idx = steps_per_epoch_estimate - 1
            
            print(f"üì¶ Legacy —á–µ–∫–ø–æ–∏–Ω—Ç: –Ω–µ —É–¥–∞–ª–æ—Å—å –∏–∑–≤–ª–µ—á—å –∏–∑ –∏–º–µ–Ω–∏, –∏—Å–ø–æ–ª—å–∑—É–µ–º saved_step={saved_step}")
            print(f"üìä –†–∞—Å—á–µ—Ç: {completed_epochs} –ø–æ–ª–Ω—ã—Ö —ç–ø–æ—Ö √ó {steps_per_epoch_estimate} = {steps_in_completed_epochs} —à–∞–≥–æ–≤")
            print(f"üìä –ü–æ–∑–∏—Ü–∏—è –≤ —ç–ø–æ—Ö–µ {start_epoch + 1}: batch_idx = {saved_step} - {steps_in_completed_epochs} = {batch_idx}")
        
        wandb.log({
            "checkpoint/legacy_batch_idx_calculated": True,
            "checkpoint/calculated_batch_idx": batch_idx,
            "checkpoint/filename": filename
        })
    
    if prev_batch_size != batch_size:
        total_samples_seen = saved_step * prev_batch_size
        adjusted_step = total_samples_seen // batch_size
        
        # –ü–µ—Ä–µ—Å—á–∏—Ç—ã–≤–∞–µ–º batch_idx –ø—Ä–∏ –∏–∑–º–µ–Ω–µ–Ω–∏–∏ batch_size
        if checkpoint_version == "legacy":
            steps_per_epoch_estimate = len(train_data) // batch_size if 'train_data' in globals() else 2000
            
            # –ü—Ä–∞–≤–∏–ª—å–Ω–æ –≤—ã—á–∏—Å–ª—è–µ–º –ø–æ–∑–∏—Ü–∏—é –≤ —Ç–µ–∫—É—â–µ–π —ç–ø–æ—Ö–µ –¥–ª—è adjusted_step
            completed_epochs = start_epoch
            steps_in_completed_epochs = completed_epochs * steps_per_epoch_estimate
            batch_idx = adjusted_step - steps_in_completed_epochs
            
            # –ü—Ä–æ–≤–µ—Ä—è–µ–º –≥—Ä–∞–Ω–∏—Ü—ã
            if batch_idx < 0:
                batch_idx = 0
            elif batch_idx >= steps_per_epoch_estimate:
                batch_idx = steps_per_epoch_estimate - 1
        
        wandb.log({
            "checkpoint/batch_size_mismatch": True,
            "checkpoint/prev_batch_size": prev_batch_size,
            "checkpoint/new_batch_size": batch_size,
            "checkpoint/samples_seen": total_samples_seen,
            "checkpoint/adjusted_step": adjusted_step,
            "checkpoint/adjusted_batch_idx": batch_idx
        })
        
        global_step = adjusted_step
    else:
        global_step = saved_step
    
    wandb.log({
        "checkpoint/loaded": True,
        "checkpoint/version": checkpoint_version,
        "checkpoint/start_epoch": start_epoch,
        "checkpoint/global_step": global_step,
        "checkpoint/best_val_loss": best_val_loss,
        "checkpoint/batch_idx": batch_idx
    })
    
    if checkpoint_version == "legacy":
        print(f"üì¶ Legacy —á–µ–∫–ø–æ–∏–Ω—Ç: –∞–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–∏ –≤—ã—á–∏—Å–ª–µ–Ω batch_idx={batch_idx} –¥–ª—è —ç–ø–æ—Ö–∏ {start_epoch}")
    
    return start_epoch, global_step, batch_idx

In [None]:
best_val_loss = float('inf')

In [None]:
class AudioProjector(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=2048):
        super().__init__()

        self.proj = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim),
            nn.LayerNorm(output_dim)
        )
        
        for layer in self.proj:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    def forward(self, x):
        # Always return in bfloat16 for consistency with model
        return self.proj(x.float()).to(torch.bfloat16)
    
    def get_l2_norm(self):
        total_norm = 0.0
        for param in self.parameters():
            total_norm += param.data.norm(2).item() ** 2
        return total_norm ** 0.5

In [None]:
class DatasetBlender:
    """
    üîÑ –£–ø—Ä–∞–≤–ª—è–µ—Ç –ø–ª–∞–≤–Ω—ã–º –ø–µ—Ä–µ—Ö–æ–¥–æ–º –º–µ–∂–¥—É –¥–≤—É–º—è –¥–∞—Ç–∞—Å–µ—Ç–∞–º–∏.
    
    –ü–æ–¥–¥–µ—Ä–∂–∏–≤–∞–µ—Ç —Ä–∞–∑–ª–∏—á–Ω—ã–µ —Å—Ç—Ä–∞—Ç–µ–≥–∏–∏ —Å–º–µ—à–∏–≤–∞–Ω–∏—è:
    - linear: –õ–∏–Ω–µ–π–Ω—ã–π –ø–µ—Ä–µ—Ö–æ–¥ –æ—Ç 0% –¥–æ 100% –≤—Ç–æ—Ä–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
    - cosine: –ö–æ—Å–∏–Ω—É—Å–Ω—ã–π –ø–µ—Ä–µ—Ö–æ–¥ (–±–æ–ª–µ–µ –ø–ª–∞–≤–Ω—ã–π)
    - exponential: –≠–∫—Å–ø–æ–Ω–µ–Ω—Ü–∏–∞–ª—å–Ω—ã–π –ø–µ—Ä–µ—Ö–æ–¥ (–±—ã—Å—Ç—Ä–µ–µ –≤ –∫–æ–Ω—Ü–µ)
    """
    
    def __init__(self, primary_data, secondary_data, transition_start_epoch, transition_end_epoch, blend_schedule="linear"):
        self.primary_data = primary_data
        self.secondary_data = secondary_data
        self.transition_start_epoch = transition_start_epoch
        self.transition_end_epoch = transition_end_epoch
        self.blend_schedule = blend_schedule
        
        print(f"üîÑ DatasetBlender –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω:")
        print(f"   üìä –û—Å–Ω–æ–≤–Ω–æ–π –¥–∞—Ç–∞—Å–µ—Ç: {len(primary_data)} –ø—Ä–∏–º–µ—Ä–æ–≤")
        print(f"   üìä –í—Ç–æ—Ä–æ–π –¥–∞—Ç–∞—Å–µ—Ç: {len(secondary_data)} –ø—Ä–∏–º–µ—Ä–æ–≤")
        print(f"   üïê –ü–µ—Ä–µ—Ö–æ–¥: —ç–ø–æ—Ö–∏ {transition_start_epoch}-{transition_end_epoch}")
        print(f"   üìà –°—Ç—Ä–∞—Ç–µ–≥–∏—è: {blend_schedule}")
    
    def get_blend_ratio(self, current_epoch):
        """–í–æ–∑–≤—Ä–∞—â–∞–µ—Ç –¥–æ–ª—é –≤—Ç–æ—Ä–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞ –¥–ª—è —Ç–µ–∫—É—â–µ–π —ç–ø–æ—Ö–∏ (0.0 - 1.0)"""
        if current_epoch < self.transition_start_epoch:
            return 0.0
        elif current_epoch >= self.transition_end_epoch:
            return 1.0
        
        # –ù–æ—Ä–º–∞–ª–∏–∑–æ–≤–∞–Ω–Ω—ã–π –ø—Ä–æ–≥—Ä–µ—Å—Å (0.0 - 1.0)
        progress = (current_epoch - self.transition_start_epoch) / (self.transition_end_epoch - self.transition_start_epoch)
        
        if self.blend_schedule == "linear":
            return progress
        elif self.blend_schedule == "cosine":
            return (1 - np.cos(progress * np.pi)) / 2  # –ü–ª–∞–≤–Ω—ã–π S-–æ–±—Ä–∞–∑–Ω—ã–π –ø–µ—Ä–µ—Ö–æ–¥
        elif self.blend_schedule == "exponential":
            return progress ** 2  # –ú–µ–¥–ª–µ–Ω–Ω—ã–π —Å—Ç–∞—Ä—Ç, –±—ã—Å—Ç—Ä—ã–π —Ñ–∏–Ω–∏—à
        else:
            raise ValueError(f"–ù–µ–∏–∑–≤–µ—Å—Ç–Ω–∞—è —Å—Ç—Ä–∞—Ç–µ–≥–∏—è —Å–º–µ—à–∏–≤–∞–Ω–∏—è: {self.blend_schedule}")
    
    def create_blended_dataset(self, current_epoch, random_seed=42):
        """–°–æ–∑–¥–∞–µ—Ç —Å–º–µ—à–∞–Ω–Ω—ã–π –¥–∞—Ç–∞—Å–µ—Ç –¥–ª—è —Ç–µ–∫—É—â–µ–π —ç–ø–æ—Ö–∏"""
        blend_ratio = self.get_blend_ratio(current_epoch)
        
        # –°–∫–æ–ª—å–∫–æ –ø—Ä–∏–º–µ—Ä–æ–≤ –≤–∑—è—Ç—å –∏–∑ –∫–∞–∂–¥–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
        total_size = len(self.primary_data)  # –°–æ—Ö—Ä–∞–Ω—è–µ–º —Ä–∞–∑–º–µ—Ä –æ—Å–Ω–æ–≤–Ω–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
        secondary_count = int(total_size * blend_ratio)
        primary_count = total_size - secondary_count
        
        # –î–µ—Ç–µ—Ä–º–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω–∞—è –≤—ã–±–æ—Ä–∫–∞
        random_state = random.Random(random_seed)
        
        # –í—ã–±–∏—Ä–∞–µ–º –ø—Ä–∏–º–µ—Ä—ã –∏–∑ –∫–∞–∂–¥–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
        selected_primary = random_state.sample(self.primary_data, min(primary_count, len(self.primary_data))) if primary_count > 0 else []
        selected_secondary = random_state.sample(self.secondary_data, min(secondary_count, len(self.secondary_data))) if secondary_count > 0 else []
        
        # –û–±—ä–µ–¥–∏–Ω—è–µ–º –∏ –ø–µ—Ä–µ–º–µ—à–∏–≤–∞–µ–º
        blended_data = selected_primary + selected_secondary
        random_state.shuffle(blended_data)
        
        print(f"üîÑ –≠–ø–æ—Ö–∞ {current_epoch+1}: –°–º–µ—à–∏–≤–∞–Ω–∏–µ {primary_count} –æ—Å–Ω–æ–≤–Ω—ã—Ö + {secondary_count} –≤—Ç–æ—Ä–∏—á–Ω—ã—Ö ({blend_ratio*100:.1f}% –≤—Ç–æ—Ä–æ–≥–æ)")
        
        return blended_data, blend_ratio


In [None]:
class TrainingLogger:
    def __init__(self, experiment_name, save_dir):
        self.experiment_name = experiment_name
        self.save_dir = save_dir
        self.logs = {
            'step': [],
            'val_loss': [],
            'val_perplexity': [],
            'val_wer': [],
            'val_bleu': [],
            'val_rouge_l': []
        }
        
    def log_step(self, step, train_loss, lr_list, grad_norm=None, projector_l2_norm=None, gpu_memory_gb=None, gpu_memory_reserved_gb=None, gpu_memory_total_gb=None, memory_breakdown_mb=None):
        log_data = {
            'train/loss': float(train_loss),
            'train/projector_lr': float(lr_list[0]) if len(lr_list) > 0 else 0.0,
            'train/lora_lr': float(lr_list[1]) if len(lr_list) > 1 else 0.0,
            'train/learning_rate': float(lr_list[0]),  # –î–ª—è —Å–æ–≤–º–µ—Å—Ç–∏–º–æ—Å—Ç–∏
            'train/grad_norm': float(grad_norm) if grad_norm is not None else 0.0,
            'step': int(step)
        }
        
        if projector_l2_norm is not None:
            log_data['projector/l2_norm'] = float(projector_l2_norm)
        
        if gpu_memory_gb is not None:
            log_data['gpu/memory_used_gb'] = float(gpu_memory_gb) if gpu_memory_gb is not None else 0.0
            log_data['gpu/memory_reserved_gb'] = float(gpu_memory_reserved_gb) if gpu_memory_reserved_gb is not None else 0.0
            log_data['gpu/memory_total_gb'] = float(gpu_memory_total_gb) if gpu_memory_total_gb is not None else 0.0
            # Fix: ensure numeric value for memory utilization
            if gpu_memory_total_gb and gpu_memory_total_gb > 0:
                log_data['gpu/memory_utilization_pct'] = float(gpu_memory_gb / gpu_memory_total_gb * 100)
            else:
                log_data['gpu/memory_utilization_pct'] = 0.0
            
        if memory_breakdown_mb:
            for k, v in memory_breakdown_mb.items():
                if k in ['grad_norm_before_clip', 'grad_norm_after_clip', 'clipping_ratio']:
                    log_data[f'gradient/{k}'] = float(v)
                elif k == 'was_clipped':
                    log_data[f'gradient/{k}'] = bool(v)
                else:
                    log_data[f"memory/{k}_mb"] = float(v)

        wandb.log(log_data)
        
    def log_validation(self, step, val_metrics):
        self.logs['step'].append(int(step))
        self.logs['val_loss'].append(float(val_metrics['loss']))
        self.logs['val_perplexity'].append(float(val_metrics['perplexity']))
        self.logs['val_wer'].append(float(val_metrics['wer']))
        self.logs['val_bleu'].append(float(val_metrics['bleu']))
        self.logs['val_rouge_l'].append(float(val_metrics['rouge_l']))
        
        wandb.log({
            'val/loss': float(val_metrics['loss']),
            'val/perplexity': float(val_metrics['perplexity']),
            'val/wer': float(val_metrics['wer']),
            'val/bleu': float(val_metrics['bleu']),
            'val/rouge_l': float(val_metrics['rouge_l']),
            'step': int(step)
        })
    
    def save_logs(self):
        if len(self.logs['step']) == 0:
            return None
        df = pd.DataFrame(self.logs)
        csv_path = os.path.join(self.save_dir, 'validation_logs.csv')
        df.to_csv(csv_path, index=False)
        return df

In [None]:
class AudioTextDataset(Dataset):
    def __init__(self, data, tokenizer, feature_extractor, zip_path=None):
        self.data = data
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.zip_file = None
        self.zip_manifest = None
        
        if zip_path and os.path.exists(zip_path):
            try:
                self.zip_file = zipfile.ZipFile(zip_path, 'r')
                
                self.zip_manifest = {
                    p: p
                    for p in self.zip_file.namelist()
                    if p.lower().endswith(('.flac', '.wav', '.mp3'))
                }
                
                wandb.log({
                    "dataset/zip_loaded": 1.0,  # Convert bool to numeric
                    "dataset/zip_audio_files": int(len(self.zip_manifest)),
                    "dataset/total_records": int(len(self.data))
                })

            except Exception as e:
                self.zip_file = None
                wandb.log({"dataset/zip_error": str(e)})
        else:
            wandb.log({"dataset/zip_loaded": 0.0, "dataset/total_records": int(len(self.data))})
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        audio_path = item["audio_path"]
        speaker_text = item["speaker_text"]
        
        try:
            if self.zip_file and self.zip_manifest is not None:
                found_path = self.zip_manifest.get(audio_path)
                if found_path:
                    with self.zip_file.open(found_path) as audio_file:
                        audio_data = audio_file.read()
                        waveform, sr = torchaudio.load(io.BytesIO(audio_data))
                else:
                    raise FileNotFoundError(f"–§–∞–π–ª '{audio_path}' –Ω–µ –Ω–∞–π–¥–µ–Ω –≤ –º–∞–Ω–∏—Ñ–µ—Å—Ç–µ ZIP.")
            else:
                waveform, sr = torchaudio.load(audio_path)
                
        except Exception as e:
            print(f"‚ö†Ô∏è –û—à–∏–±–∫–∞ –∑–∞–≥—Ä—É–∑–∫–∏ {audio_path}: {e}")
            waveform = torch.zeros(1, 16000)
            sr = 16000
        
        if sr != self.feature_extractor.sampling_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.feature_extractor.sampling_rate)
        
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        waveform_np = waveform.squeeze().numpy()
        waveform_mean = np.mean(waveform_np)
        waveform_std = np.std(waveform_np)
        
        if waveform_std > 1e-8:
            waveform_np = (waveform_np - waveform_mean) / waveform_std
        
        inputs = self.feature_extractor(
            waveform_np,
            sampling_rate=self.feature_extractor.sampling_rate,
            return_tensors="pt"
        )
        
        tokens = self.tokenizer(
            speaker_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        return {
            "input_values": inputs.input_values.squeeze(0),
            "input_ids": tokens.input_ids.squeeze(0),
            "attention_mask": tokens.attention_mask.squeeze(0)
        }
    
    def __del__(self):
        if hasattr(self, 'zip_file') and self.zip_file:
            self.zip_file.close()

In [None]:
def compress_audio_features(audio_features, compression_rate_k):
    batch_size, seq_len, hidden_dim = audio_features.shape
    
    new_seq_len = (seq_len // compression_rate_k) * compression_rate_k
    audio_features = audio_features[:, :new_seq_len, :]
    
    reshaped = audio_features.view(batch_size, new_seq_len // compression_rate_k, compression_rate_k, hidden_dim)
    compressed = reshaped.view(batch_size, new_seq_len // compression_rate_k, compression_rate_k * hidden_dim)
    
    return compressed

In [None]:
def evaluate_with_metrics(model, projector, wav2vec2, dataloader, tokenizer, prefix_embeds, device, max_new_tokens, compression_rate_k, beam_width, temperature, top_k, top_p, repetition_penalty):
    model.eval()
    projector.eval()
    wav2vec2.eval()
    total_loss = 0.0
    total_wer = 0.0
    total_bleu = 0.0
    total_rouge_l = 0.0
    count = 0
    examples_shown = 0
    smooth = SmoothingFunction().method1
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            input_ids = batch["input_ids"].to(device)

            outputs, prompt_embeds = process_batch(
                batch, model, projector, wav2vec2, tokenizer, prefix_embeds, device, compression_rate_k
            )
            loss = outputs.loss
            total_loss += loss.item()

            # Keep bfloat16 for generation consistency
            prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16)

            generated_ids = model.generate(
                inputs_embeds=prompt_embeds,
                max_new_tokens=max_new_tokens,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                num_beams=beam_width,
                temperature=temperature,
                do_sample=True,
                top_k=top_k,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                early_stopping=True
            )

            for j in range(generated_ids.size(0)):
                pred_text = tokenizer.decode(generated_ids[j], skip_special_tokens=True).strip()
                ref_text_ids = input_ids[j]
                ref_text_ids = ref_text_ids[ref_text_ids != -100]
                ref_text = tokenizer.decode(ref_text_ids, skip_special_tokens=True).strip()
                
                if ref_text and pred_text:
                    current_wer = jiwer.wer(ref_text, pred_text)
                    current_bleu = sentence_bleu([ref_text.split()], pred_text.split(), smoothing_function=smooth)
                    rouge_scores = rouge_scorer_obj.score(ref_text, pred_text)
                    
                    total_wer += current_wer
                    total_bleu += current_bleu
                    total_rouge_l += rouge_scores['rougeL'].fmeasure
                    count += 1
                    
                    if examples_shown < 3:
                        print(f"\n–ü—Ä–∏–º–µ—Ä {examples_shown + 1}:")
                        print(f"–≠—Ç–∞–ª–æ–Ω: '{ref_text}'")
                        print(f"–ì–µ–Ω–µ—Ä–∞—Ü–∏—è: '{pred_text}'")
                        print(f"WER: {current_wer:.3f}, BLEU: {current_bleu:.3f}")
                        examples_shown += 1
                    
    avg_loss = total_loss / len(dataloader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    avg_wer = total_wer / count if count > 0 else 0.0
    avg_bleu = total_bleu / count if count > 0 else 0.0
    avg_rouge_l = total_rouge_l / count if count > 0 else 0.0
    
    print(f"\n–í–∞–ª–∏–¥–∞—Ü–∏—è ({count} –ø—Ä–∏–º–µ—Ä–æ–≤):")
    print(f"Loss: {avg_loss:.4f}, WER: {avg_wer:.4f}, BLEU: {avg_bleu:.4f}")
    
    return {
        'loss': avg_loss, 'perplexity': perplexity,
        'wer': avg_wer, 'bleu': avg_bleu, 'rouge_l': avg_rouge_l
    }

In [None]:
device = torch.device("cuda")
model_id = "google/gemma-3-4b-pt"
audio_model_name = "facebook/wav2vec2-xls-r-300m"

batch_size = 4
num_epochs = 10
projector_learning_rate = 2e-3
lora_learning_rate = 2e-4
weight_decay = 0.1  # üîß –£–≤–µ–ª–∏—á–µ–Ω–æ –¥–ª—è —Å—Ç–∞–±–∏–ª–∏–∑–∞—Ü–∏–∏ –æ–±—É—á–µ–Ω–∏—è
max_grad_norm = 10.0
gradient_accumulation_steps = 4
save_every_steps = 2000
save_latest_every_steps = 50
max_new_tokens = 70
compression_rate_k = 2
beam_width = 15
temperature = 0.6
top_k = 50
top_p = 0.9
repetition_penalty = 1.2
val_subset_size = 15
use_8bit_optimizer = True

# üîÑ –ù–æ–≤—ã–µ –ø–∞—Ä–∞–º–µ—Ç—Ä—ã –¥–ª—è –ø–ª–∞–≤–Ω–æ–≥–æ –ø–µ—Ä–µ—Ö–æ–¥–∞ –º–µ–∂–¥—É –¥–∞—Ç–∞—Å–µ—Ç–∞–º–∏
enable_dataset_blending = True  # –í–∫–ª—é—á–∏—Ç—å —Å–º–µ—à–∏–≤–∞–Ω–∏–µ –¥–∞—Ç–∞—Å–µ—Ç–æ–≤
transition_start_epoch = 8      # –≠–ø–æ—Ö–∞ –Ω–∞—á–∞–ª–∞ –ø–µ—Ä–µ—Ö–æ–¥–∞ (—Å 8-–π —ç–ø–æ—Ö–∏)
transition_end_epoch = 10       # –≠–ø–æ—Ö–∞ –∑–∞–≤–µ—Ä—à–µ–Ω–∏—è –ø–µ—Ä–µ—Ö–æ–¥–∞ (10-—è —ç–ø–æ—Ö–∞ = 100% –≤—Ç–æ—Ä–æ–π –¥–∞—Ç–∞—Å–µ—Ç)
blend_schedule = "linear"       # –¢–∏–ø –ø–µ—Ä–µ—Ö–æ–¥–∞: "linear", "cosine", "exponential"

# –ü—É—Ç–∏ –∫ –¥–∞—Ç–∞—Å–µ—Ç–∞–º
primary_jsonl_path = "transcripts.jsonl"     # –û—Å–Ω–æ–≤–Ω–æ–π –¥–∞—Ç–∞—Å–µ—Ç
primary_zip_path = "LibriSpeech.zip"
secondary_jsonl_path = "transcripts_v2.jsonl"  # –í—Ç–æ—Ä–æ–π –¥–∞—Ç–∞—Å–µ—Ç (–º–æ–∂–Ω–æ –∑–∞–º–µ–Ω–∏—Ç—å)
secondary_zip_path = "LibriSpeech_v2.zip"      # –í—Ç–æ—Ä–æ–π ZIP (–º–æ–∂–Ω–æ –∑–∞–º–µ–Ω–∏—Ç—å)

input_dim = 1024
output_dim = 2560

experiment_name = f"audio_projector_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
checkpoint_dir = "/home/jovyan/persistent_volume/"
resume_training = True
checkpoint_path = "/home/jovyan/persistent_volume/latest_checkpoint_bs4_epoch_1_step_2000.pt"
skip_validation = False
interactive_mode = True

wandb.init(
    project="audio-projector",
    name=experiment_name,
    config={
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "projector_learning_rate": projector_learning_rate,
        "lora_learning_rate": lora_learning_rate,
        "weight_decay": weight_decay,
        "max_grad_norm": max_grad_norm,
        "gradient_accumulation_steps": gradient_accumulation_steps,
        "max_new_tokens": max_new_tokens,
        "compression_rate_k": compression_rate_k,
        "beam_width": beam_width,
        "temperature": temperature,
        "top_k": top_k,
        "top_p": top_p,
        "repetition_penalty": repetition_penalty,
        "val_subset_size": val_subset_size,
        "input_dim": input_dim,
        "output_dim": output_dim,
        "model_id": model_id,
        "audio_model_name": audio_model_name,
        "resume_training": resume_training,
        "z_normalization": True,
        "projector_hidden_dim": 2048,
        "activation": "GELU",
        "scheduler_type": "CosineAnnealingWarmRestarts",
        "use_8bit_optimizer": use_8bit_optimizer,
        "optimizer_type": "AdamW8bit" if use_8bit_optimizer else "AdamW",
        "mse_loss_removed": "MSE loss disabled due to alignment issues",
        "enable_dataset_blending": enable_dataset_blending,
        "transition_start_epoch": transition_start_epoch,
        "transition_end_epoch": transition_end_epoch,
        "blend_schedule": blend_schedule,
        "primary_jsonl_path": primary_jsonl_path,
        "secondary_jsonl_path": secondary_jsonl_path,
        "lora_config": {
            "r": 64,
            "lora_alpha": 128,
            "target_modules": ["k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"],
            "lora_dropout": 0.05
        }
    }
)

best_val_loss = float('inf')
best_checkpoint_path = None
latest_checkpoint_path = None

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

hf_token = os.getenv('HF_TOKEN')
if hf_token:
    login(token=hf_token)


In [None]:
notebook_login()

tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

multi_cfg = AutoConfig.from_pretrained(model_id, token=hf_token)

text_cfg_dict = multi_cfg.text_config.to_dict()
text_cfg_dict["vocab_size"] = 262208
text_cfg_dict.update({"bos_token_id": tokenizer.bos_token_id, "eos_token_id": tokenizer.eos_token_id, "pad_token_id": tokenizer.pad_token_id})

text_cfg = Gemma3TextConfig(**text_cfg_dict)
gemma_model = Gemma3ForCausalLM.from_pretrained(
    model_id,
    config=text_cfg,
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    device_map="cuda",
    token=hf_token
)

gemma_model.gradient_checkpointing_enable()
gemma_model = prepare_model_for_kbit_training(gemma_model)

lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"],
    lora_dropout=0.2,
    bias="none",
    init_lora_weights=False,
    task_type="CAUSAL_LM"
)

gemma_model = get_peft_model(gemma_model, lora_config)

# Cast lm_head to bfloat16 for consistency with the rest of the model
if hasattr(gemma_model, 'lm_head'):
    gemma_model.lm_head = gemma_model.lm_head.to(torch.bfloat16)
elif hasattr(gemma_model, 'base_model') and hasattr(gemma_model.base_model, 'lm_head'):
    gemma_model.base_model.lm_head = gemma_model.base_model.lm_head.to(torch.bfloat16)

gemma_model.eval()

In [None]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(audio_model_name)
wav2vec2 = Wav2Vec2Model.from_pretrained(audio_model_name)
wav2vec2 = wav2vec2.to(torch.bfloat16).to(device)
wav2vec2.eval()
for param in wav2vec2.parameters():
    param.requires_grad = False

In [None]:
projector_input_dim = input_dim * compression_rate_k
projector_hidden_dim = 2048
projector = AudioProjector(projector_input_dim, output_dim, hidden_dim=projector_hidden_dim).to(device).float()

params_to_optimize = [
    {"params": projector.parameters(), "lr": projector_learning_rate},
    {"params": gemma_model.parameters(), "lr": lora_learning_rate}
]

# –ò—Å–ø–æ–ª—å–∑—É–µ–º 8-–±–∏—Ç–Ω—ã–π –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä –¥–ª—è —ç–∫–æ–Ω–æ–º–∏–∏ –ø–∞–º—è—Ç–∏ (~75% —Å–Ω–∏–∂–µ–Ω–∏–µ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—è –ø–∞–º—è—Ç–∏)
if use_8bit_optimizer:
    optimizer = bnb.optim.AdamW8bit(
        params_to_optimize,
        weight_decay=weight_decay,
        betas=(0.9, 0.999),
        eps=1e-8
    )
    print("‚úÖ –ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è 8-–±–∏—Ç–Ω—ã–π AdamW –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä (—ç–∫–æ–Ω–æ–º–∏—è –ø–∞–º—è—Ç–∏ ~75%)")
else:
    optimizer = optim.AdamW(
        params_to_optimize,
        weight_decay=weight_decay,
        betas=(0.9, 0.999),
        eps=1e-8
    )
    print("‚ö†Ô∏è –ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è –æ–±—ã—á–Ω—ã–π AdamW –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä")

scaler = GradScaler()
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

prefix = "Transcribe speech to text."
prefix_ids = tokenizer(prefix, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
    prefix_embeds = gemma_model.get_input_embeddings()(prefix_ids).to(dtype=torch.bfloat16)

In [None]:
# üîÑ –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö —Å –ø–æ–¥–¥–µ—Ä–∂–∫–æ–π –ø–ª–∞–≤–Ω–æ–≥–æ –ø–µ—Ä–µ—Ö–æ–¥–∞ –º–µ–∂–¥—É –¥–∞—Ç–∞—Å–µ—Ç–∞–º–∏
def load_dataset_data(jsonl_path, zip_path=None):
    """–ó–∞–≥—Ä—É–∂–∞–µ—Ç –∏ –Ω–æ—Ä–º–∞–ª–∏–∑—É–µ—Ç –¥–∞–Ω–Ω—ã–µ –∏–∑ JSONL —Ñ–∞–π–ª–∞"""
    try:
        with open(jsonl_path, "r", encoding="utf-8") as f:
            raw_data = [json.loads(line) for line in f]
        
        # –ù–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏—è –¥–∞–Ω–Ω—ã—Ö
        normalized_data = []
        for item in raw_data:
            normalized_item = {
                "audio_path": item.get("audio_filepath", ""),
                "speaker_text": item.get("text", ""),
                "language": item.get("language", "en"),
                "source": item.get("source", "unknown")
            }
            normalized_data.append(normalized_item)
        
        print(f"üìä –ó–∞–≥—Ä—É–∂–µ–Ω–æ {len(normalized_data)} –ø—Ä–∏–º–µ—Ä–æ–≤ –∏–∑ {jsonl_path}")
        return normalized_data
    
    except FileNotFoundError:
        print(f"‚ö†Ô∏è –§–∞–π–ª {jsonl_path} –Ω–µ –Ω–∞–π–¥–µ–Ω, –≤–æ–∑–≤—Ä–∞—â–∞–µ–º –ø—É—Å—Ç–æ–π —Å–ø–∏—Å–æ–∫")
        return []

# –ó–∞–≥—Ä—É–∂–∞–µ–º –æ—Å–Ω–æ–≤–Ω–æ–π –¥–∞—Ç–∞—Å–µ—Ç
primary_data = load_dataset_data(primary_jsonl_path, primary_zip_path)

# –ó–∞–≥—Ä—É–∂–∞–µ–º –≤—Ç–æ—Ä–æ–π –¥–∞—Ç–∞—Å–µ—Ç (–µ—Å–ª–∏ –≤–∫–ª—é—á–µ–Ω–æ —Å–º–µ—à–∏–≤–∞–Ω–∏–µ)
secondary_data = []
if enable_dataset_blending:
    secondary_data = load_dataset_data(secondary_jsonl_path, secondary_zip_path)
    if len(secondary_data) == 0:
        print(f"‚ö†Ô∏è –í—Ç–æ—Ä–æ–π –¥–∞—Ç–∞—Å–µ—Ç –ø—É—Å—Ç, –æ—Ç–∫–ª—é—á–∞–µ–º —Å–º–µ—à–∏–≤–∞–Ω–∏–µ")
        enable_dataset_blending = False

# –û–±—ä–µ–¥–∏–Ω—è–µ–º –¥–ª—è train/val split
if enable_dataset_blending:
    all_data = primary_data + secondary_data  # –î–ª—è —Å–æ–∑–¥–∞–Ω–∏—è –µ–¥–∏–Ω–æ–≥–æ val_data
    print(f"üìä –û–±—ä–µ–¥–∏–Ω–µ–Ω–æ {len(primary_data)} + {len(secondary_data)} = {len(all_data)} –ø—Ä–∏–º–µ—Ä–æ–≤")
else:
    all_data = primary_data
    print(f"üìä –ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —Ç–æ–ª—å–∫–æ –æ—Å–Ω–æ–≤–Ω–æ–π –¥–∞—Ç–∞—Å–µ—Ç: {len(all_data)} –ø—Ä–∏–º–µ—Ä–æ–≤")

# –°–æ–∑–¥–∞–µ–º –µ–¥–∏–Ω—ã–π val_data –∏–∑ –≤—Å–µ—Ö –¥–æ—Å—Ç—É–ø–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö
total_records = len(all_data)
_, val_data = train_test_split(all_data, test_size=0.1, random_state=42)

# üîÑ –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä—É–µ–º DatasetBlender –µ—Å–ª–∏ –≤–∫–ª—é—á–µ–Ω–æ —Å–º–µ—à–∏–≤–∞–Ω–∏–µ
dataset_blender = None
if enable_dataset_blending:
    # –†–∞–∑–¥–µ–ª—è–µ–º primary_data –Ω–∞ train/val —Å —Ç–µ–º –∂–µ random_state
    primary_train_data, _ = train_test_split(primary_data, test_size=0.1, random_state=42)
    secondary_train_data, _ = train_test_split(secondary_data, test_size=0.1, random_state=42)
    
    dataset_blender = DatasetBlender(
        primary_data=primary_train_data,
        secondary_data=secondary_train_data,
        transition_start_epoch=transition_start_epoch,
        transition_end_epoch=transition_end_epoch,
        blend_schedule=blend_schedule
    )
    train_data = primary_train_data  # –ù–∞—á–∏–Ω–∞–µ–º —Å –æ—Å–Ω–æ–≤–Ω–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
else:
    # –û–±—ã—á–Ω—ã–π —Ä–µ–∂–∏–º –±–µ–∑ —Å–º–µ—à–∏–≤–∞–Ω–∏—è
    train_data, _ = train_test_split(all_data, test_size=0.1, random_state=42)

val_subset_data = random.sample(val_data, min(val_subset_size, len(val_data)))

print(f"üìä Data: {len(train_data)} train, {len(val_subset_data)} val")

In [None]:
# –ò–°–ü–†–ê–í–õ–ï–ù–û: –ø—Ä–æ–ø—É—Å–∫–∞–µ–º –¥–∞–Ω–Ω—ã–µ –Ω–∞ —É—Ä–æ–≤–Ω–µ JSON, –∞ –Ω–µ –±–∞—Ç—á–µ–π!
# –≠—Ç–æ –±—É–¥–µ—Ç —É—Å—Ç–∞–Ω–æ–≤–ª–µ–Ω–æ –ø–æ—Å–ª–µ –∑–∞–≥—Ä—É–∑–∫–∏ —á–µ–∫–ø–æ–∏–Ω—Ç–∞
skip_samples_from_checkpoint = 0

train_dataset = AudioTextDataset(train_data, tokenizer, feature_extractor, zip_path=primary_zip_path)
val_dataset = AudioTextDataset(val_subset_data, tokenizer, feature_extractor, zip_path=primary_zip_path)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,  # –ë—É–¥–µ—Ç –∏–∑–º–µ–Ω–µ–Ω –Ω–∞ False –ø—Ä–∏ –≤–æ–∑–æ–±–Ω–æ–≤–ª–µ–Ω–∏–∏
    collate_fn=collate_fn,
    num_workers=0,
    pin_memory=False
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0,
    pin_memory=False
)

print(f"üîß DataLoaders: {len(train_loader)} train, {len(val_loader)} val batches")


In [None]:
wav2vec2 = wav2vec2.to(device)
projector = projector.to(device)
gemma_model = gemma_model.to(device)

total_steps = num_epochs * len(train_loader) // gradient_accumulation_steps

def calculate_restart_period(base_examples, batch_size, grad_accum_steps):
    """
    –†–∞—Å—Å—á–∏—Ç—ã–≤–∞–µ—Ç –ø–µ—Ä–∏–æ–¥ —Ä–µ—Å—Ç–∞—Ä—Ç–∞ –≤ —à–∞–≥–∞—Ö –æ–ø—Ç–∏–º–∏–∑–∞—Ü–∏–∏.
    
    Args:
        base_examples: –ë–∞–∑–æ–≤–æ–µ –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ –ø—Ä–∏–º–µ—Ä–æ–≤ –¥–æ —Ä–µ—Å—Ç–∞—Ä—Ç–∞
        batch_size: –¢–µ–∫—É—â–∏–π —Ñ–∏–∑–∏—á–µ—Å–∫–∏–π —Ä–∞–∑–º–µ—Ä –±–∞—Ç—á–∞
        grad_accum_steps: –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ —à–∞–≥–æ–≤ –≥—Ä–∞–¥–∏–µ–Ω—Ç–Ω–æ–≥–æ –Ω–∞–∫–æ–ø–ª–µ–Ω–∏—è
    
    Returns:
        int: –ö–æ–ª–∏—á–µ—Å—Ç–≤–æ —à–∞–≥–æ–≤ –æ–ø—Ç–∏–º–∏–∑–∞—Ü–∏–∏ –¥–æ —Ä–µ—Å—Ç–∞—Ä—Ç–∞
    """
    actual_batch_size = batch_size * grad_accum_steps
    restart_steps = max(1, base_examples // actual_batch_size)
    return restart_steps

base_examples = 30000
adaptive_restart_period = calculate_restart_period(
    base_examples=base_examples,
    batch_size=batch_size,
    grad_accum_steps=gradient_accumulation_steps
)

scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=adaptive_restart_period,
    T_mult=1,
    eta_min=1e-6
)

print(f"üîß Training: {total_steps} steps, LR({projector_learning_rate}/{lora_learning_rate}), GradAcc({gradient_accumulation_steps})")
print(f"üîÑ Scheduler: CosineAnnealingWarmRestarts —Å –ø–µ—Ä–∏–æ–¥–æ–º {adaptive_restart_period} —à–∞–≥–æ–≤")

os.makedirs(checkpoint_dir, exist_ok=True)
logger = TrainingLogger(experiment_name, checkpoint_dir)

In [None]:
def get_gpu_memory_stats(device):
    """–ü–æ–ª—É—á–∏—Ç—å —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫—É –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—è GPU –ø–∞–º—è—Ç–∏"""
    if not torch.cuda.is_available():
        return None, None, None
    
    try:
        allocated = torch.cuda.memory_allocated(device) / 1024**3  # GB
        reserved = torch.cuda.memory_reserved(device) / 1024**3   # GB
        
        # –ü–æ–ø—ã—Ç–∞–µ–º—Å—è –ø–æ–ª—É—á–∏—Ç—å –æ–±—â–∏–π –æ–±—ä–µ–º GPU –ø–∞–º—è—Ç–∏
        gpu_properties = torch.cuda.get_device_properties(device)
        total_memory = gpu_properties.total_memory / 1024**3  # GB
        
        return allocated, reserved, total_memory
    except Exception as e:
        print(f"‚ö†Ô∏è –û—à–∏–±–∫–∞ –ø–æ–ª—É—á–µ–Ω–∏—è GPU —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏: {e}")
        return None, None, None

def light_gpu_cleanup(device):
    """
    –õ–µ–≥–∫–∞—è –æ—á–∏—Å—Ç–∫–∞ GPU –ø–∞–º—è—Ç–∏, –∫–æ—Ç–æ—Ä–∞—è –Ω–µ —É–¥–∞–ª—è–µ—Ç –ø–µ—Ä–µ–º–µ–Ω–Ω—ã–µ,
    –∞ —Ç–æ–ª—å–∫–æ –æ—á–∏—â–∞–µ—Ç –∫—ç—à –∏ —Å–æ–±–∏—Ä–∞–µ—Ç –º—É—Å–æ—Ä.
    """
    import gc
    import torch
    if not torch.cuda.is_available():
        print("üßπ GPU –Ω–µ–¥–æ—Å—Ç—É–ø–µ–Ω, –≤—ã–ø–æ–ª–Ω—è–µ—Ç—Å—è —Ç–æ–ª—å–∫–æ —Å–±–æ—Ä–∫–∞ –º—É—Å–æ—Ä–∞.")
        gc.collect()
        return

    # –°–∏–Ω—Ö—Ä–æ–Ω–∏–∑–∞—Ü–∏—è –¥–ª—è –∑–∞–≤–µ—Ä—à–µ–Ω–∏—è –≤—Å–µ—Ö —Ç–µ–∫—É—â–∏—Ö –æ–ø–µ—Ä–∞—Ü–∏–π
    torch.cuda.synchronize(device)
    
    # –°–±–æ—Ä –º—É—Å–æ—Ä–∞ Python
    gc.collect()
    
    # –û—á–∏—Å—Ç–∫–∞ –∫—ç—à–∞ PyTorch
    torch.cuda.empty_cache()
    
    # –î–æ–ø–æ–ª–Ω–∏—Ç–µ–ª—å–Ω–∞—è –æ—á–∏—Å—Ç–∫–∞ –¥–ª—è –º–µ–∂–ø—Ä–æ—Ü–µ—Å—Å–Ω–æ–≥–æ –≤–∑–∞–∏–º–æ–¥–µ–π—Å—Ç–≤–∏—è
    torch.cuda.ipc_collect()
    
    # –°–±—Ä–æ—Å —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏ –¥–ª—è –±–æ–ª–µ–µ —Ç–æ—á–Ω–æ–≥–æ –º–æ–Ω–∏—Ç–æ—Ä–∏–Ω–≥–∞
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.reset_accumulated_memory_stats(device)
    
    print(f"üßπ –õ–µ–≥–∫–∞—è –æ—á–∏—Å—Ç–∫–∞ GPU –∑–∞–≤–µ—Ä—à–µ–Ω–∞.")

In [None]:
def get_gpu_memory_stats(device):
    """–ü–æ–ª—É—á–∏—Ç—å —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫—É –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—è GPU –ø–∞–º—è—Ç–∏"""
    if not torch.cuda.is_available():
        return None, None, None
    
    try:
        allocated = torch.cuda.memory_allocated(device) / 1024**3  # GB
        reserved = torch.cuda.memory_reserved(device) / 1024**3   # GB
        
        # –ü–æ–ø—ã—Ç–∞–µ–º—Å—è –ø–æ–ª—É—á–∏—Ç—å –æ–±—â–∏–π –æ–±—ä–µ–º GPU –ø–∞–º—è—Ç–∏
        gpu_properties = torch.cuda.get_device_properties(device)
        total_memory = gpu_properties.total_memory / 1024**3  # GB
        
        return allocated, reserved, total_memory
    except Exception as e:
        print(f"‚ö†Ô∏è –û—à–∏–±–∫–∞ –ø–æ–ª—É—á–µ–Ω–∏—è GPU —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏: {e}")
        return None, None, None

def get_model_memory_footprint(model: nn.Module, trainable_only: bool = False) -> float:
    """Calculates the memory footprint of a model's parameters in megabytes."""
    total_bytes = 0
    for param in model.parameters():
        if trainable_only and not param.requires_grad:
            continue
        total_bytes += param.nelement() * param.element_size()
    return total_bytes / 1024**2

def light_gpu_cleanup(device):
    """
    –õ–µ–≥–∫–∞—è –æ—á–∏—Å—Ç–∫–∞ GPU –ø–∞–º—è—Ç–∏, –∫–æ—Ç–æ—Ä–∞—è –Ω–µ —É–¥–∞–ª—è–µ—Ç –ø–µ—Ä–µ–º–µ–Ω–Ω—ã–µ,
    –∞ —Ç–æ–ª—å–∫–æ –æ—á–∏—â–∞–µ—Ç –∫—ç—à –∏ —Å–æ–±–∏—Ä–∞–µ—Ç –º—É—Å–æ—Ä.
    """
    import gc
    import torch
    if not torch.cuda.is_available():
        print("üßπ GPU –Ω–µ–¥–æ—Å—Ç—É–ø–µ–Ω, –≤—ã–ø–æ–ª–Ω—è–µ—Ç—Å—è —Ç–æ–ª—å–∫–æ —Å–±–æ—Ä–∫–∞ –º—É—Å–æ—Ä–∞.")
        gc.collect()
        return

    # –°–∏–Ω—Ö—Ä–æ–Ω–∏–∑–∞—Ü–∏—è –¥–ª—è –∑–∞–≤–µ—Ä—à–µ–Ω–∏—è –≤—Å–µ—Ö —Ç–µ–∫—É—â–∏—Ö –æ–ø–µ—Ä–∞—Ü–∏–π
    torch.cuda.synchronize(device)
    
    # –°–±–æ—Ä –º—É—Å–æ—Ä–∞ Python
    gc.collect()
    
    # –û—á–∏—Å—Ç–∫–∞ –∫—ç—à–∞ PyTorch
    torch.cuda.empty_cache()
    
    # –î–æ–ø–æ–ª–Ω–∏—Ç–µ–ª—å–Ω–∞—è –æ—á–∏—Å—Ç–∫–∞ –¥–ª—è –º–µ–∂–ø—Ä–æ—Ü–µ—Å—Å–Ω–æ–≥–æ –≤–∑–∞–∏–º–æ–¥–µ–π—Å—Ç–≤–∏—è
    torch.cuda.ipc_collect()
    
    # –°–±—Ä–æ—Å —Å—Ç–∞—Ç–∏—Å—Ç–∏–∫–∏ –¥–ª—è –±–æ–ª–µ–µ —Ç–æ—á–Ω–æ–≥–æ –º–æ–Ω–∏—Ç–æ—Ä–∏–Ω–≥–∞
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.reset_accumulated_memory_stats(device)
    
    print(f"üßπ –õ–µ–≥–∫–∞—è –æ—á–∏—Å—Ç–∫–∞ GPU –∑–∞–≤–µ—Ä—à–µ–Ω–∞.")


In [None]:
def find_latest_checkpoint(checkpoint_dir):
    pattern = os.path.join(checkpoint_dir, "latest_checkpoint_bs*_epoch*_step*.pt")    
    checkpoints = glob.glob(pattern) 
    return max(checkpoints, key=os.path.getctime) if checkpoints else None

def find_best_checkpoint(checkpoint_dir):
    pattern = os.path.join(checkpoint_dir, "best_checkpoint_bs*_step*.pt")
    checkpoints = glob.glob(pattern)
    return checkpoints[0] if checkpoints else None

def save_checkpoint(step, epoch, batch_idx=0, is_best=False):
    global best_checkpoint_path
    
    checkpoint_data = {
        'step': step,
        'epoch': epoch,
        'batch_idx': batch_idx,
        'projector_state_dict': projector.state_dict(),
        'lora_state_dict': gemma_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_loss': best_val_loss,
        'config': {
            'projector_learning_rate': projector_learning_rate,
            'lora_learning_rate': lora_learning_rate,
            'weight_decay': weight_decay,
            'max_grad_norm': max_grad_norm,
            'batch_size': batch_size,
            'compression_rate_k': compression_rate_k,
            'input_dim': input_dim,
            'output_dim': output_dim,
            'experiment_name': experiment_name,
            'lora_config': {
                'r': 64,
                'lora_alpha': 128,
                'target_modules': ["k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"],
                'lora_dropout': 0.05
            }
        }
    }
    
    if is_best:
        if best_checkpoint_path and os.path.exists(best_checkpoint_path):
            os.remove(best_checkpoint_path)
        
        best_checkpoint_path = os.path.join(checkpoint_dir, f"best_checkpoint_bs{batch_size}_step_{step}.pt")
        torch.save(checkpoint_data, best_checkpoint_path)
        print_ephemeral(f"üèÜ –õ—É—á—à–∏–π —á–µ–∫–ø–æ–∏–Ω—Ç —Å–æ—Ö—Ä–∞–Ω–µ–Ω: best_checkpoint_bs{batch_size}_step_{step}.pt")
    else:
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_bs{batch_size}_step_{step}.pt")
        torch.save(checkpoint_data, checkpoint_path)
        print_ephemeral(f"üíæ –ß–µ–∫–ø–æ–∏–Ω—Ç —Å–æ—Ö—Ä–∞–Ω–µ–Ω: checkpoint_bs{batch_size}_step_{step}.pt")

def save_latest_checkpoint(step, epoch, batch_idx=0):
    global latest_checkpoint_path
    
    checkpoint_data = {
        'step': step,
        'epoch': epoch,
        'batch_idx': batch_idx,
        'projector_state_dict': projector.state_dict(),
        'lora_state_dict': gemma_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_loss': best_val_loss,
        'config': {
            'projector_learning_rate': projector_learning_rate,
            'lora_learning_rate': lora_learning_rate,
            'weight_decay': weight_decay,
            'max_grad_norm': max_grad_norm,
            'batch_size': batch_size,
            'compression_rate_k': compression_rate_k,
            'input_dim': input_dim,
            'output_dim': output_dim,
            'experiment_name': experiment_name,
            'lora_config': {
                'r': 64,
                'lora_alpha': 128,
                'target_modules': ["k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"],
                'lora_dropout': 0.05
            }
        }
    }
    
    if latest_checkpoint_path and os.path.exists(latest_checkpoint_path):
        os.remove(latest_checkpoint_path)
    
    latest_checkpoint_path = os.path.join(checkpoint_dir, f"latest_checkpoint_bs{batch_size}_epoch_{epoch}_step_{step}.pt")
    torch.save(checkpoint_data, latest_checkpoint_path)
    
    print_ephemeral(f"üìÑ –ü–æ—Å–ª–µ–¥–Ω–∏–π —á–µ–∫–ø–æ–∏–Ω—Ç: bs{batch_size}_epoch_{epoch}_step_{step}")

def print_ephemeral(message):
    """–ü–µ—á–∞—Ç–∞–µ—Ç —Å–æ–æ–±—â–µ–Ω–∏–µ –∫–æ—Ç–æ—Ä–æ–µ –∑–∞–º–µ–Ω—è–µ—Ç—Å—è —Å–ª–µ–¥—É—é—â–∏–º –ø—Ä–∏–Ω—Ç–æ–º"""
    print(f"\r{' ' * 120}\r{message}", end="", flush=True)

def print_vanishing(message):
    print(f"\r{message}", end="", flush=True)

def check_user_input():
    global skip_validation
    
    try:
        import sys
        import select
        if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
            user_input = sys.stdin.readline().strip().lower()
            
            if user_input == 's':
                skip_validation = True
                print("\r–ü—Ä–æ–ø—É—Å–∫ –≤–∞–ª–∏–¥–∞—Ü–∏–∏ –Ω–∞ —Å–ª–µ–¥—É—é—â–µ–º —à–∞–≥–µ", end="", flush=True)
    except:
        pass

In [None]:
start_epoch = 0
global_step = 0
batch_idx = 0

if resume_training:
    if os.path.exists(checkpoint_path):
        checkpoint_epoch, global_step, batch_idx = load_checkpoint(checkpoint_path, projector, gemma_model, optimizer, scheduler, device, batch_size)
        start_epoch = checkpoint_epoch - 1
        
        # –ò–°–ü–†–ê–í–õ–ï–ù–ù–ê–Ø –õ–û–ì–ò–ö–ê: –¥–µ—Ç–µ—Ä–º–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω–æ–µ –ø–µ—Ä–µ–º–µ—à–∏–≤–∞–Ω–∏–µ + –ø—Ä–∞–≤–∏–ª—å–Ω—ã–π –ø—Ä–æ–ø—É—Å–∫
        print(f"üîÑ –í–æ–∑–æ–±–Ω–æ–≤–ª–µ–Ω–∏–µ —Å —ç–ø–æ—Ö–∏ {start_epoch + 1}, —à–∞–≥–∞ {global_step}, batch_idx {batch_idx}")
        
        # –°–æ–∑–¥–∞–µ–º –¥–µ—Ç–µ—Ä–º–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω—ã–µ –∏–Ω–¥–µ–∫—Å—ã –¥–ª—è —ç–ø–æ—Ö–∏ –≤–æ–∑–æ–±–Ω–æ–≤–ª–µ–Ω–∏—è
        random_state = random.Random(start_epoch * 12345)  # –§–∏–∫—Å–∏—Ä–æ–≤–∞–Ω–Ω—ã–π seed –¥–ª—è —ç–ø–æ—Ö–∏
        shuffled_indices = list(range(len(train_data)))
        random_state.shuffle(shuffled_indices)
        
        # –í—ã—á–∏—Å–ª—è–µ–º –∏–Ω–¥–µ–∫—Å –≤ –¥–∞—Ç–∞—Å–µ—Ç–µ, –æ—Ç–∫—É–¥–∞ –ø—Ä–æ–¥–æ–ª–∂–∞—Ç—å
        batches_to_skip = batch_idx
        samples_to_skip = batches_to_skip * batch_size
        
        if samples_to_skip < len(shuffled_indices):
            # –ë–µ—Ä–µ–º –æ—Å—Ç–∞–≤—à–∏–µ—Å—è –∏–Ω–¥–µ–∫—Å—ã (–ù–ï —Ç–µ—Ä—è–µ–º –ø—Ä–æ–ø—É—â–µ–Ω–Ω—ã–µ –¥–∞–Ω–Ω—ã–µ!)
            remaining_indices = shuffled_indices[samples_to_skip:]
            remaining_train_data = [train_data[i] for i in remaining_indices]
            
            print(f"‚ö° –î–µ—Ç–µ—Ä–º–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω–æ –ø–µ—Ä–µ–º–µ—à–∞–Ω–æ {len(train_data)} –ø—Ä–∏–º–µ—Ä–æ–≤")
            print(f"üìä –ü—Ä–æ–ø—É—Å–∫–∞–µ–º –ø–µ—Ä–≤—ã–µ {samples_to_skip} –∏–Ω–¥–µ–∫—Å–æ–≤, –æ—Å—Ç–∞–ª–æ—Å—å: {len(remaining_train_data)} –ø—Ä–∏–º–µ—Ä–æ–≤")
            
            # –ü–µ—Ä–µ—Å–æ–∑–¥–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç —Å –æ—Å—Ç–∞–≤—à–∏–º–∏—Å—è –¥–∞–Ω–Ω—ã–º–∏ –ø–æ –ø—Ä–∞–≤–∏–ª—å–Ω—ã–º –∏–Ω–¥–µ–∫—Å–∞–º
            train_dataset = AudioTextDataset(remaining_train_data, tokenizer, feature_extractor, zip_path=primary_zip_path)
            train_loader = DataLoader(
                train_dataset,
                batch_size=batch_size,
                shuffle=False,  # –ù–ï –ø–µ—Ä–µ–º–µ—à–∏–≤–∞–µ–º - –∏–Ω–¥–µ–∫—Å—ã —É–∂–µ –¥–µ—Ç–µ—Ä–º–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω–æ –ø–µ—Ä–µ–º–µ—à–∞–Ω—ã
                collate_fn=collate_fn,
                num_workers=0,
                pin_memory=False
            )
            print(f"üîß –û–±–Ω–æ–≤–ª–µ–Ω DataLoader: {len(train_loader)} –±–∞—Ç—á–µ–π")
        else:
            print(f"‚ö†Ô∏è –ù—É–∂–Ω–æ –ø—Ä–æ–ø—É—Å—Ç–∏—Ç—å {samples_to_skip} –ø—Ä–∏–º–µ—Ä–æ–≤, –Ω–æ –≤ —ç–ø–æ—Ö–µ —Ç–æ–ª—å–∫–æ {len(shuffled_indices)}")
            print("üîÑ –ü–µ—Ä–µ—Ö–æ–¥–∏–º –∫ —Å–ª–µ–¥—É—é—â–µ–π —ç–ø–æ—Ö–µ")
            start_epoch += 1
            batch_idx = 0
    else:
        resume_training = False
else:
    batch_idx = 0

print(f"üöÄ Audio Projector | {wandb.run.project}/{wandb.run.name} | {'Resume' if resume_training else 'New'}")

In [None]:
for epoch in range(start_epoch, num_epochs):
    epoch_header = f"\n{'='*50}\n"
    epoch_header += f"üîÑ –≠–ü–û–•–ê {epoch+1}/{num_epochs}\n"
    epoch_header += f"{'='*50}"
    print_vanishing(epoch_header)
    
    projector.train()
    wav2vec2.eval()
    gemma_model.eval()
    
    is_resumed_epoch = resume_training and epoch == start_epoch
    
    # üîÑ –°–æ–∑–¥–∞–Ω–∏–µ –¥–∞—Ç–∞—Å–µ—Ç–∞ —Å –ø–æ–¥–¥–µ—Ä–∂–∫–æ–π –ø–ª–∞–≤–Ω–æ–≥–æ –ø–µ—Ä–µ—Ö–æ–¥–∞
    if is_resumed_epoch:
        print(f"üîÑ –≠–ø–æ—Ö–∞ {epoch+1}: –ø—Ä–æ–¥–æ–ª–∂–µ–Ω–∏–µ —Å –ø—Ä–µ–¥–≤–∞—Ä–∏—Ç–µ–ª—å–Ω–æ —Å–æ–∑–¥–∞–Ω–Ω—ã–º DataLoader ({len(train_loader)} –±–∞—Ç—á–µ–π)")
        print(f"üìä –ù–∞—á–∏–Ω–∞–µ–º —Å batch_idx={batch_idx}, global_step={global_step}")
    elif not is_resumed_epoch:
        if dataset_blender is not None:
            # –ò—Å–ø–æ–ª—å–∑—É–µ–º DatasetBlender –¥–ª—è —Å–º–µ—à–∏–≤–∞–Ω–∏—è –¥–∞—Ç–∞—Å–µ—Ç–æ–≤
            current_train_data, blend_ratio = dataset_blender.create_blended_dataset(
                current_epoch=epoch,
                random_seed=epoch * 12345
            )
            
            # –õ–æ–≥–∏—Ä—É–µ–º –º–µ—Ç—Ä–∏–∫–∏ —Å–º–µ—à–∏–≤–∞–Ω–∏—è
            wandb.log({
                "dataset/blend_ratio": float(blend_ratio),
                "dataset/primary_examples": int(len(current_train_data) * (1 - blend_ratio)),
                "dataset/secondary_examples": int(len(current_train_data) * blend_ratio),
                "dataset/total_examples": int(len(current_train_data)),
                "dataset/epoch": int(epoch + 1)
            })
            
            # –í—ã–±–∏—Ä–∞–µ–º –ø—Ä–∞–≤–∏–ª—å–Ω—ã–π ZIP —Ñ–∞–π–ª –≤ –∑–∞–≤–∏—Å–∏–º–æ—Å—Ç–∏ –æ—Ç –ø—Ä–µ–æ–±–ª–∞–¥–∞—é—â–µ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
            current_zip_path = secondary_zip_path if blend_ratio > 0.5 else primary_zip_path
        else:
            # –û–±—ã—á–Ω–∞—è –ª–æ–≥–∏–∫–∞ –±–µ–∑ —Å–º–µ—à–∏–≤–∞–Ω–∏—è
            random_state = random.Random(epoch * 12345)  # –£–Ω–∏–∫–∞–ª—å–Ω—ã–π seed –¥–ª—è –∫–∞–∂–¥–æ–π —ç–ø–æ—Ö–∏
            shuffled_indices = list(range(len(train_data)))
            random_state.shuffle(shuffled_indices)
            current_train_data = [train_data[i] for i in shuffled_indices]
            current_zip_path = primary_zip_path
            print(f"üîÑ –≠–ø–æ—Ö–∞ {epoch+1}: –¥–µ—Ç–µ—Ä–º–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω–æ –ø–µ—Ä–µ–º–µ—à–∞–Ω–æ {len(current_train_data)} –ø—Ä–∏–º–µ—Ä–æ–≤")
        
        train_dataset = AudioTextDataset(current_train_data, tokenizer, feature_extractor, zip_path=current_zip_path)
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=False,  # –î–∞–Ω–Ω—ã–µ —É–∂–µ –¥–µ—Ç–µ—Ä–º–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω–æ –ø–µ—Ä–µ–º–µ—à–∞–Ω—ã
            collate_fn=collate_fn,
            num_workers=0,
            pin_memory=False
        )
    
    first_batch_logged = False
    accumulated_loss = 0.0

    progress_bar = tqdm(
        enumerate(train_loader), 
        total=len(train_loader),
        initial=batch_idx if is_resumed_epoch else 0,
        desc="Epoch " + str(epoch+1),
        leave=True,
        position=0,
        dynamic_ncols=True
    )

    for batch_idx, batch in progress_bar:
        try:
            # –ü—Ä–æ—Å—Ç–æ–π –∏ –ø—Ä–∞–≤–∏–ª—å–Ω—ã–π —Ä–∞—Å—á–µ—Ç: global_step —É–∂–µ —É—á–∏—Ç—ã–≤–∞–µ—Ç –≤—Å–µ –ø—Ä–µ–¥—ã–¥—É—â–∏–µ —à–∞–≥–∏
            # batch_idx –∏–∑ enumerate –∏–¥–µ—Ç 0, 1, 2... –¥–ª—è —Ç–µ–∫—É—â–µ–≥–æ DataLoader
            real_batch_number = global_step + batch_idx
            
            if not first_batch_logged:
                wandb.log({
                    "batch/audio_seq_len": int(batch['input_values'].shape[1]),
                    "batch/audio_batch_size": int(batch['input_values'].shape[0]),
                    "batch/text_seq_len": int(batch['input_ids'].shape[1]),
                    "batch/text_batch_size": int(batch['input_ids'].shape[0]),
                    "batch/grad_accum_steps": int(gradient_accumulation_steps)
                })
                first_batch_logged = True
                    
            current_global_step = real_batch_number
            
            outputs, _ = process_batch(
                batch, gemma_model, projector, wav2vec2, tokenizer, prefix_embeds, device, compression_rate_k
            )
            loss = outputs.loss
            del outputs  # –û—Å–≤–æ–±–æ–∂–¥–∞–µ–º –ø–∞–º—è—Ç—å –æ—Ç –ª–æ–≥–∏—Ç–æ–≤
            
            loss = loss / gradient_accumulation_steps
            accumulated_loss += loss.item()
            
        except torch.cuda.OutOfMemoryError:
            print_ephemeral(f"üî• OOM —à–∞–≥ {current_global_step}, –∞—É–¥–∏–æ: {batch['input_values'].shape}, —Ç–µ–∫—Å—Ç: {batch['input_ids'].shape}")
            force_gpu_cleanup()
            wandb.log({
                "train/oom_skipped_batch": 1,
                "train/oom_audio_shape": str(batch['input_values'].shape),
                "train/oom_text_shape": str(batch['input_ids'].shape),
                "step": current_global_step
            })

        if not hasattr(projector, '_gpu_logged') and torch.cuda.is_available():
            projector._gpu_logged = True
            gpu_memory = torch.cuda.memory_allocated(device) / 1024**3
            gpu_memory_reserved = torch.cuda.memory_reserved(device) / 1024**3
            gpu_memory_max = torch.cuda.max_memory_allocated(device) / 1024**3
            
            # Numeric memory utilization classification  
            memory_util_numeric = 3.0 if gpu_memory > 20 else 2.0 if gpu_memory >= 8 else 1.0
            
            wandb.log({
                "gpu/memory_allocated_gb": float(gpu_memory),
                "gpu/memory_reserved_gb": float(gpu_memory_reserved),
                "gpu/memory_peak_gb": float(gpu_memory_max),
                "gpu/batch_size": int(batch_size),
                "gpu/memory_utilization_level": memory_util_numeric  # 1=low, 2=optimal, 3=high
            })
                
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            
            # –í—ã—á–∏—Å–ª—è–µ–º –Ω–æ—Ä–º—É –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤ –î–û clipping'–∞ –¥–ª—è –¥–∏–∞–≥–Ω–æ—Å—Ç–∏–∫–∏
            grad_norm_before_clip = 0.0
            for param in projector.parameters():
                if param.grad is not None:
                    grad_norm_before_clip += param.grad.data.norm(2).item() ** 2
            grad_norm_before_clip = grad_norm_before_clip ** 0.5
            
            # –ü—Ä–∏–º–µ–Ω—è–µ–º clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(projector.parameters(), max_grad_norm)
            
            # –û–ø—Ä–µ–¥–µ–ª—è–µ–º –±—ã–ª –ª–∏ clipping
            was_clipped = grad_norm_before_clip > max_grad_norm
            

            projector_l2_norm = projector.get_l2_norm()
            
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
            
            current_lr_list = scheduler.get_last_lr()  # –ü–æ–ª—É—á–∞–µ–º –≤–µ—Å—å —Å–ø–∏—Å–æ–∫ LR
            
            # --- –î–µ—Ç–∞–ª—å–Ω–∞—è –¥–∏–∞–≥–Ω–æ—Å—Ç–∏–∫–∞ –ø–∞–º—è—Ç–∏ ---
            gpu_memory_used, gpu_memory_reserved, gpu_memory_total = get_gpu_memory_stats(device)
            
            memory_breakdown = {}
            memory_breakdown['projector'] = get_model_memory_footprint(projector)
            memory_breakdown['lora_adapter'] = get_model_memory_footprint(gemma_model, trainable_only=True)
            memory_breakdown['wav2vec2'] = get_model_memory_footprint(wav2vec2)

            trainable_params_count = sum(p.nelement() for p in projector.parameters() if p.requires_grad) + sum(p.nelement() for p in gemma_model.parameters() if p.requires_grad)
            memory_breakdown['optimizer_state'] = (trainable_params_count * 2) / 1024**2

            total_allocated_mb = gpu_memory_used * 1024 if gpu_memory_used is not None else 0
            model_related_mb = sum(memory_breakdown.values())
            memory_breakdown['activations_grads_misc'] = max(0, total_allocated_mb - model_related_mb)
            # --- –ö–æ–Ω–µ—Ü –¥–∏–∞–≥–Ω–æ—Å—Ç–∏–∫–∏ ---
            
            # –î–æ–±–∞–≤–ª—è–µ–º –¥–æ–ø–æ–ª–Ω–∏—Ç–µ–ª—å–Ω—ã–µ –º–µ—Ç—Ä–∏–∫–∏ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤ –∫ memory_breakdown
            memory_breakdown['grad_norm_before_clip'] = grad_norm_before_clip
            memory_breakdown['grad_norm_after_clip'] = grad_norm.item()
            memory_breakdown['was_clipped'] = was_clipped
            memory_breakdown['clipping_ratio'] = grad_norm.item() / max(grad_norm_before_clip, 1e-8)
            
            logger.log_step(
                current_global_step, 
                accumulated_loss, 
                current_lr_list,  # –ü–µ—Ä–µ–¥–∞–µ–º –≤–µ—Å—å —Å–ø–∏—Å–æ–∫
                grad_norm.item(),
                projector_l2_norm,
                gpu_memory_used,
                gpu_memory_reserved,
                gpu_memory_total,
                memory_breakdown
            )
            
            clip_info = f"[CLIPPED {grad_norm_before_clip:.2f}‚Üí{grad_norm.item():.2f}]" if was_clipped else ""
            metrics_str = f"Loss={accumulated_loss:.4f}, LR-Proj={current_lr_list[0]:.2e}, LR-LoRA={current_lr_list[1]:.2e}, GN={grad_norm.item():.2f}{clip_info}, L2={projector_l2_norm:.1f}"
            mem_str = f"Mem(MB):Alloc={total_allocated_mb:.0f},Act={memory_breakdown['activations_grads_misc']:.0f}"
            progress_bar.set_postfix_str(f"{metrics_str} | {mem_str}")
            
            if current_global_step % 50 == 0:
                progress_bar.write(f"üìä Step {current_global_step}: {metrics_str} | {mem_str}")
            
            accumulated_loss = 0.0
        
        check_user_input()
        
        if current_global_step % save_latest_every_steps == 0:
            save_latest_checkpoint(current_global_step, epoch + 1, batch_idx)
            # –û—Ç–ª–∞–¥–æ—á–Ω–∞—è –∏–Ω—Ñ–æ—Ä–º–∞—Ü–∏—è –¥–ª—è –ø–æ–Ω–∏–º–∞–Ω–∏—è —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏—è
            if current_global_step % 100 == 0:  # –ö–∞–∂–¥—ã–µ 100 —à–∞–≥–æ–≤
                print_ephemeral(f"üíæ –°–æ—Ö—Ä–∞–Ω–µ–Ω —á–µ–∫–ø–æ–∏–Ω—Ç: epoch={epoch+1}, global_step={current_global_step}, batch_idx={batch_idx}")
        
        if current_global_step % save_every_steps == 0:
            if skip_validation:
                skip_validation = False
            else:
                val_metrics = evaluate_with_metrics(
                    gemma_model, projector, wav2vec2, val_loader, 
                    tokenizer, prefix_embeds, device, max_new_tokens, compression_rate_k,
                    beam_width, temperature, top_k, top_p, repetition_penalty
                )
            
                logger.log_validation(current_global_step, val_metrics)
                
                is_best = val_metrics['loss'] < best_val_loss
                if is_best:
                    best_val_loss = val_metrics['loss']
                    print_ephemeral(f"üèÜ –ù–æ–≤—ã–π –ª—É—á—à–∏–π —Ä–µ–∑—É–ª—å—Ç–∞—Ç! Loss: {best_val_loss:.4f}")
                
                save_checkpoint(current_global_step, epoch + 1, batch_idx, is_best)
                
                del val_metrics
                torch.cuda.empty_cache()
                
                projector.train()
    
    if is_resumed_epoch:
        resume_training = False
        print_vanishing("‚úÖ –≠–ø–æ—Ö–∞ " + str(epoch+1) + " –∑–∞–≤–µ—Ä—à–µ–Ω–∞, –ø–µ—Ä–µ—Ö–æ–¥–∏–º –∫ –æ–±—ã—á–Ω–æ–º—É —Ä–µ–∂–∏–º—É")

In [None]:
print_ephemeral("üéâ –û–±—É—á–µ–Ω–∏–µ –∑–∞–≤–µ—Ä—à–µ–Ω–æ! –§–∏–Ω–∞–ª—å–Ω–∞—è –≤–∞–ª–∏–¥–∞—Ü–∏—è...")

full_val_dataset = AudioTextDataset(val_data, tokenizer, feature_extractor, zip_path=zip_path)
full_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

final_val_metrics = evaluate_with_metrics(
    gemma_model, projector, wav2vec2, full_val_loader, 
    tokenizer, prefix_embeds, device, max_new_tokens, compression_rate_k,
    beam_width, temperature, top_k, top_p, repetition_penalty
)

# –§–∏–Ω–∞–ª—å–Ω—ã–µ —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã –ø–æ–∫–∞–∑—ã–≤–∞–µ–º —ç—Ñ–µ–º–µ—Ä–Ω–æ
print_ephemeral(f"üìä Final: Loss={final_val_metrics['loss']:.4f} PPL={final_val_metrics['perplexity']:.2f} WER={final_val_metrics['wer']:.3f}")

logger.log_validation(global_step, final_val_metrics)

del final_val_metrics, full_val_dataset, full_val_loader
torch.cuda.empty_cache()

logger.save_logs()

final_model_path = os.path.join(checkpoint_dir, "final_projector.pt")
torch.save(projector.state_dict(), final_model_path)
print_ephemeral(f"üèÜ –ú–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞: {final_model_path}")

wandb.finish()
torch.cuda.empty_cache()
gc.collect()
print_ephemeral("‚úÖ –ó–∞–≤–µ—Ä—à–µ–Ω–æ")

In [None]:
print_ephemeral("üéâ –û–±—É—á–µ–Ω–∏–µ –∑–∞–≤–µ—Ä—à–µ–Ω–æ! –§–∏–Ω–∞–ª—å–Ω–∞—è –≤–∞–ª–∏–¥–∞—Ü–∏—è...")

full_val_dataset = AudioTextDataset(val_data, tokenizer, feature_extractor, zip_path=zip_path)
full_val_loader = DataLoader(full_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

final_val_metrics = evaluate_with_metrics(
    gemma_model, projector, wav2vec2, full_val_loader, 
    tokenizer, prefix_embeds, device, max_new_tokens, compression_rate_k,
    beam_width, temperature, top_k, top_p, repetition_penalty
)

# –§–∏–Ω–∞–ª—å–Ω—ã–µ —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã –ø–æ–∫–∞–∑—ã–≤–∞–µ–º —ç—Ñ–µ–º–µ—Ä–Ω–æ
print_ephemeral(f"üìä Final: Loss={final_val_metrics['loss']:.4f} PPL={final_val_metrics['perplexity']:.2f} WER={final_val_metrics['wer']:.3f}")

logger.log_validation(global_step, final_val_metrics)

del final_val_metrics, full_val_dataset, full_val_loader
torch.cuda.empty_cache()

logger.save_logs()

final_model_path = os.path.join(checkpoint_dir, "final_projector.pt")
torch.save(projector.state_dict(), final_model_path)
print_ephemeral(f"üèÜ –ú–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞: {final_model_path}")

wandb.finish()
torch.cuda.empty_cache()
gc.collect()
print_ephemeral("‚úÖ –ó–∞–≤–µ—Ä—à–µ–Ω–æ")

In [None]:
# üöÄ –û–ë–ù–û–í–õ–ï–ù–ù–´–ô –ö–û–î –î–õ–Ø –ë–û–†–¨–ë–´ –°–û –°–¢–ê–ì–ù–ê–¶–ò–ï–ô –û–ë–£–ß–ï–ù–ò–Ø

## ‚úÖ –í—Å–µ –Ω–æ–≤—ã–µ –∏–∑–º–µ–Ω–µ–Ω–∏—è –¥–ª—è –≤—ã—Ö–æ–¥–∞ –∏–∑ –ª–æ–∫–∞–ª—å–Ω—ã—Ö –º–∏–Ω–∏–º—É–º–æ–≤:

### 1. **üö´ MSE Loss —É–±—Ä–∞–Ω –∏–∑-–∑–∞ —Ñ—É–Ω–¥–∞–º–µ–Ω—Ç–∞–ª—å–Ω—ã—Ö –ø—Ä–æ–±–ª–µ–º**
- ‚ùå **–ü—Ä–æ–±–ª–µ–º–∞ –≤—ã—Ä–∞–≤–Ω–∏–≤–∞–Ω–∏—è**: –ù–µ–≤–æ–∑–º–æ–∂–Ω–æ –∫–æ—Ä—Ä–µ–∫—Ç–Ω–æ —Å–æ–ø–æ—Å—Ç–∞–≤–∏—Ç—å –Ω–µ–ø—Ä–µ—Ä—ã–≤–Ω—ã–π –∞—É–¥–∏–æ-–ø–æ—Ç–æ–∫ —Å –¥–∏—Å–∫—Ä–µ—Ç–Ω—ã–º–∏ —Ç–æ–∫–µ–Ω–∞–º–∏
- ‚ùå **–ò–≥–Ω–æ—Ä–∏—Ä–æ–≤–∞–Ω–∏–µ LLM**: MSE –Ω–µ —É—á–∏—Ç—ã–≤–∞–µ—Ç –≤–Ω—É—Ç—Ä–µ–Ω–Ω—é—é –ª–æ–≥–∏–∫—É –∑–∞–º–æ—Ä–æ–∂–µ–Ω–Ω–æ–π Gemma 
- ‚ùå **–ö–æ—Å–≤–µ–Ω–Ω–∞—è –æ–ø—Ç–∏–º–∏–∑–∞—Ü–∏—è**: L2-–±–ª–∏–∑–æ—Å—Ç—å —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ –Ω–µ –≥–∞—Ä–∞–Ω—Ç–∏—Ä—É–µ—Ç –≤—ã—Å–æ–∫—É—é –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å –ø—Ä–∞–≤–∏–ª—å–Ω–æ–≥–æ —Ç–æ–∫–µ–Ω–∞
- ‚úÖ **–†–µ—à–µ–Ω–∏–µ**: –ò—Å–ø–æ–ª—å–∑—É–µ–º —Ç–æ–ª—å–∫–æ Cross-Entropy loss –¥–ª—è end-to-end –æ–±—É—á–µ–Ω–∏—è —á–µ—Ä–µ–∑ LLM

### 2. **üîÑ CosineAnnealingWarmRestarts –¥–ª—è –≤—ã—Ö–æ–¥–∞ –∏–∑ –ª–æ–∫–∞–ª—å–Ω—ã—Ö –º–∏–Ω–∏–º—É–º–æ–≤**
- ‚úÖ **–ó–∞–º–µ–Ω–µ–Ω OneCycleLR**: –¢–µ–ø–µ—Ä—å –∏—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è CosineAnnealingWarmRestarts
- ‚úÖ **–ß–∞—Å—Ç—ã–µ —Ä–µ—Å—Ç–∞—Ä—Ç—ã**: T_0=250 —à–∞–≥–æ–≤ (~–∫–∞–∂–¥—ã–µ 15 –º–∏–Ω—É—Ç –æ–±—É—á–µ–Ω–∏—è)
- ‚úÖ **–ö–æ–Ω—Å—Ç–∞–Ω—Ç–Ω—ã–π –ø–µ—Ä–∏–æ–¥**: T_mult=1 (–ø–µ—Ä–∏–æ–¥ –Ω–µ —É–≤–µ–ª–∏—á–∏–≤–∞–µ—Ç—Å—è)
- ‚úÖ **–ú–∏–Ω–∏–º–∞–ª—å–Ω—ã–π LR**: 1e-6 –ø–µ—Ä–µ–¥ –∫–∞–∂–¥—ã–º —Ä–µ—Å—Ç–∞—Ä—Ç–æ–º
- ‚úÖ **–í—ã—Ö–æ–¥ –∏–∑ –º–∏–Ω–∏–º—É–º–æ–≤**: –†–µ–≥—É–ª—è—Ä–Ω—ã–µ —Å–∫–∞—á–∫–∏ LR –ø–æ–º–æ–≥–∞—é—Ç –≤—ã–π—Ç–∏ –∏–∑ –ª–æ–∫–∞–ª—å–Ω—ã—Ö –º–∏–Ω–∏–º—É–º–æ–≤

### 3. **üöÄ –£–≤–µ–ª–∏—á–µ–Ω–Ω—ã–π Learning Rate –≤ 3 —Ä–∞–∑–∞**
- ‚úÖ **–° 1e-3 –¥–æ 3e-3**: –ê–≥—Ä–µ—Å—Å–∏–≤–Ω—ã–π –ø–æ–¥—Ö–æ–¥ –¥–ª—è –ø—Ä–µ–æ–¥–æ–ª–µ–Ω–∏—è —Å—Ç–∞–≥–Ω–∞—Ü–∏–∏
- ‚úÖ **–ë–æ–ª—å—à–µ —Å–≤–æ–±–æ–¥—ã**: –ü—Ä–æ–µ–∫—Ç–æ—Ä –ø–æ–ª—É—á–∞–µ—Ç –±–æ–ª—å—à–µ —ç–Ω–µ—Ä–≥–∏–∏ –¥–ª—è –∏–∑–º–µ–Ω–µ–Ω–∏–π
- ‚úÖ **–°–æ—á–µ—Ç–∞–Ω–∏–µ —Å —Ä–µ—Å—Ç–∞—Ä—Ç–∞–º–∏**: LR –ø–µ—Ä–∏–æ–¥–∏—á–µ—Å–∫–∏ —Å–±—Ä–∞—Å—ã–≤–∞–µ—Ç—Å—è, –ø—Ä–µ–¥–æ—Ç–≤—Ä–∞—â–∞—è —Ä–∞—Å—Ö–æ–∂–¥–µ–Ω–∏–µ

### 4. **üìä –û—á–∏—â–µ–Ω–Ω–æ–µ –ª–æ–≥–∏—Ä–æ–≤–∞–Ω–∏–µ –∏ –º–æ–Ω–∏—Ç–æ—Ä–∏–Ω–≥**
- ‚úÖ **–£–±—Ä–∞–Ω MSE Loss**: –ë–æ–ª—å—à–µ –Ω–µ –æ—Ç—Å–ª–µ–∂–∏–≤–∞–µ—Ç—Å—è –∏–∑–±—ã—Ç–æ—á–Ω—ã–π MSE loss
- ‚úÖ **–í—Å–µ –º–µ—Ç—Ä–∏–∫–∏**: Projector L2 norm, weight update ratio, gradient norm
- ‚úÖ **Scheduler type**: Flexibile –≤—ã–±–æ—Ä –º–µ–∂–¥—É "onecycle" –∏ "cosine_restarts"
- ‚úÖ **Restart –ø–∞—Ä–∞–º–µ—Ç—Ä—ã**: T_0, T_mult, eta_min –≤ –∫–æ–Ω—Ñ–∏–≥—É—Ä–∞—Ü–∏–∏

### 5. **üîß –¶–µ–Ω—Ç—Ä–∞–ª–∏–∑–æ–≤–∞–Ω–Ω—ã–µ –≥–∏–ø–µ—Ä–ø–∞—Ä–∞–º–µ—Ç—Ä—ã –¥–ª—è scheduler**
- ‚úÖ **scheduler_type**: –õ–µ–≥–∫–æ–µ –ø–µ—Ä–µ–∫–ª—é—á–µ–Ω–∏–µ –º–µ–∂–¥—É –ø–æ–¥—Ö–æ–¥–∞–º–∏
- ‚úÖ **cosine_restart_period**: –ù–∞—Å—Ç—Ä–æ–π–∫–∞ –ø–µ—Ä–∏–æ–¥–∞ —Ä–µ—Å—Ç–∞—Ä—Ç–∞
- ‚úÖ **cosine_restart_mult**: –ö–æ–Ω—Ç—Ä–æ–ª—å —Ä–æ—Å—Ç–∞ –ø–µ—Ä–∏–æ–¥–∞
- ‚úÖ **cosine_eta_min**: –ú–∏–Ω–∏–º–∞–ª—å–Ω—ã–π LR –¥–ª—è —Ä–µ—Å—Ç–∞—Ä—Ç–æ–≤
- üö´ **mse_loss_weight**: –£–±—Ä–∞–Ω –≤–º–µ—Å—Ç–µ —Å MSE loss

## üéØ –ú–µ—Ö–∞–Ω–∏–∑–º –±–æ—Ä—å–±—ã —Å–æ —Å—Ç–∞–≥–Ω–∞—Ü–∏–µ–π:

**–ü–†–û–ë–õ–ï–ú–ê**: Train loss –±—ã—Å—Ç—Ä–æ –ø–∞–¥–∞–µ—Ç —Å 6-7 –¥–æ ~3.5, –∑–∞—Ç–µ–º —Å—Ç–∞–≥–Ω–∏—Ä—É–µ—Ç  
**–ü–†–ò–ß–ò–ù–ê**: –ú–∞–ª–µ–Ω—å–∫–∏–π –ø—Ä–æ–µ–∫—Ç–æ—Ä (7M –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤) –±—ã—Å—Ç—Ä–æ –Ω–∞—Ö–æ–¥–∏—Ç –ª–æ–∫–∞–ª—å–Ω—ã–π –º–∏–Ω–∏–º—É–º  

**–†–ï–®–ï–ù–ò–Ø**:
1. **üö´ –£–±—Ä–∞–Ω MSE Loss**: –ò–∑–±–µ–≥–∞–µ–º –ø—Ä–æ–±–ª–µ–º —Å –≤—ã—Ä–∞–≤–Ω–∏–≤–∞–Ω–∏–µ–º, –¥–æ–≤–µ—Ä—è–µ–º LLM feedback
2. **üîÑ Warm Restarts**: –ü–µ—Ä–∏–æ–¥–∏—á–µ—Å–∫–∏–µ "—Ç–æ–ª—á–∫–∏" LR –¥–ª—è –≤—ã—Ö–æ–¥–∞ –∏–∑ –ª–æ–∫–∞–ª—å–Ω—ã—Ö –º–∏–Ω–∏–º—É–º–æ–≤  
3. **üöÄ 3x Learning Rate**: –ë–æ–ª—å—à–µ —ç–Ω–µ—Ä–≥–∏–∏ –¥–ª—è –∏–∑–º–µ–Ω–µ–Ω–∏—è –≤–µ—Å–æ–≤
4. **üìä –ú–æ–Ω–∏—Ç–æ—Ä–∏–Ω–≥**: –û—Ç—Å–ª–µ–∂–∏–≤–∞–Ω–∏–µ –≤—Å–µ—Ö –∫–ª—é—á–µ–≤—ã—Ö –º–µ—Ç—Ä–∏–∫ –¥–ª—è –¥–∏–∞–≥–Ω–æ—Å—Ç–∏–∫–∏

## üìà –û–∂–∏–¥–∞–µ–º—ã–µ —É–ª—É—á—à–µ–Ω–∏—è:

1. **üìâ –ü—Ä–µ–æ–¥–æ–ª–µ–Ω–∏–µ —Å—Ç–∞–≥–Ω–∞—Ü–∏–∏**: Loss –¥–æ–ª–∂–µ–Ω –ø—Ä–æ–¥–æ–ª–∂–∞—Ç—å –ø–∞–¥–∞—Ç—å –ø–æ—Å–ª–µ ~3.5
2. **üéØ –õ—É—á—à–∏–π WER**: –ß–∏—Å—Ç—ã–π end-to-end —Å–∏–≥–Ω–∞–ª –æ—Ç LLM –±–µ–∑ –∏—Å–∫–∞–∂–µ–Ω–∏–π –æ—Ç MSE
3. **üîÑ –°—Ç–∞–±–∏–ª—å–Ω–æ–µ –æ–±—É—á–µ–Ω–∏–µ**: –†–µ—Å—Ç–∞—Ä—Ç—ã –ø—Ä–µ–¥–æ—Ç–≤—Ä–∞—â–∞—é—Ç –∑–∞—Å—Ç—Ä–µ–≤–∞–Ω–∏–µ
4. **‚ö° –§–æ–∫—É—Å –Ω–∞ Cross-Entropy**: –ü—Ä–æ–µ–∫—Ç–æ—Ä —É—á–∏—Ç—Å—è "–≥–æ–≤–æ—Ä–∏—Ç—å" –Ω–∞ —è–∑—ã–∫–µ LLM

## üöÄ –ì–æ—Ç–æ–≤–æ –∫ –∑–∞–ø—É—Å–∫—É —Å –∞–≥—Ä–µ—Å—Å–∏–≤–Ω—ã–º–∏ –Ω–∞—Å—Ç—Ä–æ–π–∫–∞–º–∏ –ø—Ä–æ—Ç–∏–≤ —Å—Ç–∞–≥–Ω–∞—Ü–∏–∏!