In [None]:
from jiwer import wer
import os
import pandas as pd
import whisper
import zeno_client
import dotenv
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

dotenv.load_dotenv(override=True)

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

In [None]:
df = pd.read_csv("speech_accent_archive.csv")


In [None]:
df["data"] = "https://zenoml.s3.amazonaws.com/accents/" + df["id"]

In [None]:
client = zeno_client.ZenoClient(os.environ.get("ZENO_API_KEY"))

project = client.create_project(
    name="Transcription Whisper Distil", 
    view="audio-transcription",
    description="Test of audio transcription",
    metrics=[
        zeno_client.ZenoMetric(name="avg wer", type="mean", columns=["wer"])
    ]
)

In [None]:
project.upload_dataset(df, id_column="id", data_column="data", label_column="label")

In [None]:
models = ["tiny.en", "base.en", "medium.en", "large", "distil-medium.en", "distil-large-v2"]

In [None]:
os.makedirs("cache", exist_ok=True)

df_systems = []
for model_name in models:
    try:
        df_system = pd.read_parquet(f"cache/{model_name}.parquet")
    except:
        df_system = df[["id", "data", "label"]].copy()

        if "distil" in model_name:
            model_id = "distil-whisper/" + model_name
            model = AutoModelForSpeechSeq2Seq.from_pretrained(
                model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
            )
            model.to(device)
            processor = AutoProcessor.from_pretrained(model_id)
            model = AutoModelForSpeechSeq2Seq.from_pretrained(
                model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
            )
            model.to(device)

            processor = AutoProcessor.from_pretrained(model_id)
            pipe = pipeline(
                "automatic-speech-recognition",
                model=model,
                tokenizer=processor.tokenizer,
                feature_extractor=processor.feature_extractor,
                max_new_tokens=128,
                torch_dtype=torch_dtype,
                device=device,
            )
            df_system["output"] = df_system["data"].apply(lambda x: pipe(x)['text'])
            pass
        else:
            whisper_model = whisper.load_model(model_name)
            df_system["output"] = df_system["data"].apply(lambda x: whisper_model.transcribe(x)["text"])

        df_system["wer"] = df_system.apply(lambda x: wer(x["label"], x["output"]), axis=1)
        df_system.to_parquet(f"cache/{model_name}.parquet", index=False)
    df_systems.append(df_system) 

In [None]:
for i, df_system in enumerate(df_systems):
    project.upload_system(df_system[["id", "output", "wer"]], name=models[i], id_column="id", output_column="output")