In [None]:
import sys
from tqdm import tqdm
import subprocess

def install_with_progress(package_file):
    """Установка и обновление пакетов с прогресс-баром"""
    print(f"📦 Обновляем зависимости из {package_file}...")
    
    # Читаем список пакетов
    try:
        with open(package_file, 'r') as f:
            packages = [line.strip() for line in f if line.strip() and not line.startswith('#')]
    except FileNotFoundError:
        print(f"❌ Файл {package_file} не найден. Пропускаем...")
        return
    
    print(f"🔍 Найдено {len(packages)} пакетов для обновления")
    
    # Обновляем пакеты с прогресс-баром
    for package in tqdm(packages, desc="📥 Обновление пакетов"):
        try:
            subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", package, "-q"], 
                         check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            print(f"⚠️ Ошибка при обновлении {package}: {e}")
    
    print("✅ Обновление завершено!")

# Обновляем зависимости из requirements.txt
install_with_progress("requirements.txt")

In [None]:
#!pip install torch==2.1.1 torchaudio==2.1.1 --force-reinstall --no-cache-dir

In [None]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import json
import os
import numpy as np
import soundfile as sf
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.cuda.amp import autocast, GradScaler
from transformers import (
    AutoTokenizer,
    AutoConfig,
    BitsAndBytesConfig,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model
)
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
from huggingface_hub import login
import jiwer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from sklearn.model_selection import train_test_split
from huggingface_hub import notebook_login
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from datetime import datetime
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
import zipfile
import io

In [None]:
notebook_login()

In [None]:
def to_bfloat16(tensor):
    print(f"Converting tensor from {tensor.dtype} to bfloat16, shape: {tensor.shape}")
    return tensor.to(torch.bfloat16)

def to_fp32(tensor):
    print(f"Converting tensor from {tensor.dtype} to fp32, shape: {tensor.shape}")
    return tensor.to(torch.float32)

def sync_model_dtype(model, target_dtype):
    print(f"Syncing model to {target_dtype}")
    for param in model.parameters():
        param.data = param.data.to(target_dtype)
    return model

device = torch.device("cuda")
model_id = "google/gemma-3-4b-pt"
audio_model_name = "facebook/wav2vec2-base"

batch_size = 4
num_epochs = 3
learning_rate = 1e-4
weight_decay = 0.01
max_grad_norm = 1.0  # Gradient clipping
warmup_steps = 100

input_dim = 768
output_dim = 2560

experiment_name = f"audio_projector_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
checkpoint_dir = f"./checkpoints/{experiment_name}"
os.makedirs(checkpoint_dir, exist_ok=True)

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

hf_token = os.getenv('HF_TOKEN')
if hf_token:
    login(token=hf_token)

print(f"🚀 Эксперимент: {experiment_name}")
print(f"📁 Чекпоинты: {checkpoint_dir}")
print(f"🖥️  Устройство: {device}")
print(f"⚙️  Конфигурация:")
print(f"   - Batch size: {batch_size}")
print(f"   - Epochs: {num_epochs}")
print(f"   - Learning rate: {learning_rate}")
print(f"   - Weight decay: {weight_decay}")
print(f"   - Gradient clipping: {max_grad_norm}")
print(f"   - Warmup steps: {warmup_steps}")
print(f"🎵 Audio model: {audio_model_name}")
print(f"🤖 LLM: {model_id}")
print(f"🔗 Projector: {input_dim} -> {output_dim}")

tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

print("🔧 Применяем патч для загрузки gemma-3-4b-pt как текстовой модели...")
multi_cfg = AutoConfig.from_pretrained(model_id, token=hf_token)
text_cfg_dict = multi_cfg.text_config.to_dict()

text_cfg_dict["vocab_size"] = 262208

text_cfg_dict.update({"bos_token_id": tokenizer.bos_token_id,
                      "eos_token_id": tokenizer.eos_token_id,
                      "pad_token_id": tokenizer.pad_token_id})

text_cfg = Gemma3TextConfig(**text_cfg_dict)
gemma_model = Gemma3ForCausalLM.from_pretrained(
    model_id,
    config=text_cfg,
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    device_map="auto",
    token=hf_token
)
print("✅ Патч успешно применен, модель загружена.")


gemma_model.eval()
for param in gemma_model.parameters():
    param.requires_grad = False

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(audio_model_name)
wav2vec2 = Wav2Vec2Model.from_pretrained(audio_model_name)
wav2vec2 = wav2vec2.to(torch.bfloat16).to(device)
wav2vec2.eval()
for param in wav2vec2.parameters():
    param.requires_grad = False

