In [1]:
import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass
from transformers import AutoModel, AutoTokenizer, Wav2Vec2FeatureExtractor, GemmaForCausalLM, GemmaConfig, QuantoConfig

In [None]:
@dataclass
class TrainingConfig:
    # –ú–æ–¥–µ–ª–∏
    GEMMA_MODEL_ID: str = "google/gemma-3-4b-pt"
    XLSR_MODEL_ID: str = "facebook/wav2vec2-xls-r-300m"
    
    # –¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞
    EPOCHS: int = 50
    BATCH_SIZE: int = 4
    LEARNING_RATE: float = 1e-4
    GRADIENT_CLIP: float = 1.0
    
    # –î–∞–Ω–Ω—ã–µ
    DATASET_PATH: str = "transcripts.jsonl"
    MAX_AUDIO_LENGTH: int = 16000 * 30  # 30 —Å–µ–∫—É–Ω–¥
    MAX_TEXT_LENGTH: int = 512
    
    # –°–∏—Å—Ç–µ–º–∞
    DEVICE: str = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    SAVE_EVERY: int = 10  # –°–æ—Ö—Ä–∞–Ω—è—Ç—å –∫–∞–∂–¥—ã–µ N —ç–ø–æ—Ö
    
    # –ü—Ä–µ—Ñ–∏–∫—Å –¥–ª—è —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏
    TEXT_PREFIX: str = "–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è –∞—É–¥–∏–æ: "

