In [None]:
import pandas as pd
import requests
from jiwer import wer
from tqdm import tqdm


## MP3 Conversion
- Mainly for GCP input constraints, but also for consistency with os implementation

In [None]:
def resample_normalise_audio(in_file, out_file, sample_rate=16000):
    if not os.path.exists(in_file):
        raise ValueError(f"{in_file} not found")
    if out_file is None:
        out_file = in_file.replace(os.path.splitext(in_file)[-1], f"_{sample_rate}.wav")

    os.system(
        f"ffmpeg -i {in_file} -acodec pcm_s16le -ac 1 -af aresample=resampler=soxr -ar {sample_rate} {out_file} -y"
    )
    return out_file


transcript_manifest = (
    pd.read_csv("../output/radio_national_podcasts/manifest.csv")
    .assign(
        audio_path=lambda x: x.audio_path.apply(
            lambda y: str(
                Path("../output/radio_national_podcasts/audio/mp3") / Path(y).name
            )
        )
    )
    .assign(
        transcript_path=lambda x: x.transcript_path.apply(
            lambda y: str(
                Path("../output/radio_national_podcasts/transcripts/ground_truth")
                / Path(y).name
            )
        )
    )
)

output_dir = Path("../output/radio_national_podcasts/audio/wav")
shutil.rmtree(str(output_dir)) if output_dir.exists() else None
output_dir.mkdir(parents=True, exist_ok=True)

for idx, record in transcript_manifest.iterrows():
    input_path = Path(record.audio_path)
    output_path = input_path.parents[1] / f"wav/{input_path.stem}.wav"
    resample_normalise_audio(str(input_path), str(output_path))


## OS

In [None]:
sys.path.append("..")
from asr import transcribe_mono_audio


In [None]:
os_output_dir = Path("../output/radio_national_podcasts/transcripts/os")
shutil.rmtree(str(os_output_dir)) if os_output_dir.exists() else None
os_output_dir.mkdir(parents=True, exist_ok=True)


In [None]:
for audio_path in tqdm(transcript_manifest.audio_path):
    before = time.time()
    transcript = transcribe_mono_audio(audio_path)
    after = time.time()

    os_transcript_record = {
        "hypothesis": " ".join(transcript.transcript.tolist()),
        "elapsed_time": after - before,
        "provider": "os",
    }
    (os_output_dir / f"{Path(audio_path).stem}.json").write_text(
        json.dumps(os_transcript_record)
    )


## GCP
- Huge chunks of the transcript missing when using async methods?
- Probably use telephony model instead
- Consider using streams instead of batch

In [None]:
from google.cloud import speech, storage

project = "hobby-358221"
bucket_name = "blog-os-asr"
storage_client = storage.Client(project=project)
bucket = storage_client.get_bucket(bucket_name)
blobs = bucket.list_blobs()
gcp_uris = [f"gs://{bucket_name}/{e.name}" for e in blobs]


In [None]:
speech_client = speech.SpeechClient()


def transcribe_gcs(gcs_uri):
    """Asynchronously transcribes the audio file specified by the gcs_uri."""
    audio = speech.RecognitionAudio(uri=gcs_uri)
    config = speech.RecognitionConfig(
        encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
        sample_rate_hertz=16000,
        language_code="en-US",
    )
    operation = speech_client.long_running_recognize(config=config, audio=audio)
    res = operation.result()
    return " ".join([e.alternatives[0].transcript.strip() for e in res.results]).strip()


In [None]:
gcp_output_dir = Path("../output/radio_national_podcasts/transcripts/gcp")
shutil.rmtree(str(gcp_output_dir)) if gcp_output_dir.exists() else None
gcp_output_dir.mkdir(parents=True, exist_ok=True)

