In [46]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperModel, WhisperModel, WhisperForCausalLM
import numpy as np
import librosa

class WhisperNextWordPredictor:
    def __init__(self, model_name="openai/whisper-base"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = WhisperProcessor.from_pretrained(model_name)
        self.model =  WhisperForConditionalGeneration.from_pretrained(model_name).to(self.device)
        
    def get_next_word_logits(self, audio_path, timestamp_words):
        """
        Extract next-word prediction logits at each timestamp in the transcript.
        
        Args:
            audio_path: Path to audio file
            timestamp_words: List of (timestamp, word) tuples
        
        Returns:
            List of (word, timestamp, next_word_logits) tuples
        """
        # Load and process audio
        audio_features = self.processor.feature_extractor(
            self._load_audio(audio_path), 
            return_tensors="pt"
        ).input_features.to(self.device)
        
        # Calculate frames per second for the model
        # Whisper uses a 25ms frame stride
        frames_per_second = 1 / 0.025  # 40 frames per second
        
        results = []
        for i, (timestamp, word) in enumerate(timestamp_words):
            # Create attention mask for audio up to current timestamp
            # Convert timestamp to frame index
            current_frame = int(timestamp * frames_per_second)
            attention_mask = torch.ones_like(audio_features)
            attention_mask[:, :, current_frame:] = 0  # Mask future frames
            
            # Get encoder hidden states with masked audio
            encoder_outputs = self.model.get_encoder()(
                audio_features,
                attention_mask=attention_mask
            )

            print (attention_mask.shape)
            
            # Convert previous words to input ids
            previous_words = [w for _, w in timestamp_words[:i+1]]

            tokens = self.processor.tokenizer(
                " ".join(previous_words),
                return_tensors="pt",
                device=self.device
            ) #.input_ids.to(self.device)
            

            # Get decoder outputs
            decoder_outputs = self.model.get_decoder()(
                input_ids=tokens['input_ids'],
                encoder_hidden_states=encoder_outputs[0],
                attention_mask=tokens['attention_mask']  # Remove channel dimension
            )
            
            # Get logits for next token prediction
            next_token_logits = self.model.proj_out(decoder_outputs[0][:, -1, :])
            
            results.append({
                'word': word,
                'timestamp': timestamp,
                'next_word_logits': next_token_logits.detach().cpu().numpy()
            })
            
        return results
    
    def get_top_k_predictions(self, logits, k=5):
        """
        Get top k predicted next words from logits
        
        Args:
            logits: Logits array from get_next_word_logits
            k: Number of top predictions to return
            
        Returns:
            List of (word, probability) tuples
        """
        probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1)
        top_k_probs, top_k_indices = torch.topk(probs, k)
        
        predictions = []
        for prob, idx in zip(top_k_probs[0], top_k_indices[0]):
            word = self.processor.tokenizer.decode([idx])
            predictions.append((word, prob.item()))
            
        return predictions
    
    def _load_audio(self, audio_path):
        """
        Load and preprocess audio file to match Whisper's expected input format.
        
        Args:
            audio_path: Path to the audio file
            
        Returns:
            numpy.ndarray: Audio waveform resampled to 16kHz
        """
        # Load audio file
        # Whisper expects 16kHz mono audio
        try:
            waveform, sample_rate = librosa.load(
                audio_path,
                sr=16000,  # Whisper expects 16kHz
                mono=True
            )
        except Exception as e:
            raise Exception(f"Error loading audio file: {str(e)}")
            
        # Check if audio is too short
        if len(waveform) == 0:
            raise ValueError("Audio file is empty")
            
        # Normalize audio to float32 range [-1, 1]
        if not np.isfinite(waveform).all():
            raise ValueError("Audio file contains invalid values (inf or nan)")
            
        waveform = librosa.util.normalize(waveform)
        
        return waveform

# Complete example usage:
"""
predictor = WhisperNextWordPredictor()

# Example timestamp_words list
timestamp_words = [
    (0.0, "hello"),
    (0.5, "how"),
    (1.0, "are"),
    (1.5, "you")
]

# Get logits for each word
results = predictor.get_next_word_logits("path/to/audio.wav", timestamp_words)

# Get top 5 predictions for each word
for result in results:
    print(f"\nPredictions after '{result['word']}' at {result['timestamp']}s:")
    predictions = predictor.get_top_k_predictions(result['next_word_logits'])
    for word, prob in predictions:
        print(f"{word}: {prob:.3f}")
"""

