In [1]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import librosa
import numpy as np
from dataclasses import dataclass
from typing import Optional, List
import pandas as pd
from scipy.signal import butter, filtfilt

"""
# Parent-Infant Interaction Audio Analysis System

## Overview
This system consists of two main components:
1. Audio Transcription Pipeline using Whisper ASR
2. Transcription Comparison and Validation Tool

## Component 1: Audio Transcription
### Features
- Uses OpenAI's Whisper ASR model (large-v2)
- Specialized for parent-infant interactions
- Includes speaker detection capabilities
- Handles both verbal and non-verbal vocalizations

### Key Functions
1. `transcribe_audio(audio_path, speaker_hints)`
   - Input: Audio file path
   - Output: DataFrame with transcribed segments
   - Features:
     * Word-level timestamps
     * Duration tracking
     * Optional speaker detection
     * Parent-infant interaction specific prompting

2. `detect_speaker(text, speaker_hints)`
   - Categorizes speech segments into:
     * [Baby Crying]
     * [Baby Laughing]
     * [Infant Vocalization]
     * [Mother]
     * [Father]
     * [Unspecified]

## Component 2: Transcription Comparison
### Features
- Compares Whisper ASR output with human transcriptions
- Provides multiple accuracy metrics
- Generates detailed segment-by-segment analysis
- Exports comparison results in spreadsheet format

### Key Metrics
1. Text Similarity
   - SequenceMatcher score (Ratcliff/Obershelp algorithm)
   - Formula: ratio = 2.0 * M / T
     * M = sum of lengths of matched subsequences
     * T = total length of both strings combined

2. Word-level Analysis
   - Word count comparison
   - Correct words identification
   - Mismatch detection
   - Accuracy percentage calculation

### Output Format
CSV file with columns:
- Time Range
- Whisper Transcription
- Human Transcription
- SequenceMatcher Score
- Word Count (Human)
- Word Count (Whisper)
- Correct Words
- Mismatches
- Accuracy (%)
- Comments/Notes

## Usage Workflow
1. Transcribe Audio:
```python
df, result = transcribe_audio(
    "path/to/audio.wav",
    speaker_hints=["[Mother]", "[Father]", "[Baby]", ...]
)
```

2. Compare with Human Transcription:
```python
results_df, stats = compare_and_export(
    whisper_df,
    human_df,
    "comparison_results.csv"
)
```

## Technical Notes
- Time tolerance: 0.5 seconds for segment matching
- Similarity threshold: 0.3 for considering matches
- Text cleaning includes:
  * Case normalization
  * Punctuation removal
  * Common transcription variant standardization
- High quality matches: Accuracy > 80%
- Low quality matches: Accuracy < 60%

## File Paths
- Audio Input: '/path/to/media/12mon/IW_ID_MONTH_TL.wav'
- Whisper Output: '/path/to/transcription/12mon/ID_MONTH_whisper_results_without_speaker.csv'
- Comparison Output: '/path/to/transcription/12mon/ID_MONTH_transcription_comparison.csv'
"""




In [2]:
!pip install numpy



In [7]:
import whisper
import pandas as pd
from datetime import timedelta
import numpy as np

def transcribe_audio(audio_path, speaker_hints=None):
    """Transcribe audio using Whisper with enhanced speaker detection"""
    print("Loading model...")
    model = whisper.load_model("large-v2", device="cpu")
    
    # Keep the parent-infant context in the prompt but make it more transcription-focused
    initial_prompt = """This is a parent-infant interaction recording. 
    Please transcribe all speech, including infant vocalizations, parent speech, and any notable sounds.
    Include both verbal and non-verbal vocalizations."""
    
    print("Transcribing...")
    result = model.transcribe(
        audio_path,
        word_timestamps=True,
        verbose=True,
        initial_prompt=initial_prompt,
        language="en",
        task="transcribe"
    )
    
    # Process segments
    segments = []
    
    # First pass: Get clean transcriptions with timing
    for segment in result["segments"]:
        text = segment["text"].strip()
        start = segment["start"]
        end = segment["end"]
        duration = end - start
        
        # Skip empty segments
        if not text:
            continue
            
        segment_info = {
            'start': start,
            'end': end,
            'start_time': str(timedelta(seconds=int(start))),
            'end_time': str(timedelta(seconds=int(end))),
            'text': text,
            'duration': duration
        }
        
        # Optional: Add speaker detection without modifying the original text
        if speaker_hints:
            speaker = detect_speaker(text, speaker_hints)
            segment_info['speaker'] = speaker
        
        segments.append(segment_info)
    
    # Create DataFrame
    df = pd.DataFrame(segments)
    
    return df, result

