In [None]:
%pip install -r requirements.txt -q

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import json
import os
import numpy as np
import soundfile as sf
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Wav2Vec2FeatureExtractor, Wav2Vec2Model
from huggingface_hub import login
import jiwer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from sklearn.model_selection import train_test_split

In [None]:
def to_bfloat16(tensor):
    print(f"Converting tensor from {tensor.dtype} to bfloat16, shape: {tensor.shape}")
    return tensor.to(torch.bfloat16)

def to_fp32(tensor):
    print(f"Converting tensor from {tensor.dtype} to fp32, shape: {tensor.shape}")
    return tensor.to(torch.float32)

def sync_model_dtype(model, target_dtype):
    print(f"Syncing model to {target_dtype}")
    for param in model.parameters():
        param.data = param.data.to(target_dtype)
    return model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "google/gemma-3-4b-pt"
audio_model_name = "facebook/wav2vec2-base"
batch_size = 4
num_epochs = 3
learning_rate = 1e-4
input_dim = 768
output_dim = 3072

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

hf_token = os.getenv('HF_TOKEN')
if hf_token:
    login(token=hf_token)

print(f"Using device: {device}")
print(f"Config: batch_size={batch_size}, epochs={num_epochs}, lr={learning_rate}")
print(f"Audio model: {audio_model_name}, LLM: {model_id}")
print(f"Projector: {input_dim} -> {output_dim}")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

gemma_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    device_map="auto",
    token=hf_token
)
gemma_model.eval()
for param in gemma_model.parameters():
    param.requires_grad = False

feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(audio_model_name)
wav2vec2 = Wav2Vec2Model.from_pretrained(audio_model_name)
wav2vec2 = wav2vec2.to(torch.bfloat16).to(device)
wav2vec2.eval()
for param in wav2vec2.parameters():
    param.requires_grad = False

print(f"Gemma parameters: {sum(p.numel() for p in gemma_model.parameters()):,}")
print(f"Wav2vec2 parameters: {sum(p.numel() for p in wav2vec2.parameters()):,}")

In [None]:
class AudioProjector(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
            nn.LayerNorm(output_dim)
        )
    
    def forward(self, x):
        original_dtype = x.dtype
        x_fp32 = x.to(torch.float32)
        if next(self.proj.parameters()).dtype != torch.float32:
            self.proj = self.proj.float()
        output_fp32 = self.proj(x_fp32)
        return output_fp32.to(original_dtype)

print("–ö–ª–∞—Å—Å AudioProjector –æ–ø—Ä–µ–¥–µ–ª–µ–Ω —Å –ø—Ä–∞–≤–∏–ª—å–Ω–æ–π –æ–±—Ä–∞–±–æ—Ç–∫–æ–π —Ç–∏–ø–æ–≤ –¥–∞–Ω–Ω—ã—Ö")
print("–í—Ö–æ–¥–Ω—ã–µ –¥–∞–Ω–Ω—ã–µ: –ª—é–±–æ–π —Ç–∏–ø -> –í—ã—á–∏—Å–ª–µ–Ω–∏—è: FP32 -> –í—ã—Ö–æ–¥: –∏—Å—Ö–æ–¥–Ω—ã–π —Ç–∏–ø")

In [None]:
projector = AudioProjector(input_dim, output_dim).to(device).float()
optimizer = optim.Adam(projector.parameters(), lr=learning_rate)
scaler = GradScaler()
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

