# Reiss Lab Data Analysis Pipeline

## 1. Setup and Imports

Import the necessary libraries for data manipulation, analysis, and visualization.

In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os

# Configure plotting style
sns.set_theme(style="whitegrid")

## 2. Data Loading and Preprocessing

Load the raw data files for the vowel, consonant, and CRM experiments. We will also perform initial preprocessing, such as adding column names and mapping identifiers to human-readable labels.

In [2]:
# Prompt user to input the directory path containing the data files
base_path = input("Enter the directory path containing your data files: ").strip()

# Validate that the path exists
if not os.path.isdir(base_path):
    raise FileNotFoundError(f"Directory not found: {base_path}")

# Extract subject ID from the directory name (assumes format like 'CI148')
subject_id = os.path.basename(base_path)
print(f"Processing data for subject: {subject_id}")
print(f"Data directory: {base_path}")

# --- File Paths ---
vowel_bm_path = os.path.join(base_path, f'{subject_id}_vow9_BM_0.txt')
vowel_ci_path = os.path.join(base_path, f'{subject_id}_vow9_CI_0.txt')
consonant_path = os.path.join(base_path, f'{subject_id}_cons_BM_n_0.out')

# For CRM, we'll start by loading one file and create a list for scalability
crm_files = [os.path.join(base_path, f'{subject_id}_crm_{i}.txt') for i in range(11)]

# Show available files in directory
print(f"\nFiles found in directory:")
for f in sorted(os.listdir(base_path)):
    print(f"  {f}")

Enter the directory path containing your data files:  /Users/samipkafle/Downloads/reiss_lab_analysis/data/Data/CI148


Processing data for subject: CI148
Data directory: /Users/samipkafle/Downloads/reiss_lab_analysis/data/Data/CI148

Files found in directory:
  .DS_Store
  CI148_cons_BM_n_0.out
  CI148_consch.all
  CI148_crm_0.txt
  CI148_crm_1.txt
  CI148_crm_10.txt
  CI148_crm_2.txt
  CI148_crm_3.txt
  CI148_crm_4.txt
  CI148_crm_5.txt
  CI148_crm_6.txt
  CI148_crm_7.txt
  CI148_crm_8.txt
  CI148_crm_9.txt
  CI148_vow9_BM_0.txt
  CI148_vow9_CI_0.txt
  CI148_vowch.all
  test_vowch.all


### 2.1 Vowel Data

In [3]:
vowel_cols = ['talker_id', 'vowel_id', 'response_id', 'score', 'rt']

# Load individual vowel files
df_vowel_bm = pd.read_csv(vowel_bm_path, sep='\\s+', header=None, names=vowel_cols)
df_vowel_ci = pd.read_csv(vowel_ci_path, sep='\\s+', header=None, names=vowel_cols)

# Add a 'condition' column to distinguish them
df_vowel_bm['condition'] = 'BM' # Bimodal
df_vowel_ci['condition'] = 'CI' # Cochlear Implant

# Combine into a single DataFrame
df_vowel = pd.concat([df_vowel_bm, df_vowel_ci], ignore_index=True)

# Define vowel labels from documentation
vowel_map = {
    1: 'AE', 2: 'AH', 3: 'AW', 4: 'EH', 5: 'IH',
    6: 'IY', 7: 'OO', 8: 'UH', 9: 'UW'
}

# Map IDs to human-readable labels
df_vowel['vowel_label'] = df_vowel['vowel_id'].map(vowel_map)
df_vowel['response_label'] = df_vowel['response_id'].map(vowel_map)

print("Vowel data loaded and preprocessed:")
df_vowel.info()
df_vowel.head()

