In [1]:
from dataclasses import dataclass
from typing import List, Optional
import soundfile as sf
import librosa

import torch
import torch.nn.functional as F
import whisper
from tqdm import tqdm
from whisper.audio import N_FRAMES, N_MELS, log_mel_spectrogram, pad_or_trim
from whisper.model import Whisper
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, Tokenizer, get_tokenizer

import IPython.display as ipd

In [2]:
def readAudioFile(audio_path):
    audio, sample_rate = sf.read(audio_path)
    if audio.ndim > 1:
        audio = audio[:, 0]
    if sample_rate != 16000:
        audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
    return audio.astype("float32")

In [3]:
audio=readAudioFile("/mnt/e/Download/ja_test_0/common_voice_ja_19485593.mp3")
ipd.Audio(audio, rate=16000)

In [4]:
@torch.no_grad()
def calculate_audio_features(audio_path: Optional[str], model: Whisper) -> torch.Tensor:
    if audio_path is None:
        segment = torch.zeros((N_MELS, N_FRAMES), dtype=torch.float32).to(model.device)
    else:
        # wav or mp3 to 16KHz mono audio ndarray
        audio = readAudioFile(audio_path)
        mel = log_mel_spectrogram(audio)
        segment = pad_or_trim(mel, N_FRAMES).to(model.device)
    return model.embed_audio(segment.unsqueeze(0))

In [5]:
@torch.no_grad()
def calculate_average_logprobs(
    model: Whisper,
    audio_features: torch.Tensor,
    class_names: List[str],
    tokenizer: Tokenizer,
) -> torch.Tensor:
    initial_tokens = (
        torch.tensor(tokenizer.sot_sequence_including_notimestamps).unsqueeze(0).to(model.device)
    )
    eot_token = torch.tensor([tokenizer.eot]).unsqueeze(0).to(model.device)

    average_logprobs = torch.zeros(len(class_names))
    for i, class_name in enumerate(class_names):
        class_name_tokens = (
            torch.tensor(tokenizer.encode(" " + class_name)).unsqueeze(0).to(model.device)
        )
        input_tokens = torch.cat([initial_tokens, class_name_tokens, eot_token], dim=1)

        logits = model.logits(input_tokens, audio_features)  # (1, T, V)
        logprobs = F.log_softmax(logits, dim=-1).squeeze(0)  # (T, V)
        logprobs = logprobs[len(tokenizer.sot_sequence_including_notimestamps) - 1 : -1]  # (T', V)
        logprobs = torch.gather(logprobs, dim=-1, index=class_name_tokens.view(-1, 1))  # (T', 1)
        average_logprob = logprobs.mean().item()
        average_logprobs[i] = average_logprob

    return average_logprobs

In [6]:
def classify(
    model: Whisper,
    audio_path: str,
    class_names: List[str],
    tokenizer: Tokenizer,
    internal_lm_average_logprobs: Optional[torch.Tensor],
    verbose: bool = False,
) -> str:
    audio_features = calculate_audio_features(audio_path, model)

    average_logprobs = calculate_average_logprobs(
        model=model,
        audio_features=audio_features,
        class_names=class_names,
        tokenizer=tokenizer,
    )
    if internal_lm_average_logprobs is not None:
        average_logprobs -= internal_lm_average_logprobs

    sorted_indices = sorted(
        range(len(class_names)), key=lambda i: average_logprobs[i], reverse=True
    )
    if verbose:
        tqdm.write("  Average log probabilities for each class:")
        for i in sorted_indices:
            tqdm.write(f"    {class_names[i]}: {average_logprobs[i]:.3f}")

    return class_names[sorted_indices[0]]

In [7]:
def calculate_internal_lm_average_logprobs(
    model: Whisper,
    class_names: List[str],
    tokenizer: Tokenizer,
    verbose: bool = False,
) -> torch.Tensor:
    audio_features_from_empty_input = calculate_audio_features(None, model)
    average_logprobs = calculate_average_logprobs(
        model=model,
        audio_features=audio_features_from_empty_input,
        class_names=class_names,
        tokenizer=tokenizer,
    )
    if verbose:
        print("Internal LM average log probabilities for each class:")
        for i, class_name in enumerate(class_names):
            print(f"  {class_name}: {average_logprobs[i]:.3f}")
    return average_logprobs

In [8]:
@dataclass
class AudioData:
    audio_path: str
    category: Optional[str] = None