In [None]:
class AudioProjector(nn.Module):
    def __init__(self, audio_hidden_size: int, llm_hidden_size: int):
        super().__init__()
        # –£–ª—É—á—à–µ–Ω–Ω–∞—è –∞—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä–∞ —Å LayerNorm –¥–ª—è —Å—Ç–∞–±–∏–ª—å–Ω–æ—Å—Ç–∏
        self.proj = nn.Sequential(
            nn.LayerNorm(audio_hidden_size),
            nn.Linear(audio_hidden_size, llm_hidden_size * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(llm_hidden_size * 2, llm_hidden_size),
            nn.LayerNorm(llm_hidden_size)
        )

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        return self.proj(audio_embeds)

In [4]:
def create_gemma_config(vocab_size, pad_token_id):
    return GemmaConfig(
        vocab_size=vocab_size,
        pad_token_id=pad_token_id,
        hidden_size=2560,
        intermediate_size=10240,
        num_hidden_layers=34,
        num_attention_heads=20,
        num_key_value_heads=20,
        head_dim=128,
        model_type="gemma"
    )

In [5]:
class AudioGemmaModel(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        
        self.tokenizer = AutoTokenizer.from_pretrained(config.GEMMA_MODEL_ID)
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        gemma_config = create_gemma_config(self.tokenizer.vocab_size, self.tokenizer.pad_token_id)
        
        self.gemma = GemmaForCausalLM.from_pretrained(
            config.GEMMA_MODEL_ID, 
            config=gemma_config,
            quantization_config=QuantoConfig(weights="int4"),
            device_map={"": config.DEVICE},
            torch_dtype=torch.bfloat16
        )
        self.gemma.resize_token_embeddings(len(self.tokenizer))
        
        self.audio_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config.XLSR_MODEL_ID)
        self.audio_encoder = AutoModel.from_pretrained(config.XLSR_MODEL_ID).to(config.DEVICE)
        self.projector = AudioProjector(self.audio_encoder.config.hidden_size, self.gemma.config.hidden_size).to(config.DEVICE)
        
        for param in self.audio_encoder.parameters():
            param.requires_grad = False
        for param in self.gemma.parameters():
            param.requires_grad = False

In [6]:
def forward(self, audio_values, input_ids, attention_mask):
    audio_embeds = self.audio_encoder(audio_values).last_hidden_state
    projected_audio = self.projector(audio_embeds)
    text_embeds = self.gemma.get_input_embeddings()(input_ids)
    
    combined_embeds = torch.cat([projected_audio, text_embeds], dim=1)
    combined_embeds = combined_embeds.to(self.gemma.device).to(self.gemma.dtype)
    audio_mask = torch.ones(projected_audio.shape[:2], dtype=torch.long, device=projected_audio.device)
    combined_mask = torch.cat([audio_mask, attention_mask], dim=1)
    
    return self.gemma(inputs_embeds=combined_embeds, attention_mask=combined_mask).logits

AudioGemmaModel.forward = forward

In [7]:
config = TrainingConfig()
model = AudioGemmaModel(config)
model.eval()

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of GemmaForCausalLM were not initialized from the model checkpoint at google/gemma-3-4b-pt and are newly initialized: ['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.

AudioGemmaModel(
  (gemma): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(262145, 2560, padding_idx=0)
      (layers): ModuleList(
        (0-33): 34 x GemmaDecoderLayer(
          (self_attn): GemmaAttention(
            (q_proj): QLinear(in_features=2560, out_features=2560, bias=False)
            (k_proj): QLinear(in_features=2560, out_features=2560, bias=False)
            (v_proj): QLinear(in_features=2560, out_features=2560, bias=False)
            (o_proj): QLinear(in_features=2560, out_features=2560, bias=False)
          )
          (mlp): GemmaMLP(
            (gate_proj): QLinear(in_features=2560, out_features=10240, bias=False)
            (up_proj): QLinear(in_features=2560, out_features=10240, bias=False)
            (down_proj): QLinear(in_features=10240, out_features=2560, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm((2560,), eps=1e-06)
          (post_attention_layernorm): GemmaRM

In [8]:
dummy_audio = [np.random.randn(32000).astype(np.float32) for _ in range(config.BATCH_SIZE)]
audio_processed = model.audio_extractor(dummy_audio, return_tensors="pt", sampling_rate=16000, padding=True)
audio_values = audio_processed.input_values.to(config.DEVICE)
dummy_texts = ["Test text"] * config.BATCH_SIZE
text_processed = model.tokenizer(dummy_texts, return_tensors="pt", padding=True, max_length=32)
input_ids = text_processed.input_ids.to(config.DEVICE)
attention_mask = text_processed.attention_mask.to(config.DEVICE)
print(f"Audio shape: {audio_values.shape}")
print(f"Text shape: {input_ids.shape}")

Audio shape: torch.Size([4, 32000])
Text shape: torch.Size([4, 3])




In [9]:
print(f"–ò—Å–ø–æ–ª—å–∑—É–µ–º–æ–µ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {config.DEVICE}")

raw_audio_sr = 16000
dummy_audio_waveforms = [np.random.randn(raw_audio_sr * 2).astype(np.float32) for _ in range(config.BATCH_SIZE)]
audio_processed = model.audio_extractor(dummy_audio_waveforms, return_tensors="pt", sampling_rate=raw_audio_sr, padding=True)
audio_input_values = audio_processed.input_values.to(config.DEVICE)
print(f"–§–æ—Ä–º–∞ audio_input_values: {audio_input_values.shape}, —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {audio_input_values.device}")

dummy_texts = ["–≠—Ç–æ –ø—Ä–∏–º–µ—Ä —Ç–µ–∫—Å—Ç–∞ –¥–ª—è –º–æ–¥–µ–ª–∏ Gemma." for _ in range(config.BATCH_SIZE)]
text_tokenized = model.tokenizer(dummy_texts, return_tensors="pt", padding=True, truncation=True, max_length=32)
input_ids = text_tokenized.input_ids.to(config.DEVICE)
attention_mask = text_tokenized.attention_mask.to(config.DEVICE)
print(f"–§–æ—Ä–º–∞ input_ids: {input_ids.shape}, —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {input_ids.device}")
print(f"–§–æ—Ä–º–∞ attention_mask: {attention_mask.shape}, —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {attention_mask.device}")

print("\n–í—ã–ø–æ–ª–Ω–µ–Ω–∏–µ —Ç–µ—Å—Ç–æ–≤–æ–≥–æ –ø—Ä–æ–≥–æ–Ω–∞ –º–æ–¥–µ–ª–∏ (forward pass)...")
try:
    with torch.no_grad():
        logits = model(audio_input_values, input_ids, attention_mask)
    print(f"Success! Logits shape: {logits.shape}")
except Exception as e:
    print(f"–ö–†–ò–¢–ò–ß–ï–°–ö–ê–Ø –û–®–ò–ë–ö–ê –≤–æ –≤—Ä–µ–º—è forward pass: {e}")
    import traceback
    traceback.print_exc()
print("\n--- –¢–µ—Å—Ç–æ–≤—ã–π –∑–∞–ø—É—Å–∫ –∑–∞–≤–µ—Ä—à—ë–Ω ---")

–ò—Å–ø–æ–ª—å–∑—É–µ–º–æ–µ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: mps
–§–æ—Ä–º–∞ audio_input_values: torch.Size([4, 32000]), —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: mps:0
–§–æ—Ä–º–∞ input_ids: torch.Size([4, 8]), —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: mps:0
–§–æ—Ä–º–∞ attention_mask: torch.Size([4, 8]), —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: mps:0

–í—ã–ø–æ–ª–Ω–µ–Ω–∏–µ —Ç–µ—Å—Ç–æ–≤–æ–≥–æ –ø—Ä–æ–≥–æ–Ω–∞ –º–æ–¥–µ–ª–∏ (forward pass)...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Success! Logits shape: torch.Size([4, 107, 262145])

--- –¢–µ—Å—Ç–æ–≤—ã–π –∑–∞–ø—É—Å–∫ –∑–∞–≤–µ—Ä—à—ë–Ω ---


In [11]:
# Sampling from logits to generate varied outputs
import torch.nn.functional as F
batch_size, seq_len, vocab_size = logits.shape
sampled_ids = torch.zeros(batch_size, seq_len, dtype=torch.long, device=logits.device)
for t in range(seq_len):
    probs_t = F.softmax(logits[:, t, :], dim=-1)
    sampled_ids[:, t] = torch.multinomial(probs_t, num_samples=1).squeeze(-1)
sampled_texts = [model.tokenizer.decode(ids, skip_special_tokens=True) for ids in sampled_ids]
print('Sampled texts:', sampled_texts)



In [None]:
import json
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os

class AudioTextDataset(Dataset):
    def __init__(self, jsonl_path: str, config: TrainingConfig, audio_extractor, tokenizer):
        self.config = config
        self.audio_extractor = audio_extractor
        self.tokenizer = tokenizer
        
        # –ó–∞–≥—Ä—É–∂–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç
        self.data = []
        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                item = json.loads(line)
                if os.path.exists(item["audio_path"]):
                    self.data.append({
                        "audio_path": item["audio_path"],
                        "text": item["speaker_text"]
                    })
        
        print(f"–ó–∞–≥—Ä—É–∂–µ–Ω–æ {len(self.data)} –ø—Ä–∏–º–µ—Ä–æ–≤")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        try:
            # –ó–∞–≥—Ä—É–∂–∞–µ–º –∞—É–¥–∏–æ
            waveform, sample_rate = torchaudio.load(item["audio_path"])
            waveform = waveform.mean(dim=0, keepdim=True)  # –ú–æ–Ω–æ
            
            # –†–µ—Å–µ–º–ø–ª–∏–Ω–≥
            if sample_rate != 16000:
                resampler = torchaudio.transforms.Resample(sample_rate, 16000)
                waveform = resampler(waveform)
            
            # –û–±—Ä–µ–∑–∞–µ–º –∏–ª–∏ –¥–æ–ø–æ–ª–Ω—è–µ–º
            if waveform.shape[1] > self.config.MAX_AUDIO_LENGTH:
                waveform = waveform[:, :self.config.MAX_AUDIO_LENGTH]
            
            # –û–±—Ä–∞–±–∞—Ç—ã–≤–∞–µ–º –∞—É–¥–∏–æ
            audio_input = self.audio_extractor(
                waveform.squeeze(0).numpy(),
                sampling_rate=16000,
                return_tensors="pt",
                padding="max_length",
                max_length=self.config.MAX_AUDIO_LENGTH
            )
            
            # –¢–æ–∫–µ–Ω–∏–∑–∏—Ä—É–µ–º —Ç–µ–∫—Å—Ç —Å –ø—Ä–µ—Ñ–∏–∫—Å–æ–º
            full_text = self.config.TEXT_PREFIX + item["text"]
            text_input = self.tokenizer(
                full_text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.config.MAX_TEXT_LENGTH
            )
            
            return {
                "audio": audio_input.input_values.squeeze(0),
                "input_ids": text_input.input_ids.squeeze(0),
                "attention_mask": text_input.attention_mask.squeeze(0),
                "text": item["text"]
            }
            
        except Exception as e:
            print(f"–û—à–∏–±–∫–∞ –ø—Ä–∏ –∑–∞–≥—Ä—É–∑–∫–µ {item['audio_path']}: {e}")
            # –í–æ–∑–≤—Ä–∞—â–∞–µ–º –ø—É—Å—Ç–æ–π –ø—Ä–∏–º–µ—Ä
            return self.__getitem__((idx + 1) % len(self.data))

def collate_fn(batch):
    """–§—É–Ω–∫—Ü–∏—è –¥–ª—è –æ–±—ä–µ–¥–∏–Ω–µ–Ω–∏—è –ø—Ä–∏–º–µ—Ä–æ–≤ –≤ batch"""
    audio_batch = torch.stack([item["audio"] for item in batch])
    input_ids_batch = torch.stack([item["input_ids"] for item in batch])
    attention_mask_batch = torch.stack([item["attention_mask"] for item in batch])
    texts = [item["text"] for item in batch]
    
    return {
        "audio": audio_batch,
        "input_ids": input_ids_batch, 
        "attention_mask": attention_mask_batch,
        "texts": texts
    }

In [None]:
def train_model(model: AudioGemmaModel, config: TrainingConfig):
    """–û—Å–Ω–æ–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏"""
    
    # –°–æ–∑–¥–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç –∏ DataLoader
    dataset = AudioTextDataset(
        config.DATASET_PATH, 
        config, 
        model.audio_extractor, 
        model.tokenizer
    )
    
    dataloader = DataLoader(
        dataset, 
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2 if config.DEVICE == "cuda" else 0
    )
    
    # –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä –∏ loss
    optimizer = torch.optim.AdamW(
        model.projector.parameters(), 
        lr=config.LEARNING_RATE,
        weight_decay=0.01
    )
    
    # –ü–ª–∞–Ω–∏—Ä–æ–≤—â–∏–∫ learning rate
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=config.EPOCHS
    )
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=model.tokenizer.pad_token_id)
    
    # –ü—Ä–µ—Ñ–∏–∫—Å –¥–ª—è –≤—ã—á–∏—Å–ª–µ–Ω–∏—è loss
    prefix_ids = model.tokenizer(
        config.TEXT_PREFIX, 
        return_tensors="pt"
    ).input_ids.to(config.DEVICE)
    prefix_len = prefix_ids.shape[1]
    
    # –¢—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω—ã–π —Ü–∏–∫–ª
    for epoch in range(config.EPOCHS):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.EPOCHS}")
        
        for batch in progress_bar:
            try:
                # –ü–µ—Ä–µ–º–µ—â–∞–µ–º –¥–∞–Ω–Ω—ã–µ –Ω–∞ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ
                audio = batch["audio"].to(config.DEVICE)
                input_ids = batch["input_ids"].to(config.DEVICE)
                attention_mask = batch["attention_mask"].to(config.DEVICE)
                
                # Forward pass
                optimizer.zero_grad()
                
                # –ü–æ–ª—É—á–∞–µ–º audio embeddings
                with torch.no_grad():
                    audio_embeds = model.audio_encoder(audio).last_hidden_state
                
                # –ü—Ä–æ–µ—Ü–∏—Ä—É–µ–º –∞—É–¥–∏–æ
                projected_audio = model.projector(audio_embeds)
                
                # –ü–æ–ª—É—á–∞–µ–º text embeddings
                text_embeds = model.gemma.get_input_embeddings()(input_ids)
                
                # –û–±—ä–µ–¥–∏–Ω—è–µ–º embeddings
                combined_embeds = torch.cat([projected_audio, text_embeds], dim=1)
                combined_embeds = combined_embeds.to(model.gemma.dtype)
                
                # –°–æ–∑–¥–∞–µ–º –º–∞—Å–∫–∏
                audio_mask = torch.ones(
                    projected_audio.shape[:2], 
                    dtype=torch.long, 
                    device=config.DEVICE
                )
                combined_mask = torch.cat([audio_mask, attention_mask], dim=1)
                
                # Forward —á–µ—Ä–µ–∑ Gemma
                outputs = model.gemma(
                    inputs_embeds=combined_embeds,
                    attention_mask=combined_mask
                )
                logits = outputs.logits
                
                # –í—ã—á–∏—Å–ª—è–µ–º loss —Ç–æ–ª—å–∫–æ –¥–ª—è —Ç–µ–∫—Å—Ç–æ–≤–æ–π —á–∞—Å—Ç–∏
                audio_seq_len = projected_audio.shape[1]
                text_logits = logits[:, audio_seq_len:-1, :].contiguous()
                text_labels = input_ids[:, prefix_len:].contiguous()
                
                loss = loss_fn(
                    text_logits.view(-1, text_logits.size(-1)),
                    text_labels.view(-1)
                )
                
                # Backward pass
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    model.projector.parameters(), 
                    config.GRADIENT_CLIP
                )
                optimizer.step()
                
                epoch_loss += loss.item()
                num_batches += 1
                
                # –û–±–Ω–æ–≤–ª—è–µ–º progress bar
                progress_bar.set_postfix({
                    "Loss": f"{loss.item():.4f}",
                    "Avg Loss": f"{epoch_loss/num_batches:.4f}",
                    "LR": f"{scheduler.get_last_lr()[0]:.2e}"
                })
                
            except Exception as e:
                print(f"–û—à–∏–±–∫–∞ –≤ batch: {e}")
                continue
        
        # –û–±–Ω–æ–≤–ª—è–µ–º learning rate
        scheduler.step()
        
        # –õ–æ–≥–∏—Ä—É–µ–º —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã —ç–ø–æ—Ö–∏
        avg_loss = epoch_loss / num_batches if num_batches > 0 else 0
        print(f"Epoch {epoch+1} –∑–∞–≤–µ—Ä—à–µ–Ω–∞. Average Loss: {avg_loss:.4f}")
        
        # –°–æ—Ö—Ä–∞–Ω—è–µ–º —á–µ–∫–ø–æ–∏–Ω—Ç
        if (epoch + 1) % config.SAVE_EVERY == 0:
            checkpoint_path = f"projector_epoch_{epoch+1}.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.projector.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': avg_loss,
                'config': config
            }, checkpoint_path)
            print(f"–ß–µ–∫–ø–æ–∏–Ω—Ç —Å–æ—Ö—Ä–∞–Ω–µ–Ω: {checkpoint_path}")
    
    # –§–∏–Ω–∞–ª—å–Ω–æ–µ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ
    final_path = "audio_projector_final.pth"
    torch.save(model.projector.state_dict(), final_path)
    print(f"–§–∏–Ω–∞–ª—å–Ω–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞: {final_path}")

# –ó–∞–ø—É—Å–∫ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏
if __name__ == "__main__":
    config = TrainingConfig()
    model = AudioGemmaModel(config)
    
    print("–ù–∞—á–∏–Ω–∞–µ–º —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫—É...")
    train_model(model, config)

# –§—É–Ω–∫—Ü–∏–∏ –¥–ª—è –∏–Ω—Ñ–µ—Ä–µ–Ω—Å–∞ –∏ —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏—è

–ü–æ—Å–ª–µ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏ –º–æ–∂–Ω–æ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å –º–æ–¥–µ–ª—å –¥–ª—è –≥–µ–Ω–µ—Ä–∞—Ü–∏–∏ —Ç—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏–π –Ω–æ–≤—ã—Ö –∞—É–¥–∏–æ —Ñ–∞–π–ª–æ–≤.

In [None]:
def transcribe_audio(model: AudioGemmaModel, audio_path: str, config: TrainingConfig, max_length: int = 256):
    """–¢—Ä–∞–Ω—Å–∫—Ä–∏–±–∏—Ä—É–µ—Ç –∞—É–¥–∏–æ —Ñ–∞–π–ª –∏—Å–ø–æ–ª—å–∑—É—è –æ–±—É—á–µ–Ω–Ω—É—é –º–æ–¥–µ–ª—å"""
    
    model.eval()
    
    try:
        # –ó–∞–≥—Ä—É–∂–∞–µ–º –∞—É–¥–∏–æ
        waveform, sample_rate = torchaudio.load(audio_path)
        waveform = waveform.mean(dim=0, keepdim=True)  # –ú–æ–Ω–æ
        
        # –†–µ—Å–µ–º–ø–ª–∏–Ω–≥
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)
        
        # –û–±—Ä–µ–∑–∞–µ–º –µ—Å–ª–∏ —Å–ª–∏—à–∫–æ–º –¥–ª–∏–Ω–Ω–æ–µ
        if waveform.shape[1] > config.MAX_AUDIO_LENGTH:
            waveform = waveform[:, :config.MAX_AUDIO_LENGTH]
        
        # –û–±—Ä–∞–±–∞—Ç—ã–≤–∞–µ–º –∞—É–¥–∏–æ
        audio_input = model.audio_extractor(
            waveform.squeeze(0).numpy(),
            sampling_rate=16000,
            return_tensors="pt"
        )
        audio_values = audio_input.input_values.to(config.DEVICE)
        
        # –ù–∞—á–∞–ª—å–Ω—ã–π –ø—Ä–µ—Ñ–∏–∫—Å
        prefix_text = config.TEXT_PREFIX
        input_ids = model.tokenizer(
            prefix_text,
            return_tensors="pt"
        ).input_ids.to(config.DEVICE)
        
        with torch.no_grad():
            # –ü–æ–ª—É—á–∞–µ–º audio embeddings
            audio_embeds = model.audio_encoder(audio_values).last_hidden_state
            projected_audio = model.projector(audio_embeds)
            
            # –ù–∞—á–∞–ª—å–Ω—ã–µ text embeddings
            text_embeds = model.gemma.get_input_embeddings()(input_ids)
            
            # –û–±—ä–µ–¥–∏–Ω—è–µ–º
            combined_embeds = torch.cat([projected_audio, text_embeds], dim=1)
            combined_embeds = combined_embeds.to(model.gemma.dtype)
            
            # –ì–µ–Ω–µ—Ä–∏—Ä—É–µ–º —Ç–µ–∫—Å—Ç
            generated_ids = input_ids.clone()
            
            for _ in range(max_length):
                # –ü–æ–ª—É—á–∞–µ–º embeddings –¥–ª—è —Ç–µ–∫—É—â–µ–π –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
                current_text_embeds = model.gemma.get_input_embeddings()(generated_ids)
                current_combined = torch.cat([projected_audio, current_text_embeds], dim=1)
                current_combined = current_combined.to(model.gemma.dtype)
                
                # Forward pass
                outputs = model.gemma(inputs_embeds=current_combined)
                logits = outputs.logits
                
                # –ë–µ—Ä–µ–º –ø–æ—Å–ª–µ–¥–Ω–∏–π —Ç–æ–∫–µ–Ω
                next_token_logits = logits[0, -1, :]
                next_token_id = torch.argmax(next_token_logits, dim=-1)
                
                # –î–æ–±–∞–≤–ª—è–µ–º –∫ –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
                generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0).unsqueeze(0)], dim=1)
                
                # –ü—Ä–æ–≤–µ—Ä—è–µ–º –Ω–∞ –∫–æ–Ω–µ—Ü –ø–æ—Å–ª–µ–¥–æ–≤–∞—Ç–µ–ª—å–Ω–æ—Å—Ç–∏
                if next_token_id == model.tokenizer.eos_token_id:
                    break
            
            # –î–µ–∫–æ–¥–∏—Ä—É–µ–º —Ä–µ–∑—É–ª—å—Ç–∞—Ç
            generated_text = model.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            
            # –£–±–∏—Ä–∞–µ–º –ø—Ä–µ—Ñ–∏–∫—Å
            if generated_text.startswith(prefix_text):
                transcription = generated_text[len(prefix_text):].strip()
            else:
                transcription = generated_text.strip()
                
            return transcription
            
    except Exception as e:
        print(f"–û—à–∏–±–∫–∞ –ø—Ä–∏ —Ç—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏–∏ {audio_path}: {e}")
        return None