print(f"Gemma parameters: {sum(p.numel() for p in gemma_model.parameters()):,}")
print(f"Wav2vec2 parameters: {sum(p.numel() for p in wav2vec2.parameters()):,}")

In [None]:
class AudioProjector(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
            nn.LayerNorm(output_dim)
        )
    
    def forward(self, x):
        original_dtype = x.dtype
        x_fp32 = x.to(torch.float32)
        if next(self.proj.parameters()).dtype != torch.float32:
            self.proj = self.proj.float()
        output_fp32 = self.proj(x_fp32)
        return output_fp32.to(original_dtype)

print("Класс AudioProjector определен с правильной обработкой типов данных")
print("Входные данные: любой тип -> Вычисления: FP32 -> Выход: исходный тип")

In [None]:
projector = AudioProjector(input_dim, output_dim).to(device).float()

optimizer = optim.AdamW(
    projector.parameters(), 
    lr=learning_rate, 
    weight_decay=weight_decay,
    betas=(0.9, 0.999),
    eps=1e-8
)

scheduler = None

scaler = torch.amp.GradScaler('cuda')
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

prefix = "Транскрипция аудио: "
prefix_ids = tokenizer(prefix, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
    prefix_embeds = gemma_model.get_input_embeddings()(prefix_ids).to(dtype=torch.bfloat16)

print(f"✅ Оптимизатор: AdamW (lr={learning_rate}, wd={weight_decay})")
print(f"✅ Gradient clipping: {max_grad_norm}")
print(f"✅ Mixed precision: включен")
print(f"✅ Префикс промпта: '{prefix}'")

In [None]:
class TrainingLogger:
    def __init__(self, experiment_name, save_dir):
        self.experiment_name = experiment_name
        self.save_dir = save_dir
        self.logs = {
            'epoch': [],
            'train_loss': [],
            'val_loss': [],
            'val_perplexity': [],
            'val_wer': [],
            'val_bleu': [],
            'val_rouge_l': [],
            'learning_rate': [],
            'grad_norm': []
        }
        
    def log(self, epoch, train_loss, val_metrics, lr, grad_norm=None):
        self.logs['epoch'].append(epoch)
        self.logs['train_loss'].append(train_loss)
        self.logs['val_loss'].append(val_metrics['loss'])
        self.logs['val_perplexity'].append(val_metrics['perplexity'])
        self.logs['val_wer'].append(val_metrics['wer'])
        self.logs['val_bleu'].append(val_metrics['bleu'])
        self.logs['val_rouge_l'].append(val_metrics['rouge_l'])
        self.logs['learning_rate'].append(lr)
        if grad_norm is not None:
            self.logs['grad_norm'].append(grad_norm)
    
    def plot_training_curves(self):
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle(f'Training Progress: {self.experiment_name}', fontsize=16, fontweight='bold')
        
        axes[0, 0].plot(self.logs['epoch'], self.logs['train_loss'], 'b-', label='Train Loss', linewidth=2)
        axes[0, 0].plot(self.logs['epoch'], self.logs['val_loss'], 'r-', label='Val Loss', linewidth=2)
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Loss Curves')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        axes[0, 1].plot(self.logs['epoch'], self.logs['val_perplexity'], 'g-', linewidth=2)
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Perplexity')
        axes[0, 1].set_title('Validation Perplexity')
        axes[0, 1].grid(True, alpha=0.3)
        
        axes[0, 2].plot(self.logs['epoch'], self.logs['val_wer'], 'orange', linewidth=2)
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('WER')
        axes[0, 2].set_title('Word Error Rate')
        axes[0, 2].grid(True, alpha=0.3)
        
        axes[1, 0].plot(self.logs['epoch'], self.logs['val_bleu'], 'purple', linewidth=2)
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('BLEU Score')
        axes[1, 0].set_title('BLEU Score')
        axes[1, 0].grid(True, alpha=0.3)
        
        axes[1, 1].plot(self.logs['epoch'], self.logs['val_rouge_l'], 'brown', linewidth=2)
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('ROUGE-L')
        axes[1, 1].set_title('ROUGE-L Score')
        axes[1, 1].grid(True, alpha=0.3)
        
        axes[1, 2].plot(self.logs['epoch'], self.logs['learning_rate'], 'teal', linewidth=2)
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Learning Rate')
        axes[1, 2].set_title('Learning Rate Schedule')
        axes[1, 2].grid(True, alpha=0.3)
        axes[1, 2].set_yscale('log')
        
        plt.tight_layout()
        
        plot_path = os.path.join(self.save_dir, 'training_curves.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        print(f"📊 График сохранен: {plot_path}")
        plt.show()
    
    def save_logs(self):
        df = pd.DataFrame(self.logs)
        csv_path = os.path.join(self.save_dir, 'training_logs.csv')
        df.to_csv(csv_path, index=False)
        print(f"📝 Логи сохранены: {csv_path}")
        return df

def process_batch(batch, model, projector, wav2vec2, tokenizer, prefix_embeds, device, batch_idx=0, context=""):
    input_values = batch["input_values"].to(device, dtype=torch.bfloat16)
    input_ids = batch["input_ids"].to(device)

    with torch.amp.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
        audio_embeds = wav2vec2(input_values).last_hidden_state.mean(dim=1)
        projected_audio = projector(audio_embeds)
        batch_prefix_embeds = prefix_embeds.expand(projected_audio.size(0), -1, -1)
        
        prompt_embeds = torch.cat([batch_prefix_embeds, projected_audio.unsqueeze(1)], dim=1)
        
        embedding_input_ids = input_ids.clone()
        embedding_input_ids[embedding_input_ids == -100] = tokenizer.pad_token_id
        target_embeds = model.get_input_embeddings()(embedding_input_ids)

        inputs_embeds = torch.cat([prompt_embeds, target_embeds], dim=1)
        
        prompt_len = prompt_embeds.shape[1]
        prompt_labels = torch.full((projected_audio.size(0), prompt_len), -100, device=device, dtype=torch.long)
        labels = torch.cat([prompt_labels, input_ids], dim=1)
        
        if batch_idx == 0 and context:
            print(f"\n--- Размерности в {context} (первый батч) ---")
            print(f"Audio Projected Embeds: {projected_audio.shape}")
            print(f"Prompt Embeds (prefix + audio): {prompt_embeds.shape}")
            print(f"Target Text Embeds: {target_embeds.shape}")
            print(f"Combined Input Embeds (prompt + target): {inputs_embeds.shape}")
            print(f"Prompt Labels (игнорируются): {prompt_labels.shape}")
            print(f"Target Text Labels (input_ids): {input_ids.shape}")
            print(f"Combined Labels (ignored prompt + target): {labels.shape}")
            print("--------------------------------------------------------")

        outputs = model(inputs_embeds=inputs_embeds, labels=labels)
        
    return outputs, prompt_embeds

def evaluate_with_metrics(model, projector, wav2vec2, dataloader, tokenizer, prefix_embeds, device):
    model.eval()
    projector.eval()
    wav2vec2.eval()
    total_loss, total_wer, total_bleu = 0.0, 0.0, 0.0
    total_rouge_1, total_rouge_2, total_rouge_l = 0.0, 0.0, 0.0
    count = 0
    smooth = SmoothingFunction().method1
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
            input_ids = batch["input_ids"].to(device)

            outputs, prompt_embeds = process_batch(
                batch, model, projector, wav2vec2, tokenizer, prefix_embeds, device, 
                batch_idx=i, context="evaluate_with_metrics"
            )
            loss = outputs.loss
            total_loss += loss.item()

            with torch.amp.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
                generated_ids = model.generate(
                    inputs_embeds=prompt_embeds,
                    max_new_tokens=100,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    do_sample=False
                )
            
            input_len = prompt_embeds.shape[1]
            generated_ids_only = generated_ids[:, input_len:]

            for j in range(generated_ids.size(0)):
                pred_text = tokenizer.decode(generated_ids_only[j], skip_special_tokens=True).strip()
                ref_text_ids = input_ids[j]
                ref_text_ids = ref_text_ids[ref_text_ids != -100]
                ref_text = tokenizer.decode(ref_text_ids, skip_special_tokens=True).strip()
                
                if ref_text and pred_text:
                    total_wer += jiwer.wer(ref_text, pred_text)
                    total_bleu += sentence_bleu([ref_text.split()], pred_text.split(), smoothing_function=smooth)
                    rouge_scores = rouge_scorer_obj.score(ref_text, pred_text)
                    total_rouge_1 += rouge_scores['rouge1'].fmeasure
                    total_rouge_2 += rouge_scores['rouge2'].fmeasure
                    total_rouge_l += rouge_scores['rougeL'].fmeasure
                    count += 1
    avg_loss = total_loss / len(dataloader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    avg_wer = total_wer / count if count > 0 else 0.0
    avg_bleu = total_bleu / count if count > 0 else 0.0
    avg_rouge_1 = total_rouge_1 / count if count > 0 else 0.0
    avg_rouge_2 = total_rouge_2 / count if count > 0 else 0.0
    avg_rouge_l = total_rouge_l / count if count > 0 else 0.0
    return {
        'loss': avg_loss, 'perplexity': perplexity,
        'wer': avg_wer, 'bleu': avg_bleu,
        'rouge_1': avg_rouge_1, 'rouge_2': avg_rouge_2, 'rouge_l': avg_rouge_l
    }

In [None]:
def collate_fn(batch):
    input_values = [item['input_values'] for item in batch]
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    input_values = pad_sequence(input_values, batch_first=True)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=-100)
    attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    return {
        'input_values': input_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }

class AudioTextDataset(Dataset):
    def __init__(self, data, tokenizer, feature_extractor, zip_path=None):
        self.data = data
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.zip_path = zip_path
        self.zip_file = None
        self.extracted_folder = None
        self.use_extracted_files = False
        self.use_zip_files = False
        
        # Попытка открыть ZIP-файл или извлечь его содержимое
        if self.zip_path and os.path.exists(self.zip_path):
            try:
                self.zip_file = zipfile.ZipFile(self.zip_path, 'r')
                print(f"📦 Открыт ZIP-файл: {self.zip_path}")
                
                self.use_zip_files = True
                self.use_extracted_files = False
                
            except Exception as e:
                print(f"⚠️ Ошибка открытия ZIP: {e}")
                # Создаем папку для извлечения на основе имени ZIP
                zip_name = os.path.splitext(os.path.basename(self.zip_path))[0]
                self.extracted_folder = f"./{zip_name}"
                
                # Извлекаем ZIP-файл
                self._extract_zip_file()
                self.use_extracted_files = True
                self.use_zip_files = False
                print(f"✅ ZIP извлечен в папку: {self.extracted_folder}")
                
                # Обновляем пути к файлам
                self._update_paths_for_extracted_files()
        else:
            print(f"⚠️ ZIP файл не найден: {self.zip_path}")
            self.use_zip_files = False
            self.use_extracted_files = False
        
        print(f"📊 Загружено записей в датасете: {len(self.data)}")
    
    def _extract_zip_file(self):
        """Извлекает ZIP-файл в папку"""
        os.makedirs(self.extracted_folder, exist_ok=True)
        with zipfile.ZipFile(self.zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.extracted_folder)
    
    def _update_paths_for_extracted_files(self):
        """Обновляет пути к аудиофайлам для работы с извлеченными файлами"""
        for item in self.data:
            original_path = item["audio_path"]
            new_path = os.path.join(self.extracted_folder, original_path)
            item["audio_path"] = new_path
    
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        audio_path = item["audio_path"]
        speaker_text = item["speaker_text"]
        
        # Загружаем аудио в зависимости от режима
        if self.use_extracted_files:
            # Загрузка из извлеченных файлов
            waveform, sr = torchaudio.load(audio_path)
        elif self.use_zip_files and self.zip_file:
            # Загрузка из ZIP-файла
            zip_audio_path = audio_path
            
            # Пробуем разные варианты пути
            possible_paths = [
                audio_path,  # Оригинальный путь
                f"LibriSpeech/{audio_path}",  # С префиксом LibriSpeech
                audio_path.replace("LibriSpeech/", "LibriSpeech/LibriSpeech/"),  # Двойной путь
            ]
            
            found_path = None
            for test_path in possible_paths:
                try:
                    if test_path in self.zip_file.namelist():
                        found_path = test_path
                        break
                except:
                    continue
            
            if found_path is None:
                # Последняя попытка - поиск по имени файла
                filename = os.path.basename(audio_path)
                for zip_file_path in self.zip_file.namelist():
                    if zip_file_path.endswith(filename):
                        found_path = zip_file_path
                        break
            
            if found_path is None:
                raise FileNotFoundError(f"Не найден файл {audio_path} в ZIP архиве")
            
            with self.zip_file.open(found_path) as audio_file:
                audio_data = audio_file.read()
                waveform, sr = torchaudio.load(io.BytesIO(audio_data))
        else:
            # Загрузка с диска (обычный режим)
            # Проверяем существование файла
            if not os.path.exists(audio_path):
                raise FileNotFoundError(f"Файл не найден на диске: {audio_path}")
            
            waveform, sr = torchaudio.load(audio_path)
        
        # Приводим к нужной частоте дискретизации
        if sr != self.feature_extractor.sampling_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.feature_extractor.sampling_rate)
        
        # Если стерео - конвертируем в моно
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # Извлекаем аудио-фичи
        inputs = self.feature_extractor(
            waveform.squeeze().numpy(),
            sampling_rate=self.feature_extractor.sampling_rate,
            return_tensors="pt"
        )
        
        # Токенизируем текст
        tokens = self.tokenizer(
            speaker_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        
        return {
            "input_values": inputs.input_values.squeeze(0),
            "input_ids": tokens.input_ids.squeeze(0),
            "attention_mask": tokens.attention_mask.squeeze(0)
        }
    
    def __del__(self):
        # Закрываем ZIP-файл при удалении объекта
        if hasattr(self, 'zip_file') and self.zip_file:
            self.zip_file.close()

In [None]:
def gemma_layer_backward_hook(module, grad_input, grad_output, layer_index, total_layers):
    progress = total_layers - layer_index
    print(f"<-- Gemma backward pass: [{progress}/{total_layers}] completed for layer index {layer_index}.")

try:
    gemma_layers = gemma_model.model.layers
    num_layers = len(gemma_layers)
    print(f"Found {num_layers} layers in Gemma model. Registering backward hooks...")
    
    for i, layer in enumerate(gemma_layers):
        layer.register_full_backward_hook(
            lambda module, grad_input, grad_output, index=i, total=num_layers: gemma_layer_backward_hook(module, grad_input, grad_output, index, total)
        )
    print(f"Successfully registered backward hooks for all Gemma layers.")
    print(f"During training, you will see progress messages from 1/{num_layers} to {num_layers}/{num_layers}.")

except AttributeError:
    print("Could not find 'gemma_model.model.layers'. Unable to register detailed backward hooks for Gemma.")
    print("The training will proceed without them.")

In [None]:
def gemma_layer_backward_hook(module, grad_input, grad_output, layer_index, total_layers):
    """A hook that prints progress during the backward pass through Gemma layers."""
    layers_processed = total_layers - layer_index
    print(f"<-- Backward pass: Reached Gemma Layer {layer_index} (from 0 to {total_layers-1}). Progress: {layers_processed}/{total_layers} layers.")

if hasattr(gemma_model, 'model') and hasattr(gemma_model.model, 'layers'):
    gemma_layers = gemma_model.model.layers
    num_layers = len(gemma_layers)
    print(f"\nRegistering backward hooks for {num_layers} Gemma layers to monitor progress...")
    
    for i, layer in enumerate(gemma_layers):
        layer.register_full_backward_hook(
            lambda module, grad_input, grad_output, index=i, total=num_layers: gemma_layer_backward_hook(module, grad_input, grad_output, index, total)
        )
        
    print(f"Gemma hooks registered for layers 0 to {num_layers-1}.")
    print(f"During training, you will see progress messages as the backward pass moves through the layers in reverse order (e.g., starting with layer {num_layers-1}, then {num_layers-2}, etc.).")
else:
    print("Could not find `gemma_model.model.layers` to attach hooks.")

In [None]:
jsonl_path = "transcripts.jsonl"
zip_path = "dataset.zip"  # Путь к ZIP-файлу с датасетом

print(f"📂 Загружаем данные из {jsonl_path}...")

with open(jsonl_path, "r", encoding="utf-8") as f:
    all_data = [json.loads(line) for line in f]

print(f"📊 Загружено записей: {len(all_data)}")

# Нормализуем данные: переименовываем поля
normalized_data = []
for item in all_data:
    normalized_item = {
        "audio_path": item.get("audio_filepath", ""),
        "speaker_text": item.get("text", ""),
        "language": item.get("language", "en"),
        "source": item.get("source", "unknown")
    }
    normalized_data.append(normalized_item)

train_data, val_data = train_test_split(normalized_data, test_size=0.1, random_state=42)

# Создаем датасеты с поддержкой ZIP-файла
train_dataset = AudioTextDataset(train_data, tokenizer, feature_extractor, zip_path=zip_path)
val_dataset = AudioTextDataset(val_data, tokenizer, feature_extractor, zip_path=zip_path)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"📊 Data loaded: {len(train_data)} train, {len(val_data)} val samples.")

total_steps = num_epochs * len(train_loader)
scheduler = OneCycleLR(
    optimizer,
    max_lr=learning_rate,
    total_steps=total_steps,
    pct_start=0.1,  # 10% warmup
    anneal_strategy='cos'
)

print(f"📅 Total training steps: {total_steps}")
print(f"🔥 Warmup steps: {int(0.1 * total_steps)}")

logger = TrainingLogger(experiment_name, checkpoint_dir)

print(f"🚀 Начинаем обучение модели Audio Projector!")

for epoch in range(num_epochs):
    print(f"\n{'='*60}")
    print(f"🔄 EPOCH {epoch+1}/{num_epochs}")
    print(f"{'='*60}")
    
    projector.train()
    wav2vec2.eval()
    gemma_model.eval()
    
    epoch_loss = 0
    total_grad_norm = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        optimizer.zero_grad()
        
        outputs, _ = process_batch(
            batch, gemma_model, projector, wav2vec2, tokenizer, prefix_embeds, device,
            batch_idx=batch_idx, context="training loop" if batch_idx == 0 else ""
        )
        loss = outputs.loss
        
        # Backward pass с gradient clipping
        scaler.scale(loss).backward()
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(projector.parameters(), max_grad_norm)
        total_grad_norm += grad_norm.item()
        
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()  # Обновляем learning rate
        
        epoch_loss += loss.item()
        current_lr = scheduler.get_last_lr()[0]
        
        # Обновляем progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'LR': f'{current_lr:.2e}',
            'Grad': f'{grad_norm.item():.3f}'
        })

    # Вычисляем средние значения за эпоху
    avg_train_loss = epoch_loss / len(train_loader)
    avg_grad_norm = total_grad_norm / len(train_loader)
    final_lr = scheduler.get_last_lr()[0]
    
    print(f"\n📈 Epoch {epoch+1} Training Results:")
    print(f"   📉 Average Loss: {avg_train_loss:.4f}")
    print(f"   🎯 Learning Rate: {final_lr:.2e}")
    print(f"   ✂️  Average Grad Norm: {avg_grad_norm:.3f}")
    
    # Валидация
    print(f"\n🔍 Running validation...")
    val_metrics = evaluate_with_metrics(gemma_model, projector, wav2vec2, val_loader, tokenizer, prefix_embeds, device)
    
    print(f"\n📊 Validation Results:")
    print(f"   📉 Loss: {val_metrics['loss']:.4f}")
    print(f"   🧮 Perplexity: {val_metrics['perplexity']:.2f}")
    print(f"   🎯 WER: {val_metrics['wer']:.3f}")
    print(f"   🔤 BLEU: {val_metrics['bleu']:.3f}")
    print(f"   📝 ROUGE-L: {val_metrics['rouge_l']:.3f}")
    
    logger.log(epoch+1, avg_train_loss, val_metrics, final_lr, avg_grad_norm)
    
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
    torch.save({
        'epoch': epoch+1,
        'projector_state_dict': projector.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': avg_train_loss,
        'val_metrics': val_metrics,
        'config': {
            'learning_rate': learning_rate,
            'weight_decay': weight_decay,
            'max_grad_norm': max_grad_norm,
            'batch_size': batch_size,
            'input_dim': input_dim,
            'output_dim': output_dim
        }
    }, checkpoint_path)
    print(f"💾 Checkpoint saved: {checkpoint_path}")
    
    logger.plot_training_curves()

print(f"\n{'='*60}")
print(f"🎉 ОБУЧЕНИЕ ЗАВЕРШЕНО!")
print(f"{'='*60}")

final_logs_df = logger.save_logs()
print(f"\n📋 Итоговая статистика:")
print(final_logs_df.round(4))

final_model_path = os.path.join(checkpoint_dir, "final_projector.pt")
torch.save(projector.state_dict(), final_model_path)
print(f"🏆 Финальная модель сохранена: {final_model_path}")

print(f"\n📁 Все файлы сохранены в: {checkpoint_dir}")
print(f"   - Чекпоинты: checkpoint_epoch_*.pt")
print(f"   - Финальная модель: final_projector.pt")
print(f"   - Логи: training_logs.csv")
print(f"   - Графики: training_curves.png")