In [None]:
# load packages
!pip install torch torchaudio transformers librosa jiwer nltk rouge-score bert-score

In [None]:
# import packages
import torch
import torchaudio
import librosa
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from torch.utils.data import DataLoader, Dataset
import json
from jiwer import wer
import os
import difflib
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score as bert_score

# path to finetuned model dir
MODEL_PATH = "/kaggle/input/med-asr-whisper-finetune/model/whisper" #"/kaggle/input/med-wav2vec-asr-finetune/model/wav2vec"

# load the fine-tuned model and processor
try:
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH)
    processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
except Exception as e:
    print(f"Error loading model or processor from {MODEL_PATH}: {e}")
    raise

# use gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# set eval model for inference
model.eval()  

def transcribe_audio(audio_path: str) -> str:
    """Transcribes audio using the fine-tuned Wav2Vec2 model."""
    try:
        # load audio file
        speech_array, sampling_rate = torchaudio.load(audio_path)
        # remove channel dimension
        speech_array = speech_array.squeeze().numpy()  
        
        # resample to 16kHz as model expects
        if sampling_rate != 16000:
            speech_array = librosa.resample(speech_array, orig_sr=sampling_rate, target_sr=16000)
        
        # preprocess audio with the processor
        inputs = processor(speech_array, sampling_rate=16000, return_tensors="pt", padding=True)
        input_values = inputs.input_values.to(device)
        
        # inference
        with torch.no_grad():
            logits = model(input_values).logits
        
        # decode predicted ids to text
        pred_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(pred_ids)[0]
        
        return transcription
    
    except Exception as e:
        print(f"Error transcribing {audio_path}: {e}")
        return ""

def transcribe_and_refine(audio_path: str) -> str:
    """Transcribes and applies basic refinement."""
    transcription = transcribe_audio(audio_path)
    if transcription:
        # remove extra spaces
        transcription = " ".join(transcription.split())
    return transcription

def clean_text(text: str) -> str:
    """Cleans the given text for better comparison."""
    return (
        text.lower()
            .replace("-", " ")
            .replace("\u2022", "")
            .replace("*", "")
            .replace("\n", " ")
            .replace(",", " ")
            .replace(".", " ")
            .replace(";", " ")
            .replace(":", " ")
            .strip()
    )

def evaluate_test_set(test_loader, model, processor, device):
    """Evaluates the test set using the fine-tuned model and computes multiple metrics."""
    print("Final evaluation on test set...")
    model.eval()
    total_wer = 0
    num_batches = 0
    test_results = []

    with torch.no_grad():
        for batch in test_loader:
            if batch is None:
                print("Skipping test batch due to invalid items")
                continue
            
            input_values = batch["input_values"].to(device)
            labels = batch["labels"].to(device)
            audio_files = batch["audio_files"]
            text_files = batch["text_files"]
            
            # inference
            outputs = model(input_values)
            pred_ids = torch.argmax(outputs.logits, dim=-1)
            pred_str = processor.batch_decode(pred_ids)
            label_str = processor.batch_decode(labels, group_tokens=False)
            
            # calculate WER for the batch
            batch_wer = wer(label_str, pred_str)
            total_wer += batch_wer
            num_batches += 1
            
            # other metrics for each sample
            for i in range(len(pred_str)):
                gt_clean = clean_text(label_str[i])
                pred_clean = clean_text(pred_str[i])
                
                # WER
                sample_wer = batch_wer if len(pred_str) == 1 else wer([label_str[i]], [pred_str[i]])
                
                # SequenceMatcher
                sequence_score = difflib.SequenceMatcher(None, gt_clean, pred_clean).ratio()
                
                # BLEU
                smoothie = SmoothingFunction().method4
                bleu = sentence_bleu([gt_clean.split()], pred_clean.split(), smoothing_function=smoothie)
                
                # ROUGE-L
                rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
                rouge_score_l = rouge.score(gt_clean, pred_clean)['rougeL'].fmeasure
                
                # BERTScore
                P, R, F1 = bert_score([pred_clean], [gt_clean], lang="en", rescale_with_baseline=True)
                bert_f1 = F1[0].item()
                
                # results
                test_results.append({
                    "audio_file": audio_files[i],
                    "text_file": text_files[i],
                    "ground_truth": label_str[i],
                    "prediction": pred_str[i],
                    "wer": round(sample_wer, 4),
                    "similarity_score": round(sequence_score, 4),
                    "bleu": round(bleu, 4),
                    "rougeL": round(rouge_score_l, 4),
                    "bert_score_f1": round(bert_f1, 4)
                })
            
            print(f"Test Batch - Predicted: {pred_str}")
            print(f"Test Batch - Ground Truth: {label_str}")
            print(f"Test Batch - WER: {batch_wer:.4f}")

    # average WER
    avg_wer = total_wer / num_batches if num_batches > 0 else float('inf')
    print(f"Final Average WER on test set: {avg_wer:.4f}")

    # save results
    results_dict = {
        "test_results": test_results,
        "average_wer": round(avg_wer, 4),
        "num_samples": len(test_results),
        "num_batches": num_batches
    }
    with open("test_results.json", "w", encoding='utf-8') as f:
        json.dump(results_dict, f, ensure_ascii=False, indent=4)
    print("Test results saved to 'test_results.json'")

# dataset and loader
class AudioTextDataset(Dataset):
    def __init__(self, audio_files, text_files):
        self.audio_files = audio_files
        self.text_files = text_files
        assert len(self.audio_files) == len(self.text_files), "Mismatch between audio and text files"
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        text_path = self.text_files[idx]
        try:
            speech_array, sampling_rate = torchaudio.load(audio_path)
            speech_array = librosa.resample(speech_array.squeeze().numpy(), orig_sr=sampling_rate, target_sr=16000)
            with open(text_path, 'r', encoding='utf-8') as f:
                text = f.read().strip()
            return {
                "speech": speech_array,
                "text": text,
                "audio_files": audio_path,
                "text_files": text_path
            }
        except Exception as e:
            print(f"Error loading {audio_path} or {text_path}: {e}")
            return None

def collate_fn(batch):
    # filter None items
    batch = [item for item in batch if item is not None] 
    if not batch:
        return None
    
    # process audio batch with padding
    speech_arrays = [item["speech"] for item in batch]
    inputs = processor(speech_arrays, sampling_rate=16000, return_tensors="pt", padding=True)
    input_values = inputs.input_values
    
    # process text batch with padding
    texts = [item["text"] for item in batch]
    labels = processor.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).input_ids
    
    return {
        "input_values": input_values,
        "labels": labels,
        "audio_files": [item["audio_files"] for item in batch],
        "text_files": [item["text_files"] for item in batch]
    }

if __name__ == "__main__":
    # test data paths
    audio_dir = "/kaggle/input/med-test-ebnchmark/test/test_voice_data"
    text_dir = "/kaggle/input/med-test-ebnchmark/test/test_text_data"
    test_audio_files = [os.path.join(audio_dir, f) for f in sorted(os.listdir(audio_dir)) if f.endswith('.mp3') or f.endswith('.wav')]
    test_text_files = [os.path.join(text_dir, f) for f in sorted(os.listdir(text_dir)) if f.endswith('.txt')]

    # dataset and loader
    test_dataset = AudioTextDataset(test_audio_files, test_text_files)
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

    # evaluation
    evaluate_test_set(test_loader, model, processor, device)