for gcp_uri in tqdm(gcp_uris):
    try:
        print(f"Transcribing {gcp_uri}...")
        before = time.time()
        gcp_res = transcribe_gcs(gcp_uri)
        after = time.time()
        gcp_transcript_record = {
            "hypothesis": gcp_res,
            "elapsed_time": after - before,
            "provider": "gcp",
        }
        stub_output = (gcp_output_dir / f"{Path(gcp_uri).stem}.json").write_text(
            json.dumps(gcp_transcript_record)
        )
    except Exception:
        print(f"Unable to transcribe: {gcp_uri}")


In [None]:
# potentially use CLI?
gcloud ml speech recognize-long-running \
    'gs://blog-os-asr/test.wav' \
     --language-code='en-US' --async

# poll result
gcloud ml speech operations describe 1558607248830316847

## AWS

In [None]:
import s3fs

fs = s3fs.S3FileSystem()
aws_uris = fs.ls("blog-os-asr")

In [None]:
import logging

import boto3
import requests
from botocore.exceptions import ClientError

from transcribe_util import CustomWaiter, WaitState

logger = logging.getLogger(__name__)


class TranscribeCompleteWaiter(CustomWaiter):
    def __init__(self, client):
        super().__init__(
            "TranscribeComplete",
            "GetTranscriptionJob",
            "TranscriptionJob.TranscriptionJobStatus",
            {"COMPLETED": WaitState.SUCCESS, "FAILED": WaitState.FAILURE},
            client,
        )

    def wait(self, job_name):
        self._wait(TranscriptionJobName=job_name)


def start_job(
    job_name,
    media_uri,
    media_format,
    language_code,
    transcribe_client,
    vocabulary_name=None,
):
    try:
        job_args = {
            "TranscriptionJobName": job_name,
            "Media": {"MediaFileUri": media_uri},
            "MediaFormat": media_format,
            "LanguageCode": language_code,
        }
        if vocabulary_name is not None:
            job_args["Settings"] = {"VocabularyName": vocabulary_name}
        response = transcribe_client.start_transcription_job(**job_args)
        job = response["TranscriptionJob"]
        logger.info("Started transcription job %s.", job_name)
    except ClientError:
        logger.exception("Couldn't start transcription job %s.", job_name)
        raise
    else:
        return job


def get_job(job_name, transcribe_client):
    try:
        response = transcribe_client.get_transcription_job(
            TranscriptionJobName=job_name
        )
        job = response["TranscriptionJob"]
        logger.info("Got job %s.", job["TranscriptionJobName"])
    except ClientError:
        logger.exception("Couldn't get job %s.", job_name)
        raise
    else:
        return job


def list_jobs(job_filter, transcribe_client):
    try:
        response = transcribe_client.list_transcription_jobs(JobNameContains=job_filter)
        jobs = response["TranscriptionJobSummaries"]
        next_token = response.get("NextToken")
        while next_token is not None:
            response = transcribe_client.list_transcription_jobs(
                JobNameContains=job_filter, NextToken=next_token
            )
            jobs += response["TranscriptionJobSummaries"]
            next_token = response.get("NextToken")
        logger.info("Got %s jobs with filter %s.", len(jobs), job_filter)
    except ClientError:
        logger.exception("Couldn't get jobs with filter %s.", job_filter)
        raise
    else:
        return jobs


def delete_job(job_name, transcribe_client):
    try:
        transcribe_client.delete_transcription_job(TranscriptionJobName=job_name)
        logger.info("Deleted job %s.", job_name)
    except ClientError:
        logger.exception("Couldn't delete job %s.", job_name)
        raise


In [None]:
transcribe_client = boto3.client("transcribe")
aws_output_dir = Path("../output/radio_national_podcasts/transcripts/aws")
shutil.rmtree(str(aws_output_dir)) if aws_output_dir.exists() else None
aws_output_dir.mkdir(parents=True, exist_ok=True)

