In [26]:
import warnings
import numpy as np
import os
import pandas as pd
import itertools
import json
from pathlib import Path
from jiwer import cer
import torch
import torchaudio
from tqdm import tqdm
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, GPT2Tokenizer, GPT2Model
from sklearn.metrics.pairwise import cosine_similarity

warnings.filterwarnings("ignore")

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pre-trained tokenizer and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h').to(device)

# GPT2
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2Model.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
num_of_samples = 128
batch_size = 16

In [29]:
annotations_file = r"C:\Users\a_has\Desktop\Friends\friends_mmc\5_turns\test-metadata.json"
with open(annotations_file, "r") as f:
    annotations = json.load(f)
    annotations = list(itertools.chain.from_iterable(annotations))
    
annotations = pd.DataFrame(annotations)
annotations.head()

Unnamed: 0,frame,speaker,content,faces,video
0,s03e20-000128,chandler,"Wait a minute, wait. You’re telling me this ac...","[[[800, 257, 870, 342], joey], [[543, 259, 597...",s03e20-000063-000194
1,s03e20-000218,joey,Yeah! Oh my God! (to Chandler) Is this what it...,"[[[797, 242, 863, 326], joey], [[532, 250, 593...",s03e20-000194-000320
2,s03e20-000538,joey,"Oh, you have no idea. And-and when we’re on st...","[[[567, 79, 787, 358], joey]]",s03e20-000403-000639
3,s03e20-000680,phoebe,to see you feeling like this!,"[[[428, 74, 657, 353], phoebe]]",s03e20-000646-000715
4,s03e20-000832,ross,"Monica, uh Dad called this morning and ah, Aun...","[[[472, 396, 545, 491], phoebe], [[133, 134, 2...",s03e20-000820-000957


In [30]:
random_frames = annotations.sample(num_of_samples, random_state=57)
random_frames.head(5)

Unnamed: 0,frame,speaker,content,faces,video
7160,s03e16-007283,phoebe,Now what is Fabutec?,"[[[815, 69, 907, 202], phoebe], [[441, 108, 51...",s03e16-007258-007295
2225,s03e18-020687,pete,Hang on a second. (to the employees) I’ll-I’ll...,[],s03e18-020628-020694
6883,s03e03-023600,phoebe,"Look, he gave me his night vision goggles and ...","[[[664, 82, 810, 297], phoebe]]",s03e03-023561-023639
5663,s03e23-027540,pete,"Yeah. Monica, I want you there in the front ro...","[[[326, 57, 593, 376], pete], [[672, 108, 884,...",s03e23-027382-027623
9976,s03e25-019235,bonnie,"Oh, the water was sooo great! We jumped off th...","[[[554, 42, 727, 293], bonnie]]",s03e25-019182-019288


In [31]:
audio_paths = []
actual_transcripts = []

for index, row in random_frames.iterrows():
    episode, t1, t2 = row["video"].split("-")
    audio_file = Path(
        f"C:\\Users\\a_has\\Desktop\\Friends\\friends_mmc\\face_track_videos\\face_track_videos\\{episode}\\{t1}-{t2}\\0.wav"
    ).resolve()
    audio_paths.append(audio_file)
    actual_transcripts.append(
        row["content"]
        .strip()
        .lower()
        .replace(",", "")
        .replace(".", "")
        .replace("?", "")
    )

In [None]:
results = []