In [9]:
import os
def getAudioData(dir, category=None):
    audio_data = []
    for file in os.listdir(dir):
        audio_data.append(AudioData(os.path.join(dir, file), category))
    return audio_data

In [10]:
tokenizer = get_tokenizer(multilingual=True, language="ja")
model = whisper.load_model("base", "cuda:0")

In [11]:
results = []
records = getAudioData("/mnt/e/Download/ja_test_0","speech")[10:20]+getAudioData("/mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised","laughter")[10:20]
class_names='''[cat]
[keyboard_typing]
[sneezing]
[laughing]
[breathing]
'''
class_names = [c.strip() for c in class_names.split("\n")]

internal_lm_average_logprobs = calculate_internal_lm_average_logprobs(
    model=model,
    class_names=class_names,
    tokenizer=tokenizer,
    verbose=True,
)
print(internal_lm_average_logprobs)


Internal LM average log probabilities for each class:
  [cat]: -7.852
  [keyboard_typing]: -5.483
  [sneezing]: -4.417
  [laughing]: -6.766
  [breathing]: -5.664
  : -6.777
tensor([-7.8522, -5.4827, -4.4168, -6.7656, -5.6637, -6.7773])


In [12]:

for record in tqdm(records):
    tqdm.write(f"processing {record.audio_path} (class: {record.category})")
    result = classify(
        model=model,
        audio_path=record.audio_path,
        class_names=class_names,
        tokenizer=tokenizer,
        internal_lm_average_logprobs=internal_lm_average_logprobs,
        verbose=False
    )
    results.append(result)
    tqdm.write(f"  predicted: {result}")

  0%|          | 0/20 [00:00<?, ?it/s]

processing /mnt/e/Download/ja_test_0/common_voice_ja_19485602.mp3 (class: speech)


  5%|▌         | 1/20 [00:00<00:05,  3.19it/s]

  predicted: [keyboard_typing]
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485618.mp3 (class: speech)


 10%|█         | 2/20 [00:00<00:06,  2.77it/s]

  predicted: 
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485621.mp3 (class: speech)


 15%|█▌        | 3/20 [00:00<00:05,  3.38it/s]

  predicted: [keyboard_typing]
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485622.mp3 (class: speech)


 20%|██        | 4/20 [00:01<00:04,  3.69it/s]

  predicted: 
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485627.mp3 (class: speech)


 25%|██▌       | 5/20 [00:01<00:03,  3.81it/s]

  predicted: [keyboard_typing]
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485634.mp3 (class: speech)


 30%|███       | 6/20 [00:01<00:03,  3.88it/s]

  predicted: [keyboard_typing]
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485636.mp3 (class: speech)


 35%|███▌      | 7/20 [00:01<00:03,  4.13it/s]

  predicted: [keyboard_typing]
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485654.mp3 (class: speech)


 40%|████      | 8/20 [00:02<00:03,  3.88it/s]

  predicted: [keyboard_typing]
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485872.mp3 (class: speech)


 45%|████▌     | 9/20 [00:02<00:02,  3.72it/s]

  predicted: [keyboard_typing]
processing /mnt/e/Download/ja_test_0/common_voice_ja_19485873.mp3 (class: speech)


 55%|█████▌    | 11/20 [00:02<00:01,  4.59it/s]

  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr001-utt011.wav (class: laughter)
  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr001-utt012.wav (class: laughter)


 65%|██████▌   | 13/20 [00:03<00:01,  5.19it/s]

  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr001-utt013.wav (class: laughter)
  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr001-utt014.wav (class: laughter)


 75%|███████▌  | 15/20 [00:03<00:00,  5.98it/s]

  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr001-utt015.wav (class: laughter)
  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr001-utt016.wav (class: laughter)


 85%|████████▌ | 17/20 [00:03<00:00,  7.08it/s]

  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr002-utt001.wav (class: laughter)
  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr002-utt002.wav (class: laughter)


 95%|█████████▌| 19/20 [00:03<00:00,  7.40it/s]

  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr002-utt003.wav (class: laughter)
  predicted: [laughing]
processing /mnt/e/Download/laughterscape_ver1.0/ver1.0/denoised/spkr002-utt005.wav (class: laughter)


100%|██████████| 20/20 [00:04<00:00,  4.96it/s]

  predicted: [laughing]