def detect_speaker(text, speaker_hints):
    """Separate function for speaker detection that doesn't modify the transcription"""
    text_lower = text.lower()
    
    # Basic speaker detection logic
    if any(word in text_lower for word in ['crying', 'cries', 'waa']):
        return '[Baby Crying]'
    elif any(word in text_lower for word in ['laugh', 'giggle']):
        return '[Baby Laughing]'
    elif any(word in text_lower for word in ['goo', 'gah', 'bah', 'coo']):
        return '[Infant Vocalization]'
    elif '[mother]' in text_lower or 'mom' in text_lower:
        return '[Mother]'
    elif '[father]' in text_lower or 'dad' in text_lower:
        return '[Father]'
    else:
        return '[Unspecified]'

# Example usage
if __name__ == "__main__":
    speaker_hints = [
        "[Mother]", "[Father]", "[Baby]",
        "[Infant Vocalization]", "[Baby Crying]", "[Baby Laughing]"
    ]
    
    try:
        df, result = transcribe_audio(
            "/Users/yueyan/Documents/project/wearable/media/12mon/IW_002_12_TL.wav",
            speaker_hints=speaker_hints
        )
        
        # Save results
        df.to_csv("/Users/yueyan/Documents/project/wearable/transcription/12mon/002_12_whisper_results_without_speaker.csv", index=False)
        
        # Print statistics
        print("\nTranscription Statistics:")
        print(f"Total segments: {len(df)}")
        
        # If speaker detection is enabled
        if 'speaker' in df.columns:
            print("\nSpeaker distribution:")
            print(df['speaker'].value_counts())
        
        print("\nFirst few transcriptions:")
        print(df[['start_time', 'end_time', 'text']].head())
        
        # Optionally display with speakers if available
        if 'speaker' in df.columns:
            print("\nFirst few transcriptions with speakers:")
            print(df[['start_time', 'end_time', 'speaker', 'text']].head())
        
    except Exception as e:
        print(f"Error occurred: {str(e)}")
        import traceback
        traceback.print_exc()

Loading model...


  checkpoint = torch.load(fp, map_location=device)


Transcribing...




