In [None]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa
import numpy as np
from dataclasses import dataclass
from typing import Optional, List
import pandas as pd
from scipy.signal import butter, filtfilt

In [None]:
import whisper
import pandas as pd
from datetime import timedelta
import numpy as np

###preprocessing + prompts 
def transcribe_audio(audio_path, speaker_hints=None):
    """Transcribe audio using Whisper with enhanced speaker detection"""
    print("Loading model...")
    model = whisper.load_model("large-v2", device="cpu")
    
    # Add initial prompt to encourage speaker labeling
    initial_prompt = """This is a parent-infant interaction transcript. 
    Speakers include Mother, Father, and Baby. 
    Format: [Speaker] followed by speech content."""
    
    print("Transcribing...")
    result = model.transcribe(
        audio_path,
        word_timestamps=True,
        verbose=True,
        initial_prompt=initial_prompt,
        language="en",
        task="transcribe"
    )
    
    # Process segments
    segments = []
    if speaker_hints is None:
        speaker_hints = [
            "[Mother]", "[Father]", "[Baby]",
            "[Infant Vocalization]", "[Baby Crying]", "[Baby Laughing]"
        ]
    
    # Helper function for speaker detection
    def detect_speaker(text, previous_speaker=None):
        # Check for explicit speaker markers
        for hint in speaker_hints:
            if hint.lower() in text.lower():
                return hint
            
        # Acoustic and content-based heuristics
        words = text.lower().split()
        
        # Check for infant-specific patterns
        baby_patterns = ['goo', 'gah', 'bah', 'mama', 'dada', 'babbling', 'crying', 'laughing']
        if any(pattern in text.lower() for pattern in baby_patterns):
            return '[Baby]'
        
        # Check for parent-specific patterns
        parent_patterns = [
            'good job', 'look at', 'here we go', 'that\'s right',
            'can you', 'let\'s', 'sweetie', 'honey', 'baby'
        ]
        if any(pattern in text.lower() for pattern in parent_patterns):
            return previous_speaker if previous_speaker in ['[Mother]', '[Father]'] else '[Mother]'
        
        # Context continuation
        if previous_speaker and len(text.split()) < 5:  # Short utterances likely continue previous speaker
            return previous_speaker
            
        return previous_speaker if previous_speaker else "[Unknown]"
    
    # Process segments with context
    previous_speaker = None
    min_segment_duration = 0.3  # Minimum duration for a valid segment
    
    for i, segment in enumerate(result["segments"]):
        text = segment["text"].strip()
        start = segment["start"]
        end = segment["end"]
        duration = end - start
        
        # Skip very short segments that might be noise
        if duration < min_segment_duration:
            continue
        
        # Enhanced speaker detection
        detected_speaker = detect_speaker(text, previous_speaker)
        
        # Update previous speaker if we have a confident detection
        if detected_speaker != "[Unknown]":
            previous_speaker = detected_speaker
        
        # Add segment info
        segment_info = {
            'start': start,
            'end': end,
            'start_time': str(timedelta(seconds=int(start))),
            'end_time': str(timedelta(seconds=int(end))),
            'speaker': detected_speaker,
            'text': text,
            'duration': duration,
            'words_per_second': len(text.split()) / duration if duration > 0 else 0
        }
        
        # Add confidence metrics
        segment_info['confidence'] = 'high' if detected_speaker != "[Unknown]" else 'low'
        
        segments.append(segment_info)
    
    # Post-process segments to improve speaker consistency
    df = pd.DataFrame(segments)
    
    # Smooth speaker labels using a rolling window
    window_size = 3
    for i in range(len(df)):
        if df.iloc[i]['speaker'] == "[Unknown]":
            # Look at surrounding segments
            start_idx = max(0, i - window_size)
            end_idx = min(len(df), i + window_size + 1)
            window = df.iloc[start_idx:end_idx]
            
            # Count speaker occurrences in window
            speaker_counts = window['speaker'].value_counts()
            if len(speaker_counts) > 0 and speaker_counts.index[0] != "[Unknown]":
                df.at[i, 'speaker'] = speaker_counts.index[0]
    
    return df, result