for i in tqdm(range(0, len(audio_paths), batch_size)):
    batch_files = audio_paths[i : i + batch_size]
    batch_actual_transcripts = actual_transcripts[i : i + batch_size]
    
    waveforms = []
    max_length = 0  # Track max waveform length

    # Load and resample audio files
    for audio_file in batch_files:
        waveform, sample_rate = torchaudio.load(audio_file)
        if sample_rate != 16000:
            transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            waveform = transform(waveform)

        waveform = waveform.squeeze(0)  # Remove extra dimension
        waveforms.append(waveform)
        max_length = max(max_length, waveform.shape[0])  # Find max length in batch

    # Pad all waveforms to max_length
    padded_waveforms = []
    for waveform in waveforms:
        pad_length = max_length - waveform.shape[0]
        padded_waveform = torch.nn.functional.pad(waveform, (4000, pad_length)) # right_pad + fix left_pad
        padded_waveforms.append(padded_waveform)

    # Convert list of tensors to a batch tensor
    batch_waveforms = torch.stack(padded_waveforms)

    # Convert list of waveforms to a batch tensor with padding
    input_values = processor(batch_waveforms.numpy(), return_tensors="pt", padding=True, sampling_rate=16000)["input_values"]

    # Step 3: Perform model inference in batch
    with torch.no_grad():
        logits = model(input_values).logits

    # Step 4: Decode predictions in batch
    predicted_ids = torch.argmax(logits, dim=-1)
    transcriptions = processor.batch_decode(predicted_ids)

    # Step 5: Compute similarity metrics
    for j, transcription in enumerate(transcriptions):
        gen_text = transcription.strip().lower()
        actual_text = batch_actual_transcripts[j]

        def get_gpt2_embedding(text):
            if not text.strip():  # Check if the text is empty
                return torch.zeros(1, 768)  # Return a zero vector with GPT-2 embedding size (768)
            inputs = gpt2_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            with torch.no_grad():
                outputs = gpt2_model(**inputs)
            return outputs.last_hidden_state.mean(dim=1)  # Get mean-pooled embedding
        
        actual_embedding = get_gpt2_embedding(actual_text)
        gen_embedding = get_gpt2_embedding(gen_text)

        cosine_sim = cosine_similarity(actual_embedding.numpy(), gen_embedding.numpy())[0][0]
        cer_score = cer(actual_text, gen_text)

        results.append({
            "audio_file": batch_files[j],
            "actual_transcript": actual_text,
            "generated_transcript": gen_text,
            "CER": cer_score,
            "GPT-2 Cosine Similarity": cosine_sim
        })

In [38]:
df_results = pd.DataFrame(results)
df_results.sample(5)

Unnamed: 0,audio_file,actual_transcript,generated_transcript,CER,GPT-2 Cosine Similarity
49,C:\Users\a_has\Desktop\Friends\friends_mmc\fac...,you really really need to get some sleep honey,im in to get them slave i know,0.695652,0.994691
3,C:\Users\a_has\Desktop\Friends\friends_mmc\fac...,yeah monica i want you there in the front row ...,monica i want ye there in the front roll when ...,0.119658,0.99942
24,C:\Users\a_has\Desktop\Friends\friends_mmc\fac...,dude i don’t know,to that don't kno,0.529412,0.991003
78,C:\Users\a_has\Desktop\Friends\friends_mmc\fac...,(looking at the timer) thirty seconds left on ...,if i can sup on the time o,0.660714,0.992553
52,C:\Users\a_has\Desktop\Friends\friends_mmc\fac...,you know we don’t really take advantage of liv...,no we really don't take advanage liieg in the ...,0.393443,0.997531


In [34]:
# Character Error Rate (CER): Measures character-level errors
mean_values = df_results[["CER", "GPT-2 Cosine Similarity"]].mean()
display(mean_values)

CER                        0.537633
GPT-2 Cosine Similarity    0.880790
dtype: float64

In [16]:
# for index, row in random_frames.iterrows():
#     episode, t1, t2 = row["video"].split("-")
#     audio_file = Path(f"C:\\Users\\a_has\\Desktop\\Friends\\friends_mmc\\face_track_videos\\face_track_videos\\{episode}\\{t1}-{t2}\\0.wav")

#     waveform, sample_rate = torchaudio.load(audio_file)
#     if sample_rate != 16000:
#         transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
#         waveform = transform(waveform)

#     input_values = processor(waveform.numpy(), return_tensors="pt", sampling_rate=16000)["input_values"]
#     with torch.no_grad():
#         logits = model(input_values).logits
#     predicted_ids = torch.argmax(logits, dim=-1)
#     transcription = processor.batch_decode(predicted_ids)

#     print(transcription)