[01:27.180 --> 01:29.980]  Please transcribe all speech, including infant vocalizations, parent speech, and any notable sounds.
[01:53.640 --> 01:57.700]  Please transcribe all speech, including infant vocalizations, parent speech, and any notable sounds.
[02:05.740 --> 02:13.280]  Please transcribe all speech, including infant vocalizations, parent speech, and any notable sounds.
[02:30.560 --> 02:32.580]  What do you have, a ring?
[02:34.060 --> 02:35.080]  Squeak!
[02:38.000 --> 02:38.520]  Squeak!
[02:40.200 --> 02:42.220]  Uh-uh, a ring?
[02:44.060 --> 02:44.840]  Yeah.
[02:45.880 --> 02:46.460]  Ah!
[02:51.000 --> 02:51.100]  Ah!
[02:52.300 --> 02:53.720]  Yeah, it's a ring.
[02:54.660 --> 02:55.980]  Thank you.
[02:58.000 --> 02:58.620]  Thank you.
[03:00.520 --> 03:01.860]  Can I have that?
[03:02.280 --> 03:02.940]  Thank you.
[03:03.200 --> 03:04.180]  You want it back?
[03:04.740 --> 03:05.540]  It squeaks.
[03:07.040 --> 03:08.300]  You want the ring?
[03:09.680 --> 03:10.7

In [14]:
import pandas as pd
import numpy as np
from difflib import SequenceMatcher
from tqdm import tqdm
import string
import re
from typing import Tuple, Dict, List

class EnhancedTranscriptionComparator:
    def __init__(self, whisper_df: pd.DataFrame, human_df: pd.DataFrame):
        """Initialize with enhanced text cleaning and comparison capabilities"""
        self.whisper_data = whisper_df.copy()
        self.human_data = human_df.copy()
        
        # Pre-clean the data
        self.human_data = self.human_data[self.human_data['text'].notna()]
        if 'end ' in self.human_data.columns:
            self.human_data = self.human_data.rename(columns={'end ': 'end'})
        
        # Convert times to numeric
        self.human_data['start'] = pd.to_numeric(self.human_data['start'])
        self.human_data['end'] = pd.to_numeric(self.human_data['end'])
        
        # Clean text and store originals
        self.whisper_data['clean_text'] = self.whisper_data['text'].apply(self.clean_text)
        self.human_data['clean_text'] = self.human_data['text'].apply(self.clean_text)
        self.whisper_data['original_text'] = self.whisper_data['text']
        self.human_data['original_text'] = self.human_data['text']
        
        # Find overlap range
        self.overlap_start = max(self.whisper_data['start'].min(), self.human_data['start'].min())
        self.overlap_end = min(self.whisper_data['end'].max(), self.human_data['end'].max())

    @staticmethod
    def clean_text(text: str) -> str:
        """Enhanced text cleaning function"""
        if pd.isna(text):
            return ""
            
        text = str(text).lower()
        text = text.translate(str.maketrans("", "", string.punctuation))
        text = " ".join(text.split())
        
        replacements = {
            'uhm': 'um', 'uhh': 'uh', 'hmm': 'hm',
            'mhm': 'mm', 'yeah': 'yes', 'yah': 'yes',
            'nah': 'no'
        }
        
        for old, new in replacements.items():
            text = re.sub(r'\b' + old + r'\b', new, text)
            
        return text

    def get_word_metrics(self, human_text: str, whisper_text: str) -> Dict[str, int]:
        """Calculate detailed word-level metrics"""
        human_words = set(self.clean_text(human_text).split())
        whisper_words = set(self.clean_text(whisper_text).split())
        
        correct_words = len(human_words.intersection(whisper_words))
        total_words_human = len(human_words)
        total_words_whisper = len(whisper_words)
        mismatches = max(total_words_human, total_words_whisper) - correct_words
        
        return {
            'word_count_human': total_words_human,
            'word_count_whisper': total_words_whisper,
            'correct_words': correct_words,
            'mismatches': mismatches
        }

    def generate_comparison_results(self, time_tolerance: float = 0.5, 
                                  similarity_threshold: float = 0.3) -> pd.DataFrame:
        """Generate comprehensive comparison results for spreadsheet"""
        results = []
        
        human_segments = self.human_data[
            (self.human_data['start'] >= self.overlap_start) &
            (self.human_data['end'] <= self.overlap_end)
        ]
        
        whisper_segments = self.whisper_data[
            (self.whisper_data['start'] >= self.overlap_start) &
            (self.whisper_data['end'] <= self.overlap_end)
        ]
        
        for _, human_seg in tqdm(human_segments.iterrows(), desc="Analyzing segments"):
            potential_matches = whisper_segments[
                (whisper_segments['start'] >= human_seg['start'] - time_tolerance) &
                (whisper_segments['start'] <= human_seg['end'] + time_tolerance)
            ]
            
            for _, whisper_seg in potential_matches.iterrows():
                similarity = SequenceMatcher(
                    None,
                    human_seg['clean_text'],
                    whisper_seg['clean_text']
                ).ratio()
                
                if similarity > similarity_threshold:
                    # Get word-level metrics
                    word_metrics = self.get_word_metrics(
                        human_seg['original_text'],
                        whisper_seg['original_text']
                    )
                    
                    # Calculate accuracy
                    accuracy = (word_metrics['correct_words'] / word_metrics['word_count_human'] * 100) \
                        if word_metrics['word_count_human'] > 0 else 0
                    
                    # Generate comments
                    comments = []
                    if similarity < 0.5:
                        comments.append("Low similarity score")
                    if abs(word_metrics['word_count_human'] - word_metrics['word_count_whisper']) > 3:
                        comments.append("Significant word count difference")
                    if accuracy < 60:
                        comments.append("Low accuracy")
                    
                    results.append({
                        'Time Range': f"{human_seg['start']:.1f}-{human_seg['end']:.1f}",
                        'Whisper Transcription': whisper_seg['original_text'],
                        'Human Transcription': human_seg['original_text'],
                        'SequenceMatcher Score': round(similarity, 3),
                        'Word Count (Human)': word_metrics['word_count_human'],
                        'Word Count (Whisper)': word_metrics['word_count_whisper'],
                        'Correct Words': word_metrics['correct_words'],
                        'Mismatches': word_metrics['mismatches'],
                        'Accuracy (%)': round(accuracy, 1),
                        'Comments/Notes': "; ".join(comments) if comments else "OK"
                    })
        
        return pd.DataFrame(results)

def compare_and_export(whisper_file: pd.DataFrame, human_file: pd.DataFrame, 
                      output_path: str = "transcription_comparison.csv") -> Tuple[pd.DataFrame, Dict]:
    """Compare transcriptions and export results"""
    # Initialize comparator
    comparator = EnhancedTranscriptionComparator(whisper_file, human_file)
    
    # Generate comparison results
    results_df = comparator.generate_comparison_results()
    
    # Calculate summary statistics
    stats = {
        'Total Segments': len(results_df),
        'Average Accuracy': results_df['Accuracy (%)'].mean(),
        'Average Similarity': results_df['SequenceMatcher Score'].mean(),
        'High Quality Matches': len(results_df[results_df['Accuracy (%)'] > 80]),
        'Low Quality Matches': len(results_df[results_df['Accuracy (%)'] < 60])
    }
    
    # Export to CSV
    results_df.to_csv(output_path, index=False)
    
    # Print summary
    print("\nComparison Summary:")
    print(f"Total segments analyzed: {stats['Total Segments']}")
    print(f"Average accuracy: {stats['Average Accuracy']:.1f}%")
    print(f"Average similarity score: {stats['Average Similarity']:.3f}")
    print(f"High quality matches (>80%): {stats['High Quality Matches']}")
    print(f"Low quality matches (<60%): {stats['Low Quality Matches']}")
    
    return results_df, stats

# Example usage
if __name__ == "__main__":
    # Assume whisper_df and human_df are your input DataFrames
    results_df, stats = compare_and_export(whisper_df, human_df, "/Users/yueyan/Documents/project/wearable/transcription/12mon/002_12_transcription_comparison.csv")
    
    # Display first few results
    print("\nSample comparison results:")
    print(results_df[['Time Range', 'Human Transcription', 
                     'Whisper Transcription', 'Accuracy (%)']].head())

Analyzing segments: 118it [00:00, 2660.46it/s]


Comparison Summary:
Total segments analyzed: 95
Average accuracy: 65.3%
Average similarity score: 0.756
High quality matches (>80%): 39
Low quality matches (<60%): 38

Sample comparison results:
    Time Range Human Transcription Whisper Transcription  Accuracy (%)
0  157.7-158.4             squeak.               Squeak!         100.0
1  164.1-165.3               yeah.                 Yeah.         100.0
2  172.5-174.0  yeah it's a rings.    Yeah, it's a ring.          75.0
3  177.9-178.8          thank you.            Thank you.         100.0
4  182.3-183.0          thank you.            Thank you.         100.0