# Example usage
if __name__ == "__main__":
    speaker_hints = [
        "[Mother]", "[Father]", "[Baby]",
        "[Infant Vocalization]", "[Baby Crying]", "[Baby Laughing]"
    ]
    
    try:
        df, result = transcribe_audio(
            "/Users/yueyan/Documents/project/wearable/media/025_04/IW_025_04_YT.wav",
            speaker_hints=speaker_hints
        )
        
        # Save results
        df.to_csv("/Users/yueyan/Documents/project/wearable/transcription/025_04_whisper_results_with_speaker.csv", index=False)
        
        # Print statistics
        print("\nTranscription Statistics:")
        print(f"Total segments: {len(df)}")
        print("\nSpeaker distribution:")
        print(df['speaker'].value_counts())
        
        print("\nFirst few transcriptions:")
        print(df[['start_time', 'end_time', 'speaker', 'text']].head())
        
    except Exception as e:
        print(f"Error occurred: {str(e)}")
        import traceback
        traceback.print_exc()

In [None]:
import pandas as pd
import numpy as np
from difflib import SequenceMatcher
import matplotlib.pyplot as plt
from tqdm import tqdm
import re
import string


####reliability checks 
class TranscriptionComparator:
    def __init__(self, whisper_df: pd.DataFrame, human_df: pd.DataFrame):
        """Initialize with enhanced text cleaning"""
        self.whisper_data = whisper_df.copy()
        self.human_data = human_df.copy()
        
        # Pre-clean the data
        #Removes rows with missing text
        #fix space after "end" (could be deleted)
        self.human_data = self.human_data[self.human_data['text'].notna()]
        self.human_data = self.human_data.rename(columns={'end ': 'end'})
        
        # Convert times to numeric
        # Ensures time values are numeric for calculations
        self.human_data['start'] = pd.to_numeric(self.human_data['start'])
        self.human_data['end'] = pd.to_numeric(self.human_data['end'])
        
        # Clean text thoroughly
        # while preserving originals
        self.whisper_data['clean_text'] = self.whisper_data['text'].apply(self.clean_text)
        self.human_data['clean_text'] = self.human_data['text'].apply(self.clean_text)
        
        # Store original text for reference
        self.whisper_data['original_text'] = self.whisper_data['text']
        self.human_data['original_text'] = self.human_data['text']
        
        # Find overlap range
        # Determines the time range where both transcriptions overlap
        # Uses max of starts and min of ends to find common time period
        self.overlap_start = max(
            self.whisper_data['start'].min(),
            self.human_data['start'].min()
        )
        self.overlap_end = min(
            self.whisper_data['end'].max(),
            self.human_data['end'].max()
        )
        
        print(f"Data loaded and cleaned. Time range: {self.overlap_start:.2f}s - {self.overlap_end:.2f}s")

    @staticmethod
    def clean_text(text: str) -> str:
        """
        text cleaning function
        - Converts to lowercase
        - Removes punctuation
        - Removes extra whitespace
        - Standardizes common transcription artifacts
        """

        # Handles missing/NaN values by returning empty string
        if pd.isna(text):
            return ""
            
        # Convert to string and lowercase
        text = str(text).lower()
        
        # Remove punctuation
        text = text.translate(str.maketrans("", "", string.punctuation))
        
        # Standardize whitespace
        # Normalizes spacing (removes extra spaces)
        text = " ".join(text.split())
        
        # Standardize common transcription variations
        # could be adjusted 
        # Maps different spellings/variations to standard forms
        replacements = {
            'uhm': 'um',
            'uhh': 'uh',
            'hmm': 'hm',
            'mhm': 'mm',
            'yeah': 'yes',
            'yah': 'yes',
            'nah': 'no'
        }
        
        for old, new in replacements.items():
            text = re.sub(r'\b' + old + r'\b', new, text)
            
        return text
        
    def show_text_cleaning_examples(self, n_examples: int = 5):
        """Show examples of text cleaning"""
        print("\nText Cleaning Examples:")
        print("-" * 50)
        
        # Whisper examples
        print("\nWhisper Transcription Examples:")
        samples = self.whisper_data.sample(n=n_examples)
        for _, row in samples.iterrows():
            print(f"Original: {row['original_text']}")
            print(f"Cleaned : {row['clean_text']}")
            print("-" * 30)
            
        # Human examples
        print("\nHuman Transcription Examples:")
        samples = self.human_data.sample(n=n_examples)
        for _, row in samples.iterrows():
            print(f"Original: {row['original_text']}")
            print(f"Cleaned : {row['clean_text']}")
            print("-" * 30)
            
    # Main comparison function
    # time_tolerance: allows segments to be slightly offset (0.5 seconds)
    # similarity_threshold: minimum text similarity to consider a match (0.3 or 30%)
    def find_overlapping_segments(self, time_tolerance: float = 0.5, 
                                similarity_threshold: float = 0.3):
        """Find overlapping segments with cleaned text comparison"""
        overlapping = []
        
        # Filter relevant segments to only those within overlap period
        # Reduces processing by excluding non-overlapping parts
        human_segments = self.human_data[
            (self.human_data['start'] >= self.overlap_start) &
            (self.human_data['end'] <= self.overlap_end)
        ]

        
        whisper_segments = self.whisper_data[
            (self.whisper_data['start'] >= self.overlap_start) &
            (self.whisper_data['end'] <= self.overlap_end)
        ]
        
        print(f"Processing {len(human_segments)} segments...")

        # Iterates through human segments with progress bar
        for _, human_seg in tqdm(human_segments.iterrows()):
            # Find potential matches
            potential_matches = whisper_segments[
                (whisper_segments['start'] >= human_seg['start'] - time_tolerance) &
                (whisper_segments['start'] <= human_seg['end'] + time_tolerance)
            ]
            
            for _, whisper_seg in potential_matches.iterrows():
                # Calculate similarity using cleaned text
                # Returns ratio between 0 (completely different) and 1 (identical)
                """
                Demonstrates SequenceMatcher algorithm with visualization
    
                Formula for ratio calculation:
                ratio = 2.0 * M / (T1 + T2)
                where:
                - M = sum of length of matching blocks
                - T1 = length of text1
                - T2 = length of text2
                """
                similarity = SequenceMatcher(
                    None,
                    human_seg['clean_text'],
                    whisper_seg['clean_text']
                ).ratio()
                
                if similarity > similarity_threshold:
                    overlapping.append({
                        'start_human': human_seg['start'],
                        'end_human': human_seg['end'],
                        'text_human_original': human_seg['original_text'],
                        'text_human_cleaned': human_seg['clean_text'],
                        'start_whisper': whisper_seg['start'],
                        'end_whisper': whisper_seg['end'],
                        'text_whisper_original': whisper_seg['original_text'],
                        'text_whisper_cleaned': whisper_seg['clean_text'],
                        'text_similarity': similarity,
                        'time_diff': abs(human_seg['start'] - whisper_seg['start'])
                    })
        
        return pd.DataFrame(overlapping)

    def analyze_overlap(self):
        """Analyze overlapping segments with detailed text comparison"""
        overlapping = self.find_overlapping_segments()
        
        stats = {
            'overlap_duration': self.overlap_end - self.overlap_start,
            'total_duration_whisper': self.whisper_data['end'].max() - self.whisper_data['start'].min(),
            'total_duration_human': self.human_data['end'].max() - self.human_data['start'].min(),
            'matching_segments': len(overlapping),
            'average_similarity': overlapping['text_similarity'].mean() if len(overlapping) > 0 else 0,
            'high_similarity_matches': len(overlapping[overlapping['text_similarity'] > 0.7])
        }
        
        return overlapping, stats