# –§—É–Ω–∫—Ü–∏—è –¥–ª—è –∑–∞–≥—Ä—É–∑–∫–∏ –æ–±—É—á–µ–Ω–Ω–æ–π –º–æ–¥–µ–ª–∏
def load_trained_model(checkpoint_path: str, config: TrainingConfig):
    """–ó–∞–≥—Ä—É–∂–∞–µ—Ç –æ–±—É—á–µ–Ω–Ω—É—é –º–æ–¥–µ–ª—å –∏–∑ —á–µ–∫–ø–æ–∏–Ω—Ç–∞"""
    
    model = AudioGemmaModel(config)
    
    if checkpoint_path.endswith('_final.pth'):
        # –ü—Ä–æ—Å—Ç–æ–µ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ç–æ–ª—å–∫–æ projector
        model.projector.load_state_dict(torch.load(checkpoint_path, map_location=config.DEVICE))
    else:
        # –ü–æ–ª–Ω—ã–π —á–µ–∫–ø–æ–∏–Ω—Ç
        checkpoint = torch.load(checkpoint_path, map_location=config.DEVICE)
        model.projector.load_state_dict(checkpoint['model_state_dict'])
        print(f"–ó–∞–≥—Ä—É–∂–µ–Ω–∞ –º–æ–¥–µ–ª—å —Å —ç–ø–æ—Ö–∏ {checkpoint['epoch']}, loss: {checkpoint['loss']:.4f}")
    
    model.eval()
    return model