for aws_uri in tqdm(aws_uris):
    try:
        print(f"Transcribing {aws_uri}...")
        before = time.time()
        job_name_simple = Path(aws_uri).name

        # ensure a job doesn't already exist
        for job in list_jobs(job_name_simple, transcribe_client):
            delete_job(job["TranscriptionJobName"], transcribe_client)

        print(f"Starting transcription job {job_name_simple}.")
        start_job(job_name_simple, f"s3://{aws_uri}", "wav", "en-US", transcribe_client)
        transcribe_waiter = TranscribeCompleteWaiter(transcribe_client)
        transcribe_waiter.wait(job_name_simple)
        job_simple = get_job(job_name_simple, transcribe_client)
        transcript_simple = requests.get(
            job_simple["Transcript"]["TranscriptFileUri"]
        ).json()
        after = time.time()
        aws_transcript_record = {
            "hypothesis": transcript_simple["results"]["transcripts"][0]["transcript"],
            "elapsed_time": after - before,
            "provider": "aws",
        }
        stub_output = (aws_output_dir / f"{Path(aws_uri).stem}.json").write_text(
            json.dumps(aws_transcript_record)
        )
        # clean-up jobs
        for job in list_jobs(job_name_simple, transcribe_client):
            delete_job(job["TranscriptionJobName"], transcribe_client)

    except Exception:
        print(f"Unable to transcribe: {aws_uri}")


## Azure

- Batch transcription instructions: https://github.com/Azure-Samples/cognitive-services-speech-sdk/tree/master/samples/batch/python
- Generate swagger, download python package, install python package
- Example code via: https://github.com/Azure-Samples/cognitive-services-speech-sdk/blob/master/samples/batch/python/python-client/main.py

In [None]:
speech_key = os.environ["azure_asr_key"]
service_region = os.environ["azure_asr_region"]
endpoint = os.environ["azure_asr_endpoint"]

blob_key = os.environ["azure_blob_key"]
blob_connection_string = os.environ["azure_blob_connection_string"]
blob_container_name = os.environ["azure_blob_container_name"]
storage_account_name = os.environ["azure_storage_account_name"]


In [None]:
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
import swagger_client as cris_client

logging.basicConfig(
    stream=sys.stdout,
    level=logging.DEBUG,
    format="%(asctime)s %(message)s",
    datefmt="%m/%d/%Y %I:%M:%S %p %Z",
)

# Your subscription key and region for the speech service
SUBSCRIPTION_KEY = speech_key
SERVICE_REGION = service_region

NAME = "Simple transcription"
DESCRIPTION = "Simple transcription description"

LOCALE = "en-US"
# with a single request. At least 'read' and 'list' (rl) permissions are required.
# RECORDINGS_CONTAINER_URI = "https://blogosasr.blob.core.windows.net/blog-os-asr"
RECORDINGS_CONTAINER_URI = "https://blogosasr.blob.core.windows.net/blog-os-asr?sp=rl&st=2022-09-05T21:46:42Z&se=2022-09-06T05:46:42Z&spr=https&sv=2021-06-08&sr=c&sig=5cKlVMcMexgvS2BEI%2BWiaeNT1apvATLRJfK3E4lhTvM%3D"

# Set model information when doing transcription with custom models
MODEL_REFERENCE = None  # guid of a custom model


def transcribe_from_single_blob(uri, properties):
    """
    Transcribe a single audio file located at `uri` using the settings specified in `properties`
    using the base model for the specified locale.
    """
    return cris_client.Transcription(
        display_name=NAME,
        description=DESCRIPTION,
        locale=LOCALE,
        content_urls=[uri],
        properties=properties,
    )


def transcribe_with_custom_model(api, uri, properties):
    """
    Transcribe a single audio file located at `uri` using the settings specified in `properties`
    using the base model for the specified locale.
    """
    if MODEL_REFERENCE is None:
        logging.error("Custom model ids must be set when using custom models")
        sys.exit()
    model = api.get_model(MODEL_REFERENCE)
    return cris_client.Transcription(
        display_name=NAME,
        description=DESCRIPTION,
        locale=LOCALE,
        content_urls=[uri],
        model=model,
        properties=properties,
    )