In [None]:
if __name__ == "__main__":
    import time
    start_time = time.time()
    
    # Load your dataframes
    whisper_df = pd.read_csv("/Users/yueyan/Documents/project/wearable/transcription/025_04_whisper_results_with_speaker.csv")
    human_df = pd.read_csv("/Users/yueyan/Documents/project/wearable/transcription/025_04_human_transcription.csv")
    
    # Create comparator
    comparator = TranscriptionComparator(whisper_df, human_df)
    
    # Show text cleaning examples
    comparator.show_text_cleaning_examples()
    
    # Analyze overlaps
    overlapping, stats = comparator.analyze_overlap()
    
    # Print results
    print("\nResults:")
    for key, value in stats.items():
        print(f"{key}: {value:.2f}")
    
    print(f"\nProcessing time: {time.time() - start_time:.2f} seconds")
    
    # Save results with both original and cleaned text
    overlapping.to_csv("overlap_comparison_with_text.csv", index=False)


# Create comparator
comparator = TranscriptionComparator(whisper_df, human_df)

# Show cleaning examples
comparator.show_text_cleaning_examples()

# Run comparison
overlapping, stats = comparator.analyze_overlap()

# Check specific examples
print("\nExample Comparisons:")
for _, row in overlapping.head().iterrows():
    print("\nHuman    :", row['text_human_original'])
    print("Cleaned  :", row['text_human_cleaned'])
    print("Whisper  :", row['text_whisper_original'])
    print("Cleaned  :", row['text_whisper_cleaned'])
    print(f"Similarity: {row['text_similarity']:.2f}")