# –ü—Ä–∏–º–µ—Ä –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—è
def test_transcription():
    """–¢–µ—Å—Ç–∏—Ä—É–µ–º —Ç—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—é –Ω–∞ –ø—Ä–∏–º–µ—Ä–µ"""
    config = TrainingConfig()
    
    # –ó–∞–≥—Ä—É–∂–∞–µ–º –æ–±—É—á–µ–Ω–Ω—É—é –º–æ–¥–µ–ª—å
    model = load_trained_model("audio_projector_final.pth", config)
    
    # –¢–µ—Å—Ç–∏—Ä—É–µ–º –Ω–∞ —Ñ–∞–π–ª–µ
    test_audio_path = "test_audio.wav"  # –ó–∞–º–µ–Ω–∏—Ç–µ –Ω–∞ –≤–∞—à —Ñ–∞–π–ª
    
    if os.path.exists(test_audio_path):
        transcription = transcribe_audio(model, test_audio_path, config)
        print(f"–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è: {transcription}")
    else:
        print(f"–§–∞–π–ª {test_audio_path} –Ω–µ –Ω–∞–π–¥–µ–Ω")

# test_transcription()  # –†–∞—Å–∫–æ–º–º–µ–Ω—Ç–∏—Ä—É–π—Ç–µ –¥–ª—è —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏—è

# üî¢ –ö–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏—è –≤ Deep Learning: –¢–µ–æ—Ä–∏—è –∏ –ü—Ä–∞–∫—Ç–∏–∫–∞

## –ß—Ç–æ —Ç–∞–∫–æ–µ –∫–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏—è?

**–ö–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏—è** - —ç—Ç–æ –ø—Ä–æ—Ü–µ—Å—Å —É–º–µ–Ω—å—à–µ–Ω–∏—è —Ç–æ—á–Ω–æ—Å—Ç–∏ —á–∏—Å–µ–ª —Å –ø–ª–∞–≤–∞—é—â–µ–π —Ç–æ—á–∫–æ–π –¥–ª—è —ç–∫–æ–Ω–æ–º–∏–∏ –ø–∞–º—è—Ç–∏ –∏ —É—Å–∫–æ—Ä–µ–Ω–∏—è –≤—ã—á–∏—Å–ª–µ–Ω–∏–π.

### –¢–∏–ø—ã –¥–∞–Ω–Ω—ã—Ö –∏ –∏—Ö —Ä–∞–∑–º–µ—Ä—ã:
- **FP32** (float32): 32 –±–∏—Ç–∞, ~7 –∑–Ω–∞—á–∞—â–∏—Ö —Ü–∏—Ñ—Ä
- **FP16** (float16): 16 –±–∏—Ç, ~3-4 –∑–Ω–∞—á–∞—â–∏—Ö —Ü–∏—Ñ—Ä—ã  
- **BF16** (bfloat16): 16 –±–∏—Ç, –±–æ–ª—å—à–∏–π –¥–∏–∞–ø–∞–∑–æ–Ω —á–µ–º FP16
- **INT8**: 8 –±–∏—Ç, —Ç–æ–ª—å–∫–æ —Ü–µ–ª—ã–µ —á–∏—Å–ª–∞
- **INT4**: 4 –±–∏—Ç–∞, –æ—á–µ–Ω—å –æ–≥—Ä–∞–Ω–∏—á–µ–Ω–Ω—ã–π –¥–∏–∞–ø–∞–∑–æ–Ω

### –ü—Ä–æ–±–ª–µ–º—ã —Å FP16:
- **–ü–µ—Ä–µ–ø–æ–ª–Ω–µ–Ω–∏–µ** (overflow): —á–∏—Å–ª–∞ —Å—Ç–∞–Ω–æ–≤—è—Ç—Å—è `inf`
- **–ò—Å—á–µ–∑–Ω–æ–≤–µ–Ω–∏–µ** (underflow): –æ—á–µ–Ω—å –º–∞–ª–µ–Ω—å–∫–∏–µ —á–∏—Å–ª–∞ —Å—Ç–∞–Ω–æ–≤—è—Ç—Å—è `0`
- **–ü–æ—Ç–µ—Ä—è —Ç–æ—á–Ω–æ—Å—Ç–∏**: –Ω–∞–∫–æ–ø–ª–µ–Ω–∏–µ –æ—à–∏–±–æ–∫ –æ–∫—Ä—É–≥–ª–µ–Ω–∏—è