def transcribe_from_container(uri, properties):
    """
    Transcribe all files in the container located at `uri` using the settings specified in `properties`
    using the base model for the specified locale.
    """
    return cris_client.Transcription(
        display_name=NAME,
        description=DESCRIPTION,
        locale=LOCALE,
        content_container_url=uri,
        properties=properties,
    )


def _paginate(api, paginated_object):  # sourcery skip: raise-specific-error
    """
    The autogenerated client does not support pagination. This function returns a generator over
    all items of the array that the paginated object `paginated_object` is part of.
    """
    yield from paginated_object.values
    typename = type(paginated_object).__name__
    auth_settings = ["apiKeyHeader", "apiKeyQuery"]
    while paginated_object.next_link:
        link = paginated_object.next_link[len(api.api_client.configuration.host) :]
        paginated_object, status, headers = api.api_client.call_api(
            link, "GET", response_type=typename, auth_settings=auth_settings
        )

        if status == 200:
            yield from paginated_object.values
        else:
            raise Exception(f"could not receive paginated data: status {status}")


def delete_all_transcriptions(api):
    """
    Delete all transcriptions associated with your speech resource.
    """
    logging.info("Deleting all existing completed transcriptions.")

    # get all transcriptions for the subscription
    transcriptions = list(_paginate(api, api.get_transcriptions()))

    # Delete all pre-existing completed transcriptions.
    # If transcriptions are still running or not started, they will not be deleted.
    for transcription in transcriptions:
        transcription_id = transcription._self.split("/")[-1]
        logging.debug(f"Deleting transcription with id {transcription_id}")
        try:
            api.delete_transcription(transcription_id)
        except cris_client.rest.ApiException as exc:
            logging.error(f"Could not delete transcription {transcription_id}: {exc}")


def transcribe_azure(blob_sas):
    logging.info("Starting transcription client...")
    before = time.time()

    # configure API key authorization: subscription_key
    configuration = cris_client.Configuration()
    configuration.api_key["Ocp-Apim-Subscription-Key"] = SUBSCRIPTION_KEY
    configuration.host = (
        f"https://{SERVICE_REGION}.api.cognitive.microsoft.com/speechtotext/v3.0"
    )

    # create the client object and authenticate
    client = cris_client.ApiClient(configuration)

    # create an instance of the transcription api class
    api = cris_client.CustomSpeechTranscriptionsApi(api_client=client)

    properties = {
        # "punctuationMode": "DictatedAndAutomatic",
        # "profanityFilterMode": "Masked",
        # "wordLevelTimestampsEnabled": True,
        # "diarizationEnabled": True,
        # "destinationContainerUrl": "<SAS Uri with at least write (w) permissions for an Azure Storage blob container that results should be written to>",
        # "timeToLive": "PT1H"
    }

    transcription_definition = transcribe_from_single_blob(blob_sas, properties)

    # Uncomment this block to transcribe all files from a container.
    # transcription_definition = transcribe_from_container(
    #     RECORDINGS_CONTAINER_URI, properties)

    created_transcription, status, headers = api.create_transcription_with_http_info(
        transcription=transcription_definition
    )

    # get the transcription Id from the location URI
    transcription_id = headers["location"].split("/")[-1]

    logging.info(
        f"Created new transcription with id '{transcription_id}' in region {SERVICE_REGION}"
    )

    logging.info("Checking status.")

    completed = False
    while not completed:
        # wait for 5 seconds before refreshing the transcription status
        time.sleep(5)

        transcription = api.get_transcription(transcription_id)
        logging.info(f"Transcriptions status: {transcription.status}")

        if transcription.status in ("Failed", "Succeeded"):
            completed = True

        if transcription.status == "Succeeded":
            pag_files = api.get_transcription_files(transcription_id)
            for file_data in _paginate(api, pag_files):
                if file_data.kind != "Transcription":
                    continue

                file_data.name
                results_url = file_data.links.content_url
                results = requests.get(results_url)
                after = time.time()
                return {
                    "hypothesis": results.json()["combinedRecognizedPhrases"][0][
                        "display"
                    ],
                    "provider": "azure",
                    "elapsed_time": after - before,
                }
        elif transcription.status == "Failed":
            logging.info(
                f"Transcription failed: {transcription.properties.error.message}"
            )


