In [None]:
!pip install jiwer
!pip install wandb

import wandb
wandb.login(key=WANDB_API_KEY)

Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting click>=8.1.8 (from jiwer)
  Downloading click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading click-8.1.8-py3-none-any.whl (98 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.2/98.2 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m39.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, click, jiwer
  Attempting uninstall: click
    Found existing installation: click 8.1.7
    Uninstalling click-8.1.7:
      Successfully uninstalled click-8.1.7
Successfully in

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrabiaedayilmaz[0m ([33mheisenbugtachi[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
import os
import torch
from torch.utils.data import Dataset as TorchDataset, DataLoader
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torchaudio
import librosa
import numpy as np
import re
from jiwer import wer
import wandb
import json

TRAIN_DIR = "/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train"
DEV_DIR = "/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/dev"
TEST_DIR = "/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/test"

chars_to_ignore = r'[,\?\.\!\-\;:"“%‘”�]'
chars_to_mapping = {"ğ": "g", "ı": "i", "ö": "o", "ü": "u", "ş": "s", "ç": "c"}

def normalize_text(text):
    if text is None or not isinstance(text, str):
        print(f"Warning: normalize_text received invalid input: {text}")
        return ""
    text = text.lower().strip()
    for src, dst in chars_to_mapping.items():
        text = text.replace(src, dst)
    text = re.sub(chars_to_ignore, '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

class AudioTextDataset(TorchDataset):
    def __init__(self, folder, max_samples=100):
        all_audio_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.wav')]
        all_text_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.txt')]
        all_audio_files.sort()
        all_text_files.sort()
        
        self.audio_files = all_audio_files[:max_samples]
        self.text_files = all_text_files[:max_samples]
        
        print(f"Folder: {folder}")
        print(f"Found {len(all_audio_files)} audio files, using {len(self.audio_files)}: {self.audio_files[:5]}")
        print(f"Found {len(all_text_files)} text files, using {len(self.text_files)}: {self.text_files[:5]}")
        assert len(self.audio_files) > 0, f"No .wav files in {folder}"
        assert len(self.text_files) > 0, f"No .txt files in {folder}"
        assert len(self.audio_files) == len(self.text_files), "Mismatch between audio and text files"
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        text_path = self.text_files[idx]
        try:
            speech_array, sampling_rate = torchaudio.load(audio_path)
            speech_array = librosa.resample(speech_array.squeeze().numpy(), orig_sr=sampling_rate, target_sr=16000)
        except Exception as e:
            print(f"Error loading audio {audio_path}: {e}")
            return {"speech": np.array([]), "sentence": "", "audio_file": audio_path, "text_file": text_path}
        
        try:
            with open(text_path, 'r', encoding='utf-8') as f:
                text = f.read().strip()
                transcription = normalize_text(text)
        except Exception as e:
            print(f"Error loading text {text_path}: {e}")
            return {"speech": np.array([]), "sentence": "", "audio_file": audio_path, "text_file": text_path}
        
        return {"speech": speech_array, "sentence": transcription, "audio_file": audio_path, "text_file": text_path}

# load datasets
print("Loading datasets...")
train_dataset = AudioTextDataset(TRAIN_DIR, max_samples=5)
dev_dataset = AudioTextDataset(DEV_DIR, max_samples=2)
test_dataset = AudioTextDataset(TEST_DIR, max_samples=2)

# init W&B run
wandb.init(project="turkish-asr-whisper", config={
    "model": "openai/whisper-small",
    "num_epochs": 10,
    "learning_rate": 1e-5,
    "batch_size": 8,
    "gradient_accumulation_steps": 2,
    "max_samples_train": 1000,
    "max_samples_dev": 200,
    "max_samples_test": 100
})

# init model and processor
print("Initializing Whisper model and processor...")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
processor = WhisperProcessor.from_pretrained("openai/whisper-small")

# set lang to Tr and task to transcribe
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="tr", task="transcribe")

# use gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def data_collator(batch):
    valid_items = [item for item in batch if item["speech"].size > 0 and isinstance(item["sentence"], str)]
    if not valid_items:
        print("Warning: No valid items in batch, skipping...")
        return None
    
    try:
        # audio into mel spectrograms
        inputs = processor(
            [item["speech"] for item in valid_items],
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=30 * 16000  # 30 seconds at 16kHz
        )
        
        # padding mel features to 3000 time steps
        input_features = inputs.input_features  # [batch_size, 80, time_steps]
        batch_size, n_mels, time_steps = input_features.shape
        if time_steps < 3000:
            padding = torch.zeros(batch_size, n_mels, 3000 - time_steps).to(input_features.device)
            input_features = torch.cat([input_features, padding], dim=2)
        elif time_steps > 3000:
            input_features = input_features[:, :, :3000]  # truncate if too long
        
        labels = processor.tokenizer(
            [item["sentence"] for item in valid_items],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=448
        ).input_ids
        
        # replace padding with -100 for loss calculation
        labels[labels == processor.tokenizer.pad_token_id] = -100
        
        return {
            "input_features": input_features.to(device),
            "labels": labels.to(device),
            "audio_files": [item["audio_file"] for item in valid_items],
            "text_files": [item["text_file"] for item in valid_items]
        }
    except Exception as e:
        print(f"Error in data_collator: {e}")
        print(f"Sentences passed to tokenizer: {[item['sentence'] for item in valid_items]}")
        return None

# dataLoaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
eval_loader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

# train params
num_epochs = 10
learning_rate = 1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
gradient_accumulation_steps = 2

# eval
def evaluate(loader, model, processor, device, dataset_name="eval"):
    model.eval()
    total_wer = 0
    num_batches = 0
    with torch.no_grad():
        for batch in loader:
            if batch is None:
                print(f"Skipping {dataset_name} batch due to invalid items")
                continue
            input_features = batch["input_features"].to(device)
            outputs = model.generate(
                input_features,
                forced_decoder_ids=processor.get_decoder_prompt_ids(language="tr", task="transcribe"),
                max_length=448
            )
            pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
            label_str = processor.batch_decode(batch["labels"], skip_special_tokens=True)
            
            print(f"{dataset_name} Batch - Predicted: {pred_str}")
            print(f"{dataset_name} Batch - Ground Truth: {label_str}")
            
            batch_wer = wer(label_str, pred_str)
            total_wer += batch_wer
            num_batches += 1
    avg_wer = total_wer / num_batches if num_batches > 0 else float('inf')
    model.train()
    return avg_wer

# train loop
print("Starting fine-tuning...")
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    steps = 0
    optimizer.zero_grad()
    
    for i, batch in enumerate(train_loader):
        if batch is None:
            print(f"Skipping batch {i} due to invalid items")
            continue
        
        input_features = batch["input_features"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(input_features, labels=labels)
        loss = outputs.loss / gradient_accumulation_steps
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            steps += 1
            wandb.log({"train/loss": loss.item() * gradient_accumulation_steps, "train/step": steps + epoch * len(train_loader)})
    
    avg_loss = total_loss / (len(train_loader) / gradient_accumulation_steps)
    print(f"Epoch {epoch+1}/{num_epochs} completed, Average Loss: {avg_loss:.4f}")
    
    # eval dev set
    dev_wer = evaluate(eval_loader, model, processor, device, "dev")
    print(f"Epoch {epoch+1}/{num_epochs}, Dev WER: {dev_wer:.4f}")
    
    # eval test set
    test_wer = evaluate(test_loader, model, processor, device, "test")
    print(f"Epoch {epoch+1}/{num_epochs}, Test WER: {test_wer:.4f}")
    
    # log metrics to W&B
    wandb.log({
        "train/avg_loss": avg_loss,
        "dev/avg_wer": dev_wer,
        "test/avg_wer": test_wer,
        "epoch": epoch + 1
    })

# final eval on dev set
print("Final evaluation on dev set...")
dev_wer = evaluate(eval_loader, model, processor, device, "dev")
print(f"Final Average WER on dev set: {dev_wer:.4f}")
wandb.log({"dev/final_avg_wer": dev_wer})

# final eval on test set
print("Final evaluation on test set...")
model.eval()
total_wer = 0
num_batches = 0
test_results = []
test_table = wandb.Table(columns=["audio_file", "text_file", "ground_truth", "prediction", "wer"])

with torch.no_grad():
    for batch in test_loader:
        if batch is None:
            print("Skipping test batch due to invalid items")
            continue
        
        input_features = batch["input_features"].to(device)
        audio_files = batch["audio_files"]
        text_files = batch["text_files"]
        
        outputs = model.generate(
            input_features,
            forced_decoder_ids=processor.get_decoder_prompt_ids(language="tr", task="transcribe"),
            max_length=448
        )
        pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
        label_str = processor.batch_decode(batch["labels"], skip_special_tokens=True)
        
        print(f"Test Batch - Predicted: {pred_str}")
        print(f"Test Batch - Ground Truth: {label_str}")
        
        batch_wer = wer(label_str, pred_str)
        total_wer += batch_wer
        num_batches += 1
        
        for i in range(len(pred_str)):
            sample_wer = batch_wer if len(pred_str) == 1 else wer([label_str[i]], [pred_str[i]])
            test_results.append({
                "audio_file": audio_files[i],
                "text_file": text_files[i],
                "ground_truth": label_str[i],
                "prediction": pred_str[i],
                "wer": sample_wer
            })
            test_table.add_data(
                audio_files[i],
                text_files[i],
                label_str[i],
                pred_str[i],
                sample_wer
            )
        
        wandb.log({"test/batch_wer": batch_wer})

avg_wer = total_wer / num_batches if num_batches > 0 else float('inf')
print(f"Final Average WER on test set: {avg_wer:.4f}")
wandb.log({"test/final_avg_wer": avg_wer})

# log test results table to W&B
wandb.log({"test/predictions": test_table})

# save test results
results_dict = {
    "test_results": test_results,
    "average_wer": avg_wer,
    "num_samples": len(test_results),
    "num_batches": num_batches
}
with open("test_results.json", "w", encoding='utf-8') as f:
    json.dump(results_dict, f, ensure_ascii=False, indent=4)
print("Test results saved to 'test_results.json'")

# Save the fine-tuned model and processor
print("Saving fine-tuned model and processor...")
model.save_pretrained("./model/whisper")
processor.save_pretrained("./model/whisper")

# finish W&B run
wandb.finish()
print("Fine-tuning complete!")

Loading datasets...
Folder: /kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train
Found 27 audio files, using 5: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00001.wav', '/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00002.wav', '/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00003.wav', '/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00004.wav', '/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00005.wav']
Found 27 text files, using 5: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00001.txt', '/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00002.txt', '/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00003.txt', '/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/0000

[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250331_115842-9yafpdl0[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mfiery-forest-6[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/heisenbugtachi/turkish-asr-whisper[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/heisenbugtachi/turkish-asr-whisper/runs/9yafpdl0[0m


Initializing Whisper model and processor...


config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.87k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

Starting fine-tuning...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch 1/10 completed, Average Loss: 4.1768


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


dev Batch - Predicted: [' Hastada, Mehmet Yılmaz, yaş 62 cinsiyet erkek. Ameliyat tarihi 15 Mart 2022. Ameliyat notu.', ' Hasta adı Murat Yılmaz. Yaşı 45, cinsiyeti erkek. Ameliyat Tarihi 15 Ocak 2022. Ameliyat notu, Murat Yılmaz isimli 45 yaşında erkek hasta, 15 Ocak 2022 tarihinde, Saadizinde oluşan kronik ağrı ve hareket kısıtlılığı şikayetleri üzerine klinimize başvurmuştur. Muayene ve radyolojik değerlendirmeler sonucunda hasta, Saadizinde ilerlemiş osteoartrit teşhisiyle, ortopedi klinimize yatırıldı.']
dev Batch - Ground Truth: ["hasta adi mehmet yilmaz yas 62 cinsiyet erkek ameliyat tarihi 15 mart 2022 ameliyat notu hasta mehmet yilmaz akut koroner sendrom tanisiyla 15 mart 2022 tarihinde kardiyoloji klinigimize basvurmustur ekg bulgusu ve yapilan kan testleri sonucunda akut miyokard enfarktusu tanisi konulmustur hasta acil olarak koroner anjiyografiye alindi sag koroner arterde 90 darlik tespit edilerek primer perkutan koroner girisim (pci) gerceklestirildi basarili bir sekild

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:        dev/avg_wer ▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  dev/final_avg_wer ▁
[34m[1mwandb[0m:              epoch ▁▂▃▃▄▅▆▆▇█
[34m[1mwandb[0m:       test/avg_wer ▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:     test/batch_wer ▁
[34m[1mwandb[0m: test/final_avg_wer ▁
[34m[1mwandb[0m:     train/avg_loss ▆▃▆▃▁▁█▃▆▆
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:        dev/avg_wer 0.94044
[34m[1mwandb[0m:  dev/final_avg_wer 0.94044
[34m[1mwandb[0m:              epoch 10
[34m[1mwandb[0m:       test/avg_wer 0.9172
[34m[1mwandb[0m:     test/batch_wer 0.9172
[34m[1mwandb[0m: test/final_avg_wer 0.9172
[34m[1mwandb[0m:     train/avg_loss 4.17681
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mfiery-forest-6[0m at: [34m[4mhttps://wandb.ai/heisenbugtachi/turkish-asr-whisper/runs/9y

Fine-tuning complete!
