In [19]:
import warnings
import numpy as np
import os
import pandas as pd
import itertools
import json
from pathlib import Path
from jiwer import cer
import torch
import torchaudio
from tqdm import tqdm
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, GPT2Tokenizer, GPT2Model
from sklearn.metrics.pairwise import cosine_similarity
import random 

warnings.filterwarnings("ignore")

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained tokenizer and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h').to(device)

# GPT2
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2Model.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
num_of_samples = 1024
batch_size = 4
random_seed = 53

In [22]:
audio_paths = []
actual_transcripts = []

# Base directories
text_base_dir = Path("C:\\Users\\a_has\\Desktop\\DS_10283_3443\\txt")
audio_base_dir = Path("C:\\Users\\a_has\\Desktop\\DS_10283_3443\\wav48_silence_trimmed")

# Collect all text files across all subfolders
all_text_files = []
for subfolder in text_base_dir.iterdir():
    if subfolder.is_dir():
        text_files = list(subfolder.glob("*.txt"))
        for text_file in text_files:
            all_text_files.append((subfolder.name, text_file))

random.seed(random_seed)
selected_files = random.sample(all_text_files, min(num_of_samples, len(all_text_files)))

# Process the selected files
for speaker_id, text_file in selected_files:
    with open(text_file, "r", encoding="utf-8") as f:
        content = f.read().strip().lower().replace(",", "").replace(".", "").replace("?", "")

    # Construct the corresponding audio file path
    audio_filename = text_file.stem + "_mic1.flac"
    audio_file = audio_base_dir / speaker_id / audio_filename

    # Append to lists if the audio file exists
    if audio_file.exists():
        audio_paths.append(audio_file.resolve())
        actual_transcripts.append(content)
    else:
        print(f"Missing audio file for: {text_file}")

Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_350.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_184.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_068.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_007.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_381.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_355.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_036.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_311.txt
Missing audio file for: C:\Users\a_has\Desktop\DS_10283_3443\txt\p362\p362_085.txt


In [23]:
results = []

for i in tqdm(range(0, len(audio_paths), batch_size)):
    batch_files = audio_paths[i : i + batch_size]
    batch_actual_transcripts = actual_transcripts[i : i + batch_size]
    
    waveforms = []
    max_length = 0  # Track max waveform length

    # Load and resample audio files
    for audio_file in batch_files:
        waveform, sample_rate = torchaudio.load(audio_file)
        if sample_rate != 16000:
            transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = transform(waveform)

        waveform = waveform.squeeze(0)  # Remove extra dimension
        waveforms.append(waveform)
        max_length = max(max_length, waveform.shape[0])  # Find max length in batch

    # Pad all waveforms to max_length
    padded_waveforms = []
    for waveform in waveforms:
        pad_length = max_length - waveform.shape[0]
        padded_waveform = torch.nn.functional.pad(waveform, (4000, pad_length)) # right_pad + fix left_pad
        padded_waveforms.append(padded_waveform)

    # Convert list of tensors to a batch tensor
    batch_waveforms = torch.stack(padded_waveforms)

    # Convert list of waveforms to a batch tensor with padding
    input_values = processor(batch_waveforms.numpy(), return_tensors="pt", padding=True, sampling_rate=16000)["input_values"]

    # Step 3: Perform model inference in batch
    with torch.no_grad():
        logits = model(input_values).logits

    # Step 4: Decode predictions in batch
    predicted_ids = torch.argmax(logits, dim=-1)
    transcriptions = processor.batch_decode(predicted_ids)

    # Step 5: Compute similarity metrics
    for j, transcription in enumerate(transcriptions):
        gen_text = transcription.strip().lower()
        actual_text = batch_actual_transcripts[j]

        def get_gpt2_embedding(text):
            if not text.strip():  # Check if the text is empty
                return torch.zeros(1, 768)  # Return a zero vector with GPT-2 embedding size (768)
            inputs = gpt2_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            with torch.no_grad():
                outputs = gpt2_model(**inputs)
            return outputs.last_hidden_state.mean(dim=1)  # Get mean-pooled embedding
        
        actual_embedding = get_gpt2_embedding(actual_text)
        gen_embedding = get_gpt2_embedding(gen_text)

        cosine_sim = cosine_similarity(actual_embedding.numpy(), gen_embedding.numpy())[0][0]
        cer_score = cer(actual_text, gen_text)

        results.append({
            "audio_file": batch_files[j],
            "actual_transcript": actual_text,
            "generated_transcript": gen_text,
            "CER": cer_score,
            "GPT-2 Cosine Similarity": cosine_sim
        })

100%|██████████| 254/254 [15:10<00:00,  3.59s/it]


In [34]:
df_results = pd.DataFrame(results)
pd.set_option("display.max_colwidth", None)
df_results.sample(5)

Unnamed: 0,audio_file,actual_transcript,generated_transcript,CER,GPT-2 Cosine Similarity
899,C:\Users\a_has\Desktop\DS_10283_3443\wav48_silence_trimmed\p314\p314_125_mic1.flac,but it will backfire,but it will back fire,0.05,0.999678
410,C:\Users\a_has\Desktop\DS_10283_3443\wav48_silence_trimmed\p236\p236_261_mic1.flac,he also launched a new strategy for the agency,he also launched new stratagy for the agency,0.065217,0.997225
251,C:\Users\a_has\Desktop\DS_10283_3443\wav48_silence_trimmed\p279\p279_092_mic1.flac,they lived for their children,they lived for their children,0.0,1.0
346,C:\Users\a_has\Desktop\DS_10283_3443\wav48_silence_trimmed\p240\p240_164_mic1.flac,however the french government has a major dilemma on its hands,however the french government has a major dilemma on its hands,0.0,1.0
160,C:\Users\a_has\Desktop\DS_10283_3443\wav48_silence_trimmed\p286\p286_051_mic1.flac,he also presented you bet!,he also presented yeu bet,0.076923,0.998713


In [31]:
# Character Error Rate (CER): Measures character-level errors
mean_values = df_results[["CER", "GPT-2 Cosine Similarity"]].mean()
display(mean_values)

CER                        0.032010
GPT-2 Cosine Similarity    0.999332
dtype: float64

In [26]:
# for index, row in random_frames.iterrows():
#     episode, t1, t2 = row["video"].split("-")
#     audio_file = Path(f"C:\\Users\\a_has\\Desktop\\Friends\\friends_mmc\\face_track_videos\\face_track_videos\\{episode}\\{t1}-{t2}\\0.wav")

#     waveform, sample_rate = torchaudio.load(audio_file)
#     if sample_rate != 16000:
#         transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
#         waveform = transform(waveform)

#     input_values = processor(waveform.numpy(), return_tensors="pt", sampling_rate=16000)["input_values"]
#     with torch.no_grad():
#         logits = model(input_values).logits
#     predicted_ids = torch.argmax(logits, dim=-1)
#     transcription = processor.batch_decode(predicted_ids)

#     print(transcription)