In [None]:
from azure.storage.blob import (
    BlobClient,
    BlobSasPermissions,
    BlobServiceClient,
    ContainerClient,
    __version__,
    generate_blob_sas,
)

blob_service_client = BlobServiceClient.from_connection_string(blob_connection_string)
blob_client = blob_service_client.get_container_client(blob_container_name)

azure_output_dir = Path("../output/radio_national_podcasts/transcripts/azure")
shutil.rmtree(str(azure_output_dir)) if azure_output_dir.exists() else None
azure_output_dir.mkdir(parents=True, exist_ok=True)


In [None]:
from datetime import datetime, timedelta

for blob_file in tqdm(list(blob_client.list_blobs())):
    print(f"Transcribing {blob_file['name']}...")

    blob_sas = generate_blob_sas(
        account_name=storage_account_name,
        container_name=blob_container_name,
        blob_name=blob_file["name"],
        account_key=blob_key,
        permission=BlobSasPermissions(read=True),
        expiry=datetime.utcnow() + timedelta(hours=6),
    )

    blob_sas_formatted = f"https://{storage_account_name}.blob.core.windows.net/{blob_container_name}/{blob_file['name']}?{blob_sas}"

    azure_transcript_record = transcribe_azure(blob_sas_formatted)
    stub_output = (
        azure_output_dir / f"{Path(blob_file['name']).stem}.json"
    ).write_text(json.dumps(azure_transcript_record))


## Consolidate, evaluate

In [None]:
import json

import pandas as pd


def load_transcripts(transcript_dir):
    records = []
    for e in transcript_dir.rglob("*.json"):
        record = json.loads(e.read_text())
        record["stem"] = e.stem
        records.append(record)
    return pd.DataFrame(records)


aws = load_transcripts(Path("../output/radio_national_podcasts/transcripts/aws"))
azure = load_transcripts(Path("../output/radio_national_podcasts/transcripts/azure"))
gcp = load_transcripts(Path("../output/radio_national_podcasts/transcripts/gcp"))
os = load_transcripts(Path("../output/radio_national_podcasts/transcripts/os"))

ground_truth = transcript_manifest.pipe(lambda x: x[["transcript", "stem"]])

wer_frames = []
for provider in [aws, azure, gcp, os]:
    wer_frames.append(
        pd.merge(ground_truth, provider, how="inner", on="stem").assign(
            wer=lambda x: x.apply(lambda y: wer(y.transcript, y.hypothesis), axis=1)
        )
    )

eval_res = (
    pd.concat(
        [
            e[["elapsed_time", "wer"]].describe().assign(provider=e.iloc[0].provider)
            for e in wer_frames
        ]
    )
    .reset_index()
    .pipe(lambda x: x[x["index"].str.contains("mean|min|50%|max")])
    .rename(mapper={"index": "metric"}, axis="columns", inplace=False)
)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("dark")
sns.barplot(x="provider", y="elapsed_time", data=eval_res)
plt.xticks(rotation=45)
plt.title("ASR Provider x Elapsed Time")


In [None]:
sns.set_style("dark")
sns.barplot(x="provider", y="wer", hue="metric", data=eval_res)
plt.xticks(rotation=45)
plt.title("ASR Provider x WER (Various)")


## Tidy up
- keep two shortest podcasts under 10 minutes

In [None]:
keep_records = (
    transcript_manifest.query("len_minutes < 10").sort_values("len_minutes").head(2)
)
keep_records.to_csv('../output/radio_national_podcasts/manifest.csv',index=False)

In [None]:
for e in Path("../output").rglob("*"):
    if (
        e.suffix in [".json", ".wav", ".mp3", ".DS_Store", ".txt"]
        and e.stem not in keep_records.stem.tolist()
    ):
        e.unlink()
