In [None]:
import torch, os, re
from huggingface_hub import login
from datasets import load_dataset, Audio
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from english import EnglishTextNormalizer
from collections import defaultdict

# Set device to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Log into HuggingFace Hub
os.environ['HUGGINGFACE_HUB_TOKEN'] = '####'
login(token=os.environ['HUGGINGFACE_HUB_TOKEN'])

In [None]:
# Load dataset
edacc_test = load_dataset("sage-bergerson/edacc_whisper", split="test", token=True)

# Resample audio to 16kH
edacc_test = edacc_test.cast_column("audio", Audio(sampling_rate=16000))

# Initialize Whisper Processor
processor = WhisperProcessor.from_pretrained("openai/whisper-large", language="English", task="transcribe")

# Download and initialize Whisper model
model = WhisperForConditionalGeneration.from_pretrained("sage-bergerson/whisper-large-edacc").to(device)
model.generation_config.language = "english"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

In [None]:
# Function: Generate transcript predictions
def map_to_pred(batch):
    audio = batch['audio']
    input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features.to(device)
    batch["reference"] = processor.tokenizer.normalize(batch['transcript'])

    with torch.no_grad():
        predicted_ids = model.generate(input_features)[0]
    transcription = processor.decode(predicted_ids)
    batch["prediction"] = processor.tokenizer.normalize(transcription)
    return batch

# Generate transcript predictions
result = edacc_test.map(map_to_pred)

In [None]:
# Initialize text normalizer
norm = EnglishTextNormalizer()

# Define special tokens
special_tokens = [
    '<overlap>', '<laugh>', '<foreign>', '<no-speech>', '<lipsmack>', 
    '<breath>', '<click>', '<cough>', '<dtmf>', '<ring>']

# Function: Filter out empty reference transcripts and clean transcripts
def clean(item, special_tokens):

    # Remove special tokens
    pattern = '|'.join(map(re.escape, special_tokens))
    item['reference'] = norm(re.sub(pattern, '', item['reference']))
    item['prediction'] = norm(re.sub(pattern, '', item['prediction']))

    return True

# Filter and clean the transcripts
result = [item for item in result if clean(item, special_tokens)]

In [None]:
# Group results into conversations
grouped = defaultdict(lambda: {"code": "", "reference": "", "prediction": ""})

for item in result:
    group_key = item["code"][:12]
    if not grouped[group_key]["code"]:
        grouped[group_key]["code"] = group_key[:12]
    grouped[group_key]["reference"] += " " + item["reference"]
    grouped[group_key]["prediction"] += " " + item["prediction"]

# Format conversation-level results
final_result = [{"code": value["code"], "reference": value["reference"].strip(), "prediction": value["prediction"].strip()} for value in grouped.values()]

# Write final results to two separate files for SCLITE
reference_filename = 'ft_references.txt'
prediction_filename = 'ft_predictions.txt'

with open(reference_filename, 'w', newline='', encoding='utf-8') as ref_file, \
     open(prediction_filename, 'w', newline='', encoding='utf-8') as pred_file:
    for item in final_result:
        ref_file.write(f"{item['reference']} ({item['code']})\n")
        pred_file.write(f"{item['prediction']} ({item['code']})\n")