In [None]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa
import numpy as np
from dataclasses import dataclass
from typing import Optional, List
import pandas as pd
from scipy.signal import butter, filtfilt

In [None]:
import whisper
import pandas as pd
from datetime import timedelta
import numpy as np

###preprocessing + prompts 
def transcribe_audio(audio_path, speaker_hints=None):
    """Transcribe audio using Whisper with enhanced speaker detection"""
    print("Loading model...")
    model = whisper.load_model("large-v2", device="cpu")
    
    # Add initial prompt to encourage speaker labeling
    initial_prompt = """This is a parent-infant interaction transcript. 
    Speakers include Mother, Father, and Baby. 
    Format: [Speaker] followed by speech content."""
    
    print("Transcribing...")
    result = model.transcribe(
        audio_path,
        word_timestamps=True,
        verbose=True,
        initial_prompt=initial_prompt,
        language="en",
        task="transcribe"
    )
    
    # Process segments
    segments = []
    if speaker_hints is None:
        speaker_hints = [
            "[Mother]", "[Father]", "[Baby]",
            "[Infant Vocalization]", "[Baby Crying]", "[Baby Laughing]"
        ]
    
    # Helper function for speaker detection
    def detect_speaker(text, previous_speaker=None):
        # Check for explicit speaker markers
        for hint in speaker_hints:
            if hint.lower() in text.lower():
                return hint
            
        # Acoustic and content-based heuristics
        words = text.lower().split()
        
        # Check for infant-specific patterns
        baby_patterns = ['goo', 'gah', 'bah', 'mama', 'dada', 'babbling', 'crying', 'laughing']
        if any(pattern in text.lower() for pattern in baby_patterns):
            return '[Baby]'
        
        # Check for parent-specific patterns
        parent_patterns = [
            'good job', 'look at', 'here we go', 'that\'s right',
            'can you', 'let\'s', 'sweetie', 'honey', 'baby'
        ]
        if any(pattern in text.lower() for pattern in parent_patterns):
            return previous_speaker if previous_speaker in ['[Mother]', '[Father]'] else '[Mother]'
        
        # Context continuation
        if previous_speaker and len(text.split()) < 5:  # Short utterances likely continue previous speaker
            return previous_speaker
            
        return previous_speaker if previous_speaker else "[Unknown]"
    
    # Process segments with context
    previous_speaker = None
    min_segment_duration = 0.3  # Minimum duration for a valid segment
    
    for i, segment in enumerate(result["segments"]):
        text = segment["text"].strip()
        start = segment["start"]
        end = segment["end"]
        duration = end - start
        
        # Skip very short segments that might be noise
        if duration < min_segment_duration:
            continue
        
        # Enhanced speaker detection
        detected_speaker = detect_speaker(text, previous_speaker)
        
        # Update previous speaker if we have a confident detection
        if detected_speaker != "[Unknown]":
            previous_speaker = detected_speaker
        
        # Add segment info
        segment_info = {
            'start': start,
            'end': end,
            'start_time': str(timedelta(seconds=int(start))),
            'end_time': str(timedelta(seconds=int(end))),
            'speaker': detected_speaker,
            'text': text,
            'duration': duration,
            'words_per_second': len(text.split()) / duration if duration > 0 else 0
        }
        
        # Add confidence metrics
        segment_info['confidence'] = 'high' if detected_speaker != "[Unknown]" else 'low'
        
        segments.append(segment_info)
    
    # Post-process segments to improve speaker consistency
    df = pd.DataFrame(segments)
    
    # Smooth speaker labels using a rolling window
    window_size = 3
    for i in range(len(df)):
        if df.iloc[i]['speaker'] == "[Unknown]":
            # Look at surrounding segments
            start_idx = max(0, i - window_size)
            end_idx = min(len(df), i + window_size + 1)
            window = df.iloc[start_idx:end_idx]
            
            # Count speaker occurrences in window
            speaker_counts = window['speaker'].value_counts()
            if len(speaker_counts) > 0 and speaker_counts.index[0] != "[Unknown]":
                df.at[i, 'speaker'] = speaker_counts.index[0]
    
    return df, result

# Example usage
if __name__ == "__main__":
    speaker_hints = [
        "[Mother]", "[Father]", "[Baby]",
        "[Infant Vocalization]", "[Baby Crying]", "[Baby Laughing]"
    ]
    
    try:
        df, result = transcribe_audio(
            "/Users/yueyan/Documents/project/wearable/media/025_04/IW_025_04_YT.wav",
            speaker_hints=speaker_hints
        )
        
        # Save results
        df.to_csv("/Users/yueyan/Documents/project/wearable/transcription/025_04_whisper_results_with_speaker.csv", index=False)
        
        # Print statistics
        print("\nTranscription Statistics:")
        print(f"Total segments: {len(df)}")
        print("\nSpeaker distribution:")
        print(df['speaker'].value_counts())
        
        print("\nFirst few transcriptions:")
        print(df[['start_time', 'end_time', 'speaker', 'text']].head())
        
    except Exception as e:
        print(f"Error occurred: {str(e)}")
        import traceback
        traceback.print_exc()