Vowel data loaded and preprocessed:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 360 entries, 0 to 359
Data columns (total 8 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   talker_id       360 non-null    int64  
 1   vowel_id        360 non-null    int64  
 2   response_id     360 non-null    int64  
 3   score           360 non-null    int64  
 4   rt              360 non-null    float64
 5   condition       360 non-null    object 
 6   vowel_label     360 non-null    object 
 7   response_label  360 non-null    object 
dtypes: float64(1), int64(4), object(3)
memory usage: 22.6+ KB


Unnamed: 0,talker_id,vowel_id,response_id,score,rt,condition,vowel_label,response_label
0,16,7,7,1,1.8965,BM,OO,OO
1,8,7,7,1,3.5742,BM,OO,OO
2,9,3,2,0,1.7898,BM,AW,AH
3,4,8,8,1,1.6783,BM,UH,UH
4,4,7,7,1,2.1938,BM,OO,OO


### 2.2 Consonant Data

In [4]:
consonant_cols = ['talker_id', 'consonant_id', 'response_id', 'score', 'rt']

# Load the consonant file
df_consonant = pd.read_csv(consonant_path, sep='\\s+', header=None, names=consonant_cols)

# Define consonant labels from documentation
consonant_map = {
    1: '#', 2: '_', 3: 'b', 4: 'd', 5: 'f', 6: 'g', 7: 'k',
    8: 'm', 9: 'n', 10: '%', 11: 'p', 12: 's', 13: 't',
    14: 'v', 15: 'z', 16: '$'
}

# Map IDs to human-readable labels
df_consonant['consonant_label'] = df_consonant['consonant_id'].map(consonant_map)
df_consonant['response_label'] = df_consonant['response_id'].map(consonant_map)

print("Consonant data loaded and preprocessed:")
df_consonant.info()
df_consonant.head()

Consonant data loaded and preprocessed:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 64 entries, 0 to 63
Data columns (total 7 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   talker_id        64 non-null     int64  
 1   consonant_id     64 non-null     int64  
 2   response_id      64 non-null     int64  
 3   score            64 non-null     int64  
 4   rt               64 non-null     float64
 5   consonant_label  64 non-null     object 
 6   response_label   64 non-null     object 
dtypes: float64(1), int64(4), object(2)
memory usage: 3.6+ KB


Unnamed: 0,talker_id,consonant_id,response_id,score,rt,consonant_label,response_label
0,4,10,1,0,20.8202,%,#
1,3,2,2,1,3.4925,_,_
2,4,7,7,1,1.663,k,k
3,3,9,9,1,1.83,n,n
4,2,8,8,1,2.181,m,m


### 2.3 CRM Data

In [5]:
import re
import ast

def parse_crm_header(filepath):
    """Parse CRM file header to extract talker and masker info."""
    with open(filepath, 'r') as f:
        header = f.readline()
    match = re.search(r'Talker (\\d+), Maskers (\\d+) and (\\d+)', header)
    if match:
        talker = int(match.group(1))
        masker1 = int(match.group(2))
        masker2 = int(match.group(3))
        return talker, masker1, masker2
    return None, None, None

def get_gender(talker_id):
    """Talkers 0-3 are male, 4-7 are female."""
    return 'M' if talker_id <= 3 else 'F'

def get_masker_type(talker, masker1, masker2):
    """Determine if maskers are same or different gender from target."""
    target_gender = get_gender(talker)
    masker1_gender = get_gender(masker1)
    masker2_gender = get_gender(masker2)
    if target_gender == masker1_gender == masker2_gender:
        return 'same'
    elif masker1_gender == masker2_gender and masker1_gender != target_gender:
        return 'different'
    else:
        return 'mixed'

def calculate_srt_from_trials(df_run):
    """
    Calculate SRT using adaptive staircase reversal method.
    Assumptions (matching MATLAB code):
    - A trial is correct if BOTH color and number are correct.
    - SRT = mean of reversals 5-14.
    """
    snr_values = df_run['snr'].values
    correct = (df_run['target_color'] == df_run['response_color']) & (df_run['target_number'] == df_run['response_number'])
    
    reversals = []
    if len(correct) > 1:
        prev_correct = correct.iloc[0]
        for i in range(1, len(correct)):
            if correct.iloc[i] != prev_correct:
                reversals.append(snr_values[i])
            prev_correct = correct.iloc[i]

    if len(reversals) >= 14:
        srt_reversals = reversals[4:14]
        srt = np.mean(srt_reversals)
        sd = np.std(srt_reversals, ddof=0)
    elif len(reversals) >= 5:
        srt_reversals = reversals[4:]
        srt = np.mean(srt_reversals)
        sd = np.std(srt_reversals, ddof=0)
    else:
        srt, sd = np.nan, np.nan
    
    return srt, sd, len(reversals)

# --- CRM Data Loading ---
crm_files_found = sorted([f for f in os.listdir(base_path) if '_crm_' in f and f.endswith('.txt')])
print(f"Found {len(crm_files_found)} CRM files:\n")
for f in crm_files_found:
    print(f"  {f}")

print("\nPlease assign a condition (e.g., BM, CI, HA) to the CRM files.")
print("Enter assignments as a dictionary: {'CONDITION': 'file_numbers'}")
print("Example: {'BM': '1,2,3', 'CI': '4,5,6'}")
assignments_str = input("Enter assignments: ").strip()

try:
    assignments = ast.literal_eval(assignments_str)
except (ValueError, SyntaxError):
    print("Invalid dictionary format. Please restart and try again.")
    assignments = {}

file_to_condition = {}
for condition, numbers_str in assignments.items():
    try:
        numbers = [int(n.strip()) for n in numbers_str.split(',')]
        for num in numbers:
            file_to_condition[num] = condition.strip().upper()
    except ValueError:
        print(f"Warning: Could not parse numbers for condition '{condition}'. Please check formatting.")

crm_cols = ['run', 'target_color', 'response_color', 'target_number', 'response_number', 'snr', 'rt']
crm_data_frames = []
crm_summary = []

for filename in crm_files_found:
    filepath = os.path.join(base_path, filename)
    talker, masker1, masker2 = parse_crm_header(filepath)
    masker_type = get_masker_type(talker, masker1, masker2)

    file_num_match = re.search(r'_crm_(\\d+)\\.txt', filename)
    condition = 'UNKNOWN'
    if file_num_match:
        file_num = int(file_num_match.group(1))
        condition = file_to_condition.get(file_num, 'UNKNOWN')

    df_temp = pd.read_csv(filepath, sep='\\s+', header=None, skiprows=2, names=crm_cols, on_bad_lines='skip')
    df_temp = df_temp[pd.to_numeric(df_temp['run'], errors='coerce').notna()].copy()
    df_temp = df_temp.astype(dict.fromkeys(df_temp.columns[:5], int) | dict.fromkeys(df_temp.columns[5:], float))

    srt, sd, n_reversals = calculate_srt_from_trials(df_temp)

    df_temp['filename'] = filename
    df_temp['condition'] = condition
    df_temp['talker'] = talker
    df_temp['masker_type'] = masker_type
    crm_data_frames.append(df_temp)

    crm_summary.append({
        'filename': filename, 'condition': condition, 'talker': talker,
        'talker_gender': get_gender(talker), 'masker_type': masker_type, 'n_trials': len(df_temp),
        'n_reversals': n_reversals, 'srt': srt, 'sd': sd
    })

if crm_data_frames:
    df_crm = pd.concat(crm_data_frames, ignore_index=True)
    df_crm_summary = pd.DataFrame(crm_summary)
    print(f"\n{'='*60}")
    print(f"Total CRM files loaded: {len(crm_data_frames)}")
    print(f"Total CRM trials: {len(df_crm)}")
else:
    print("No CRM files were loaded!")

Found 11 CRM files:

  CI148_crm_0.txt
  CI148_crm_1.txt
  CI148_crm_10.txt
  CI148_crm_2.txt
  CI148_crm_3.txt
  CI148_crm_4.txt
  CI148_crm_5.txt
  CI148_crm_6.txt
  CI148_crm_7.txt
  CI148_crm_8.txt
  CI148_crm_9.txt

Please assign a condition (e.g., BM, CI, HA) to the CRM files.
Enter assignments as a dictionary: {'CONDITION': 'file_numbers'}
Example: {'BM': '1,2,3', 'CI': '4,5,6'}


Enter assignments:  {'BM': '1,2,3,8', 'CI': '4, 6, 10', 'HA': '5,7,9'}


TypeError: '<=' not supported between instances of 'NoneType' and 'int'

## 3. Analysis and Visualization

### 3.1 Vowel Analysis

In [None]:
# Overall Accuracy
vowel_accuracy = df_vowel['score'].mean() * 100
print(f"Overall Vowel Accuracy: {vowel_accuracy:.2f}%\n")

# Confusion Matrix
vowel_labels = [v for k, v in sorted(vowel_map.items())]
vowel_cm_counts = pd.crosstab(df_vowel['vowel_label'], df_vowel['response_label'], rownames=['Target'], colnames=['Response'], margins=False, dropna=False).reindex(index=vowel_labels, columns=vowel_labels, fill_value=0)
vowel_cm_probs = vowel_cm_counts.div(vowel_cm_counts.sum(axis=1), axis=0).fillna(0)
print("Vowel Confusion Matrix (Probabilities):")
print(vowel_cm_probs)

# Per-Vowel Accuracy
per_vowel_accuracy = pd.DataFrame(np.diag(vowel_cm_probs) * 100, index=vowel_cm_probs.index, columns=['Accuracy'])
print("\nPer-Vowel Accuracy:")
print(per_vowel_accuracy)

# By-Talker Accuracy
vowel_talker_accuracy = df_vowel.groupby('talker_id')['score'].mean().reset_index()
print("\nBy-Talker Vowel Accuracy:")
print(vowel_talker_accuracy)

# Plotting
plt.figure(figsize=(10, 8))
sns.heatmap(vowel_cm_probs, annot=True, fmt='.2f', cmap='viridis')
plt.title('Vowel Confusion Matrix')
plt.savefig('vowel_confusion_matrix.png')
plt.show()

plt.figure(figsize=(10, 6))
sns.barplot(x=per_vowel_accuracy.index, y='Accuracy', data=per_vowel_accuracy)
plt.title('Per-Vowel Accuracy')
plt.savefig('per_vowel_accuracy.png')
plt.show()

plt.figure(figsize=(10, 6))
sns.barplot(x='talker_id', y='score', data=vowel_talker_accuracy)
plt.title('By-Talker Vowel Accuracy')
plt.savefig('vowel_talker_accuracy.png')
plt.show()

plt.figure(figsize=(10, 6))
sns.histplot(df_vowel['rt'], bins=20, kde=True)
plt.title('Vowel Response Time Distribution')
plt.savefig('vowel_rt_histogram.png')
plt.show()

# Save data
vowel_cm_probs.to_csv('vowel_confusion_matrix.csv')
per_vowel_accuracy.to_csv('per_vowel_accuracy.csv')
vowel_talker_accuracy.to_csv('vowel_talker_accuracy.csv')

### 3.2 Consonant Analysis

In [None]:
# Overall Accuracy
consonant_accuracy = df_consonant['score'].mean() * 100
print(f"Overall Consonant Accuracy: {consonant_accuracy:.2f}%\n")

# Confusion Matrix
consonant_labels = [v for k, v in sorted(consonant_map.items())]
consonant_cm_counts = pd.crosstab(df_consonant['consonant_label'], df_consonant['response_label'], rownames=['Target'], colnames=['Response'], margins=False, dropna=False).reindex(index=consonant_labels, columns=consonant_labels, fill_value=0)
consonant_cm_probs = consonant_cm_counts.div(consonant_cm_counts.sum(axis=1), axis=0).fillna(0)
print("Consonant Confusion Matrix (Probabilities):")
print(consonant_cm_probs)

# Per-Consonant Accuracy
per_consonant_accuracy = pd.DataFrame(np.diag(consonant_cm_probs) * 100, index=consonant_cm_probs.index, columns=['Accuracy'])
print("\nPer-Consonant Accuracy:")
print(per_consonant_accuracy)

# By-Talker Accuracy
consonant_talker_accuracy = df_consonant.groupby('talker_id')['score'].mean().reset_index()
print("\nBy-Talker Consonant Accuracy:")
print(consonant_talker_accuracy)

# Plotting
plt.figure(figsize=(12, 10))
sns.heatmap(consonant_cm_probs, annot=True, fmt='.2f', cmap='viridis')
plt.title('Consonant Confusion Matrix')
plt.savefig('consonant_confusion_matrix.png')
plt.show()

plt.figure(figsize=(10, 6))
sns.barplot(x=per_consonant_accuracy.index, y='Accuracy', data=per_consonant_accuracy)
plt.title('Per-Consonant Accuracy')
plt.savefig('per_consonant_accuracy.png')
plt.show()

plt.figure(figsize=(10, 6))
sns.barplot(x='talker_id', y='score', data=consonant_talker_accuracy)
plt.title('By-Talker Consonant Accuracy')
plt.savefig('consonant_talker_accuracy.png')
plt.show()

plt.figure(figsize=(10, 6))
sns.histplot(df_consonant['rt'], bins=20, kde=True)
plt.title('Consonant Response Time Distribution')
plt.savefig('consonant_rt_histogram.png')
plt.show()

# Save data
consonant_cm_probs.to_csv('consonant_confusion_matrix.csv')
per_consonant_accuracy.to_csv('per_consonant_accuracy.csv')
consonant_talker_accuracy.to_csv('consonant_talker_accuracy.csv')

### 3.3 CRM Analysis

In [None]:
# Display summary table
print("CRM Summary Table:")
print("="*80)
display_cols = ['filename', 'condition', 'talker_gender', 'masker_type', 'n_trials', 'srt', 'sd']
print(df_crm_summary[display_cols].to_string(index=False))
print()

# Summary statistics by condition
print("\nSRT by Condition:")
print("-"*40)
condition_stats = df_crm_summary.groupby('condition').agg({
    'srt': ['mean', 'std', 'count'],
    'sd': 'mean'
}).round(2)
condition_stats.columns = ['Mean SRT (dB)', 'SRT SD', 'N runs', 'Mean within-run SD']
print(condition_stats)
print()

# Summary statistics by masker type
print("\nSRT by Masker Type:")
print("-"*40)
masker_stats = df_crm_summary.groupby('masker_type').agg({
    'srt': ['mean', 'std', 'count']
}).round(2)
masker_stats.columns = ['Mean SRT (dB)', 'SRT SD', 'N runs']
print(masker_stats)
print()

# Summary by condition AND masker type (for VGRM-like analysis)
print("\nSRT by Condition × Masker Type:")
print("-"*40)
cross_stats = df_crm_summary.groupby(['condition', 'masker_type']).agg({
    'srt': ['mean', 'std', 'count']
}).round(2)
cross_stats.columns = ['Mean SRT (dB)', 'SRT SD', 'N']
print(cross_stats)
print()

# Calculate Voice-Gender Release from Masking (VGRM) if both conditions exist
print("\nVoice-Gender Release from Masking (VGRM):")
print("-"*40)
print("VGRM = SRT(same-gender) - SRT(different-gender)")
print("Positive values indicate benefit from different-gender maskers\n")

for condition in df_crm_summary['condition'].unique():
    cond_data = df_crm_summary[df_crm_summary['condition'] == condition]
    same_srt = cond_data[cond_data['masker_type'] == 'same']['srt'].mean()
    diff_srt = cond_data[cond_data['masker_type'] == 'different']['srt'].mean()
    
    if not np.isnan(same_srt) and not np.isnan(diff_srt):
        vgrm = same_srt - diff_srt
        print(f"{condition}: VGRM = {vgrm:.2f} dB (same: {same_srt:.2f}, diff: {diff_srt:.2f})")
    else:
        print(f"{condition}: Insufficient data for VGRM calculation")

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot 1: SRT by condition
sns.barplot(data=df_crm_summary, x='condition', y='srt', ax=axes[0], errorbar='sd')
axes[0].set_xlabel('Condition')
axes[0].set_ylabel('SRT (dB)')
axes[0].set_title('SRT by Condition')
axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# Plot 2: SRT by masker type
sns.barplot(data=df_crm_summary, x='masker_type', y='srt', ax=axes[1], errorbar='sd')
axes[1].set_xlabel('Masker Type')
axes[1].set_ylabel('SRT (dB)')
axes[1].set_title('SRT by Masker Gender')
axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# Plot 3: SRT by condition and masker type
sns.barplot(data=df_crm_summary, x='condition', y='srt', hue='masker_type', ax=axes[2], errorbar='sd')
axes[2].set_xlabel('Condition')
axes[2].set_ylabel('SRT (dB)')
axes[2].set_title('SRT by Condition × Masker Type')
axes[2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
axes[2].legend(title='Masker Type')

plt.tight_layout()
plt.savefig('crm_srt_analysis.png', dpi=150)
plt.show()

# Save summary data
df_crm_summary.to_csv('crm_summary.csv', index=False)
print("\nSummary saved to crm_summary.csv")

## 4. Exploratory Analysis

Additional visualizations and analyses to explore patterns in the data before collecting more subjects.

In [None]:
# 4.1 Trial-by-trial SNR trajectory for each run
print("Trial-by-trial SNR trajectories (adaptive staircase visualization):")
print("="*60)

n_files = len(df_crm['filename'].unique())
n_cols = min(3, n_files)
n_rows = int(np.ceil(n_files / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows), squeeze=False)
axes = axes.flatten()

for idx, (filename, group) in enumerate(df_crm.groupby('filename')):
    ax = axes[idx]
    trials = range(1, len(group) + 1)
    ax.plot(trials, group['snr'].values, 'b-o', markersize=3, alpha=0.7)
    
    # Mark correct vs incorrect trials
    correct = (group['target_color'].values == group['response_color'].values) & \
              (group['target_number'].values == group['response_number'].values)
    ax.scatter(np.array(trials)[correct], group['snr'].values[correct], 
               c='green', s=20, zorder=5, label='Correct')
    ax.scatter(np.array(trials)[~correct], group['snr'].values[~correct], 
               c='red', s=20, zorder=5, label='Incorrect')
    
    # Get SRT for this file
    file_srt = df_crm_summary[df_crm_summary['filename'] == filename]['srt'].values[0]
    ax.axhline(y=file_srt, color='purple', linestyle='--', alpha=0.7, 
               label=f'SRT={file_srt:.1f}')
    
    condition = group['condition'].iloc[0]
    masker_type = group['masker_type'].iloc[0]
    ax.set_title(f"{filename}\n{condition}, {masker_type}-gender", fontsize=9)
    ax.set_xlabel('Trial')
    ax.set_ylabel('SNR (dB)')
    ax.axhline(y=0, color='gray', linestyle=':', alpha=0.5)
    if idx == 0:
        ax.legend(fontsize=7, loc='upper right')

# Hide unused subplots
for idx in range(n_files, len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout()
plt.savefig('crm_staircase_trajectories.png', dpi=150)
plt.show()

In [None]:
# 4.2 Response time analysis
print("Response Time Analysis:")
print("="*60)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# RT distribution overall
sns.histplot(df_crm['rt'], bins=30, kde=True, ax=axes[0])
axes[0].set_xlabel('Response Time (s)')
axes[0].set_title('Overall RT Distribution')
axes[0].axvline(x=df_crm['rt'].median(), color='r', linestyle='--', 
                label=f'Median: {df_crm["rt"].median():.2f}s')
axes[0].legend()

# RT by condition
sns.boxplot(data=df_crm, x='condition', y='rt', ax=axes[1])
axes[1].set_xlabel('Condition')
axes[1].set_ylabel('Response Time (s)')
axes[1].set_title('RT by Condition')

# RT vs SNR (does RT increase at harder SNRs?)
axes[2].scatter(df_crm['snr'], df_crm['rt'], alpha=0.3, s=10)
# Add trend line
z = np.polyfit(df_crm['snr'], df_crm['rt'], 1)
p = np.poly1d(z)
snr_range = np.linspace(df_crm['snr'].min(), df_crm['snr'].max(), 100)
axes[2].plot(snr_range, p(snr_range), 'r-', alpha=0.8, label=f'Trend')
axes[2].set_xlabel('SNR (dB)')
axes[2].set_ylabel('Response Time (s)')
axes[2].set_title('RT vs SNR')
axes[2].legend()

plt.tight_layout()
plt.savefig('crm_response_time_analysis.png', dpi=150)
plt.show()

# RT statistics
print("\nRT Statistics by Condition:")
rt_stats = df_crm.groupby('condition')['rt'].agg(['mean', 'median', 'std']).round(2)
rt_stats.columns = ['Mean RT (s)', 'Median RT (s)', 'RT SD (s)']
print(rt_stats)

In [None]:
# 4.3 Error pattern analysis
print("Error Pattern Analysis:")
print("="*60)

# Calculate error types
df_crm['color_correct'] = df_crm['target_color'] == df_crm['response_color']
df_crm['number_correct'] = df_crm['target_number'] == df_crm['response_number']
df_crm['both_correct'] = df_crm['color_correct'] & df_crm['number_correct']

# Error breakdown
def categorize_error(row):
    if row['both_correct']:
        return 'Correct'
    elif row['color_correct'] and not row['number_correct']:
        return 'Number Error'
    elif not row['color_correct'] and row['number_correct']:
        return 'Color Error'
    else:
        return 'Both Error'

df_crm['error_type'] = df_crm.apply(categorize_error, axis=1)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Overall error breakdown
error_counts = df_crm['error_type'].value_counts()
colors = {'Correct': 'green', 'Number Error': 'orange', 'Color Error': 'blue', 'Both Error': 'red'}
error_counts.plot(kind='bar', ax=axes[0], color=[colors.get(x, 'gray') for x in error_counts.index])
axes[0].set_xlabel('Error Type')
axes[0].set_ylabel('Count')
axes[0].set_title('Overall Error Breakdown')
axes[0].tick_params(axis='x', rotation=45)

# Error rates by condition
error_by_condition = pd.crosstab(df_crm['condition'], df_crm['error_type'], normalize='index') * 100
error_by_condition.plot(kind='bar', ax=axes[1], color=[colors.get(x, 'gray') for x in error_by_condition.columns])
axes[1].set_xlabel('Condition')
axes[1].set_ylabel('Percentage')
axes[1].set_title('Error Types by Condition (%)')
axes[1].tick_params(axis='x', rotation=0)
axes[1].legend(title='Error Type', bbox_to_anchor=(1.05, 1))

plt.tight_layout()
plt.savefig('crm_error_analysis.png', dpi=150)
plt.show()

print("\nError rates by condition:")
print(error_by_condition.round(1))

In [None]:
# 4.4 Individual run comparison plot
print("Individual Run Comparison:")
print("="*60)

# Sort by SRT for visualization
df_sorted = df_crm_summary.sort_values('srt')

fig, ax = plt.subplots(figsize=(12, 6))

# Create color map for conditions
condition_colors = {'BM': 'blue', 'CI': 'red', 'HA': 'green', 'UNKNOWN': 'gray'}
colors = [condition_colors.get(c, 'gray') for c in df_sorted['condition']]

# Create hatching for masker type
bars = ax.bar(range(len(df_sorted)), df_sorted['srt'], color=colors, 
              edgecolor='black', linewidth=1)

# Add error bars for SD
ax.errorbar(range(len(df_sorted)), df_sorted['srt'], yerr=df_sorted['sd'],
            fmt='none', color='black', capsize=3)

# Add hatching for different-gender maskers
for idx, (bar, masker_type) in enumerate(zip(bars, df_sorted['masker_type'])):
    if masker_type == 'different':
        bar.set_hatch('//')

ax.set_xticks(range(len(df_sorted)))
ax.set_xticklabels([f.replace(subject_id + '_crm_', '').replace('.txt', '') 
                    for f in df_sorted['filename']], rotation=45)
ax.set_xlabel('Run Number')
ax.set_ylabel('SRT (dB)')
ax.set_title(f'SRT by Run for {subject_id}\n(Hatched = different-gender maskers)')
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# Legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=condition_colors[c], label=c) 
                   for c in df_sorted['condition'].unique()]
legend_elements.append(Patch(facecolor='white', edgecolor='black', hatch='//', label='Diff-gender'))
ax.legend(handles=legend_elements, loc='upper left')

plt.tight_layout()
plt.savefig('crm_individual_runs.png', dpi=150)
plt.show()

In [None]:
# 4.5 Summary statistics and data quality checks
print("Data Quality Summary:")
print("="*60)

print(f"\nSubject: {subject_id}")
print(f"Total CRM runs: {len(df_crm_summary)}")
print(f"Total trials: {len(df_crm)}")
print(f"\nConditions tested: {', '.join(df_crm_summary['condition'].unique())}")
print(f"Masker types: {', '.join(df_crm_summary['masker_type'].unique())}")

print("\nRuns per condition:")
print(df_crm_summary['condition'].value_counts())

print("\nRuns per masker type:")
print(df_crm_summary['masker_type'].value_counts())

print("\nReversal count check (should be 14 for valid runs):")
rev_check = df_crm_summary[['filename', 'n_reversals', 'srt']]
print(rev_check.to_string(index=False))

# Flag any runs with fewer than 14 reversals
low_rev = df_crm_summary[df_crm_summary['n_reversals'] < 14]
if len(low_rev) > 0:
    print("\n⚠️  Warning: The following runs have fewer than 14 reversals:")
    print(low_rev[['filename', 'n_reversals']].to_string(index=False))

print("\n" + "="*60)
print("OVERALL SUMMARY")
print("="*60)
print(f"\nMean SRT across all runs: {df_crm_summary['srt'].mean():.2f} ± {df_crm_summary['srt'].std():.2f} dB")
print(f"SRT range: {df_crm_summary['srt'].min():.2f} to {df_crm_summary['srt'].max():.2f} dB")
print(f"\nMean accuracy: {df_crm['both_correct'].mean()*100:.1f}%")
print(f"Median response time: {df_crm['rt'].median():.2f} s")

# Version 1: Enhanced Phonetic & Distribution Analysis
This section digs deeper into *why* errors happen (Feature Analysis) and improves visualization.

In [None]:
print("--- 1.1 Phonetic Feature Analysis ---")

# Feature Map: Maps consonant characters to (Voicing, Place, Manner)
# 1 = Feature Present/High, 0 = Feature Absent/Low
feature_map = {
    'b': (1, 1, 0), 'd': (1, 0, 0), 'g': (1, 0, 0),
    'p': (0, 1, 0), 't': (0, 0, 0), 'k': (0, 0, 0),
    'm': (1, 1, 1), 'n': (1, 0, 1),
    'f': (0, 1, 2), 'v': (1, 1, 2), 's': (0, 0, 2), 'z': (1, 0, 2),
    '#': (0, 0, 2), '_': (1, 0, 2), # Sh, Zh
    '%': (0, 0, 3), '$': (1, 0, 3)  # Ch, J
}

def calculate_information_transfer(df, label_col, resp_col, feat_map):
    # Filter for valid keys
    valid = df[df[label_col].isin(feat_map.keys()) & df[resp_col].isin(feat_map.keys())]
    if len(valid) == 0: return None
    
    features = ['Voicing', 'Place', 'Manner']
    results = {}
    
    for i, feat_name in enumerate(features):
        t_feat = valid[label_col].apply(lambda x: feat_map[x][i])
        r_feat = valid[resp_col].apply(lambda x: feat_map[x][i])
        # Calculate simple percent correct for that feature
        acc = (t_feat == r_feat).mean() * 100
        results[feat_name] = acc
        
    return pd.Series(results)

if 'df_consonant' in locals():
    feat_res = calculate_information_transfer(df_consonant, 'consonant_label', 'response_label', feature_map)
    if feat_res is not None:
        print("Feature Transmission Rates (% Correct):")
        print(feat_res.round(2))
        
        # Plot
        plt.figure(figsize=(6, 4))
        feat_res.plot(kind='bar', color=['#1f77b4', '#ff7f0e', '#2ca02c'])
        plt.title('Phonetic Feature Transmission')
        plt.ylabel('% Correct')
        plt.ylim(0, 100)
        plt.show()

In [None]:
print("--- 1.2 Advanced Heatmaps (Cluster Organized) ---")

def plot_clustered_matrix(df, target_col, resp_col, map_dict, title):
    labels = [v for k, v in sorted(map_dict.items())]
    
    # Generate Counts
    cm = pd.crosstab(df[target_col], df[resp_col])
    # Reindex to ensure all labels exist
    cm = cm.reindex(index=labels, columns=labels, fill_value=0)
    # Normalize row-wise (Probability)
    cm_prob = cm.div(cm.sum(axis=1), axis=0).fillna(0)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_prob, annot=True, fmt='.2f', cmap='rocket_r', 
                mask=(cm_prob==0), linewidths=0.5, linecolor='lightgray')
    plt.title(title, fontsize=15)
    plt.ylabel('Target')
    plt.xlabel('Response')
    plt.show()

if 'df_vowel' in locals():
    plot_clustered_matrix(df_vowel, 'vowel_label', 'response_label', vowel_map, "Vowel Confusion Matrix")

if 'df_consonant' in locals():
    plot_clustered_matrix(df_consonant, 'consonant_label', 'response_label', cons_map, "Consonant Confusion Matrix")

In [None]:
print("--- 1.3 CRM Distributional Analysis (Violin Plot) ---")

if 'df_crm_summary' in locals() and not df_crm_summary.empty:
    plt.figure(figsize=(10, 6))
    
    # Violin plot shows density, Strip plot shows individual run data points
    sns.violinplot(x='condition', y='srt', data=df_crm_summary, inner=None, color='lightgray', linewidth=0)
    sns.stripplot(x='condition', y='srt', data=df_crm_summary, hue='masker_type', 
                  size=10, jitter=True, palette='bright', dodge=True, edgecolor='black', linewidth=1)
    
    plt.axhline(0, color='black', linestyle='--', alpha=0.3, label='0 dB SNR')
    plt.title('Speech Reception Thresholds by Condition and Masker', fontsize=14)
    plt.ylabel('SRT (dB SNR) - Lower is Better')
    plt.legend(title='Masker Gender', bbox_to_anchor=(1.05, 1))
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("CRM Summary data not available for plotting.")

# Version 2: Exploratory & Statistical Analysis
This section looks for temporal trends (fatigue/learning), performs statistical significance testing, and provides interactive exploration tools.

In [None]:
print("--- 2.1 Temporal Analysis (Fatigue / Learning Effects) ---")

def plot_temporal_trend(df, metric_col, title):
    # Reset index to get global trial count
    df_seq = df.reset_index(drop=True).reset_index().rename(columns={'index': 'trial'})
    
    # Calculate rolling average
    df_seq['rolling'] = df_seq[metric_col].rolling(window=20).mean()
    
    plt.figure(figsize=(12, 4))
    plt.scatter(df_seq['trial'], df_seq[metric_col], alpha=0.2, color='gray', s=10, label='Raw Trial')
    plt.plot(df_seq['trial'], df_seq['rolling'], color='red', linewidth=2, label='20-Trial Rolling Avg')
    
    # Simple Linear Regression for Trend
    z = np.polyfit(df_seq['trial'], df_seq[metric_col].fillna(0), 1)
    p = np.poly1d(z)
    plt.plot(df_seq['trial'], p(df_seq['trial']), "b--", alpha=0.8, label=f'Trend (Slope={z[0]:.4f})')
    
    plt.title(f'Temporal Trend: {title}')
    plt.xlabel('Trial Sequence')
    plt.ylabel(metric_col)
    plt.legend()
    plt.show()
    
    if z[0] < -0.001: print("-> Potential Fatigue/Decline detected (Negative Slope)")
    if z[0] > 0.001: print("-> Potential Learning/Improvement detected (Positive Slope)")

if 'df_vowel' in locals():
    plot_temporal_trend(df_vowel, 'score', "Vowel Identification Accuracy")

if 'df_crm' in locals():
    # Filter out ridiculously high SNRs (initial trials)
    crm_clean = df_crm[df_crm['snr'] < 20]
    plot_temporal_trend(crm_clean, 'snr', "CRM SNR Tracking (All Runs)")

In [None]:
print("--- 2.2 Statistical Testing (ANOVA) ---")

if 'df_crm_summary' in locals():
    # Filter valid data
    stat_df = df_crm_summary[df_crm_summary['condition'] != 'Unknown'].dropna()
    
    if len(stat_df['condition'].unique()) > 1:
        print("\nRunning ANOVA on CRM SRTs (Condition + MaskerType)...")
        try:
            model = ols('srt ~ C(condition) + C(masker_type)', data=stat_df).fit()
            anova_table = sm.stats.anova_lm(model, typ=2)
            print(anova_table)
            
            if anova_table['PR(>F)'][0] < 0.05:
                print("\n-> Significant Condition Effect detected. Suggest Post-hoc t-tests.")
            else:
                print("\n-> No significant differences detected between conditions (p > 0.05).")
        except Exception as e:
            print(f"Stats error: {e} (Likely insufficient data points)")
    else:
        print("Not enough conditions for ANOVA.")
        
print("\n--- 2.3 Advanced CRM Error Analysis ---")
# Breakdown errors by type: Color wrong? Number wrong? Both?

if 'df_crm' in locals():
    def classify_error(row):
        c_ok = row['target_color'] == row['response_color']
        n_ok = row['target_number'] == row['response_number']
        if c_ok and n_ok: return 'Correct'
        if c_ok and not n_ok: return 'Number Err'
        if not c_ok and n_ok: return 'Color Err'
        return 'Double Err'

    df_crm['err_type'] = df_crm.apply(classify_error, axis=1)
    
    err_counts = df_crm.groupby(['condition', 'err_type']).size().unstack(fill_value=0)
    # Normalize to percentage
    err_pct = err_counts.div(err_counts.sum(axis=1), axis=0) * 100
    
    err_pct.plot(kind='bar', stacked=True, colormap='viridis', figsize=(10, 5))
    plt.title('CRM Error Type Breakdown')
    plt.ylabel('% of Trials')
    plt.legend(bbox_to_anchor=(1.05, 1))
    plt.show()

In [None]:
from ipywidgets import interact, fixed

print("--- 2.4 Interactive Data Explorer ---")
print("(If visualizations do not appear, ensure all previous cells ran successfully)")

# -------------------------------------------------------
# 1. VOWEL EXPLORER
# -------------------------------------------------------
def plot_vowel_metrics(talker_id, condition):
    # scope check
    if 'df_vowel' not in globals():
        print("Error: df_vowel not loaded. Run Section 2 first.")
        return

    data = df_vowel.copy()
    
    # Filter
    if talker_id != 'All':
        data = data[data['talker_id'] == talker_id]
    if condition != 'All':
        data = data[data['condition'] == condition]
        
    if len(data) == 0:
        print("No data for this selection.")
        return

    # Calculate Metrics
    acc = data['score'].mean() * 100
    mean_rt = data['rt'].mean()
    
    print(f"--- Vowel Statistics (n={len(data)}) ---")
    print(f"Accuracy:      {acc:.2f}%")
    print(f"Mean Reaction: {mean_rt:.2f}s")
    
    # Simple Bar Plot for Accuracy per Vowel
    plt.figure(figsize=(10, 4))
    vowel_acc = data.groupby('vowel_label')['score'].mean() * 100
    sns.barplot(x=vowel_acc.index, y=vowel_acc.values, palette='viridis')
    plt.title('Accuracy by Phoneme')
    plt.ylim(0, 100)
    plt.ylabel('% Correct')
    plt.show()

if 'df_vowel' in locals():
    print("\n>> VOWEL EXPLORER")
    talkers = ['All'] + sorted(list(df_vowel['talker_id'].unique()))
    conds = ['All'] + sorted(list(df_vowel['condition'].unique()))
    
    interact(plot_vowel_metrics, talker_id=talkers, condition=conds)
else:
    print("Skipping Vowel Explorer (Data not loaded)")

# -------------------------------------------------------
# 2. CRM EXPLORER
# -------------------------------------------------------
def plot_crm_track(filename):
    # scope check
    if 'df_crm' not in globals():
        print("Error: df_crm not loaded.")
        return
        
    data = df_crm[df_crm['filename'] == filename].copy()
    if len(data) == 0: return
    
    # Calculate Correctness for color coding
    data['correct'] = (data['target_color'] == data['response_color']) & \
                      (data['target_number'] == data['response_number'])
    
    plt.figure(figsize=(10, 5))
    plt.plot(range(len(data)), data['snr'], 'b-', alpha=0.5)
    
    # Plot Correct as Green, Incorrect as Red
    correct_trials = data[data['correct']]
    incorrect_trials = data[~data['correct']]
    
    plt.scatter(correct_trials.index - data.index[0], correct_trials['snr'], c='green', label='Correct')
    plt.scatter(incorrect_trials.index - data.index[0], incorrect_trials['snr'], c='red', label='Incorrect')
    
    plt.title(f"Adaptive Track: {filename}")
    plt.xlabel("Trial")
    plt.ylabel("SNR (dB)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

if 'df_crm' in locals():
    print("\n>> CRM TRACK EXPLORER")
    files = sorted(list(df_crm['filename'].unique()))
    interact(plot_crm_track, filename=files)
else:
    print("Skipping CRM Explorer (Data not loaded)")