### Setup

Import libraries, set up model, and read input data.

In [None]:
from jiwer import wer
import os
import pandas as pd
import whisper
import zeno_client
import dotenv
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import pandas as pd
import requests
from io import BytesIO
import wave
import struct
from tqdm import tqdm

tqdm.pandas()
dotenv.load_dotenv(override=True)

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

In [None]:
df = pd.read_csv("speech_accent_archive.csv")

In [None]:
df["data"] = "https://zenoml.s3.amazonaws.com/accents/" + df["id"]

In [None]:
# Define the function to get amplitude and length
def get_amplitude_and_length_from_url(url):
    # Download the WAV file content from the URL
    try:
        response = requests.get(url)
        response.raise_for_status()  # will raise an HTTPError if the HTTP request returned an unsuccessful status code

        # Use the BytesIO object as input for the wave module
        with wave.open(BytesIO(response.content), 'rb') as wav_file:
            frame_rate = wav_file.getframerate()
            n_frames = wav_file.getnframes()
            n_channels = wav_file.getnchannels()
            sample_width = wav_file.getsampwidth()
            duration = n_frames / frame_rate

            frames = wav_file.readframes(n_frames)
            if sample_width == 1:  # 8-bit audio
                fmt = '{}B'.format(n_frames * n_channels)
            elif sample_width == 2:  # 16-bit audio
                fmt = '{}h'.format(n_frames * n_channels)
            else:
                raise ValueError("Only supports up to 16-bit audio.")
            
            frame_amplitudes = struct.unpack(fmt, frames)
            max_amplitude = max(frame_amplitudes)
            max_amplitude_normalized = max_amplitude / float(int((2 ** (8 * sample_width)) / 2))

            return max_amplitude_normalized, duration
    except requests.RequestException as e:
        print(f"Request failed: {e}")
        return None, None

# Define a wrapper function for apply to work row-wise
def apply_get_amplitude_and_length(row):
    url = row['data']  # Assuming the URL is in the 'data' column
    amplitude, length = get_amplitude_and_length_from_url(url)
    return pd.Series({'amplitude': amplitude, 'length': length})

# Usage with apply on the DataFrame
# This will create two new columns 'amplitude' and 'length' in the DataFrame
df[['amplitude', 'length']] = df.progress_apply(apply_get_amplitude_and_length, axis=1)

### Zeno Project

We create a Zeno project with a WER metric and upload our base data.

In [None]:
client = zeno_client.ZenoClient(os.environ.get("ZENO_API_KEY"))

project = client.create_project(
    name="Transcription Whisper Distil", 
    view="audio-transcription",
    description="Test of audio transcription",
    metrics=[
        zeno_client.ZenoMetric(name="avg wer", type="mean", columns=["wer"])
    ]
)

In [None]:
project.upload_dataset(df, id_column="id", data_column="data", label_column="label")

### Run Inference

We now run inference on the base Whisper models and Distil models, cacheing the output.

In [None]:
models = ["medium.en", "large-v1", "large-v2", "large-v3", "distil-medium.en", "distil-large-v2"]

In [None]:
os.makedirs("cache", exist_ok=True)

df_systems = []
for model_name in models:
    try:
        df_system = pd.read_parquet(f"cache/{model_name}.parquet")
    except:
        df_system = df[["id", "data", "label"]].copy()

        if "distil" in model_name:
            model_id = "distil-whisper/" + model_name
            model = AutoModelForSpeechSeq2Seq.from_pretrained(
                model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
            )
            model.to(device)

            processor = AutoProcessor.from_pretrained(model_id)
            pipe = pipeline(
                "automatic-speech-recognition",
                model=model,
                tokenizer=processor.tokenizer,
                feature_extractor=processor.feature_extractor,
                max_new_tokens=128,
                chunk_length_s=15,
                batch_size=16,
                torch_dtype=torch_dtype,
                device=device,
            )
            df_system["output"] = df_system["data"].progress_apply(lambda x: pipe(x)['text'])
            pass
        else:
            whisper_model = whisper.load_model(model_name)
            df_system["output"] = df_system["data"].progress_apply(lambda x: whisper_model.transcribe(x)["text"])

        df_system["wer"] = df_system.progress_apply(lambda x: wer(x["label"], x["output"]), axis=1)
        df_system.to_parquet(f"cache/{model_name}.parquet", index=False)
    df_systems.append(df_system) 

### Upload Results

Lastly, we upload our final results.

In [None]:
for i, df_system in enumerate(df_systems):
    project.upload_system(df_system[["id", "output", "wer"]], name=models[i], id_column="id", output_column="output")