In [None]:
# –î–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏—è –ø—Ä–æ–±–ª–µ–º —Å FP16
import torch

print("=== –ü—Ä–æ–±–ª–µ–º—ã —Å FP16 ===")

# 1. –ü–µ—Ä–µ–ø–æ–ª–Ω–µ–Ω–∏–µ (Overflow)
large_number = torch.tensor([65000.0], dtype=torch.float32)
print(f"FP32: {large_number}")
print(f"FP16: {large_number.half()}")  # –ú–æ–∂–µ—Ç —Å—Ç–∞—Ç—å inf

# 2. –ò—Å—á–µ–∑–Ω–æ–≤–µ–Ω–∏–µ (Underflow) 
small_number = torch.tensor([1e-8], dtype=torch.float32)
print(f"FP32: {small_number}")
print(f"FP16: {small_number.half()}")  # –°—Ç–∞–Ω–µ—Ç 0

# 3. –ü–æ—Ç–µ—Ä—è —Ç–æ—á–Ω–æ—Å—Ç–∏ –≤ –≥—Ä–∞–¥–∏–µ–Ω—Ç–∞—Ö
gradient = torch.tensor([1e-6], dtype=torch.float32)
print(f"Gradient FP32: {gradient}")
print(f"Gradient FP16: {gradient.half()}")

# 4. –°—Ä–∞–≤–Ω–µ–Ω–∏–µ –¥–∏–∞–ø–∞–∑–æ–Ω–æ–≤
print(f"\nFP16 range: {torch.finfo(torch.float16).min} to {torch.finfo(torch.float16).max}")
print(f"FP32 range: {torch.finfo(torch.float32).min} to {torch.finfo(torch.float32).max}")
print(f"BF16 range: {torch.finfo(torch.bfloat16).min} to {torch.finfo(torch.bfloat16).max}")

## üéØ –¢–∏–ø—ã –∫–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏–∏

### 1. **Post-Training Quantization (PTQ)**
- –ö–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏—è **–ü–û–°–õ–ï** —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏
- –ë—ã—Å—Ç—Ä–æ, –Ω–æ –º–æ–∂–µ—Ç –ø–æ—Ç–µ—Ä—è—Ç—å –∫–∞—á–µ—Å—Ç–≤–æ
- –ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è –¥–ª—è –∏–Ω—Ñ–µ—Ä–µ–Ω—Å–∞

### 2. **Quantization-Aware Training (QAT)**  
- –ö–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏—è **–í–û –í–†–ï–ú–Ø** —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏
- –ú–æ–¥–µ–ª—å —É—á–∏—Ç—Å—è —Ä–∞–±–æ—Ç–∞—Ç—å —Å –∫–≤–∞–Ω—Ç–∏–∑–æ–≤–∞–Ω–Ω—ã–º–∏ –≤–µ—Å–∞–º–∏
- –õ—É—á—à–µ–µ –∫–∞—á–µ—Å—Ç–≤–æ, –Ω–æ —Å–ª–æ–∂–Ω–µ–µ

### 3. **Mixed Precision Training**
- –ß–∞—Å—Ç—å –æ–ø–µ—Ä–∞—Ü–∏–π –≤ FP16/BF16
- –ö—Ä–∏—Ç–∏—á–µ—Å–∫–∏–µ –æ–ø–µ—Ä–∞—Ü–∏–∏ –≤ FP32
- –ê–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–æ–µ –º–∞—Å—à—Ç–∞–±–∏—Ä–æ–≤–∞–Ω–∏–µ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤

### 4. **Selective Quantization**
- –ö–≤–∞–Ω—Ç–∏–∑–∏—Ä—É–µ–º —Ç–æ–ª—å–∫–æ –Ω–µ–∫–æ—Ç–æ—Ä—ã–µ —Å–ª–æ–∏
- –ù–∞–ø—Ä–∏–º–µ—Ä: –∑–∞–º–æ—Ä–æ–∑–∏–ª–∏ LLM –≤ INT4, —Ç—Ä–µ–Ω–∏—Ä—É–µ–º –∞–¥–∞–ø—Ç–µ—Ä –≤ FP32

## üéØ –ù–∞—à —Å–ª—É—á–∞–π: Audio + Gemma

### –ß—Ç–æ —É –Ω–∞—Å –µ—Å—Ç—å:

```
Audio Encoder (Wav2Vec2) -> Projector -> Gemma (–∑–∞–º–æ—Ä–æ–∂–µ–Ω)
     ‚Üì                         ‚Üì           ‚Üì
 –ó–∞–º–æ—Ä–æ–∂–µ–Ω            –¢—Ä–µ–Ω–∏—Ä—É–µ—Ç—Å—è    –ó–∞–º–æ—Ä–æ–∂–µ–Ω
 FP32/FP16               FP32         INT4
```

### –°—Ç—Ä–∞—Ç–µ–≥–∏—è –∫–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏–∏:

1. **Gemma**: INT4 –∫–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏—è (—É–∂–µ –∑–∞–º–æ—Ä–æ–∂–µ–Ω, —Ç–æ–ª—å–∫–æ –∏–Ω—Ñ–µ—Ä–µ–Ω—Å)
2. **Audio Encoder**: FP16 (–∑–∞–º–æ—Ä–æ–∂–µ–Ω, –º–æ–∂–Ω–æ –∫–≤–∞–Ω—Ç–∏–∑–æ–≤–∞—Ç—å) 
3. **Projector**: FP32 (—Ç—Ä–µ–Ω–∏—Ä—É–µ—Ç—Å—è, –Ω—É–∂–Ω–∞ —Ç–æ—á–Ω–æ—Å—Ç—å)
4. **–ì—Ä–∞–¥–∏–µ–Ω—Ç—ã**: FP32 —Å gradient scaling

### –ü–æ—á–µ–º—É —É –≤–∞—Å –±—ã–ª–∏ NaN —Å FP16:
- –ì—Ä–∞–¥–∏–µ–Ω—Ç—ã projector'–∞ —Å—Ç–∞–ª–∏ —Å–ª–∏—à–∫–æ–º –º–∞–ª–µ–Ω—å–∫–∏–º–∏
- FP16 –Ω–µ —Å–º–æ–≥ –∏—Ö –ø—Ä–µ–¥—Å—Ç–∞–≤–∏—Ç—å ‚Üí 0 ‚Üí NaN –≤ loss

In [None]:
# –ü—Ä–∞–≤–∏–ª—å–Ω–∞—è —Ä–µ–∞–ª–∏–∑–∞—Ü–∏—è —Å Mixed Precision
from torch.amp import GradScaler, autocast
import torch.nn as nn

class OptimizedAudioGemmaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(config.GEMMA_MODEL_ID)
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Gemma –≤ INT4 (–∑–∞–º–æ—Ä–æ–∂–µ–Ω)
        self.gemma = GemmaForCausalLM.from_pretrained(
            config.GEMMA_MODEL_ID,
            quantization_config=QuantoConfig(weights="int4"),  # INT4!
            device_map={"": config.DEVICE},
            torch_dtype=torch.bfloat16  # BF16 –ª—É—á—à–µ —á–µ–º FP16
        )
        
        # Audio encoder –≤ BF16 (–∑–∞–º–æ—Ä–æ–∂–µ–Ω)
        self.audio_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config.XLSR_MODEL_ID)
        self.audio_encoder = AutoModel.from_pretrained(
            config.XLSR_MODEL_ID,
            torch_dtype=torch.bfloat16  # –≠–∫–æ–Ω–æ–º–∏–º –ø–∞–º—è—Ç—å
        ).to(config.DEVICE)
        
        # Projector –≤ FP32 (—Ç—Ä–µ–Ω–∏—Ä—É–µ—Ç—Å—è!)
        self.projector = AudioProjector(
            self.audio_encoder.config.hidden_size,
            self.gemma.config.hidden_size
        ).to(config.DEVICE).to(torch.float32)  # –û–±—è–∑–∞—Ç–µ–ª—å–Ω–æ FP32!
        
        # –ó–∞–º–æ—Ä–∞–∂–∏–≤–∞–µ–º
        for param in self.audio_encoder.parameters():
            param.requires_grad = False
        for param in self.gemma.parameters():
            param.requires_grad = False
    
    def forward(self, audio_values, input_ids, attention_mask):
        # Audio processing –≤ BF16
        with autocast(device_type=config.DEVICE.split(':')[0]):
            audio_embeds = self.audio_encoder(audio_values.to(torch.bfloat16)).last_hidden_state
            
            # Projector –≤ FP32 (—Ç–æ—á–Ω–æ—Å—Ç—å –≤–∞–∂–Ω–∞!)
            projected_audio = self.projector(audio_embeds.to(torch.float32))
            
            # Text embeddings
            text_embeds = self.gemma.get_input_embeddings()(input_ids)
            
            # –û–±—ä–µ–¥–∏–Ω—è–µ–º (–ø—Ä–∏–≤–æ–¥–∏–º –∫ BF16 –¥–ª—è Gemma)
            combined_embeds = torch.cat([
                projected_audio.to(torch.bfloat16), 
                text_embeds
            ], dim=1)
            
            # –ú–∞—Å–∫–∏
            audio_mask = torch.ones(projected_audio.shape[:2], dtype=torch.long, device=config.DEVICE)
            combined_mask = torch.cat([audio_mask, attention_mask], dim=1)
            
            # Gemma inference –≤ BF16
            return self.gemma(inputs_embeds=combined_embeds, attention_mask=combined_mask).logits

In [None]:
# –¢—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è —Å Mixed Precision –∏ GradScaler
def train_with_mixed_precision(model, config):
    """–¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞ —Å –ø—Ä–∞–≤–∏–ª—å–Ω–æ–π –∫–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏–µ–π –∏ mixed precision"""
    
    # GradScaler –¥–ª—è –∞–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–æ–≥–æ –º–∞—Å—à—Ç–∞–±–∏—Ä–æ–≤–∞–Ω–∏—è –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤
    scaler = GradScaler()
    
    # –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä —Ç–æ–ª—å–∫–æ –¥–ª—è projector (–≤ FP32!)
    optimizer = torch.optim.AdamW(
        model.projector.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=0.01
    )
    
    # –°–æ–∑–¥–∞–µ–º –ø—Ä–æ—Å—Ç–æ–π –ø—Ä–∏–º–µ—Ä –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏
    dummy_audio = torch.randn(2, 16000).to(config.DEVICE)
    dummy_text = ["–ü—Ä–∏–≤–µ—Ç –º–∏—Ä", "–¢–µ—Å—Ç —Ç–µ–∫—Å—Ç–∞"]
    
    # –û–±—Ä–∞–±–∞—Ç—ã–≤–∞–µ–º –¥–∞–Ω–Ω—ã–µ
    audio_processed = model.audio_extractor(
        [audio.cpu().numpy() for audio in dummy_audio], 
        return_tensors="pt", 
        sampling_rate=16000,
        padding=True
    )
    audio_values = audio_processed.input_values.to(config.DEVICE)
    
    text_processed = model.tokenizer(
        dummy_text, 
        return_tensors="pt", 
        padding=True, 
        max_length=64
    )
    input_ids = text_processed.input_ids.to(config.DEVICE)
    attention_mask = text_processed.attention_mask.to(config.DEVICE)
    
    print("=== –î–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏—è Mixed Precision Training ===")
    
    for step in range(3):
        optimizer.zero_grad()
        
        # Forward pass —Å –∞–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–∏–º casting
        with autocast(device_type=config.DEVICE.split(':')[0]):
            # –ü–æ–ª—É—á–∞–µ–º logits
            logits = model(audio_values, input_ids, attention_mask)
            
            # –ü—Ä–æ—Å—Ç–æ–π loss –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏
            # –í —Ä–µ–∞–ª—å–Ω–æ—Å—Ç–∏ —Ç—É—Ç –±—É–¥–µ—Ç –ø—Ä–∞–≤–∏–ª—å–Ω—ã–π —Ä–∞—Å—á–µ—Ç loss –¥–ª—è seq2seq
            target_ids = input_ids[:, 1:]  # –°–¥–≤–∏–≥–∞–µ–º –¥–ª—è next token prediction
            logits_for_loss = logits[:, -target_ids.shape[1]:, :]
            
            loss = nn.CrossEntropyLoss()(
                logits_for_loss.reshape(-1, logits_for_loss.size(-1)),
                target_ids.reshape(-1)
            )
        
        # Backward —Å –º–∞—Å—à—Ç–∞–±–∏—Ä–æ–≤–∞–Ω–∏–µ–º –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤
        scaler.scale(loss).backward()
        
        # –ü—Ä–æ–≤–µ—Ä—è–µ–º –≥—Ä–∞–¥–∏–µ–Ω—Ç—ã –ø–µ—Ä–µ–¥ –æ–±–Ω–æ–≤–ª–µ–Ω–∏–µ–º
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.projector.parameters(), max_norm=1.0)
        
        # –û–±–Ω–æ–≤–ª—è–µ–º –ø–∞—Ä–∞–º–µ—Ç—Ä—ã
        scaler.step(optimizer)
        scaler.update()
        
        print(f"Step {step+1}: Loss = {loss.item():.4f}")
        
        # –ü—Ä–æ–≤–µ—Ä—è–µ–º, —á—Ç–æ –≥—Ä–∞–¥–∏–µ–Ω—Ç—ã –Ω–µ NaN
        for name, param in model.projector.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                has_nan = torch.isnan(param.grad).any().item()
                print(f"  {name}: grad_norm={grad_norm:.6f}, has_nan={has_nan}")
    
    print("\n‚úÖ –¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞ –∑–∞–≤–µ—Ä—à–µ–Ω–∞ –±–µ–∑ NaN!")
    return model

# –¢–µ—Å—Ç–∏—Ä—É–µ–º
config = TrainingConfig()
optimized_model = OptimizedAudioGemmaModel(config)
trained_model = train_with_mixed_precision(optimized_model, config)