'\npredictor = WhisperNextWordPredictor()\n\n# Example timestamp_words list\ntimestamp_words = [\n    (0.0, "hello"),\n    (0.5, "how"),\n    (1.0, "are"),\n    (1.5, "you")\n]\n\n# Get logits for each word\nresults = predictor.get_next_word_logits("path/to/audio.wav", timestamp_words)\n\n# Get top 5 predictions for each word\nfor result in results:\n    print(f"\nPredictions after \'{result[\'word\']}\' at {result[\'timestamp\']}s:")\n    predictions = predictor.get_top_k_predictions(result[\'next_word_logits\'])\n    for word, prob in predictions:\n        print(f"{word}: {prob:.3f}")\n'

In [18]:
import os, sys
import glob
import pandas as pd
import numpy as np

sys.path.append('../utils/')

from config import *
from tommy_utils import nlp

In [40]:
task = 'wheretheressmoke'

# load the preprocessed file --> this has next-word-candidates selected
stim_preprocessed_fn = os.path.join(BASE_DIR, 'stimuli/preprocessed', task, f'{task}_transcript-preprocessed.csv')
df_preproc = pd.read_csv(stim_preprocessed_fn)

# remap for our functions
df_preproc = df_preproc.rename(columns={'Word_Written': 'word', 'Punctuation': 'punctuation'})

# create a list of indices that we will iterate through to sample the transcript
segments = nlp.get_segment_indices(n_words=len(df_preproc), window_size=25)[:-1]

In [47]:
audio_fn = os.path.join(BASE_DIR, f'stimuli/audio/{task}.wav')

timestamp_words = [tuple(df_preproc.loc[i, ['Onset', 'word']]) for i in range(10)]

predictor = WhisperNextWordPredictor()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [48]:
results = predictor.get_next_word_logits(audio_fn, timestamp_words)

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])


Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])


Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])


Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])


Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])


Keyword arguments {'device': 'cpu'} not recognized.
Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])
torch.Size([1, 80, 3000])


Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])


Keyword arguments {'device': 'cpu'} not recognized.
Keyword arguments {'device': 'cpu'} not recognized.


torch.Size([1, 80, 3000])
torch.Size([1, 80, 3000])


In [49]:
for result in results:
    print(f"\nPredictions after '{result['word']}' at {result['timestamp']}s:")
    predictions = predictor.get_top_k_predictions(result['next_word_logits'])
    for word, prob in predictions:
        print(f"{word}: {prob:.3f}")


Predictions after 'I' at 0.0124716553288s:
<|endoftext|>: 0.518
.: 0.271
!: 0.023
?: 0.017
 (: 0.017

Predictions after 'reached' at 0.1277805392196221s:
<|endoftext|>: 0.422
.: 0.124
?: 0.043
,: 0.025
 (: 0.017

Predictions after 'over' at 0.4938471714535156s:
<|endoftext|>: 0.374
?: 0.141
.: 0.063
': 0.015
;: 0.014

Predictions after 'and' at 1.53900226757s:
<|transcribe|>: 0.838
<|translate|>: 0.065
<|endoftext|>: 0.054
 I: 0.018
 and: 0.005

Predictions after 'secretly' at 1.6649148365141804s:
 my: 0.061
<|endoftext|>: 0.043
?: 0.037
.: 0.033
': 0.022

Predictions after 'undid' at 2.41700680272s:
 my: 0.208
<|endoftext|>: 0.062
 and: 0.048
 I: 0.040
 of: 0.037

Predictions after 'my' at 2.9010308719513387s:
 and: 0.127
<|endoftext|>: 0.084
 —: 0.036
 I: 0.029
 of: 0.028

Predictions after 'seatbelt' at 3.091175727917867s:
 and: 0.271
<|endoftext|>: 0.128
 when: 0.042
 And: 0.038
—: 0.033

Predictions after 'and' at 4.53740830352817s:
 when: 0.197
 and: 0.162
<|endoftext|>: 0.081
 

In [15]:

# # Example timestamp_words list
# timestamp_words = [
#     (0.0, "hello"),
#     (0.5, "how"),
#     (1.0, "are"),
#     (1.5, "you")
# ]

# # Get logits for each word
# results = predictor.get_next_word_logits("path/to/audio.wav", timestamp_words)




preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

In [10]:

# # we don't need to get the last word
# for i, segment in enumerate(segments):

#     ground_truth_index = segment[-1] + 1
#     ground_truth_word = df_preproc.loc[ground_truth_index, 'word']
    
#     # also keep track of the current ground truth word
#     inputs = nlp.transcript_to_input(df_preproc, segment, add_punctuation=True)		