In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from sklearn.metrics import confusion_matrix, cohen_kappa_score

project_root = Path(os.getcwd()).parent
print(f"Project root set to: {project_root}")
sys.path.insert(0, str(project_root))

try:
    from configs.config import DATA_DIR, DERIVATIVES_DIR, FIGS_DIR
except ImportError:
    print("Warning: Could not import configs.config, using defaults")
    DATA_DIR = Path("./data")
    DERIVATIVES_DIR = Path("./derivatives")
    FIGS_DIR = Path("./figs")

RATERS = {'JC', 'AS', 'GL', 'KG'}

raters_dir = DATA_DIR / "rec/svf_ratings/"

def load_ratings(file_path):
    """Load ratings from a xlsx file into a pandas DataFrame."""
    df = pd.read_excel(file_path, engine='openpyxl')
    # Forward-fill category column to handle NaN values
    if 'category' in df.columns:
        df['category'] = df['category'].ffill()
    return df

def load_ratings_by_rater(rater: str) -> dict[str, pd.DataFrame]:
    """Load ratings for a specific rater from corresponding folder."""
    if rater not in RATERS:
        raise ValueError(f"Rater '{rater}' is not recognized. Valid raters are: {RATERS}")
    rater_dir = os.path.join(raters_dir, rater)
    xlsx_files = [f for f in os.listdir(rater_dir) if f.endswith('.xlsx')]
    ratings_session = {}
    for file in xlsx_files:
        session_name = file.split('_task-svf')[0]
        file_path = os.path.join(rater_dir, file)
        ratings_session[session_name] = load_ratings(file_path)
    return ratings_session

def load_all_ratings() -> dict[str, dict[str, pd.DataFrame]]:
    """Load ratings from all raters."""
    all_ratings = {}
    for rater in RATERS:
        try:
            all_ratings[rater] = load_ratings_by_rater(rater)
        except FileNotFoundError:
            print(f"Warning: Directory for rater '{rater}' not found.")
    return all_ratings

def get_all_sessions(all_ratings: dict) -> set:
    """Get all unique sessions across all raters."""
    sessions = set()
    for rater, rater_sessions in all_ratings.items():
        sessions.update(rater_sessions.keys())
    return sessions

def get_shared_sessions(all_ratings: dict, raters: list = None) -> set:
    """Get sessions that are rated by all specified raters."""
    if raters is None:
        raters = list(all_ratings.keys())
    
    session_sets = [set(all_ratings[r].keys()) for r in raters if r in all_ratings]
    if not session_sets:
        return set()
    
    return set.intersection(*session_sets)

def filter_for_analysis(df: pd.DataFrame) -> pd.DataFrame:
    """
    Filter DataFrame for analysis by excluding:
    - 'next' marker rows
    - First word of each category
    Returns a filtered copy, does not modify original.
    """
    # Exclude 'next' markers
    mask = df['word'] != 'next'
    
    # Exclude first word of each category
    is_first_in_category = df['category'] != df['category'].shift(1)
    mask = mask & ~is_first_in_category
    
    return df[mask]

def align_ratings_by_word(all_ratings: dict, session: str) -> pd.DataFrame:
    """
    Align ratings from different raters for the same session by word.
    Uses (category, word, start) as a unique identifier for each item.
    """
    dfs = []
    for rater, sessions in all_ratings.items():
        if session in sessions:
            df = sessions[session].copy()
            df = filter_for_analysis(df)
            df = df.rename(columns={'switch_flag': f'switch_flag_{rater}'})
            df = df[['category', 'word', 'start', f'switch_flag_{rater}']]
            dfs.append(df)
    
    if len(dfs) < 2:
        raise ValueError(f"Need at least 2 raters for session '{session}', found {len(dfs)}")
    
    merged = dfs[0]
    for df in dfs[1:]:
        merged = pd.merge(merged, df, on=['category', 'word', 'start'], how='outer')
    
    return merged

def align_ratings_all_sessions(all_ratings: dict) -> pd.DataFrame:
    """
    Align ratings from all raters across all sessions.
    Returns a single DataFrame with session as an additional column.
    """
    all_aligned = []
    sessions = get_all_sessions(all_ratings)
    
    for session in sessions:
        raters_with_session = [r for r in all_ratings if session in all_ratings[r]]
        if len(raters_with_session) < 2:
            continue
        
        try:
            aligned = align_ratings_by_word(all_ratings, session)
            aligned['session'] = session
            all_aligned.append(aligned)
        except ValueError as e:
            print(f"Skipping session '{session}': {e}")
    
    if not all_aligned:
        raise ValueError("No sessions with at least 2 raters found.")
    
    return pd.concat(all_aligned, ignore_index=True)

