In [1]:
!pip install jiwer

Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting click>=8.1.8 (from jiwer)
  Downloading click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-3.1.0-py3-none-any.whl (22 kB)
Downloading click-8.1.8-py3-none-any.whl (98 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.2/98.2 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.12.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m45.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, click, jiwer
  Attempting uninstall: click
    Found existing installation: click 8.1.7
    Uninstalling click-8.1.7:
      Successfully uninstalled click-8.1.7
Successfully in

In [2]:
!pip install wandb



In [None]:
import wandb
wandb.login(key=WANDB_API_KEY)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrabiaedayilmaz[0m ([33mheisenbugtachi[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
import os
import torch
from torch.utils.data import Dataset as TorchDataset, DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
import torchaudio
import librosa
import numpy as np
import re
from jiwer import wer
import json
import wandb

TRAIN_DIR = "/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train"
DEV_DIR = "/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/dev"
TEST_DIR = "/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/test"

chars_to_ignore = r'[,\?\.\!\-\;:"“%‘”�]'
chars_to_mapping = {"ğ": "g", "ı": "i", "ö": "o", "ü": "u", "ş": "s", "ç": "c"}

def normalize_text(text):
    if text is None or not isinstance(text, str):
        print(f"Warning: normalize_text received invalid input: {text}")
        return ""
    text = text.lower().strip()
    for src, dst in chars_to_mapping.items():
        text = text.replace(src, dst)
    text = re.sub(chars_to_ignore, '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

class AudioTextDataset(TorchDataset):
    def __init__(self, folder, max_samples=100):
        all_audio_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.wav')]
        all_text_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.txt')]
        all_audio_files.sort()
        all_text_files.sort()
        
        self.audio_files = all_audio_files[:max_samples]
        self.text_files = all_text_files[:max_samples]
        
        print(f"Folder: {folder}")
        print(f"Found {len(all_audio_files)} audio files, using {len(self.audio_files)}: {self.audio_files[:5]}")
        print(f"Found {len(all_text_files)} text files, using {len(self.text_files)}: {self.text_files[:5]}")
        assert len(self.audio_files) > 0, f"No .wav files in {folder}"
        assert len(self.text_files) > 0, f"No .txt files in {folder}"
        assert len(self.audio_files) == len(self.text_files), "Mismatch between audio and text files"
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        text_path = self.text_files[idx]
        try:
            speech_array, sampling_rate = torchaudio.load(audio_path)
            speech_array = librosa.resample(speech_array.squeeze().numpy(), orig_sr=sampling_rate, target_sr=16000)
        except Exception as e:
            print(f"Error loading audio {audio_path}: {e}")
            return {"speech": np.array([]), "sentence": "", "audio_file": audio_path, "text_file": text_path}
        
        try:
            with open(text_path, 'r', encoding='utf-8') as f:
                text = f.read().strip()
                transcription = normalize_text(text)
        except Exception as e:
            print(f"Error loading text {text_path}: {e}")
            return {"speech": np.array([]), "sentence": "", "audio_file": audio_path, "text_file": text_path}
        
        return {"speech": speech_array, "sentence": transcription, "audio_file": audio_path, "text_file": text_path}

# load datasets
print("Loading datasets...")
train_dataset = AudioTextDataset(TRAIN_DIR, max_samples=1)
dev_dataset = AudioTextDataset(DEV_DIR, max_samples=1)
test_dataset = AudioTextDataset(TEST_DIR, max_samples=1)

# create vocab
print("Creating vocabulary...")
vocab = set()
for i in range(len(train_dataset)):
    item = train_dataset[i]
    if "sentence" in item and item["sentence"]:
        vocab.update(item["sentence"])
vocab_dict = {v: k for k, v in enumerate(sorted(vocab))}
vocab_dict["|"] = len(vocab_dict)
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
with open("vocab.json", "w") as vocab_file:
    json.dump(vocab_dict, vocab_file)

# init W&B run
wandb.init(project="turkish-asr", config={
    "model": "facebook/wav2vec2-large-xlsr-53",
    "num_epochs": 50,  # Increased for better training
    "learning_rate": 3e-4,  # Increased back for faster adaptation
    "batch_size": 8,
    "gradient_accumulation_steps": 2,
    "max_samples_train": 2000,
    "max_samples_dev": 500,
    "max_samples_test": 500
})

# init model and processor
print("Initializing Wav2Vec2 model and processor...")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
tokenizer = Wav2Vec2CTCTokenizer(
    "vocab.json",
    unk_token="[UNK]",
    pad_token="[PAD]",
    word_delimiter_token="|"
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# update config and reinit lm_head
model.config.vocab_size = len(vocab_dict)
model.lm_head = torch.nn.Linear(model.config.hidden_size, len(vocab_dict))
torch.nn.init.xavier_uniform_(model.lm_head.weight)

# freeze lower layers to focus training on lm_head and upper layers
for param in model.wav2vec2.feature_extractor.parameters():
    param.requires_grad = False  # Freeze feature extractor
for i, layer in enumerate(model.wav2vec2.encoder.layers):
    if i < 20:  # Freeze first 20 of 24 layers
        for param in layer.parameters():
            param.requires_grad = False

# use gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def data_collator(batch):
    valid_items = [item for item in batch if item["speech"].size > 0 and isinstance(item["sentence"], str)]
    if not valid_items:
        print("Warning: No valid items in batch, skipping...")
        return None
    
    try:
        input_values = processor(
            [item["speech"] for item in valid_items],
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        ).input_values
        
        labels = processor.tokenizer(
            [item["sentence"] for item in valid_items],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).input_ids
        
        print(f"Labels shape: {labels.shape}, dtype: {labels.dtype}")
        if not torch.is_tensor(labels) or labels.dtype != torch.int64:
            raise ValueError("Labels are not a proper tensor of integers")
        
        return {
            "input_values": input_values,
            "labels": labels,
            "audio_files": [item["audio_file"] for item in valid_items],
            "text_files": [item["text_file"] for item in valid_items]
        }
    except Exception as e:
        print(f"Error in data_collator: {e}")
        print(f"Sentences passed to tokenizer: {[item['sentence'] for item in valid_items]}")
        return None

# ataLoaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
eval_loader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

# train params
num_epochs = 50  
learning_rate = 3e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
gradient_accumulation_steps = 2

# eval func
def evaluate(loader, model, processor, device, dataset_name="eval"):
    model.eval()
    total_wer = 0
    num_batches = 0
    with torch.no_grad():
        for batch in loader:
            if batch is None:
                print(f"Skipping {dataset_name} batch due to invalid items")
                continue
            input_values = batch["input_values"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(input_values)
            pred_ids = torch.argmax(outputs.logits, dim=-1)
            pred_str = processor.batch_decode(pred_ids)
            label_str = processor.batch_decode(labels, group_tokens=False)
            
            print(f"{dataset_name} Batch - Raw pred_ids: {pred_ids.tolist()[:2]}")
            print(f"{dataset_name} Batch - Logits max/min: {outputs.logits.max().item():.2f}/{outputs.logits.min().item():.2f}")
            print(f"{dataset_name} Batch - Predicted: {pred_str}")
            print(f"{dataset_name} Batch - Ground Truth: {label_str}")
            
            batch_wer = wer(label_str, pred_str)
            total_wer += batch_wer
            num_batches += 1
    avg_wer = total_wer / num_batches if num_batches > 0 else float('inf')
    model.train()
    return avg_wer

# training loop
print("Starting fine-tuning...")
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    steps = 0
    optimizer.zero_grad()
    
    for i, batch in enumerate(train_loader):
        if batch is None:
            print(f"Skipping batch {i} due to invalid items")
            continue
        
        input_values = batch["input_values"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(input_values, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            steps += 1
            wandb.log({"train/loss": loss.item(), "train/step": steps + epoch * len(train_loader)})
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} completed, Average Loss: {avg_loss:.4f}")
    
    # Evaluate on dev set after each epoch
    dev_wer = evaluate(eval_loader, model, processor, device, "dev")
    print(f"Epoch {epoch+1}/{num_epochs}, Dev WER: {dev_wer:.4f}")
    
    # Evaluate on test set after each epoch
    test_wer = evaluate(test_loader, model, processor, device, "test")
    print(f"Epoch {epoch+1}/{num_epochs}, Test WER: {test_wer:.4f}")
    
    # Log metrics to W&B
    wandb.log({
        "train/avg_loss": avg_loss,
        "dev/avg_wer": dev_wer,
        "test/avg_wer": test_wer,
        "epoch": epoch + 1
    })

# Final evaluation on dev set
print("Final evaluation on dev set...")
dev_wer = evaluate(eval_loader, model, processor, device, "dev")
print(f"Final Average WER on dev set: {dev_wer:.4f}")
wandb.log({"dev/final_avg_wer": dev_wer})

# Final evaluation on test set with results saving and W&B logging
print("Final evaluation on test set...")
model.eval()
total_wer = 0
num_batches = 0
test_results = []
test_table = wandb.Table(columns=["audio_file", "text_file", "ground_truth", "prediction", "wer", "pred_ids"])

with torch.no_grad():
    for batch in test_loader:
        if batch is None:
            print("Skipping test batch due to invalid items")
            continue
        
        input_values = batch["input_values"].to(device)
        labels = batch["labels"].to(device)
        audio_files = batch["audio_files"]
        text_files = batch["text_files"]
        
        outputs = model(input_values)
        pred_ids = torch.argmax(outputs.logits, dim=-1)
        pred_str = processor.batch_decode(pred_ids)
        label_str = processor.batch_decode(labels, group_tokens=False)
        
        # Debugging
        print(f"Test Batch - Raw pred_ids: {pred_ids.tolist()[:2]}")
        print(f"Test Batch - Logits max/min: {outputs.logits.max().item():.2f}/{outputs.logits.min().item():.2f}")
        print(f"Test Batch - Predicted: {pred_str}")
        print(f"Test Batch - Ground Truth: {label_str}")
        
        batch_wer = wer(label_str, pred_str)
        total_wer += batch_wer
        num_batches += 1
        
        for i in range(len(pred_str)):
            sample_wer = batch_wer if len(pred_str) == 1 else wer([label_str[i]], [pred_str[i]])
            test_results.append({
                "audio_file": audio_files[i],
                "text_file": text_files[i],
                "ground_truth": label_str[i],
                "prediction": pred_str[i],
                "wer": sample_wer
            })
            test_table.add_data(
                audio_files[i],
                text_files[i],
                label_str[i],
                pred_str[i],
                sample_wer,
                str(pred_ids[i].tolist())
            )
        
        wandb.log({"test/batch_wer": batch_wer})

avg_wer = total_wer / num_batches if num_batches > 0 else float('inf')
print(f"Final Average WER on test set: {avg_wer:.4f}")
wandb.log({"test/final_avg_wer": avg_wer})

# Log test results table to W&B
wandb.log({"test/predictions": test_table})

# Save test results to JSON
results_dict = {
    "test_results": test_results,
    "average_wer": avg_wer,
    "num_samples": len(test_results),
    "num_batches": num_batches
}
with open("test_results.json", "w", encoding='utf-8') as f:
    json.dump(results_dict, f, ensure_ascii=False, indent=4)
print("Test results saved to 'test_results.json'")

# Save the fine-tuned model and processor
print("Saving fine-tuned model and processor...")
model.save_pretrained("./model/wav2vec")
processor.save_pretrained("./model/wav2vec")

# Finish W&B run
wandb.finish()
print("Fine-tuning complete!")

Loading datasets...
Folder: /kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train
Found 27 audio files, using 1: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00001.wav']
Found 27 text files, using 1: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/train/00001.txt']
Folder: /kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/dev
Found 27 audio files, using 1: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/dev/00066.wav']
Found 27 text files, using 1: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/dev/00066.txt']
Folder: /kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/test
Found 27 audio files, using 1: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_turkish/test/00131.wav']
Found 27 text files, using 1: ['/kaggle/input/medical-speech-turkish/dataset/artificial_generated_tur

[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250331_123709-48u6hza5[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mpolar-sponge-19[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/heisenbugtachi/turkish-asr[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/heisenbugtachi/turkish-asr/runs/48u6hza5[0m


Initializing Wav2Vec2 model and processor...


config.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

Starting fine-tuning...
Labels shape: torch.Size([1, 512]), dtype: torch.int64
Epoch 1/50 completed, Average Loss: 12726.8936
Labels shape: torch.Size([1, 512]), dtype: torch.int64
dev Batch - Raw pred_ids: [[11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 8, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 32, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 

[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:        dev/avg_wer ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:  dev/final_avg_wer ▁
[34m[1mwandb[0m:              epoch ▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
[34m[1mwandb[0m:       test/avg_wer ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:     test/batch_wer ▁
[34m[1mwandb[0m: test/final_avg_wer ▁
[34m[1mwandb[0m:     train/avg_loss ▆█▁▅▆▅▅▅▄▅▅▅▅█▅▅▅▅▅▅▄▃▃▅▂▅▅▇▅▅▅▅▅▃▆▅▅▅▅▅
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:        dev/avg_wer 1
[34m[1mwandb[0m:  dev/final_avg_wer 1
[34m[1mwandb[0m:              epoch 50
[34m[1mwandb[0m:       test/avg_wer 1
[34m[1mwandb[0m:     test/batch_wer 1
[34m[1mwandb[0m: test/final_avg_wer 1
[34m[1mwandb[0m:     train/avg_loss 12622.25098
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run

Fine-tuning complete!