In [None]:
# –ê–Ω–∞–ª–∏–∑ –ø–æ—Ç—Ä–µ–±–ª–µ–Ω–∏—è –ø–∞–º—è—Ç–∏
def analyze_memory_usage():
    """–°—Ä–∞–≤–Ω–∏–≤–∞–µ–º –ø–æ—Ç—Ä–µ–±–ª–µ–Ω–∏–µ –ø–∞–º—è—Ç–∏ —Ä–∞–∑–Ω—ã—Ö –ø–æ–¥—Ö–æ–¥–æ–≤"""
    
    if torch.cuda.is_available():
        device = "cuda"
    else:
        print("CUDA –Ω–µ–¥–æ—Å—Ç—É–ø–Ω–∞, –∏—Å–ø–æ–ª—å–∑—É–µ–º CPU –¥–ª—è –¥–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏–∏")
        device = "cpu"
    
    def get_memory_mb():
        if device == "cuda":
            return torch.cuda.memory_allocated() / 1024 / 1024
        else:
            return 0  # –ù–∞ CPU —Å–ª–æ–∂–Ω–µ–µ –∏–∑–º–µ—Ä–∏—Ç—å
    
    print("=== –ê–Ω–∞–ª–∏–∑ –ø–æ—Ç—Ä–µ–±–ª–µ–Ω–∏—è –ø–∞–º—è—Ç–∏ ===\n")
    
    # –ë–∞–∑–æ–≤–∞—è –ø–∞–º—è—Ç—å
    torch.cuda.empty_cache() if device == "cuda" else None
    base_memory = get_memory_mb()
    print(f"–ë–∞–∑–æ–≤–∞—è –ø–∞–º—è—Ç—å: {base_memory:.1f} MB")
    
    # 1. –í—Å–µ –≤ FP32
    print("\n1. –í—Å–µ –∫–æ–º–ø–æ–Ω–µ–Ω—Ç—ã –≤ FP32:")
    model_fp32 = torch.nn.Linear(1024, 2560).to(device).to(torch.float32)
    memory_fp32 = get_memory_mb() - base_memory
    print(f"   –ü–∞–º—è—Ç—å: {memory_fp32:.1f} MB")
    del model_fp32
    
    # 2. –í—Å–µ –≤ FP16
    print("\n2. –í—Å–µ –∫–æ–º–ø–æ–Ω–µ–Ω—Ç—ã –≤ FP16:")
    model_fp16 = torch.nn.Linear(1024, 2560).to(device).to(torch.float16)
    memory_fp16 = get_memory_mb() - base_memory
    print(f"   –ü–∞–º—è—Ç—å: {memory_fp16:.1f} MB")
    print(f"   –≠–∫–æ–Ω–æ–º–∏—è: {(memory_fp32 - memory_fp16) / memory_fp32 * 100:.1f}%")
    del model_fp16
    
    # 3. Mixed precision (–Ω–∞—à –ø–æ–¥—Ö–æ–¥)
    print("\n3. Mixed Precision (–æ–ø—Ç–∏–º–∞–ª—å–Ω—ã–π):")
    # –ó–∞–º–æ—Ä–æ–∂–µ–Ω–Ω—ã–µ —á–∞—Å—Ç–∏ –≤ FP16/INT4
    frozen_part = torch.nn.Linear(1024, 2560).to(device).to(torch.float16)
    frozen_part.requires_grad_(False)
    
    # –¢—Ä–µ–Ω–∏—Ä—É–µ–º–∞—è —á–∞—Å—Ç—å –≤ FP32
    trainable_part = torch.nn.Linear(1024, 512).to(device).to(torch.float32)
    
    memory_mixed = get_memory_mb() - base_memory
    print(f"   –ü–∞–º—è—Ç—å: {memory_mixed:.1f} MB")
    print(f"   –≠–∫–æ–Ω–æ–º–∏—è vs FP32: {(memory_fp32 - memory_mixed) / memory_fp32 * 100:.1f}%")
    
    # Cleanup
    del frozen_part, trainable_part
    torch.cuda.empty_cache() if device == "cuda" else None
    
    print("\n=== –†–µ–∫–æ–º–µ–Ω–¥–∞—Ü–∏–∏ ===")
    print("‚úÖ –ó–∞–º–æ—Ä–æ–∂–µ–Ω–Ω—ã–µ –º–æ–¥–µ–ª–∏: INT4/FP16")
    print("‚úÖ –¢—Ä–µ–Ω–∏—Ä—É–µ–º—ã–µ —Å–ª–æ–∏: FP32")
    print("‚úÖ –ò—Å–ø–æ–ª—å–∑—É–π—Ç–µ GradScaler")
    print("‚úÖ Gradient checkpointing –¥–ª—è –±–æ–ª—å—à–∏—Ö –º–æ–¥–µ–ª–µ–π")

analyze_memory_usage()

# üéØ –ò—Ç–æ–≥–æ–≤—ã–µ —Ä–µ–∫–æ–º–µ–Ω–¥–∞—Ü–∏–∏ –¥–ª—è –≤–∞—à–µ–≥–æ –ø—Ä–æ–µ–∫—Ç–∞

## –û–ø—Ç–∏–º–∞–ª—å–Ω–∞—è —Å—Ç—Ä–∞—Ç–µ–≥–∏—è –∫–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏–∏:

### 1. **Gemma (–∑–∞–º–æ—Ä–æ–∂–µ–Ω)**
```python
quantization_config=QuantoConfig(weights="int4")
torch_dtype=torch.bfloat16
```
- **INT4** –≤–µ—Å–∞ (—ç–∫–æ–Ω–æ–º–∏—è –ø–∞–º—è—Ç–∏ –≤ 8 —Ä–∞–∑!)
- **BF16** –∞–∫—Ç–∏–≤–∞—Ü–∏–∏ (—Å—Ç–∞–±–∏–ª—å–Ω–µ–µ FP16)

### 2. **Audio Encoder (–∑–∞–º–æ—Ä–æ–∂–µ–Ω)** 
```python
torch_dtype=torch.bfloat16
param.requires_grad = False
```
- **BF16** (—ç–∫–æ–Ω–æ–º–∏—è –ø–∞–º—è—Ç–∏ –≤ 2 —Ä–∞–∑–∞)
- –ë–µ–∑ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤

### 3. **Projector (—Ç—Ä–µ–Ω–∏—Ä—É–µ—Ç—Å—è)**
```python
.to(torch.float32)  # –û–±—è–∑–∞—Ç–µ–ª—å–Ω–æ!
```
- **FP32** –¥–ª—è —Å—Ç–∞–±–∏–ª—å–Ω–æ—Å—Ç–∏
- –≠—Ç–æ –º–∞–ª–µ–Ω—å–∫–∏–π —Å–ª–æ–π, –ø–∞–º—è—Ç—å –Ω–µ –∫—Ä–∏—Ç–∏—á–Ω–∞

### 4. **–¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞**
```python
from torch.amp import GradScaler, autocast
scaler = GradScaler()
with autocast(device_type="cuda"):
    # forward pass
```

## –ü–æ—á–µ–º—É —É –≤–∞—Å –±—ã–ª–∏ NaN —Å FP16:

1. **–ú–∞–ª–µ–Ω—å–∫–∏–µ –≥—Ä–∞–¥–∏–µ–Ω—Ç—ã** ‚Üí FP16 –Ω–µ –º–æ–∂–µ—Ç –∏—Ö –ø—Ä–µ–¥—Å—Ç–∞–≤–∏—Ç—å ‚Üí 0
2. **0 –≥—Ä–∞–¥–∏–µ–Ω—Ç—ã** ‚Üí –¥–µ–ª–µ–Ω–∏–µ –Ω–∞ 0 –≤ Adam ‚Üí NaN  
3. **NaN –≤ loss** ‚Üí –∫—Ä–∞—Ö —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏

## –†–µ—à–µ–Ω–∏–µ:
- ‚úÖ **Projector –≤ FP32** (–≥—Ä–∞–¥–∏–µ–Ω—Ç—ã —Å—Ç–∞–±–∏–ª—å–Ω—ã)
- ‚úÖ **GradScaler** (–∞–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–æ–µ –º–∞—Å—à—Ç–∞–±–∏—Ä–æ–≤–∞–Ω–∏–µ)
- ‚úÖ **BF16 –≤–º–µ—Å—Ç–æ FP16** (–±–æ–ª—å—à–∏–π –¥–∏–∞–ø–∞–∑–æ–Ω)
- ‚úÖ **Gradient clipping** (–∑–∞—â–∏—Ç–∞ –æ—Ç –≤–∑—Ä—ã–≤–æ–≤)

In [None]:
# üöÄ –§–∏–Ω–∞–ª—å–Ω–∞—è –æ–ø—Ç–∏–º–∏–∑–∏—Ä–æ–≤–∞–Ω–Ω–∞—è —Ä–µ–∞–ª–∏–∑–∞—Ü–∏—è –¥–ª—è –ø—Ä–æ–¥–∞–∫—à–µ–Ω–∞

@dataclass
class OptimizedTrainingConfig:
    # –ú–æ–¥–µ–ª–∏
    GEMMA_MODEL_ID: str = "google/gemma-3-4b-pt"
    XLSR_MODEL_ID: str = "facebook/wav2vec2-xls-r-300m"
    
    # –¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞ —Å mixed precision
    EPOCHS: int = 50
    BATCH_SIZE: int = 8  # –ú–æ–∂–Ω–æ –±–æ–ª—å—à–µ –±–ª–∞–≥–æ–¥–∞—Ä—è –∫–≤–∞–Ω—Ç–∏–∑–∞—Ü–∏–∏
    LEARNING_RATE: float = 1e-4
    GRADIENT_CLIP: float = 1.0
    USE_MIXED_PRECISION: bool = True
    
    # –î–∞–Ω–Ω—ã–µ
    DATASET_PATH: str = "transcripts.jsonl"
    MAX_AUDIO_LENGTH: int = 16000 * 30
    MAX_TEXT_LENGTH: int = 512
    
    # –°–∏—Å—Ç–µ–º–∞
    DEVICE: str = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    SAVE_EVERY: int = 10
    TEXT_PREFIX: str = "–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è –∞—É–¥–∏–æ: "

class ProductionAudioGemmaModel(nn.Module):
    """–û–ø—Ç–∏–º–∏–∑–∏—Ä–æ–≤–∞–Ω–Ω–∞—è –º–æ–¥–µ–ª—å –¥–ª—è –ø—Ä–æ–¥–∞–∫—à–µ–Ω–∞"""
    
    def __init__(self, config: OptimizedTrainingConfig):
        super().__init__()
        self.config = config
        
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(config.GEMMA_MODEL_ID)
        if not self.tokenizer.pad_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Gemma –≤ INT4 + BF16 (–º–∞–∫—Å–∏–º–∞–ª—å–Ω–∞—è —ç–∫–æ–Ω–æ–º–∏—è –ø–∞–º—è—Ç–∏)
        self.gemma = GemmaForCausalLM.from_pretrained(
            config.GEMMA_MODEL_ID,
            quantization_config=QuantoConfig(weights="int4"),
            device_map={"": config.DEVICE},
            torch_dtype=torch.bfloat16
        )
        
        # Audio encoder –≤ BF16 (–∑–∞–º–æ—Ä–æ–∂–µ–Ω)
        self.audio_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config.XLSR_MODEL_ID)
        self.audio_encoder = AutoModel.from_pretrained(
            config.XLSR_MODEL_ID,
            torch_dtype=torch.bfloat16
        ).to(config.DEVICE)
        
        # Projector –≤ FP32 (—Ç—Ä–µ–Ω–∏—Ä—É–µ—Ç—Å—è) - –ö–†–ò–¢–ò–ß–ù–û –¥–ª—è —Å—Ç–∞–±–∏–ª—å–Ω–æ—Å—Ç–∏!
        self.projector = AudioProjector(
            self.audio_encoder.config.hidden_size,
            self.gemma.config.hidden_size
        ).to(config.DEVICE).to(torch.float32)
        
        # –ó–∞–º–æ—Ä–∞–∂–∏–≤–∞–µ–º –≤—Å–µ –∫—Ä–æ–º–µ projector
        for param in self.audio_encoder.parameters():
            param.requires_grad = False
        for param in self.gemma.parameters():
            param.requires_grad = False
            
        print(f"‚úÖ –ú–æ–¥–µ–ª—å –∏–Ω–∏—Ü–∏–∞–ª–∏–∑–∏—Ä–æ–≤–∞–Ω–∞:")
        print(f"   Gemma: INT4 weights + BF16 activations")
        print(f"   Audio Encoder: BF16 (frozen)")
        print(f"   Projector: FP32 (trainable)")
    
    def forward(self, audio_values, input_ids, attention_mask):
        # –ò—Å–ø–æ–ª—å–∑—É–µ–º autocast –¥–ª—è –∞–≤—Ç–æ–º–∞—Ç–∏—á–µ—Å–∫–æ–≥–æ —É–ø—Ä–∞–≤–ª–µ–Ω–∏—è —Ç–∏–ø–∞–º–∏
        with autocast(device_type=self.config.DEVICE.split(':')[0], enabled=self.config.USE_MIXED_PRECISION):
            # Audio processing –≤ BF16
            audio_embeds = self.audio_encoder(audio_values.to(torch.bfloat16)).last_hidden_state
            
            # Projector –≤ FP32 –¥–ª—è —Ç–æ—á–Ω–æ—Å—Ç–∏ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤
            projected_audio = self.projector(audio_embeds.to(torch.float32))
            
            # Text embeddings
            text_embeds = self.gemma.get_input_embeddings()(input_ids)
            
            # –ü—Ä–∏–≤–æ–¥–∏–º –∫ BF16 –¥–ª—è Gemma
            combined_embeds = torch.cat([
                projected_audio.to(torch.bfloat16),
                text_embeds
            ], dim=1)
            
            # Attention masks
            audio_mask = torch.ones(projected_audio.shape[:2], dtype=torch.long, device=self.config.DEVICE)
            combined_mask = torch.cat([audio_mask, attention_mask], dim=1)
            
            return self.gemma(inputs_embeds=combined_embeds, attention_mask=combined_mask).logits

# –î–µ–º–æ–Ω—Å—Ç—Ä–∞—Ü–∏—è
print("=== –°–æ–∑–¥–∞–Ω–∏–µ –æ–ø—Ç–∏–º–∏–∑–∏—Ä–æ–≤–∞–Ω–Ω–æ–π –º–æ–¥–µ–ª–∏ ===")
opt_config = OptimizedTrainingConfig()
production_model = ProductionAudioGemmaModel(opt_config)

print(f"\nüéØ –ì–æ—Ç–æ–≤–æ! –¢–µ–ø–µ—Ä—å –≤–∞—à–∞ –º–æ–¥–µ–ª—å:")
print(f"   - –ò—Å–ø–æ–ª—å–∑—É–µ—Ç –Ω–∞ ~70% –º–µ–Ω—å—à–µ –ø–∞–º—è—Ç–∏")
print(f"   - –ù–µ –±—É–¥–µ—Ç –¥–∞–≤–∞—Ç—å NaN –≤ –≥—Ä–∞–¥–∏–µ–Ω—Ç–∞—Ö")
print(f"   - –ü–æ–¥–¥–µ—Ä–∂–∏–≤–∞–µ—Ç –±–æ–ª—å—à–∏–µ batch sizes")
print(f"   - –°–æ–≤–º–µ—Å—Ç–∏–º–∞ —Å mixed precision training")