def compute_pairwise_stats_all_sessions(all_ratings: dict) -> pd.DataFrame:
    """
    Compute inter-rater reliability statistics for all rater pairs across ALL sessions.
    """
    aligned = align_ratings_all_sessions(all_ratings)
    
    rater_cols = [col for col in aligned.columns if col.startswith('switch_flag_')]
    raters = [col.replace('switch_flag_', '') for col in rater_cols]
    
    results = []
    for r1, r2 in combinations(raters, 2):
        col1 = f'switch_flag_{r1}'
        col2 = f'switch_flag_{r2}'
        
        valid = aligned[[col1, col2]].dropna()
        if len(valid) == 0:
            continue
            
        y1 = valid[col1].astype(int)
        y2 = valid[col2].astype(int)
        
        kappa = cohen_kappa_score(y1, y2)
        agreement = (y1 == y2).mean() * 100
        
        results.append({
            'rater_1': r1,
            'rater_2': r2,
            'n_items': len(valid),
            'percent_agreement': round(agreement, 2),
            'cohen_kappa': round(kappa, 3)
        })
    
    return pd.DataFrame(results)

def compute_confusion_matrix_all_sessions(all_ratings: dict, rater1: str, rater2: str) -> tuple[np.ndarray, list]:
    """
    Compute aggregated confusion matrix for two raters across all sessions.
    """
    aligned = align_ratings_all_sessions(all_ratings)
    
    col1 = f'switch_flag_{rater1}'
    col2 = f'switch_flag_{rater2}'
    
    valid = aligned[[col1, col2]].dropna()
    y1 = valid[col1].astype(int)
    y2 = valid[col2].astype(int)
    
    labels = sorted(set(y1) | set(y2))
    cm = confusion_matrix(y1, y2, labels=labels)
    cm = cm / np.sum(cm.sum(axis=1, keepdims=True)) * 100  # Convert to percentages
    cm = np.round(cm).astype(int)  # Round to integers for display
    
    return cm, labels