prefix = "–¢—Ä–∞–Ω—Å–∫—Ä–∏–ø—Ü–∏—è –∞—É–¥–∏–æ: "
prefix_ids = tokenizer(prefix, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
    prefix_embeds = gemma_model.get_input_embeddings()(prefix_ids).to(dtype=torch.bfloat16)

In [None]:
class AudioTextDataset(Dataset):
    def __init__(self, data, tokenizer, feature_extractor):
        self.data = data
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        audio_path = item["audio_path"]
        speaker_text = item["speaker_text"]
        waveform, sr = torchaudio.load(audio_path)
        if sr != self.feature_extractor.sampling_rate:
            waveform = torchaudio.functional.resample(waveform, sr, self.feature_extractor.sampling_rate)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        inputs = self.feature_extractor(
            waveform.squeeze().numpy(),
            sampling_rate=self.feature_extractor.sampling_rate,
            return_tensors="pt"
        )
        tokens = self.tokenizer(
            speaker_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        return {
            "input_values": inputs.input_values.squeeze(0),
            "input_ids": tokens.input_ids.squeeze(0),
            "attention_mask": tokens.attention_mask.squeeze(0)
        }

def evaluate_with_metrics(model, projector, wav2vec2, dataloader, tokenizer, prefix_embeds, device):
    model.eval()
    projector.eval()
    wav2vec2.eval()
    total_loss, total_wer, total_bleu = 0.0, 0.0, 0.0
    total_rouge_1, total_rouge_2, total_rouge_l = 0.0, 0.0, 0.0
    count = 0
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
    smooth = SmoothingFunction().method1
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_values = batch["input_values"].to(device, dtype=torch.bfloat16)
            input_ids = batch["input_ids"].to(device)
            with autocast(dtype=torch.bfloat16):
                audio_embeds = wav2vec2(input_values).last_hidden_state.mean(dim=1)
                projected_audio = projector(audio_embeds)
                batch_prefix_embeds = prefix_embeds.expand(projected_audio.size(0), -1, -1)
                full_embeds = torch.cat([batch_prefix_embeds, projected_audio.unsqueeze(1)], dim=1)
                outputs = model(inputs_embeds=full_embeds, labels=input_ids)
                loss = outputs.loss
                logits = outputs.logits
                text_start_idx = full_embeds.size(1)
                text_logits = logits[:, text_start_idx-1:-1, :]
            total_loss += loss.item()
            pred_ids = torch.argmax(text_logits, dim=-1)
            for i in range(pred_ids.size(0)):
                pred_text = tokenizer.decode(pred_ids[i], skip_special_tokens=True).strip()
                ref_text = tokenizer.decode(input_ids[i][input_ids[i] != -100], skip_special_tokens=True).strip()
                if ref_text and pred_text:
                    total_wer += jiwer.wer(ref_text, pred_text)
                    total_bleu += sentence_bleu([ref_text.split()], pred_text.split(), smoothing_function=smooth)
                    rouge_scores = rouge_scorer_obj.score(ref_text, pred_text)
                    total_rouge_1 += rouge_scores['rouge1'].fmeasure
                    total_rouge_2 += rouge_scores['rouge2'].fmeasure
                    total_rouge_l += rouge_scores['rougeL'].fmeasure
                    count += 1
    avg_loss = total_loss / len(dataloader)
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    avg_wer = total_wer / count if count > 0 else 0.0
    avg_bleu = total_bleu / count if count > 0 else 0.0
    avg_rouge_1 = total_rouge_1 / count if count > 0 else 0.0
    avg_rouge_2 = total_rouge_2 / count if count > 0 else 0.0
    avg_rouge_l = total_rouge_l / count if count > 0 else 0.0
    return {
        'loss': avg_loss, 'perplexity': perplexity, 'wer': avg_wer, 'bleu': avg_bleu,
        'rouge_1': avg_rouge_1, 'rouge_2': avg_rouge_2, 'rouge_l': avg_rouge_l
    }

In [None]:
jsonl_path = "transcripts.jsonl"
with open(jsonl_path, "r", encoding="utf-8") as f:
    all_data = [json.loads(line) for line in f]

train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)
train_dataset = AudioTextDataset(train_data, tokenizer, feature_extractor)
val_dataset = AudioTextDataset(val_data, tokenizer, feature_extractor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"Data loaded: {len(train_data)} train, {len(val_data)} val samples.")

for epoch in range(num_epochs):
    print(f"\n--- EPOCH {epoch+1}/{num_epochs} ---")
    projector.train()
    wav2vec2.eval() # Wav2Vec2 –Ω–µ –æ–±—É—á–∞–µ—Ç—Å—è
    gemma_model.eval() # Gemma –Ω–µ –æ–±—É—á–∞–µ—Ç—Å—è
    
    epoch_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        input_values = batch["input_values"].to(device, dtype=torch.bfloat16)
        input_ids = batch["input_ids"].to(device)
        optimizer.zero_grad()
        
        with autocast(dtype=torch.bfloat16):
            audio_embeds = wav2vec2(input_values).last_hidden_state.mean(dim=1)
            projected_audio = projector(audio_embeds)
            batch_prefix_embeds = prefix_embeds.expand(projected_audio.size(0), -1, -1)
            full_embeds = torch.cat([batch_prefix_embeds, projected_audio.unsqueeze(1)], dim=1)
            outputs = gemma_model(inputs_embeds=full_embeds, labels=input_ids)
            loss = outputs.loss
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
    
    avg_train_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")
    
    val_metrics = evaluate_with_metrics(gemma_model, projector, wav2vec2, val_loader, tokenizer, prefix_embeds, device)
    print(f"Validation - Loss: {val_metrics['loss']:.4f}, Perplexity: {val_metrics['perplexity']:.4f}, WER: {val_metrics['wer']:.4f}")

    checkpoint_path = f"checkpoint_epoch_{epoch+1}.pt"
    torch.save({
        'epoch': epoch+1,
        'projector_state_dict': projector.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': avg_train_loss,
        'val_metrics': val_metrics
    }, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")

print("\nüéâ Training completed!")