def plot_confusion_matrices_all_sessions(all_ratings: dict):
    """
    Plot confusion matrices for all rater pairs aggregated across all sessions.
    """
    aligned = align_ratings_all_sessions(all_ratings)
    
    rater_cols = [col for col in aligned.columns if col.startswith('switch_flag_')]
    raters = sorted([col.replace('switch_flag_', '') for col in rater_cols])
    pairs = list(combinations(raters, 2))
    
    n_pairs = len(pairs)
    ncols = min(3, n_pairs)
    nrows = (n_pairs + ncols - 1) // ncols
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows))
    axes = np.atleast_2d(axes).flatten()
    
    for idx, (r1, r2) in enumerate(pairs):
        cm, labels = compute_confusion_matrix_all_sessions(all_ratings, r1, r2)
        labels = ['unsure', 'cluster', 'switch']
        
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=labels, yticklabels=labels, ax=axes[idx], cbar_kws={'label': 'Percentage (%)'})
        axes[idx].set_xlabel(f"{r2}'s rating")
        axes[idx].set_ylabel(f"{r1}'s rating")
        # axes[idx].set_title(f'{r1} vs {r2}')
    
    for idx in range(n_pairs, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.suptitle('Inter-rater Confusion Matrices (All Sessions)', y=1.02)
    return fig

def compute_stats_by_session(all_ratings: dict) -> pd.DataFrame:
    """
    Compute inter-rater stats for each session separately.
    Useful for identifying problematic sessions.
    """
    sessions = get_all_sessions(all_ratings)
    all_results = []
    
    for session in sessions:
        raters_with_session = [r for r in all_ratings if session in all_ratings[r]]
        if len(raters_with_session) < 2:
            continue
        
        try:
            aligned = align_ratings_by_word(all_ratings, session)
        except ValueError:
            continue
        
        rater_cols = [col for col in aligned.columns if col.startswith('switch_flag_')]
        raters = [col.replace('switch_flag_', '') for col in rater_cols]
        
        for r1, r2 in combinations(raters, 2):
            col1 = f'switch_flag_{r1}'
            col2 = f'switch_flag_{r2}'
            
            valid = aligned[[col1, col2]].dropna()
            if len(valid) == 0:
                continue
            
            y1 = valid[col1].astype(int)
            y2 = valid[col2].astype(int)
            
            kappa = cohen_kappa_score(y1, y2)
            agreement = (y1 == y2).mean() * 100
            
            all_results.append({
                'session': session,
                'rater_1': r1,
                'rater_2': r2,
                'n_items': len(valid),
                'percent_agreement': round(agreement, 2),
                'cohen_kappa': round(kappa, 3)
            })
    
    return pd.DataFrame(all_results)

def summarize_inter_rater_reliability(all_ratings: dict):
    """
    Print a comprehensive summary of inter-rater reliability.
    """
    print("=" * 60)
    print("INTER-RATER RELIABILITY SUMMARY (ALL SESSIONS)")
    print("=" * 60)
    
    overall_stats = compute_pairwise_stats_all_sessions(all_ratings)
    print("\n>>> Overall Pairwise Statistics:\n")
    print(overall_stats.to_string(index=False))
    
    avg_kappa = overall_stats['cohen_kappa'].mean()
    avg_agreement = overall_stats['percent_agreement'].mean()
    print(f"\n>>> Average Cohen's Kappa: {avg_kappa:.3f}")
    print(f">>> Average Percent Agreement: {avg_agreement:.2f}%")
    
    session_stats = compute_stats_by_session(all_ratings)
    print("\n>>> Per-Session Statistics:\n")
    
    session_summary = session_stats.groupby('session').agg({
        'n_items': 'first',
        'percent_agreement': 'mean',
        'cohen_kappa': 'mean'
    }).round(3)
    print(session_summary.to_string())
    
    low_kappa_sessions = session_summary[session_summary['cohen_kappa'] < 0.4]
    if not low_kappa_sessions.empty:
        print("\n>>> Sessions with low agreement (Kappa < 0.4):")
        print(low_kappa_sessions.to_string())
    
    return overall_stats, session_stats


# =============================================================================
# Usage examples:
# =============================================================================
# all_ratings = load_all_ratings()
#
# # Option 1: Four separate plots
# fig = plot_rating_agreement_distribution(all_ratings)
# plt.show()
#
# # Option 2: Side-by-side grouped bar chart (more compact)
# fig = plot_rating_agreement_summary(all_ratings)
# plt.show()
#
# # Option 3: Print summary table
# print_rating_agreement_table(all_ratings)

# Example usage:
# all_ratings = load_all_ratings()
# summarize_inter_rater_reliability(all_ratings)

Project root set to: /home/zli230/projects/stateswitch


In [None]:
def compute_class_specific_reliability(all_ratings: dict) -> pd.DataFrame:
    """
    Compute per-class inter-rater reliability metrics.
    Shows how consistently raters identify each class (0=unsure, 1=cluster, 2=switch).
    """
    aligned = align_ratings_all_sessions(all_ratings)
    
    rater_cols = [col for col in aligned.columns if col.startswith('switch_flag_')]
    raters = [col.replace('switch_flag_', '') for col in rater_cols]
    
    class_results = []
    class_labels = {0: 'unsure', 1: 'cluster', 2: 'switch'}
    
    for r1, r2 in combinations(raters, 2):
        col1 = f'switch_flag_{r1}'
        col2 = f'switch_flag_{r2}'
        
        valid = aligned[[col1, col2]].dropna()
        if len(valid) == 0:
            continue
            
        y1 = valid[col1].astype(int)
        y2 = valid[col2].astype(int)
        
        for label, label_name in class_labels.items():
            # Items where at least one rater assigned this label
            either_labeled = (y1 == label) | (y2 == label)
            both_labeled = (y1 == label) & (y2 == label)
            
            n_either = either_labeled.sum()
            n_both = both_labeled.sum()
            
            # Agreement when either rater uses this label
            if n_either > 0:
                agreement_rate = n_both / n_either * 100
            else:
                agreement_rate = np.nan
            
            # Per-rater counts
            n_r1 = (y1 == label).sum()
            n_r2 = (y2 == label).sum()
            
            class_results.append({
                'rater_pair': f'{r1}-{r2}',
                'label': label,
                'label_name': label_name,
                'n_r1': n_r1,
                'n_r2': n_r2,
                'n_both_agree': n_both,
                'n_either': n_either,
                'agreement_pct': round(agreement_rate, 1) if not np.isnan(agreement_rate) else np.nan
            })
    
    return pd.DataFrame(class_results)


def compute_class_confusion_summary(all_ratings: dict) -> dict:
    """
    Compute detailed confusion analysis: when raters disagree, what do they confuse?
    """
    aligned = align_ratings_all_sessions(all_ratings)
    
    rater_cols = [col for col in aligned.columns if col.startswith('switch_flag_')]
    raters = [col.replace('switch_flag_', '') for col in rater_cols]
    
    # Aggregate confusion across all rater pairs
    all_confusions = []
    
    for r1, r2 in combinations(raters, 2):
        col1 = f'switch_flag_{r1}'
        col2 = f'switch_flag_{r2}'
        
        valid = aligned[[col1, col2]].dropna()
        y1 = valid[col1].astype(int)
        y2 = valid[col2].astype(int)
        
        # Find disagreements
        disagree = y1 != y2
        for idx in valid[disagree].index:
            all_confusions.append((int(y1.loc[idx]), int(y2.loc[idx])))
    
    # Count confusion patterns
    confusion_counts = {}
    for v1, v2 in all_confusions:
        key = tuple(sorted([v1, v2]))
        confusion_counts[key] = confusion_counts.get(key, 0) + 1
    
    return confusion_counts


def print_class_reliability_summary(all_ratings: dict):
    """
    Print a focused summary of inter-rater reliability by class.
    """
    print("=" * 70)
    print("CLASS-SPECIFIC INTER-RATER RELIABILITY SUMMARY")
    print("=" * 70)
    print("\nLabels: 0=unsure, 1=cluster, 2=switch")
    
    # Get class-specific stats
    class_stats = compute_class_specific_reliability(all_ratings)
    
    # Summarize by class across all rater pairs
    print("\n" + "-" * 70)
    print("PER-CLASS AGREEMENT (when either rater assigns the label)")
    print("-" * 70)
    
    for label in [1, 2]:  # Focus on cluster and switch
        label_name = {1: 'CLUSTER', 2: 'SWITCH'}[label]
        label_data = class_stats[class_stats['label'] == label]
        
        print(f"\n>>> {label_name} (label={label}):")
        print(f"    {'Rater Pair':<12} {'R1 count':>10} {'R2 count':>10} {'Both agree':>12} {'Agreement %':>12}")
        print(f"    {'-'*56}")
        
        for _, row in label_data.iterrows():
            print(f"    {row['rater_pair']:<12} {row['n_r1']:>10} {row['n_r2']:>10} {row['n_both_agree']:>12} {row['agreement_pct']:>11.1f}%")
        
        # Average across pairs
        avg_agreement = label_data['agreement_pct'].mean()
        total_r1 = label_data['n_r1'].sum()
        total_r2 = label_data['n_r2'].sum()
        total_both = label_data['n_both_agree'].sum()
        print(f"    {'-'*56}")
        print(f"    {'AVERAGE':<12} {total_r1//3:>10} {total_r2//3:>10} {total_both//3:>12} {avg_agreement:>11.1f}%")
    
    # Confusion analysis
    print("\n" + "-" * 70)
    print("DISAGREEMENT PATTERNS (what gets confused?)")
    print("-" * 70)
    
    confusion_counts = compute_class_confusion_summary(all_ratings)
    label_names = {0: 'unsure', 1: 'cluster', 2: 'switch'}
    
    total_disagreements = sum(confusion_counts.values())
    print(f"\nTotal disagreements across all rater pairs: {total_disagreements}")
    print(f"\nBreakdown of confusion types:")
    
    for (v1, v2), count in sorted(confusion_counts.items(), key=lambda x: -x[1]):
        pct = count / total_disagreements * 100
        print(f"    {label_names[v1]:>7} <-> {label_names[v2]:<7}: {count:>5} ({pct:>5.1f}%)")
    
    # Key insight: cluster vs switch confusion
    cluster_switch_confusion = confusion_counts.get((1, 2), 0)
    print(f"\n>>> KEY METRIC: Cluster vs Switch confusion rate: {cluster_switch_confusion/total_disagreements*100:.1f}% of all disagreements")
    
    # Per-rater label distributions
    print("\n" + "-" * 70)
    print("PER-RATER LABEL DISTRIBUTIONS")
    print("-" * 70)
    
    aligned = align_ratings_all_sessions(all_ratings)
    rater_cols = [col for col in aligned.columns if col.startswith('switch_flag_')]
    
    print(f"\n    {'Rater':<8} {'Unsure (0)':>12} {'Cluster (1)':>12} {'Switch (2)':>12} {'Total':>10}")
    print(f"    {'-'*54}")
    
    for col in sorted(rater_cols):
        rater = col.replace('switch_flag_', '')
        vals = aligned[col].dropna().astype(int)
        n0 = (vals == 0).sum()
        n1 = (vals == 1).sum()
        n2 = (vals == 2).sum()
        total = len(vals)
        print(f"    {rater:<8} {n0:>5} ({n0/total*100:>4.1f}%) {n1:>5} ({n1/total*100:>4.1f}%) {n2:>5} ({n2/total*100:>4.1f}%) {total:>10}")


# Run analysis
all_ratings = load_all_ratings()

# Overall summary (existing)
summarize_inter_rater_reliability(all_ratings)

# Class-specific summary (new)
print("\n\n")
print_class_reliability_summary(all_ratings)

#JC: 0-1-2: 15/59/26
#AS: 0-1-2: 22/55/23
#GL: 0-1-2: 10/63/27

In [None]:
def compute_agreement_table(all_ratings: dict) -> pd.DataFrame:
    """
    Compute agreement table showing how many words have N raters agreeing
    on clustering vs switching labels.
    
    Returns a DataFrame with counts and proportions for each agreement level (1-4 raters).
    """
    aligned = align_ratings_all_sessions(all_ratings)
    
    rater_cols = [col for col in aligned.columns if col.startswith('switch_flag_')]
    n_raters = len(rater_cols)
    
    # For each word, count votes for clustering (1) and switching (2)
    aligned['n_cluster'] = (aligned[rater_cols] == 1).sum(axis=1)
    aligned['n_switch'] = (aligned[rater_cols] == 2).sum(axis=1)
    aligned['n_valid'] = aligned[rater_cols].notna().sum(axis=1)
    
    # Only include words where all raters provided ratings
    aligned_complete = aligned[aligned['n_valid'] == n_raters].copy()
    total_words = len(aligned_complete)
    
    print(f"Total words with all {n_raters} raters: {total_words}")
    
    # Build agreement table
    results = []
    
    for n_agree in range(1, n_raters + 1):
        # Clustering: exactly n_agree raters said clustering (1)
        n_cluster = (aligned_complete['n_cluster'] == n_agree).sum()
        pct_cluster = n_cluster / total_words * 100
        
        # Switching: exactly n_agree raters said switching (2)
        n_switch = (aligned_complete['n_switch'] == n_agree).sum()
        pct_switch = n_switch / total_words * 100
        
        results.append({
            'n_raters_agree': n_agree,
            'clustering_n': n_cluster,
            'clustering_pct': round(pct_cluster, 1),
            'switching_n': n_switch,
            'switching_pct': round(pct_switch, 1),
        })
    
    df = pd.DataFrame(results)
    return df, total_words


def print_agreement_table(all_ratings: dict):
    """
    Print a formatted agreement table for clustering and switching labels.
    """
    df, total_words = compute_agreement_table(all_ratings)
    
    print("=" * 70)
    print("RATER AGREEMENT TABLE FOR CLUSTERING AND SWITCHING LABELS")
    print("=" * 70)
    print(f"\nTotal words analyzed: {total_words}")
    print("\n" + "-" * 70)
    print(f"{'# Raters':<12} {'Clustering':^25} {'Switching':^25}")
    print(f"{'Agreeing':<12} {'Count':>10} {'Proportion':>12} {'Count':>10} {'Proportion':>12}")
    print("-" * 70)
    
    for _, row in df.iterrows():
        print(f"{row['n_raters_agree']:<12} {row['clustering_n']:>10} {row['clustering_pct']:>10.1f}% {row['switching_n']:>10} {row['switching_pct']:>10.1f}%")
    
    print("-" * 70)
    
    # Summary: consensus (>=3 raters agree)
    consensus_cluster = df[df['n_raters_agree'] >= 3]['clustering_n'].sum()
    consensus_switch = df[df['n_raters_agree'] >= 3]['switching_n'].sum()
    
    print(f"\nConsensus (>=3 raters agree):")
    print(f"  Clustering: {consensus_cluster} words ({consensus_cluster/total_words*100:.1f}%)")
    print(f"  Switching:  {consensus_switch} words ({consensus_switch/total_words*100:.1f}%)")
    
    # No consensus
    no_consensus = total_words - consensus_cluster - consensus_switch
    print(f"  No consensus: {no_consensus} words ({no_consensus/total_words*100:.1f}%)")
    
    return df


# Load data and run the agreement table analysis
all_ratings = load_all_ratings()
agreement_df = print_agreement_table(all_ratings)